# Preparation of the qualitative evaluation of the stain transfer's plausibility by human pathologists

For the direction from real terminal to synthetic serial H&E. Modify accordingly for the other direction.

In [None]:
import csv
from itertools import product
from logging import info
from random import random, sample, shuffle

import numpy as np
from PIL import Image
from skimage.util import img_as_ubyte

import constants as c
import crossvalidation as cv
from experiment import Experiment
from tiling import TiledImage
from utils import fuse_sparse_tiles


def get_candidates(
    terminal_tiling,
    serial_tiling_dir,
    label_tiling_dir,
    threshold_tissue_ratio,
    threshold_tissue_overlap,
):
    num_tiles_y, num_tiles_x = terminal_tiling.num_tiles

    terminal_tissue_ratios = np.load(terminal_tiling.path / "tissue_fg_ratios.lut")
    serial_tissue_ratios = np.load(serial_tiling_dir / "tissue_fg_ratios.lut")
    tissue_overlap_ratios = np.load(
        serial_tiling_dir / "serial_terminal_tissue_fg_overlaps.lut"
    )
    terminal_epithelium_ratios = np.load(label_tiling_dir / "label_fg_ratios.lut")

    num_total = 0
    candidates = []
    for y, x in product(range(0, num_tiles_y - 1, 2), range(0, num_tiles_x - 1, 2)):
        num_total += 1
        terminal_tissue_ratio = get_ratio(terminal_tissue_ratios, terminal_tiling, y, x)
        if (
            terminal_tissue_ratio > threshold_tissue_ratio
            and get_ratio(serial_tissue_ratios, terminal_tiling, y, x)
            > threshold_tissue_ratio
            and get_ratio(tissue_overlap_ratios, terminal_tiling, y, x)
            > threshold_tissue_overlap
        ):
            terminal_epithelium_ratio = get_ratio(
                terminal_epithelium_ratios, terminal_tiling, y, x
            )
            candidates.append((y, x, terminal_tissue_ratio, terminal_epithelium_ratio))
    info(
        f"Filtered {len(candidates)} candidate frames from {num_total} frames in total."
    )
    return candidates


def get_ratio(ratios, tiling, y, x):
    yxs = [
        (y, x),
        (y, x + 1),
        (y + 1, x),
        (y + 1, x + 1),
    ]
    return sum(ratios[tiling.get_flat_index(y_, x_)] for (y_, x_) in yxs) / len(yxs)


def sample_candidates(candidates, n):
    # TODO: Weight by epithelium and FG/BG ratios?
    return sample(candidates, n)


def extract_frames(terminal_tiling, staintrans_tiling, serial_tiling, candidates):
    tilings = terminal_tiling, staintrans_tiling, serial_tiling
    return [
        (terminal_tiling.path.name, y, x)
        + tuple(extract_frame(tiling, y, x) for tiling in tilings)
        for y, x, *_ in candidates
    ]


def extract_frame(tiling, y, x):
    tile = tiling[y, x]
    return fuse_sparse_tiles(
        [
            tile,
            tiling[y, x + 1],
            tiling[y + 1, x],
            tiling[y + 1, x + 1],
        ],
        [
            (0, 0),
            (0, 1),
            (1, 0),
            (1, 1),
        ],
        (2 * tile.shape[-2], 2 * tile.shape[-1]),
    )


def save_frame(frame, save_path):
    frame = img_as_ubyte(np.moveaxis(frame, 0, 2))
    Image.fromarray(frame, "RGB").save(save_path)


staintrans_name = "feature_loss_lr_0.0004"
checkpoint_epochs = {
    0: 64,
    1: 50,
    2: 41,
    3: None,
    4: 30,
}
n = 11
threshold_tissue_ratio = 0.8
threshold_tissue_overlap = 0.8

terminal_samples_dir = c.scratch_dir / cv.get_terminal_sample_dirs()[0]
terminal_targets_dir = c.scratch_dir / cv.get_terminal_targets_dir()
serial_samples_dir = c.scratch_dir / cv.get_serial_sample_dirs()[0]

