In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
"""
Read in the phase contrast videos
"""

import numpy as np
from scale_cell_transport import read

video_data: dict[str, tuple[np.ndarray, list[str]]] = read.phase_videos()

In [None]:
"""
Normalise contrast of the videos
"""

from tqdm import tqdm
from skimage import exposure

norm_percentile = 0.80

video_data_norm = {
    k: (
        exposure.rescale_intensity(
            v[0],
            in_range=(np.min(v[0]), np.quantile(v[0], norm_percentile)),
        ),
        v[1],
    )
    for k, v in tqdm(video_data.items())
}

In [None]:
"""
Get the actual video data
"""

import matplotlib.pyplot as plt

video_names = list(video_data_norm.keys())

before_frame_idx, after_frame_idx = 5, 6

# Our video data is {name: (video, timestamps)}
before_images = [video[before_frame_idx] for (video, _) in video_data_norm.values()]
after_images = [video[after_frame_idx] for (video, _) in video_data_norm.values()]

In [None]:
"""
Resize them - they're currently too big for the model
"""

from skimage.transform import resize

target_size = (1024, 1024)

def resize_images(image_list: list[np.ndarray], target_size: tuple[int, int]):
    """Resize images while preserving detail using anti-aliasing"""
    return [
        resize(img, target_size, anti_aliasing=True, preserve_range=True).astype(
            img.dtype
        )
        for img in image_list
    ]


before_images = resize_images(before_images, target_size)
after_images = resize_images(after_images, target_size)

In [None]:
plot_kw = {"cmap": "gray", "interpolation": "nearest"}
for before, after, video_name in zip(before_images, after_images, video_names):
    fig, axes = plt.subplots(1, 2, figsize=(4, 2))
    fig.subplots_adjust(wspace=0.1)

    axes[0].imshow(before, **plot_kw)
    axes[1].imshow(after, **plot_kw)

    fig.suptitle(video_name)
    for i in range(2):
        axes[i].axis("off")

In [None]:
"""
Get the test data in the right format

This should be a torch tensor of shape (n_imgs, 1, height, width) for both the target and the source images

"""

import torch

target_images = torch.stack(
    [torch.tensor(x, dtype=torch.float32) for x in before_images]
).unsqueeze(1)
template_images = torch.stack(
    [torch.tensor(x, dtype=torch.float32) for x in after_images]
).unsqueeze(1)

data_dict = {"Template_image": template_images, "Target_image": target_images}

In [None]:
"""
Load the rotir model and weights
"""

from scale_cell_transport import files
from rotir.model import ImageRegistration
import torch

model_file = torch.load(files.model_path())

model = ImageRegistration(model_file["Parameter"]["model"])
model.load_state_dict(model_file["Model_state"])

In [None]:
"""
Run the model on the test data
"""

from rotir.utils import affine_transform, matrix_calculation_function

model.eval()
with torch.no_grad():
    output = model(data_dict)

score_thr = []
for op in output["score_map"][:, :-1, :-1]:
    t = torch.minimum(op.flatten().sort()[0][-3], torch.tensor(0.4))
    score_thr.append(t)

score_thr = torch.Tensor(score_thr).view(-1, 1, 1)

affine_matirx, matches, num = matrix_calculation_function(
    output,
    "Auto",
    score_thr,
    not model_file["Parameter"]["model"]["Apply_scale"],
    True,
    coordinate=True,
)
matches = matches.mul(32).add(16)

out_total_image = affine_transform(data_dict["Template_image"], affine_matirx)

In [112]:
from rotir.plotting import plot_matches

for i in range(len(data_dict["Template_image"])):
    plot_matches(
        data_dict["Template_image"][i],
        data_dict["Target_image"][i],
        matches[i][: num[i]],
        lines=True,
    )