In [None]:
# System
import sys
import os

# Misc
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt

# ML
import torch
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import StepLR

# Augmentations
import albumentations as A

# Custom
sys.path.append(os.path.abspath(".."))
from hover_net.dataloader.dataset import get_dataloader
from hover_net.datasets.puma_dataset import PumaDataset

In [None]:
# Packages
import numpy as np
from scipy.ndimage import measurements

def get_bounding_box(img):
    """Get bounding box coordinate information."""
    rows = np.any(img, axis=1)
    cols = np.any(img, axis=0)
    rmin, rmax = np.where(rows)[0][[0, -1]]
    cmin, cmax = np.where(cols)[0][[0, -1]]
    # due to python indexing, need to add 1 to max
    # else accessing will be 1px in the box, not out
    rmax += 1
    cmax += 1
    return [rmin, rmax, cmin, cmax]


def fix_mirror_padding(ann):
    """
    Deal with duplicated instances due to mirroring in interpolation
    during shape augmentation (scale, rotation etc.).
    """
    current_max_id = np.amax(ann)
    inst_list = list(np.unique(ann))
    inst_list.remove(0)  # 0 is background
    for inst_id in inst_list:
        inst_map = np.array(ann == inst_id, np.uint8)
        remapped_ids = measurements.label(inst_map)[0]
        remapped_ids[remapped_ids > 1] += current_max_id
        ann[remapped_ids > 1] = remapped_ids[remapped_ids > 1]
        current_max_id = np.amax(ann)
    return ann


def gen_instance_hv_map(ann):
    """
    Generate the HoVer maps for each nuclear instance.

    Args:
        - ann: instance ID map
    Returns:
        - hv_map: horizontal and vertical distance maps
    """
    fixed_ann = fix_mirror_padding(ann)

    x_map = np.zeros(ann.shape[:2], dtype=np.float32)
    y_map = np.zeros(ann.shape[:2], dtype=np.float32)

    inst_list = list(np.unique(ann))
    inst_list.remove(0)  # 0 is background
    for inst_id in inst_list:
        inst_map = np.array(fixed_ann == inst_id, np.uint8)
        inst_box = get_bounding_box(inst_map)

        # expand the box by 2px
        # Because we first pad the ann at line 207, the bboxes
        # will remain valid after expansion
        inst_box[0] -= 2
        inst_box[2] -= 2
        inst_box[1] += 2
        inst_box[3] += 2

        inst_map = inst_map[inst_box[0]:inst_box[1], inst_box[2]:inst_box[3]]

        if inst_map.shape[0] < 2 or inst_map.shape[1] < 2:
            continue

        # instance center of mass, rounded to nearest pixel
        inst_com = list(measurements.center_of_mass(inst_map))

        inst_com[0] = int(inst_com[0] + 0.5)
        inst_com[1] = int(inst_com[1] + 0.5)

        inst_x_range = np.arange(1, inst_map.shape[1] + 1)
        inst_y_range = np.arange(1, inst_map.shape[0] + 1)
        # shifting center of pixels grid to instance center of mass
        inst_x_range -= inst_com[1]
        inst_y_range -= inst_com[0]

        inst_x, inst_y = np.meshgrid(inst_x_range, inst_y_range)

        # remove coord outside of instance
        inst_x[inst_map == 0] = 0
        inst_y[inst_map == 0] = 0
        inst_x = inst_x.astype("float32")
        inst_y = inst_y.astype("float32")

        # normalize min into -1 scale
        if np.min(inst_x) < 0:
            inst_x[inst_x < 0] /= -np.amin(inst_x[inst_x < 0])
        if np.min(inst_y) < 0:
            inst_y[inst_y < 0] /= -np.amin(inst_y[inst_y < 0])
        # normalize max into +1 scale
        if np.max(inst_x) > 0:
            inst_x[inst_x > 0] /= np.amax(inst_x[inst_x > 0])
        if np.max(inst_y) > 0:
            inst_y[inst_y > 0] /= np.amax(inst_y[inst_y > 0])

        ####
        x_map_box = x_map[inst_box[0]:inst_box[1], inst_box[2]:inst_box[3]]
        x_map_box[inst_map > 0] = inst_x[inst_map > 0]

        y_map_box = y_map[inst_box[0]:inst_box[1], inst_box[2]:inst_box[3]]
        y_map_box[inst_map > 0] = inst_y[inst_map > 0]

    hv_map = np.dstack([x_map, y_map])
    return hv_map


def gen_targets(ann):
    """
    Generate the targets for the network.
    
    Args:
        - ann: instance ID map
    Returns:
        - target_dict: dictionary containing the hv_maps (X, Y, 2) and np_map (X, Y)
    """
    hv_map = gen_instance_hv_map(ann)
    np_map = ann.copy()
    np_map[np_map > 0] = 1

    target_dict = {
        "hv_map": hv_map,
        "np_map": np_map,
    }

    return target_dict

In [None]:
IMAGE_PATH      = '../data/01_training_dataset_tif_ROIs'
GEOJSON_PATH    = '../data/01_training_dataset_geojson_nuclei'
PATCH_SIZE      = 512
BATCH_SIZE      = 1

dataset = PumaDataset(
    image_path=IMAGE_PATH,
    geojson_path=GEOJSON_PATH,
    input_shape=(
        PATCH_SIZE,
        PATCH_SIZE
    ),
    mask_shape=(
        PATCH_SIZE,
        PATCH_SIZE
    ),
    run_mode="test",
    augment=True
)

# Get a batch of data
img, ann = PumaDataset.load_data(dataset, 0)
print("Pre-augmentation")
print(f"Image shape: {img.shape}")
print(f"Annotation shape: {ann.shape}")

aug = A.Compose([A.CenterCrop(height=PATCH_SIZE, width=PATCH_SIZE)])
augmented = aug(image=img, mask=ann)
img = augmented["image"]
ann = augmented["mask"]
print("Post-augmentation")
print(f"Image shape: {img.shape}")
print(f"Annotation shape: {ann.shape}")

targets = gen_targets(ann[..., 1])
print(f"Target keys: {targets.keys()}")
print(f"Horizontal and vertical map shape: {targets['hv_map'].shape}")
print(f"Nuclei pixel map shape: {targets['np_map'].shape}")

fig, ax = plt.subplots(1, 4, figsize=(15, 5))
ax[0].imshow(img)
ax[0].set_title("Image")
ax[1].imshow(targets["hv_map"][..., 0])
ax[1].set_title("Horizontal map")
ax[2].imshow(targets["hv_map"][..., 1])
ax[2].set_title("Vertical map")
ax[3].imshow(targets["np_map"])
ax[3].set_title("Nuclei pixel map")