In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
"""
Read in the videos
"""

from scale_cell_transport import files

video_dir = files.incucyte_video_dir_1()

paths = list(video_dir.glob("*.mp4"))

In [None]:
import cv2
import numpy as np
from tqdm import tqdm

arrays = []
for path in tqdm(paths):
    cap = cv2.VideoCapture(str(path))
    n_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))

    width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))

    buffer = np.empty((n_frames, height, width, 3), dtype=np.uint8)

    fc = 0
    ret = True

    while fc < n_frames and ret:
        ret, buffer[fc] = cap.read()
        fc += 1

    # Convert BGR to RGB
    arrays.append(buffer[:, :, :, ::-1])
    cap.release()

In [None]:
"""
Some of the videos are broken, so let's remove them from the list of videos and also the list of paths
"""

for i, (path, video) in enumerate(zip(paths, arrays)):
    if video.shape[:3] == (0, 0, 0):
        print(f"Removing {path}")
        paths.remove(path)
        arrays.pop(i)

In [None]:
"""Display two frames from each video"""

import matplotlib.pyplot as plt

before_frame_idx, after_frame_idx = 0, 10

# Convert to grayscale
before_images = [
    cv2.cvtColor(video[before_frame_idx], cv2.COLOR_RGB2GRAY) for video in arrays
]
after_images = [
    cv2.cvtColor(video[after_frame_idx], cv2.COLOR_RGB2GRAY) for video in arrays
]

In [None]:
fig, axes = plt.subplots(len(arrays), 2, figsize=(5, 20))

plot_kw = {"cmap": "gray", "interpolation": "nearest"}
for axs, before, after in zip(axes, before_images, after_images):
    for i in range(2):
        axs[i].axis("off")
    axs[0].imshow(before, **plot_kw)
    axs[1].imshow(after, **plot_kw)

In [None]:
"""
Get the test data in the right format

This should be a torch tensor of shape (n_imgs, 1, height, width) for both the target and the source images

"""

import torch

target_images = torch.stack([torch.tensor(x, dtype=torch.float32) for x in before_images]).unsqueeze(1)
template_images = torch.stack([torch.tensor(x, dtype=torch.float32) for x in after_images]).unsqueeze(1)

data_dict = {"Template_image": template_images, "Target_image": target_images}

In [None]:
"""
Instead read in the images from the test repo
"""
import os

from PIL import Image
import torchvision.transforms as T

image_list1 = [
    "2021-09-27_RenataPlate1_A5_1.tif_03.jpg",
    "2021-09-27_RenataPlate1_H8_1.tif_06.jpg",
    "2022-04-11_QiaoPlate1_B2_1.tif_07.jpg",
    "2022-04-11_QiaoPlate1_C3_1.tif_15.jpg",
    "2022-04-11_QiaoPlate1_E9_1.tif_09.jpg",
    "2021-09-27_RenataPlate1_C7_1.tif_02.jpg",
    "2022-04-11_QiaoPlate1_E5_1.tif_01.jpg",
]

image_list2 = [
    "2021-09-27_RenataPlate1_C7_1.tif_09.jpg",
    "2021-09-27_RenataPlate1_H8_1.tif_08.jpg",
    "2022-04-11_QiaoPlate1_B2_1.tif_64.jpg",
    "2022-04-11_QiaoPlate1_C3_1.tif_19.jpg",
    "2022-04-11_QiaoPlate1_E9_1.tif_72.jpg",
    "2021-09-27_RenataPlate1_A5_1.tif_08.jpg",
    "2022-04-11_QiaoPlate1_E5_1.tif_78.jpg",
]

spath = "../../rotir_test_images"
test_dict = {
    "Template_image": [],
    "Target_image": [],
}

for i, (item1, item2) in enumerate(zip(image_list1, image_list2)):

    im1 = Image.open(os.path.join(spath, item1)).resize((512, 512))
    im2 = Image.open(os.path.join(spath, item2)).resize((512, 512))
    image1 = T.ToTensor()(im1).mul(2).add(-1)
    image2 = T.ToTensor()(im2).mul(2).add(-1)

    test_dict["Template_image"].append(image1)
    test_dict["Target_image"].append(image2)

for k, v in test_dict.items():
    test_dict[k] = torch.stack(v, dim=0)

In [None]:
"""
Load the rotir model and weights
"""
import torch
from rotir.model import ImageRegistration

model_file = torch.load(files.model_path())

model = ImageRegistration(model_file["Parameter"]["model"])

In [None]:
"""
Run the model on the test data
"""

from rotir.utils import matrix_calculation_function, affine_transform

model.eval()
with torch.no_grad():
    output = model(data_dict)

score_thr = []
for op in output["score_map"][:, :-1, :-1]:
    t = torch.minimum(op.flatten().sort()[0][-3], torch.tensor(0.2))
    score_thr.append(t)
    print(t)

score_thr = torch.Tensor(score_thr).view(-1, 1, 1)

affine_matirx, matches, num = matrix_calculation_function(
    output,
    "Auto",
    score_thr,
    not model_file["Parameter"]["model"]["Apply_scale"],
    True,
    coordinate=True,
)
matches = matches.mul(32).add(16)

out_total_image = affine_transform(template_images, affine_matirx)

In [None]:
from rotir.plotting import plot_matches

for i in range(len(template_images)):
    plot_matches(
        template_images[i], target_images[i], matches[i][: num[i]], lines=True
    )