# Build Stack – Aligning a Stack of Images Based On a Fixed Target Outline

Authors: Daniel Sieber, Samuel John

#### Abstract

Images that have been cut or grinded from a block are oftentimes not aligned. This IPython notebook uses a fixed target structure in the image (in our case the outline of an overmold) that is visible in all images of the stack to find the best affine transform which aligns all images to the given target. The target is based on one image of the stack where only the fixed structure remains visible and the remaining area is made transparent.


*TODO*:

- Write better Abstract
- Link to GitHub repository here
- Add "How to cite" statement and link to paper (DOI) here


#### [The MIT License (MIT)](http://opensource.org/licenses/MIT)

Copyright (c) 2015 Daniel Sieber, Samuel John


<div style="font-size:7pt;">
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
</div>

### Imports & Set-Up

In [None]:
# Plot in this IPython Notebook instead of opening separate windows
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

In [None]:
import os
import time

# Import external modules used by this script
from skimage import img_as_float, io, transform

# Scientific Python and typed array/matrix support (by including NumPy)
import scipy as sp

# Plotting
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

# Write Python objects to disk.
# TODO: This should be replaced by some HDF5 files that store the transformation matrix
import pickle

# Parsing svg files and accessing paths in there
from xml.dom import minidom
import svg.path  # you might need to `pip install svg.path`

Our own modules:

In [None]:
import pattern_finder_gpu
from pattern_finder_gpu import center_roi_around

### Definition of functions used by this script:

In [None]:
def rotation_transform_center(image, angle, center_xy=None):
    """
    This function returns the transformation matrix for a rotation of a given image for a given angle around a given center
    The operation is implemented by the following steps to avoid unwanted translational side effects:
    1.) Translate the image to the rotation center
    2.) Rotating the image for the given angle
    3.) Translate the image back to its original translatory position
    """

    #If no rotation center is defined, set the center of the image as center
    if center_xy is None:
        cols, rows = image.shape[:2]
        center_xy = sp.array((rows, cols)) / 2. - 0.5
    #Calculate transformation matrices 
    tform1 = transform.SimilarityTransform(translation=-center_xy)
    tform2 = transform.SimilarityTransform(rotation=sp.deg2rad(angle))
    tform3 = transform.SimilarityTransform(translation=center_xy)
    #Return transformation matrix
    return tform1 + tform2 + tform3

In [None]:
def plot_overlay(image, svg_path, ax=None, figsize=(15,15)):
    """This function plots a path from an SVG_xml and shows it on top of image"""
    
    #Create new figure and axes
    if ax is None:
        fig = plt.figure(figsize=figsize)
        ax = fig.add_axes([0,0,1,1])
    #Show transformed image
    ax.imshow(image, interpolation='nearest')
    #Sample 10000 points from the path and get their coordinates
    numberSamplePoints = 10000
    overlay_coords = sp.array([svg_path.point(p/numberSamplePoints) for p in range(numberSamplePoints)])
    #Plot the path
    ax.plot(overlay_coords.real, overlay_coords.imag, color='magenta')

