## Templating

In [None]:
import ants
from liom_toolkit.registration import create_template
from liom_toolkit.utils import load_allen_template, load_zarr, segment_3d_brain, load_zarr_image_from_node
from tqdm import tqdm

In [ ]:
def build_template_for_resolution(template_name: str, brains: list, brain_names: list,
                                  resolution_level: int = 3, template_resolution: int = 50, iterations: int = 15,
                                  init_with_template: bool = False,
                                  register_to_template: bool = False, flipped_brains: bool = False,
                                  wavelength: str = "647nm"):
    resolution_mm = template_resolution / 1000
    atlas_file = f"/data/templates/allen/average_template_{template_resolution}.nrrd"
    atlas_volume = load_allen_template(atlas_file, template_resolution, padding=False)
    atlas_volume = ants.reorient_image2(atlas_volume, "RAS")
    brain_volumes = []
    masks = []
    for brain in tqdm(brains, desc="Loading brains", leave=False, total=len(brains), unit="brain", position=1):
        zarr_file = f"/data/LSFM/{brain}/{wavelength}.zarr"
        nodes = load_zarr(zarr_file)
        image_node = nodes[0]
        mask_node = nodes[2]

        brain_volume, mask = load_volume(image_node, mask_node, resolution_level, flipped=False)
        brain_volumes.append(brain_volume)
        masks.append(mask)

        # Added flipped brains
        if flipped_brains:
            brain_volume, mask = load_volume(image_node, mask_node, resolution_level, flipped=True)
            brain_volumes.append(brain_volume)
            masks.append(mask)

    if init_with_template:
        template = create_template(brain_volumes, masks, brain_names, atlas_volume,
                                   template_resolution=resolution_mm, iterations=iterations,
                                   pre_registration_type="Rigid")
    else:
        template = create_template(brain_volumes, masks, brain_names, atlas_volume,
                                   template_resolution=resolution_mm, iterations=iterations,
                                   init_with_template=init_with_template, pre_registration_type="Rigid")
    if register_to_template:
        template_transform = ants.registration(fixed=atlas_volume, moving=template, type_of_transform="SyN")
        template = ants.apply_transforms(fixed=atlas_volume, moving=template,
                                         transformlist=template_transform["fwdtransforms"])
    # Mask template to remove noise
    template_mask = segment_3d_brain(template)
    new_template = template * template_mask
    # Apply properties after multiplication
    new_template.set_direction(template.direction)
    new_template.set_spacing(template.spacing)
    new_template.set_origin(template.origin)

    ants.image_write(new_template, f"templates/{template_name}_{template_resolution}_{iterations}.nii")

    return template, atlas_volume


In [ ]:
def load_volume(image_node, mask_node, resolution_level, flipped=False):
    brain_volume = load_zarr_image_from_node(image_node, resolution_level=resolution_level)
    mask = load_zarr_image_from_node(mask_node, resolution_level=resolution_level)
    brain_volume = brain_volume * mask
    if flipped:
        direction = brain_volume.direction
        direction[0][0] = -1
        brain_volume.set_direction(direction)
        mask.set_direction(direction)
    brain_volume = ants.reorient_image2(brain_volume, "RAS")
    mask = ants.reorient_image2(mask, "RAS")
    # Fix for physical shape being reset after multiplication
    brain_volume.physical_shape = mask.physical_shape
    return brain_volume, mask

In [ ]:
# Set the resolution levels (zarr index) and the matching atlas resolution
resolution_levels = [2, 3]
atlas_resolutions = [25, 50]

In [ ]:
# Set the brains to use for the template
brains = [
    "S18",
    "S19",
    "S21",
    "S23"
]

# Set the name of the brains (use for saving intermediate results).
# The order of the names should match the order of the brains and if mirrored brains are used these should be added here
brain_names = [
    "S18",
    "S18_mirrored"
    "S19",
    "S19_mirrored",
    "S21",
    "S21_mirrored",
    "S23"
    "S23_mirrored"
]

# The code below can be used when no mirrored brains are used
# brain_names = brains

In [ ]:
# Set parameters for the templating
name = "p11_lightsheet"
iterations = 15
# This decides if the first brain used in the templating process is the template or the first brain in the list
init_with_template = False
# This decides if the new template is registered to the allen atlas after the templating process
register_to_template = False
# This decides if the mirrored brains are used in the templating process
flipped_brains = False
# This decides which wavelength to use for the templating
wavelength = "647nm"

In [ ]:
# Build the template for the different resolutions
for (resolution_level, atlas_resolution) in (
        pbar := tqdm(zip(resolution_levels, atlas_resolutions), desc="Building templates", leave=True,
                     total=len(resolution_levels), unit="template", position=0)):
    pbar.set_description(f"Building template at {atlas_resolution} microns")
    build_template_for_resolution(name, brains=brains, brain_names=brain_names,
                                  resolution_level=resolution_level,
                                  template_resolution=atlas_resolution, iterations=iterations,
                                  init_with_template=init_with_template,
                                  register_to_template=register_to_template,
                                  flipped_brains=flipped_brains,
                                  wavelength=wavelength)