In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
"""
Read in the phase contrast videos
"""

import numpy as np
from scale_cell_transport import read

video_data: dict[str, tuple[np.ndarray, list[str]]] = read.phase_videos()

In [None]:
"""
Normalise contrast of the videos
"""

from tqdm import tqdm

# Segments out the scale, pretty much
thresh_percentile = 0.0385

thresholded_videos = {
    k: (
        (
            np.stack(
                [frame > np.quantile(frame, 1 - thresh_percentile) for frame in v[0]]
            )
            * 255
        ).astype(np.uint8),
        v[1],
    )
    for k, v in tqdm(video_data.items())
    if not np.isnan(v[0]).any()
}

In [None]:
"""
Get the actual video data
"""
video_names = list(thresholded_videos.keys())

before_frame_idx, after_frame_idx = 5, 15

# Our video data is {name: (video, timestamps)}
before_images = [video[before_frame_idx] for (video, _) in thresholded_videos.values()]
after_images = [video[after_frame_idx] for (video, _) in thresholded_videos.values()]

In [None]:
"""
Resize them - they're currently too big for the model
"""

from skimage.transform import resize

target_size = (512, 512)

def resize_images(image_list: list[np.ndarray], target_size: tuple[int, int]):
    """Resize images while preserving detail using anti-aliasing"""
    return [
        resize(img, target_size, anti_aliasing=True, preserve_range=True).astype(
            img.dtype
        )
        for img in image_list
    ]


before_images = resize_images(before_images, target_size)
after_images = resize_images(after_images, target_size)

In [None]:
"""
Apply a Gaussian blur to the images

"""

from skimage.filters import gaussian

before_images = [gaussian(img, sigma=1.0) for img in before_images]
after_images = [gaussian(img, sigma=1.0) for img in after_images]

In [None]:
import matplotlib.pyplot as plt

plot_kw = {"cmap": "gray", "interpolation": "nearest"}
for before, after, video_name in zip(before_images, after_images, video_names):
    fig, axes = plt.subplots(1, 2, figsize=(6, 3))
    fig.subplots_adjust(wspace=0.1)

    axes[0].imshow(before, **plot_kw)
    axes[1].imshow(after, **plot_kw)

    fig.suptitle(video_name)
    for i in range(2):
        axes[i].axis("off")

In [None]:
"""
Get the test data in the right format

This should be a torch tensor of shape (n_imgs, 1, height, width) for both the target and the source images

"""

import torch

target_images = torch.stack(
    [torch.tensor(x, dtype=torch.float32) for x in before_images]
).unsqueeze(1)
template_images = torch.stack(
    [torch.tensor(x, dtype=torch.float32) for x in after_images]
).unsqueeze(1)

data_dict = {"Template_image": template_images, "Target_image": target_images}

In [None]:
"""
Load the rotir model and weights
"""

from scale_cell_transport import files
from rotir.model import ImageRegistration
import torch

model_file = torch.load(files.model_path())

model = ImageRegistration(model_file["Parameter"]["model"])
model.load_state_dict(model_file["Model_state"])

In [None]:
"""
Run the model on the test data
"""

from rotir.utils import affine_transform, matrix_calculation_function

model.eval()
with torch.no_grad():
    output = model(data_dict)

score_thr = []
for op in output["score_map"][:, :-1, :-1]:
    t = torch.minimum(op.flatten().sort()[0][-3], torch.tensor(0.4))
    score_thr.append(t)

score_thr = torch.Tensor(score_thr).view(-1, 1, 1)

affine_matirx, matches, num = matrix_calculation_function(
    output,
    "Auto",
    score_thr,
    not model_file["Parameter"]["model"]["Apply_scale"],
    True,
    coordinate=True,
)
matches = matches.mul(32).add(16)

In [None]:
from rotir.plotting import plot_matches

for i in range(len(data_dict["Template_image"])):
    plot_matches(
        data_dict["Template_image"][i],
        data_dict["Target_image"][i],
        matches[i][: num[i]],
        lines=True,
    )

In [None]:
"""
Plot the aligned images
"""

out_total_image = affine_transform(data_dict["Template_image"], affine_matirx)

for i in range(len(video_names)):
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))

    target = data_dict["Target_image"][i][0].cpu().numpy()
    transformed = out_total_image[i][0].cpu().numpy()
    template = data_dict["Template_image"][i][0].cpu().numpy()

    # Create red version of target image
    target_red = np.zeros((target.shape[0], target.shape[1], 3))
    target_red[:, :, 0] = target / target.max()  # Only red channel

    # Create cyan version of transformed template
    transformed_cyan = np.zeros((transformed.shape[0], transformed.shape[1], 3))
    transformed_cyan[:, :, 1] = transformed / transformed.max()  # Green channel
    transformed_cyan[:, :, 2] = transformed / transformed.max()  # Blue channel

    # Plot target in red
    axes[0].imshow(target_red)
    axes[0].set_title(f"Target (Frame {before_frame_idx})")
    axes[0].axis("off")

    # Create overlay for middle plot
    overlay = np.zeros((target.shape[0], target.shape[1], 3))
    overlay[:, :, 0] = target / target.max()  # Red channel = target
    overlay[:, :, 1] = transformed / transformed.max()  # Green channel
    overlay[:, :, 2] = transformed / transformed.max()  # Blue channel

    axes[1].imshow(overlay)
    axes[1].set_title(f"Registration Overlay\nRed=Target, Cyan=Transformed")
    axes[1].axis("off")

    # Plot transformed template in cyan
    axes[2].imshow(transformed_cyan)
    axes[2].set_title(f"Transformed Template (Frame {after_frame_idx})")
    axes[2].axis("off")

    fig.suptitle(f"Registration Results: {video_names[i]}", fontsize=14)
    plt.tight_layout()
    plt.show()