# Process Electron Microscopy (EM) Dataset

Processes a dataset of EM image grids, creating stitched images according to the selected configuration.

In [None]:
%load_ext autoreload
%autoreload 2

import cv2
import matplotlib.pyplot as plt
import os
import sys

# Add project directories.
sys.path.append(os.path.abspath(".."))
sys.path.append(os.path.abspath("../src/dataset/"))

# Import project definitions.
from src.config.config import get_cfg_defaults
from src.dataset.dataset_loader import DatasetLoader
from src.dataset.demis_loader import DemisLoader
from src.pipeline.demis_stitcher import DemisStitcher
from src.pipeline.image_cache import ImageCache

In [None]:
# Prepare configuration.
config = "../configs/tescan_2x2.yaml"
config = "../configs/tescan_8x3.yaml"
config = "../configs/demis.yaml"
cfg = get_cfg_defaults()
cfg.merge_from_file(config)
cfg.freeze()

# Load image paths.
if "x" in cfg.DATASET.PATH.lower():
    loader = DatasetLoader(cfg.DATASET.PATH)
    image_paths = loader.load_paths()
else:
    loader = DemisLoader(cfg.DATASET.PATH)
    labels = loader.load_labels()
    image_paths = loader.load_paths(labels)

len(image_paths)

In [None]:
# Setup directories.
os.makedirs(cfg.STITCHER.OUTPUT_PATH, exist_ok=True)

# Stitch image tiles in each grid. A minimum spanning tree (MST) is constructed using
# Prim-Jarník's algorithm to estimate the best stitching order for each grid.
cache = ImageCache(cfg)
stitcher = DemisStitcher(cfg, cache)
for path_key, tile_paths in list(image_paths.items()):
    stitched_image, _ = stitcher.stitch_grid_mst(tile_paths)
    
    # Save the result.
    grid_index, slice_index = path_key.split("_")
    out_path = os.path.join(cfg.STITCHER.OUTPUT_PATH, f"g{int(grid_index):05d}_s{int(slice_index):05d}.png")
    cv2.imwrite(out_path, stitched_image)

    # Plot the result.
    if cfg.STITCHER.PLOT_OUTPUT:
        fig = plt.figure(figsize=(50, 50))
        if cfg.STITCHER.COLORED_OUTPUT:
            plt.imshow(stitched_image)
        else:
            plt.imshow(stitched_image, cmap="gray")

In [None]:
for grid_labels in labels:
    stitched_image, _ = stitcher.stitch_demis_grid_mst(grid_labels)

    # Save the result.
    grid_index, slice_index = path_key.split("_")
    out_path = os.path.join(cfg.STITCHER.OUTPUT_PATH, f"g{int(grid_index):05d}_s{int(slice_index):05d}_demis.png")
    cv2.imwrite(out_path, stitched_image)

    fig = plt.figure(figsize=(50, 50))
    plt.imshow(stitched_image, cmap="gray")
