Script used to run elastix and transformix on the whole dataset.  
#### Note:  
This script is meant to run over the entire dataset or at least over all cycles of one point.
If you want to re-run specific cycles of a point, simply delete the folder in the aligned image directory. E.g. delete ".../Point0006/cycle12") if you want to re-run cycle 12 of Point0006. Already exisitng folders will simply be skipped so you won't have to specify which cycles you want to run. Specifying specific cycles will work as well, but the reference cycles might not be correct in that case. You could change the first_ref variable and the reference_cycles list, just make sure you know what you're doing. Pay attention to the output of the console to see if the correct reference cycle was picked in case you do that.

In [1]:
import SimpleITK as sitk
from pathlib import Path
from skimage import io
import numpy as np
from datetime import datetime
from skimage import img_as_uint
# NOTE:
# Skimage has a hard time dealing with Path variables from the pathlib library, therefore every time an image
# is loaded or saved (io.imread() or io.imsave()) instead of giving the Path variable it itself, it needs to be
# converted to a string using str(Path).

In [2]:
# Define input and output paths
path_plate = Path(r"/links/groups/treutlein/DATA/imaging/PW/4i/plate14")
path_results = Path(r"/links/groups/treutlein/USERS/pascal_noser/plate14_results/alignment/alignment_name")

# Create list with all the cycle folder names in the right order
cycles = ["cycle0",
         "cycle1",
         "cycle2",
         "cycle3",
         "cycle4",
         "cycle5",
         "cycle5_0",
         "cycle6",
         "cycle7",
         "cycle8",
         "cycle9",
         "cycle10",
         "cycle10_0",
         "cycle11",
         "cycle12",
         "cycle13",
         "cycle14",
         "cycle15",
         "cycle15_0",
         "cycle16",
         "cycle17",
         "cycle18",
         "cycle19",
         "cycle20",
         "cycle20_0",
         "cycle21",
         "cycle1_2",
         "cycle1_3"]

# Create list with all the points
points = ["Point000"+str(x) for x in range(10)] + ["Point00"+str(x) for x in range(10, 74)]

# Specify "special" points and corresponding missing cycles
points_special = {
    "Point0000": ["cycle15"],
    "Point0001": ["cycle15"],
    "Point0042": ["cycle18"],
    "Point0065": ["cycle0", "cycle1", "cycle2", "cycle5_0"],
    "Point0066": ["cycle0", "cycle1"],
    "Point0067": ["cycle0", "cycle1"],
    "Point0070": ["cycle1"],
    "Point0071": ["cycle1"],
    "Point0072": ["cycle1"],
    "Point0073": ["cycle1", "cycle16", "cycle17", "cycle18", "cycle19", 
                  "cycle20", "cycle20_0", "cycle21", "cycle1_2", "cycle1_3"]
}

# Remove points to be excluded from analysis
points_bad = ["Point0047", "Point0052", "Point0053", "Point0058", "Point0059", "Point0062", "Point0063",
             "Point0064", "Point0068", "Point0069"]

points = [point for point in points if point not in points_bad]

The most important part are the parameter maps you use for the alignment. They are stored in 'Elastix/param_maps/' and can be adjusted depending on what kind of alignment you want to run. From the code below it should be fairly clear how to load parameter maps and add multiple ones etc.

In [None]:
# Initialise elastix and transformix
elastix_filter = sitk.ElastixImageFilter()
elastix_filter.SetParameterMap(sitk.ReadParameterFile('Elastix/param_maps/translation.txt'))
elastix_filter.AddParameterMap(sitk.ReadParameterFile('Elastix/param_maps/affine.txt'))
#elastix_filter.AddParameterMap(sitk.ReadParameterFile('Elastix/param_maps/bspline.txt'))

transformix_filter = sitk.TransformixImageFilter()