with Experiment(
    name=f"pathology_test_staintrans_{staintrans_name}", seed=c.seed
) as exp:
    info(
        (
            f'Preparing pathology test for results of stain transfer experiment "{staintrans_name}". '
            f"Will sample {n} frames per image. Candidate frames will be filtered by individual FG/BG ratio "
            f"(threshold={threshold_tissue_ratio}) and pairwise FG overlap (threshold={threshold_tissue_overlap}) "
            "beforehand."
        )
    )
    all_frames = []
    for i, (*_, terminal_test_images, serial_test_images) in cv.get_enumerated_folds():
        checkpoint_epoch = checkpoint_epochs[i]
        info(f"Frames of cross-validation fold {i} are taken from checkpoint epoch {checkpoint_epoch}.")
        staintrans_samples_dir = (
            c.scratch_dir
            / cv.get_staintrans_terminal_sample_dirs(staintrans_name, checkpoint_epoch, i)[0]
        )
        for terminal_image_name, serial_image_name in zip(
            terminal_test_images, serial_test_images
        ):
            info(
                f"Extracting frames from image pair {terminal_image_name} <-> {serial_image_name}."
            )
            terminal_tiling = TiledImage(terminal_samples_dir / terminal_image_name)
            serial_tiling = TiledImage(serial_samples_dir / serial_image_name)
            label_tiling_dir = terminal_targets_dir / terminal_image_name
            staintrans_tiling = TiledImage(staintrans_samples_dir / terminal_image_name)

            candidates = get_candidates(
                terminal_tiling,
                serial_tiling.path,
                label_tiling_dir,
                threshold_tissue_ratio,
                threshold_tissue_overlap,
            )
            samples = sample_candidates(candidates, n)
            frames = extract_frames(
                terminal_tiling, staintrans_tiling, serial_tiling, samples
            )
            all_frames.extend(frames)
    info(f"Extracted {len(all_frames)} frames in total.")

    save_dir = exp.working_dir / "data"
    serial_save_dir = save_dir / "serial"
    serial_save_dir.mkdir(parents=True)
    terminal_save_dir = save_dir / "terminal"
    terminal_save_dir.mkdir()
    info(
        f"Shuffling and saving all frames to their respective directories in {save_dir}."
    )

    shuffle(all_frames)
    with open(save_dir / "solution.csv", "x", newline="") as f:
        solution_writer = csv.writer(f)
        solution_writer.writerow(
            ["Frame number", "Image name", "Y", "X", "Real suffix", "Synthetic suffix"]
        )
        for i, (
            name,
            y,
            x,
            terminal_frame,
            staintrans_frame,
            serial_frame,
        ) in enumerate(all_frames):
            file_stem = f"{i:03}"
            if random() < 0.5:
                serial_suffix = "a"
                staintrans_suffix = "b"
            else:
                serial_suffix = "b"
                staintrans_suffix = "a"
            serial_save_path = serial_save_dir / (file_stem + serial_suffix + ".png")
            staintrans_save_path = serial_save_dir / (
                file_stem + staintrans_suffix + ".png"
            )
            terminal_save_path = terminal_save_dir / (file_stem + "terminal.png")

            save_frame(serial_frame, serial_save_path)
            save_frame(staintrans_frame, staintrans_save_path)
            save_frame(terminal_frame, terminal_save_path)

            solution_writer.writerow([i, name, y, x, serial_suffix, staintrans_suffix])
    info("Done preparing pathology test.")


Fuse each extracted pair into a single image with annotations.

In [38]:
from pathlib import Path
from PIL import Image, ImageDraw, ImageFont
from utils import list_files


source_dir = Path.home() / "Downloads" / "terminal"
target_dir = source_dir / "paired"
target_dir.mkdir(exist_ok=True)

image_paths_a = list_files(source_dir, file_pattern="*a.png")
for image_path_a in image_paths_a:
    image_name = image_path_a.stem[:-1]
    image_path_b = image_path_a.with_stem(image_name + "b")
    with Image.open(image_path_a) as image_a, Image.open(image_path_b) as image_b:
        fused = Image.new(
            "RGBA", (image_a.width + image_b.width + 50, image_a.height + 100)
        )
        fused.paste(image_a, (0, 0))
        fused.paste(image_b, (image_a.width + 50, 0))
        draw = ImageDraw.Draw(fused)
        font = ImageFont.truetype("Arimo-Regular.ttf", 96)
        *_, w, _ = draw.textbbox((0, 0), "A", font=font)
        draw.text(
            ((image_a.width - w) / 2, image_a.height), "A", fill="black", font=font
        )
        *_, w, _ = draw.textbbox((0, 0), "B", font=font)
        draw.text(
            (image_a.width + 50 + (image_b.width - w) / 2, image_b.height),
            "B",
            fill="black",
            font=font,
        )
        fused.save(target_dir / (image_name + ".png"))
