# Image registration

To create masks for each phantom, we co-registered the images of each phantom with a template circular mask. 

Make sure to download the data in advance and update the path in the second code block.

In [None]:
import patato as pat
import numpy as np
import matplotlib.pyplot as plt
from registering import coregister
import SimpleITK as sitk

from pathlib import Path

from scipy.ndimage import median_filter

In [None]:
root_path = Path("/media/telse/Extreme SSD/Papers/IPASCMultiCentre/Fixed")

In [None]:
def rgb_visualise(image_a, image_b):
    pa_image = np.stack([image_a, image_b, image_b], axis=2)
    pa_image -= np.min(pa_image)
    pa_image /= np.max(pa_image)
    return pa_image

In [None]:
def sort_key(s):
    return int(s.split("_")[-1].split(".")[0])

def apply_transformation(resampler, mask):
    new_mask = sitk.GetArrayFromImage(resampler.Execute(sitk.GetImageFromArray(np.squeeze(mask).astype(np.double))))>0.5
    return new_mask

In [None]:
files = root_path.glob("**/*.hdf5")
template_file = root_path / "Data/Scan_2.hdf5"

In [None]:
file_masks = {}

data_files = files
template_file = template_file

wavelength_to_register_on = -4

template_pa = pat.PAData.from_hdf5(template_file) # type: ignore
template_pa.set_default_recon(("Reference Backprojection", "0")) # type: ignore
print(template_pa.get_scan_name())
template_reconstruction = np.squeeze(template_pa.get_scan_reconstructions().raw_data[0, wavelength_to_register_on]) # type: ignore
template_reconstruction = median_filter(template_reconstruction, 5)
template_reconstruction -= np.min(template_reconstruction)
template_reconstruction /= np.max(template_reconstruction)

phantom_roi = template_pa.get_rois()[('phantom_', '0')]
points = phantom_roi.points

xpoints, ypoints = points.T
(x,), (y,) = phantom_roi.get_polygon().centroid.xy

r = 0.015

rec = template_pa.get_scan_reconstructions()
x0, x1, y0, y1 = rec.extent # type: ignore


mask = (np.linspace(x0, x1, rec.shape[2])[None, :] - x) ** 2 + (np.linspace(y0, y1, rec.shape[4])[:, None] - y) ** 2 < r ** 2 # type: ignore

for file in data_files:
    pa_moving = pat.PAData.from_hdf5(file) # type: ignore
    if "Clear" in pa_moving.get_scan_name() or ("Reference Backprojection", "0") not in pa_moving.get_scan_reconstructions(): # type: ignore
        print(file, pa_moving.get_scan_name())
        continue
    pa_moving.set_default_recon(("Reference Backprojection", "0")) # type: ignore
    rec1 = pa_moving.get_scan_reconstructions()
        
    pa_image1 = np.squeeze(rec1.raw_data[rec1.shape[0]//2, wavelength_to_register_on]) # type: ignore
    pa_image1 = median_filter(pa_image1, 5)
    pa_image1 -= np.min(pa_image1)
    pa_image1 /= np.max(pa_image1)
    _, output, resampler = coregister(pa_image1, template_reconstruction, metric="lstsq", verbose=0)
    
    masks = []
    new_mask = apply_transformation(resampler, mask)
    # masks.append(new_mask)
    
    # Now coregister each image in stack to the first one.
    for i in range(0, rec1.raw_data.shape[0]): # type: ignore
        pa_image2 = np.squeeze(rec1.raw_data[i, wavelength_to_register_on]) # type: ignore
        pa_image2 = median_filter(pa_image2, 5)
        pa_image2 -= np.min(pa_image2)
        pa_image2 /= np.max(pa_image2)
        _, _, resampler2 = coregister(pa_image2, pa_image1, metric="lstsq", verbose=0)
        masks.append(apply_transformation(resampler2, new_mask))
    
    # print(masks[-1].shape, len(masks))
    file_masks[file] = masks
    
    fig, (ax1, ax2) = plt.subplots(1, 2)
    pa_image = rgb_visualise(pa_image1, template_reconstruction)
    ax1.imshow(pa_image, extent=rec1.extent, origin="lower") # type: ignore
    ax1.set_title("Not registered")
    ax1.axis("off")
    pa_image = rgb_visualise(pa_image1, output)
    ax2.imshow(pa_image1, extent=rec1.extent, origin="lower", cmap="gray") # type: ignore
    ax2.contour(new_mask, extent=rec1.extent, origin="lower", colors="red") # type: ignore
    ax2.set_title("Registered")
    ax2.axis("off")
    fig.suptitle(pa_moving.get_scan_name() + ": " + file.parent.stem + "/" + file.stem)
    plt.tight_layout()
    plt.show()


In [None]:
r

In [None]:
file_masks = dict([(str(k), v) for k, v in file_masks.items()])

In [None]:
np.savez_compressed("intermediate_results/translated_masks.npz", **file_masks)