# Preparation of the independent data set used for external validation

## Import

### Import serial H&E images into workspace

In [None]:
import constants as c
from pathlib import Path
from datasets.independent import import_images


source_dir = c.serial_he_source_dir
image_names = list(c.serial2terminal.keys())
target_dir = c.scratch_dir / "serial_he"


import_images(source_dir, target_dir, image_names=image_names)


## Reading and preprocessing

### Spatially register serial sections with terminal counterparts
The registration process generates a lot of textual output which tends to overwhelm Jupyter when run directly in the notebook, so do it in an external script instead.

In [None]:
import constants as c


for serial_img_name in c.serial2terminal.keys():
    !source activate torch && echo Registering $serial_img_name && python run_registration.py $serial_img_name > /dev/null


### Read registered H&E channels at full scale

In [None]:
import constants as c
import numpy as np
import utils
from logging import info
from skimage.util import img_as_float32
from tifffile import imread



source_dir = c.scratch_dir / "serial_he_registered"
target_dir = c.scratch_dir / "serial_he_preprocessed" / "level_0"


image_paths = utils.list_files(source_dir, file_pattern="*/*.tif")
target_dir.mkdir(parents=True)
for image_path in image_paths:
    save_path = target_dir / (image_path.name + ".npy")
    info(f"Reading H&E channels from image {image_path} and saving them to {save_path}.")
    image = imread(image_path)
    image = np.moveaxis(image, -1, 0)  # Channels last to channels first.
    image = img_as_float32(image)
    info(f"Shape and data type of the channels: {image.shape}, {image.dtype}.")
    np.save(save_path, image)
    

### Generate multi-scale versions of registered images

In [None]:
import constants as c
import numpy as np
import utils
from logging import info
from skimage.transform import pyramid_reduce


source_dir = c.scratch_dir / "serial_he_preprocessed" / "level_0"
target_dir = c.scratch_dir / "serial_he_preprocessed"
pyramid_levels = [2, 3]


terminal_imgs_dir = c.scratch_dir / "dataset_208_preprocessed" / "he"

image_paths = utils.list_files(source_dir, file_extension=".npy")
for image_path in image_paths:
    last_level = 0
    last_image = np.load(image_path)
    for level in pyramid_levels:
        level_dir = target_dir / f"level_{level}"
        level_dir.mkdir(exist_ok=True)
        save_path = level_dir / image_path.name
        info(f"Downsampling image {image_path.name} to level {level} and saving it to {save_path}.")
        image = pyramid_reduce(np.moveaxis(last_image, 0, -1), downscale=2*(level-last_level), multichannel=True)
        image = np.moveaxis(image, -1, 0)
        # We may need to manually crop the individual levels of the pyramid (by at most 1 pixel) since pyramid_reduce
        # seems to use a different rounding mechanism for the shape dimensions than was used in the terminal images.
        terminal_img_path = terminal_imgs_dir / f"level_{level}" /  (c.serial2terminal[image_path.stem] + ".npy")
        *_, terminal_size_y, terminal_size_x = np.load(terminal_img_path, mmap_mode="r").shape
        image = image[..., :terminal_size_y, :terminal_size_x]
        info(f"Shape of the level: {image.shape}.")
        np.save(save_path, image)
        last_level = level
        last_image = image
        

## Tiling #1
For segmentation using the model from Schmitz et al. (2021).

### Split the H&E images at the different scale levels into tiles

In [None]:
import constants as c
from preprocessing import tile_images


source_dir = c.scratch_dir / "serial_he_preprocessed"
pyramid_levels = [0, 2, 3]
tile_shapes = [(512, 512), (512, 512), (512, 512)]
stride = (512, 512)
target_dir = c.scratch_dir / "serial_he_tiled"


for level, tile_shape in zip(pyramid_levels, tile_shapes):
    level_dir = target_dir / f"level_{level}"
    level_dir.mkdir(parents=True, exist_ok=True)
    overlap = (tile_shape[0] - stride[0]) // 2, (tile_shape[1] - stride[1]) // 2
    tiling_dir = (
        level_dir / f"shape_{stride[0]}_{stride[1]}_overlap_{overlap[0]}_{overlap[1]}"
    )

    anchor_y = stride[0] // 2 ** (level + 1)
    anchor_x = stride[1] // 2 ** (level + 1)
    stride_y = stride[0] // 2**level
    stride_x = stride[1] // 2**level

    tile_images(
        source_dir / f"level_{level}",
        tiling_dir,
        tile_shape,
        (anchor_y, anchor_x),
        (stride_y, stride_x),
    )