In [None]:
def find_pattern_rotated(PF, pattern, image, rescale=0.2, rotate=(-60,61,120),
                         ellipsecorr=(1,1,1), ellipseres=1,
                         roi_center=None,
                         roi_size=(41,41),
                         plot=False):
    start_time = time.time()
    print("find_pattern_rotated:")
    print("Rescaling image and target by scale={rescale}.\n"
          "   image {0}x{1} px to {2:.2f}x{3:.02f} px."
          .format(image.shape[0], image.shape[1],
                  image.shape[0]*rescale, image.shape[1]*rescale, rescale=rescale), flush=True)
    pattern_scaled = transform.rescale(pattern, rescale)
    image_scaled = transform.rescale(image, rescale)

    if len(sp.linspace(*rotate))>1 and rescale<0.3:
        #remove zero values in low rescale values because they may end up in false local minima
        rotations = sp.linspace(*rotate)
        rotations = rotations[~sp.isclose(rotations,0,atol=0.2)]
        print("Rotation values close to zero were removed in low res rescale")
    else:       
        rotations = sp.linspace(*rotate)
    ellipseangles   = sp.linspace(0,180,ellipseres)
    ellipsecorrs  = sp.linspace(*ellipsecorr )
    
    result = []
    vmax = 0.0
    vmin = sp.Inf
    if roi_center is None:
        roi_center = sp.array(im.shape[:2])/2.0 - 0.5
    roi = center_roi_around(roi_center*rescale, roi_size)
    print("ROI: center={0}, {1}, in unscaled image.\n"
          "     height={2}, width={3} in scaled image"
          .format(roi_center[0], roi_center[1], roi_size[0], roi_size[1]))
    if ellipsecorr[2]>1:
        print("Now correlating ellipse correction from {0}x to {1}x in {2} steps at an"
              .format(*ellipsecorr)+" Angular resolution {}º".format(180/ellipseres))
    #TODO: correct number for eliminated values which are close to zero in low resolution pics
    if rotate[2]>1:
        print("Now correlating rotations from {0}º to {1}º in {2} steps:"
              .format(*rotate))
    else:
        print("Rotation is kept constant at {0}°".format(rotate[0]))
    
    for r in rotations:
        for ea in ellipseangles:
            for ec in ellipsecorrs:
                cols, rows = pattern_scaled.shape[:2]
                center = sp.array((rows, cols))/2. - 0.5
                #ellipse_matrix = transform.AffineTransform(matrix=scale_matrix(ec,ea,center))
                rotation_matrix = rotation_transform_center(pattern_scaled,r,center_xy=center)
                out, min_coords, value = PF.find(transform.warp(pattern_scaled,rotation_matrix), image_scaled, roi)
                outmax = out.max()
                outmin = out.min()
                if outmax > vmax:
                    vmax = outmax
                if outmin < vmin:
                    vmin = outmin
                # undo the rescale for the coordinates
                min_coords = min_coords.astype(sp.float64) / rescale
                result.append([r, ea, ec, min_coords, value, out])
                print(".",end="", flush=True)
    print("")
    best_param_set = result[sp.argmin([r[4] for r in result])]
    print("best_degree= {}°".format(best_param_set[0]))
    print("best_ellipseangle= {}".format(best_param_set[1]))
    print("best_ellipsecorr= {}".format(best_param_set[2]))
    print("coordinates= {}".format(best_param_set[3]))
    print("minimum value= {}".format(best_param_set[4]))
    print("took {0} seconds.".format(time.time()-start_time))
    
    if bool(plot) and rotate[2] > 1 and ellipsecorr[2] > 1:
        fig = plt.figure()
        ax = fig.add_subplot(111, projection='3d')
        ax.surface_plot([a[0] for a in result],[a[2] for a in result],[a[4] for a in result])
        plt.show()
    
    elif bool(plot) and rotate[2] > 1:
        fig, ax = plt.subplots(1)
        ax.plot([a[0] for a in result], [a[4] for a in result])
        ax.set_xlabel('Angle (rotation)')
        ax.set_ylabel('difference image-target')
        plt.show()
    
    elif bool (plot) and ellipsecorr[2] > 1:
        fig = plt.figure()
        ax = fig.add_subplot(111, projection='3d')
        ax.scatter([a[2] for a in result],[a[1] for a in result],[a[4] for a in result])
        plt.show()
        
    if plot == 'all':
        n_rows = int(sp.sqrt(len(result)))
        n_cols = int(sp.ceil(len(result)/n_rows))
        fig, ax = plt.subplots(n_rows, n_cols, squeeze=False, figsize = (2 * n_cols, 2 * n_rows))
        fig.tight_layout(rect=[0, 0.03, 1, 0.97])
        fig.suptitle("Correlation map of where target is in image\n", size=16)
        n = 0
        for i in range(n_rows):
            for j in range(n_cols):
                ax[i,j].axis("off")
                if n < len(result):
                    ax[i,j].imshow(result[n][5], interpolation="nearest", cmap='cubehelix', vmin=vmin, vmax=vmax)
                    ax[i,j].annotate('A:{0:.2f};EA:{1:.2f};ES{2:.3f}'.format(result[n][0],result[n][1],result[n][2]),[0,0])
                    ax[i,j].annotate('Value:{0:.2f}'.format(result[n][4]), [0,5])
                n += 1
        plt.show()
    if (min_coords[0] == roi[0] or min_coords[0] == roi[2] or
        min_coords[1] == roi[1] or min_coords[1] == roi[3]):
        raise PatternAtROIBorderWarning()
    return result