for point in points:
    print("\n----------------")
    print(point)
    filename = "multichannel_"+point+"_Point00{ii}_ChannelSD 640,SD 488.tif"
    
    # Create a local list of cycles where potential cycles specified in the dictionary are removed
    if point in points_special.keys():
        print("Removing the following cycles:")
        print(points_special[point])
        cycles_local = [cycle for cycle in cycles if cycle not in points_special[point]]
    else:
        cycles_local = cycles.copy()  # without the .copy() it creates a reference rather than a new list
    
    # Define first reference cycle. Ideally cycle1, else cycle2 and else cycle3 (assuming cycle3 is present in all IDs)
    if "cycle1" in cycles_local:
        first_ref = "cycle1"
    elif "cycle2" in cycles_local:
        first_ref = "cycle2"
    else:
        first_ref = "cycle3"
    
    print("First reference: ", first_ref)
    # Reference cycles to be used during the alignment. Can be changed, increased, decreased etc.
    reference_cycles = [first_ref, "cycle6", "cycle11", "cycle16", "cycle21"]
    
    # remove first_ref from cycles list because that one won't need to be aligned
    cycles_local.remove(first_ref)
    
    # Create directory if it doesn't exist. Doesn't overwrite existing directories
    Path(path_results/point/first_ref).mkdir(parents=True, exist_ok=True)
    
    # Save first ref image unaltered. Already load as fixed_img since it will be used as such for the first cycles
    fixed_img_name = point+"_"+first_ref+".tif"
    fixed_img = io.imread(str(path_plate/first_ref/"stitched"/filename))
    # Create mask for the fixed img (exclude pixels with intensity of 0 when sampling)
    fixed_mask = img_as_uint(fixed_img[...,2]>0)
    if not Path(path_plate/point/first_ref/fixed_img_name).is_file():
        io.imsave(str(path_results/point/first_ref/fixed_img_name), fixed_img, check_contrast=False)

    for cycle in cycles_local:
        print("Aligning", cycle)
        # Specify output path
        out_path = path_results/point/cycle
        
        # If directory already exists don't do the alignment
        if out_path.is_dir():
            print("Directory already exists. Skipping", cycle)
        else:
            print("Reference cycle:", fixed_img_name)
            # Create directory
            Path(out_path/"param_maps").mkdir(parents=True, exist_ok=True)
            
            # Load moving image
            moving_img = io.imread(str(path_plate/cycle/"stitched"/filename))
            
            # Elastix
            start=datetime.now()
            elastix_filter.SetFixedImage(sitk.GetImageFromArray(fixed_img[...,2]))
            elastix_filter.SetFixedMask(sitk.Cast(sitk.GetImageFromArray(fixed_mask), sitk.sitkUInt8))
            elastix_filter.SetMovingImage(sitk.GetImageFromArray(moving_img[...,2]))
            elastix_filter.SetOutputDirectory(str(out_path/"param_maps"))
            elastix_filter.Execute()
            
            # Transformix
            channels = []
            for channel in range(moving_img.shape[2]):            
                transformix_filter.SetTransformParameterMap(elastix_filter.GetTransformParameterMap())
                transformix_filter.SetMovingImage(sitk.GetImageFromArray(moving_img[..., channel]))
                channel_aligned = transformix_filter.Execute()
                
                # Convert to numpy array
                channel_aligned = sitk.GetArrayFromImage(channel_aligned)
                # Cap values just in case
                channel_aligned[channel_aligned < 0] = 0
                channel_aligned[channel_aligned > 65535] = 65535
                # Convert to uint16
                channel_aligned = channel_aligned.astype(np.uint16)
                # append aligned channel to the list
                channels.append(channel_aligned)
            
            # combine channels into an image
            img_aligned = np.dstack(channels)
            
            # save aligned image
            img_aligned_name = point+"_"+cycle+".tif"
            io.imsave(str(out_path/img_aligned_name), img_aligned, check_contrast=False)
            
            print("Alignment took ", datetime.now()-start)
        
        # set new reference cycle if necessary
        if cycle in reference_cycles:
            ref_cycle = cycle
            print("Setting reference cycle to", ref_cycle)
            
            # Specify new fixed image
            fixed_img_name = point+"_"+ref_cycle+".tif"
            fixed_img = io.imread(str(path_results/point/ref_cycle/fixed_img_name))
            fixed_mask = img_as_uint(fixed_img[...,2]>0)


----------------
Point0043
First reference:  cycle1
Aligning cycle0
Reference cycle: Point0043_cycle1.tif