## Index structures for guidance of tile sampling at training time

### Create lookup tables for tissue foreground-to-background ratio

In [None]:
import constants as c
import numpy as np
from datasets.independent import extract_tissue_fg
from preprocessing import compute_tile_statistics


tilings_dir = (
    c.scratch_dir / "serial_he_tiled" / "level_0" / "shape_512_512_overlap_0_0"
)


def compute_tissue_fg_ratio(tile: np.ndarray) -> float:
    assert np.issubdtype(tile.dtype, np.floating)
    tissue_mask = extract_tissue_fg(tile)
    return tissue_mask.sum() / float(tissue_mask.size)


compute_tile_statistics(tilings_dir, "tissue_fg_ratios", compute_tissue_fg_ratio)


## Tiling #2
For stain transfer (training).

In [None]:
import constants as c
from preprocessing import tile_images


source_dir = c.scratch_dir / "serial_he_preprocessed"
pyramid_levels = [0]
tile_shapes = [(256, 256), (256, 256), (256, 256)]
stride = (256, 256)
target_dir = c.scratch_dir / "serial_he_tiled"


for level, tile_shape in zip(pyramid_levels, tile_shapes):
    level_dir = target_dir / f"level_{level}"
    level_dir.mkdir(parents=True, exist_ok=True)
    overlap = (tile_shape[0] - stride[0]) // 2, (tile_shape[1] - stride[1]) // 2
    tiling_dir = (
        level_dir / f"shape_{stride[0]}_{stride[1]}_overlap_{overlap[0]}_{overlap[1]}"
    )

    anchor_y = stride[0] // 2 ** (level + 1)
    anchor_x = stride[1] // 2 ** (level + 1)
    stride_y = stride[0] // 2**level
    stride_x = stride[1] // 2**level

    tile_images(
        source_dir / f"level_{level}",
        tiling_dir,
        tile_shape,
        (anchor_y, anchor_x),
        (stride_y, stride_x),
    )


In [None]:
import constants as c
import numpy as np
from datasets.independent import extract_tissue_fg
from preprocessing import compute_tile_statistics


tilings_dir = (
    c.scratch_dir / "serial_he_tiled" / "level_0" / "shape_256_256_overlap_0_0"
)


def compute_tissue_fg_ratio(tile: np.ndarray) -> float:
    assert np.issubdtype(tile.dtype, np.floating)
    tissue_mask = extract_tissue_fg(tile)
    return tissue_mask.sum() / float(tissue_mask.size)


compute_tile_statistics(tilings_dir, "tissue_fg_ratios", compute_tissue_fg_ratio)


## Tiling #3
For stain transfer (inference).

In [None]:
import constants as c
from preprocessing import tile_images


source_dir = c.scratch_dir / "serial_he_preprocessed"
pyramid_levels = [0]
tile_shapes = [(2048, 2048), (2048, 2048), (2048, 2048)]
stride = (512, 512)
target_dir = c.scratch_dir / "serial_he_tiled"


for level, tile_shape in zip(pyramid_levels, tile_shapes):
    level_dir = target_dir / f"level_{level}"
    level_dir.mkdir(parents=True, exist_ok=True)
    overlap = (tile_shape[0] - stride[0]) // 2, (tile_shape[1] - stride[1]) // 2
    tiling_dir = (
        level_dir / f"shape_{stride[0]}_{stride[1]}_overlap_{overlap[0]}_{overlap[1]}"
    )

    anchor_y = stride[0] // 2 ** (level + 1)
    anchor_x = stride[1] // 2 ** (level + 1)
    stride_y = stride[0] // 2**level
    stride_x = stride[1] // 2**level

    tile_images(
        source_dir / f"level_{level}",
        tiling_dir,
        tile_shape,
        (anchor_y, anchor_x),
        (stride_y, stride_x),
    )