In [None]:
def build_stack_roughly(images, target, search_strategy, plot=False, write_files=False, PF=None):
    if PF is None:
        PF = PatternFinder(partitions=10)
    
    target_center = sp.array(target.shape[:2]) / 2. - 0.5
        
    best_angles = []
    best_coords = []
    best_values = []
    best_ellipsecorrs = []
    best_ellipseangles = []
    
    final_transforms = []
    final_images = []
    
    use_ic = False
    if type(images) is io.ImageCollection:
        use_ic = True       
        # Some tif images contain actually two images (a big one and a smaller
        # thumbnail preview). image_collection therefore seems to generate two
        # entries for each of the files. The load_func, however, always loads
        # the big one, which is then actaully loaded twice. So we use a `set`
        # to make this unique and drop duplicates.
        imagelist = sorted(set(images.files))
    else:
        imagelist = images

    for im_nr, image_file in enumerate(imagelist):
        if use_ic:
            im = img_as_float(images.load_func(image_file))
            print("\n\nImage Nr. {0} {1}".format(im_nr, image_file))
        else:
            im = img_as_float(image_file)
            print("\n\nImage Nr. {0}".format(im_nr))
        im_center = sp.array(im.shape[:2]) / 2. - 0.5
        best_angle = 0.0
        best_ellipsecorr = 1.0
        best_ellipseangle = 0
        best_coord = im_center
        best_value = None
        for nr, search_phase in enumerate(search_strategy):
            print("\nSearch phase {0}".format(nr), flush=True)
            angle_range = (search_phase["angle_range"][0] + best_angle,
                           search_phase["angle_range"][1] + best_angle,
                           search_phase["angle_range"][2])
            try:
                res = find_pattern_rotated(PF, target, im, 
                                           rescale=search_phase["rescale"],
                                           rotate=angle_range,
                                           ellipsecorr=search_phase["ellipsecorr"],
                                           ellipseres=search_phase["ellipseres"],
                                           roi_center=best_coord,
                                           roi_size=search_phase["roi_hw"], 
                                           plot=plot)
            except PatternAtROIBorderWarning:
                print("\nWARNING! Pattern found at the border of the ROI. \n"
                      "That usually indicates that the "
                      "true target pattern is located outside of the ROI!")
            best_res = res[sp.argmin([r[4] for r in res])]  # best_res is the best result in the res list
            best_angle = best_res[0]  # The rotation angle is the 0-th element in best_res
            best_ellipseangle= best_res[1]
            best_ellipsecorr = best_res[2] # The focus correction is the 1-st element in best_res
            best_coord = best_res[3]  # The coordinates are in the 2-nd element
            best_value = best_res[4]  # The actual value is the 3-rd element
            
        # Add this image's best stuff to the global list for all images
        best_angles.append(best_angle)
        best_ellipsecorrs.append (best_ellipsecorr)
        best_ellipseangles.append (best_ellipseangle)
        best_coords.append(best_coord)
        best_values.append(best_value)
        
                                       
        #ellipse_correction= transform.AffineTransform(matrix=scale_matrix(1/best_ellipsecorr,best_ellipseangle,[0,0]))
        move_to_center = transform.AffineTransform(translation=-(best_coord)[::-1])
        move_back = transform.AffineTransform(translation=(best_coord[::-1]))
        rotation = transform.AffineTransform(rotation=-sp.deg2rad(best_angle))
        translation = transform.AffineTransform(translation=(best_coord-target_center)[::-1])

        final_trans = translation+move_to_center+rotation+move_back
        final_transforms.append(final_trans)
        im_trans = transform.warp(im, final_trans, output_shape=[target.shape[0], target.shape[1]])

        if write_files:
            io.imsave("../EXPORT/"+os.path.basename(image_file)[0:3]+".PNG", im_trans)
        else:
            final_images.append(im_trans)
        print("translation={0} \n".format((best_coord-target_center)[::-1]), flush=True)
        if plot:
            plot_overlay(im_trans, svg_path)
            if write_files:
                plt.savefig("../EXPORT/Plot_"+os.path.basename(image_file)[0:3]+".PNG",dpi=100)
                plt.close()
            else:
                plt.show()           
    
    result = None
    del PF
    if write_files:
        result = final_transforms
        # Write final transformations to file
        file = open ("../EXPORT/transforms.pkl",'wb')
        pickle.dump (result, file)
        file.close()
    else:
        result = list(zip(final_transforms, final_images, best_values))    
    return result

