In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from scale_cell_transport import read
read.phase_videos()

In [None]:
"""
Read in the videos
"""

from scale_cell_transport import files

video_dir = files.incucyte_video_dir_1()

paths = list(video_dir.glob("*.mp4"))

In [None]:
import cv2
import numpy as np
from tqdm import tqdm

arrays = []
for path in tqdm(paths):
    cap = cv2.VideoCapture(str(path))
    n_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))

    width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))

    buffer = np.empty((n_frames, height, width, 3), dtype=np.uint8)

    fc = 0
    ret = True

    while fc < n_frames and ret:
        ret, buffer[fc] = cap.read()
        fc += 1

    # Convert BGR to RGB
    arrays.append(buffer[:, :, :, ::-1])
    cap.release()

In [None]:
"""
Some of the videos are broken, so let's remove them from the list of videos and also the list of paths
"""

for i, (path, video) in enumerate(zip(paths, arrays)):
    if video.shape[:3] == (0, 0, 0):
        print(f"Removing {path}")
        paths.remove(path)
        arrays.pop(i)

In [None]:
"""Display two frames from each video"""

import matplotlib.pyplot as plt

before_frame_idx, after_frame_idx = 0, 0

before_images = [video[before_frame_idx] for video in arrays]
after_images = [video[after_frame_idx] for video in arrays]

plot_kw = {"cmap": "gray", "interpolation": "nearest"}
for before, after, path in zip(before_images, after_images, paths):
    fig, axes = plt.subplots(1, 2, figsize=(4, 2))
    fig.subplots_adjust(wspace=0.1)

    axes[0].imshow(before, **plot_kw)
    axes[1].imshow(after, **plot_kw)

    fig.suptitle(path.name)
    for i in range(2):
        axes[i].axis("off")

In [None]:
# Convert to grayscale
# before_images = [
#     cv2.cvtColor(video[before_frame_idx], cv2.COLOR_RGB2GRAY) for video in arrays
# ]
# after_images = [
#     cv2.cvtColor(video[after_frame_idx], cv2.COLOR_RGB2GRAY) for video in arrays
# ]
before_images = [
    np.mean(video[before_frame_idx][:, :, (0, 2)], axis=-1) for video in arrays
]
after_images = [
    np.mean(video[after_frame_idx][:, :, (0, 2)], axis=-1) for video in arrays
]

In [None]:
plot_kw = {"cmap": "gray", "interpolation": "nearest"}
for before, after in zip(before_images, after_images):
    fig, axes = plt.subplots(1, 2, figsize=(4, 2))
    fig.subplots_adjust(wspace=0.1)
    axes[0].imshow(before, **plot_kw)
    axes[1].imshow(after, **plot_kw)
    for i in range(2):
        axes[i].axis("off")

In [None]:
"""
Get the test data in the right format

This should be a torch tensor of shape (n_imgs, 1, height, width) for both the target and the source images

"""

import torch

target_images = torch.stack([torch.tensor(x, dtype=torch.float32) for x in before_images]).unsqueeze(1)
template_images = torch.stack([torch.tensor(x, dtype=torch.float32) for x in after_images]).unsqueeze(1)

data_dict = {"Template_image": template_images, "Target_image": target_images}

In [None]:
"""
Load the rotir model and weights
"""

from rotir.model import ImageRegistration
import torch

model_file = torch.load(files.model_path())

model = ImageRegistration(model_file["Parameter"]["model"])
model.load_state_dict(model_file["Model_state"])

In [None]:
"""
Run the model on the test data
"""

from rotir.utils import affine_transform, matrix_calculation_function

model.eval()
with torch.no_grad():
    output = model(data_dict)

score_thr = []
for op in output["score_map"][:, :-1, :-1]:
    t = torch.minimum(op.flatten().sort()[0][-3], torch.tensor(0.4))
    score_thr.append(t)

score_thr = torch.Tensor(score_thr).view(-1, 1, 1)

affine_matirx, matches, num = matrix_calculation_function(
    output,
    "Auto",
    score_thr,
    not model_file["Parameter"]["model"]["Apply_scale"],
    True,
    coordinate=True,
)
matches = matches.mul(32).add(16)

out_total_image = affine_transform(data_dict["Template_image"], affine_matirx)

In [None]:
from rotir.plotting import plot_matches

for i in range(len(data_dict["Template_image"])):
    plot_matches(
        data_dict["Template_image"][i],
        data_dict["Target_image"][i],
        matches[i][: num[i]],
        lines=True,
    )