# Process Electron Microscopy (EM) Dataset

Processes the given EM dataset using LoFTR, creating a single, stitched image.

In [None]:
%cd ..

import re
import os
import torch
import cv2
import numpy as np
import matplotlib.cm as cm
import matplotlib.pyplot as plt
from glob import glob
from copy import deepcopy
from src.loftr import LoFTR, default_cfg
from src.utils.plotting import make_matching_figure
import random

In [None]:
overlap = 0.05
max_matches_shown = 50
resolution_ratio = 0.5
out_dir = "out/TESCAN/8x3/"
target_dir = "assets/TESCAN/8x3/"

# Check if the target directory exists.
if not os.path.isdir(target_dir):
    raise ValueError(f"Cannot read directory: {target_dir}")

# Read the expected grid size.
dir_name = os.path.basename(os.path.normpath(target_dir))
match = re.match(r"(\d+)x(\d+)", dir_name, flags=re.IGNORECASE)
if match is None:
    raise ValueError(f"Invalid grid specification: {dir_name}'")
grid_size = [int(x) for x in match.groups()]  # (rows, columns)

# Pair adjacent images together.
stitched_images = {}
for path in sorted(os.listdir(target_dir)):
    if os.path.isdir(path):
        continue
    
    # Read grid, tile and slice info.
    match = re.match(r".*g(\d+).t(\d+).s(\d+).*", path, flags=re.IGNORECASE)
    if match is None:
        continue
    tile_str = match.group(2)
    grid_index, tile_index, slice_index = [int(x) for x in match.groups()]

    # Pair with the below image if present.
    new_pairs = []
    if tile_index + grid_size[1] < grid_size[0] * grid_size[1]:
        paired_tile_index = tile_index + grid_size[1]
        paired = path.replace("t" + tile_str,
                              f"t{paired_tile_index:0{len(tile_str)}d}")
        new_pairs.append({"first": path, "second": paired, "type": "row"})
    
    # Pair with the right image if present.
    if (tile_index + 1) % grid_size[1] != 0:
        paired_tile_index = tile_index + 1
        paired = path.replace("t" + tile_str,
                              f"t{paired_tile_index:0{len(tile_str)}d}")
        new_pairs.append({"first": path, "second": paired, "type": "column"})

    # Check that all paired files exist.
    for pair in new_pairs:
        if not os.path.isfile(pair["second"]):
            ValueError(f"Expected file in grid but none found: {pair['second']}")
    
    # Save new pairs under the corresponding grid index.
    if grid_index in stitched_images:
        stitched_images[grid_index] = stitched_images[grid_index] + new_pairs
    else:
        stitched_images[grid_index] = new_pairs

stitched_images

In [None]:
# Prepare LoFTR for matching using indoor weights (better than outdoor for EM).
_default_cfg = deepcopy(default_cfg)
_default_cfg["coarse"]["temp_bug_fix"] = True  # Temporary bugfix for the indoor checkpoint.
matcher = LoFTR(config=_default_cfg)
checkpoint_path = "weights/indoor_ds.ckpt"
matcher.load_state_dict(torch.load(checkpoint_path)["state_dict"])
matcher = matcher.eval().cuda()

# Match all paired images and save match visualisations.
os.makedirs(out_dir, exist_ok=True)
for grid_index, stitched_pairs in stitched_images.items():
    # Go over all stitched image pairs for the current grid segment.
    for pair in stitched_pairs:
        # Read both images.
        path1 = os.path.join(target_dir, pair["first"])
        path2 = os.path.join(target_dir, pair["second"])
        img1 = cv2.imread(path1, cv2.IMREAD_GRAYSCALE)
        img2 = cv2.imread(path2, cv2.IMREAD_GRAYSCALE)

        # Check that both images are valid.
        if img1 is None or img2 is None:
            raise ValueError(f"Failed to read image pair: ({pair['first']}, {pair['second']})")
        
        # Crop both images based on expected overlap. Also rotate row pairs to
        # a vertical position (required for precise matching).
        height1, width1 = img1.shape
        height2, width2 = img2.shape
        if pair["type"] == "row":
            height_crop_range = int(height1 * overlap)
            img1 = img1[-height_crop_range:, :]  # Top side
            img2 = img2[:height_crop_range, :]  # Bottom side

            # Rotate the images.
            img1 = cv2.rotate(img1, cv2.ROTATE_90_COUNTERCLOCKWISE)
            img2 = cv2.rotate(img2, cv2.ROTATE_90_COUNTERCLOCKWISE)
        else:
            width_crop_range = int(width1 * overlap)
            img1 = img1[:, -width_crop_range:]  # Left side
            img2 = img2[:, :width_crop_range]  # Right side

        # Resize the images. Note that the final resolution needs to be
        # adjusted to a multiple of 8.
        height1, width1 = img1.shape
        height2, width2 = img2.shape
        img1 = cv2.resize(img1, (int(width1 * resolution_ratio) // 8 * 8,
                                 int(height1 * resolution_ratio) // 8 * 8))
        img2 = cv2.resize(img2, (int(width2 * resolution_ratio) // 8 * 8,
                                 int(height2 * resolution_ratio) // 8 * 8))
        
        # Create float batch tensors.
        batch_img1 = torch.from_numpy(img1)
        batch_img1 = batch_img1.reshape(1, 1, *batch_img1.shape).cuda() / 255.0
        batch_img2 = torch.from_numpy(img2)
        batch_img2 = batch_img2.reshape(1, 1, *batch_img2.shape).cuda() / 255.0

        # Run inference with LoFTR and get the prediction.
        batch = {"image0": batch_img1, "image1": batch_img2}
        with torch.no_grad():
            matcher(batch)
            matches1 = batch["mkpts0_f"].cpu().numpy()
            matches2 = batch["mkpts1_f"].cpu().numpy()
            confidence = batch["mconf"].cpu().numpy()

        # Save the matching figure.
        match_samples = random.sample(range(0, len(matches1)),
                                      min(max_matches_shown, len(matches1)))
        color = cm.jet(confidence[match_samples])
        text = [f"Total Matches: {len(matches1)}"]
        fig = make_matching_figure(img1, img2, matches1[match_samples],
                                   matches2[match_samples], color, text=text)
        out_path = os.path.join(out_dir, f"g{grid_index:04d}",
                                f"{pair['first']}_{pair['second']}_matches.png")
        os.makedirs(os.path.dirname(out_path), exist_ok=True)
        fig.savefig(out_path)
        plt.close(fig)