In [None]:
def build_stack_optimized(images, target, transMatrix, PF=None, rescale=1, **kws):
 
    #Convert initialGuess transformation matrix into an ndarray with six entries for the DOFs
    initialGuess = sp.asarray(transMatrix.params).flatten()[0:6]

    target_scaled = transform.rescale(target, rescale)
    
    if PF is None:
        PF = PatternFinder(partitions=10)
           
    final_transforms = []
    final_images = []

    if type(images) is io.ImageCollection:
        use_ic = True       
    else:
        use_ic = False

    for im_nr,image_file in enumerate(images.files):
        if use_ic:
            im = img_as_float(images[im_nr])
            print("\n\nImage Nr. {0} {1}".format(im_nr, image_file))
        else:
            im = img_as_float(images)
            print("\n\nImage Nr. {0}".format(im_nr))

        im_scaled = transform.rescale(im, rescale)
        res = sp.optimize.minimize(loss_fcn, initialGuess,args=(PF, target_scaled, im_scaled, rescale), 
                                   method='Nelder-Mead', **kws)    
    return res

def loss_fcn(guess, PF, target_scaled, image_scaled, rescale):
    
    global parameters
    
    T = transform.AffineTransform(sp.asmatrix(sp.append(guess,[0,0,1]).reshape(3,3)))
    scale_mat = sp.asmatrix(transform.AffineTransform(scale=[rescale, rescale]).params)
    combined_transform = scale_mat * T.params * scale_mat.I    
        
    # Create "fake" ROI around image center with size one
    roi_center = sp.array(image_scaled.shape[:2])/2.0 - 0.5
    roi = pattern_finder_gpu.center_roi_around(roi_center, [1,1])

    # Execute Pattern Finder and calculate best match
    transformed_targed = transform.warp(target_scaled,
                                        combined_transform.I,
                                        output_shape=image_scaled.shape[:2])
    out, min_coords, value = PF.find(transformed_targed, image_scaled, roi)

    parameters.append([guess, transformed_targed, roi])


    
    #Print what is done in current step
    print("x={x:.0f} y={y:.0f} r={rot:.3f}º sx={sx:.3f} sy={sy:.3f} shear={shear:.4f} => {value:.3f}".format(x=T.translation[0,0],
                                                                                               y=T.translation[1,0],
                                                                                               rot=sp.rad2deg(T.rotation),
                                                                                               sx=T.scale[0],
                                                                                               sy=T.scale[1], 
                                                                                               shear=T.shear,
                                                                                               value=value))
    return value

In [None]:
def plot_parameter(image, guess , svg_path, target, **kws):

    transMatrix= transform.AffineTransform(sp.append(guess,[0,0,1]).reshape(3,3))
    
    image_out = transform.warp(image, transMatrix, output_shape=[target.shape[0], target.shape[1]])
    
    plot_overlay(image_out, svg_path, **kws)    

## Start of main script
### Step 0: Load data from histology

Here the value of $\phi_a$ ist blup.

1. foo
2. bar
3. baz

In [None]:
# Load Target File containing the template for the further template matching
target = img_as_float(io.imread("./../ZETA/Target_Zeta_110_20_smooth_edges_filled.png"))
# Load SVG file containing outline of template and extract path frpom xml format
svg_xml = minidom.parse("../ZETA/Target_ZETA_110_20_smooth_filled_edges.SVG")
svg_path = svg.path.parse_path([path.getAttribute('d') for path in svg_xml.getElementsByTagName('path')][0])
svg_xml.unlink()
# Load image collection
ic=io.ImageCollection('../ZETA/193.tif')
# border transparent
target[0,:,3] = 0.0
target[-1,:,3] = 0.0
target[:,0,3] = 0.0
target[:,-1,3] = 0.0

In [None]:
#Quick check if the target image and the SVG outline match
plot_overlay(target, svg_path, figsize=(15,15))

### Step 1: Rough alignment of slices using brute force

In [None]:
#Definition of search strategy 
search_strategy=[dict(rescale=0.1, angle_range=(   0,  0,  1), roi_hw=(51, 51),ellipsecorr=(   1,   1,  1),ellipseres=1),
                 dict(rescale=0.1, angle_range=( 55, 35, 101), roi_hw=(15, 15)  ,ellipsecorr=(   1,   1,  1),ellipseres=1)]

In [None]:
#Execution of brute force optimization
result = build_stack_roughly(ic[0:1], target, search_strategy, PF=pattern_finder_gpu.PatternFinder(partitions=10),write_files=False,plot='all')

## Step 2: Fine adjustments using local optimization

In [None]:
parameters = []
res=build_stack_optimized(ic[0:1],
                          target,
                          result[0][0],
                          PF=pattern_finder_gpu.PatternFinder(partitions=1),
                          rescale=0.1)
                          #options={'maxiter':1})

In [None]:
plot_parameter(ic[0],parameters[-1][0], svg_path, target)