In [None]:
# System
import sys
import os
import json

# Misc
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
from matplotlib.patches import Patch


# ML
import torch
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import StepLR

from shapely.geometry import shape

# Augmentations
import albumentations as A

# Custom
sys.path.append(os.path.abspath(".."))
from hover_net.dataloader.dataset import get_dataloader
from hover_net.datasets.puma_dataset import PumaDataset
from hover_net.models import HoVerNetExt

In [None]:
IMAGE_PATH      = '../data/01_training_dataset_tif_ROIs'
GEOJSON_PATH    = '../data/01_training_dataset_geojson_nuclei'
PATCH_SIZE      = 512
BATCH_SIZE      = 1

dataset = PumaDataset(
    image_path=IMAGE_PATH,
    geojson_path=GEOJSON_PATH,
    input_shape=(
        PATCH_SIZE,
        PATCH_SIZE
    ),
    mask_shape=(
        PATCH_SIZE,
        PATCH_SIZE
    ),
    run_mode="test",
    augment=True
)

img, ann = PumaDataset.load_data(dataset, 7)
print("Pre-augmentation")
print(f"Image shape: {img.shape}")
print(f"Annotation shape: {ann.shape}")
aug = A.Compose([A.CenterCrop(height=PATCH_SIZE, width=PATCH_SIZE)])
augmented = aug(image=img, mask=ann)
img = augmented["image"]
ann = augmented["mask"]

# Define the colors for each class
colors = ['white', 'teal', 'lightblue', 'lightgreen']  # Example colors for 0, 1, 2, 3
class_labels = ['Tumor', 'TILs', 'Other']  # Labels for each color
cmap = ListedColormap(colors)

fig, ax = plt.subplots(1, 2, figsize=(10, 5))
ax[0].imshow(img, interpolation='none')

ann_image = ann[:, :, 1]
#ax[1].imshow(ann_image, cmap=cmap)
contour_fill = ax[1].contourf(
    ann_image, 
    levels=[-0.5, 0.5, 1.5, 2.5, 3.5], 
    cmap=cmap, 
    alpha=0.8, 
    extent=(0, ann.shape[1], 0, ann.shape[0])
)
contour_lines = ax[1].contour(
    ann_image, 
    levels=[-0.5, 0.5, 1.5, 2.5, 3.5], 
    cmap=cmap, 
    linewidths=2,
    alpha=1, 
    extent=(0, ann.shape[1], 0, ann.shape[0])
)
ax[1].invert_yaxis()

ax[0].set_title("Image", fontsize=16)
ax[1].set_title("Ground Truth", fontsize=16)

ax[0].tick_params(
    axis='both',  # Apply to both x and y axes
    which='both',  # Apply to major and minor ticks
    bottom=False,  # Remove ticks at the bottom
    left=False,    # Remove ticks at the left
    labelbottom=False,  # Remove tick labels at the bottom
    labelleft=False     # Remove tick labels at the left
)

ax[1].tick_params(
    axis='both',  # Apply to both x and y axes
    which='both',  # Apply to major and minor ticks
    bottom=False,  # Remove ticks at the bottom
    left=False,    # Remove ticks at the left
    labelbottom=False,  # Remove tick labels at the bottom
    labelleft=False     # Remove tick labels at the left
)

ax[1].set_aspect('equal', 'box')

legend_elements = [Patch(facecolor=colors[i+1], label=class_labels[i]) for i in range(len(colors)-1)]
fig.legend(
    handles=legend_elements, 
    loc='lower center', 
    ncol=3,
    fontsize=16,
    bbox_to_anchor=(0.5, -0.1),  # Adjust this value to move the legend up or down
    frameon=True
)

# plt.subplots_adjust(top=0.3)

fig.tight_layout()

fig.savefig("../figures/dataset_example.pdf", bbox_inches='tight')

In [None]:
GEOJSON_PATH = '../data/01_training_dataset_geojson_nuclei'
geojsons = os.listdir(GEOJSON_PATH)

nuclei_count = {
    "primary":
    {
        "tumor": 0,
        "tils": 0,
        "other": 0
    },
    "metastatic":
    {
        "tumor": 0,
        "tils": 0,
        "other": 0
    }
}

for file in geojsons:
    with open(os.path.join(GEOJSON_PATH, file), encoding="utf-8") as f:
        geojson = json.load(f)
    for feature in geojson["features"]:
        geometry = shape(feature["geometry"])
        label = feature["properties"]["classification"]["name"]

        if geometry.geom_type == "Polygon":
            coords = geometry.exterior.coords
            if label == "nuclei_tumor":
                if "metastatic" in file:
                    nuclei_count["metastatic"]["tumor"] += 1
                else:
                    nuclei_count["primary"]["tumor"] += 1
            elif label in ["nuclei_lymphocyte", "nuclei_plasma_cell"]:
                if "metastatic" in file:
                    nuclei_count["metastatic"]["tils"] += 1
                else:
                    nuclei_count["primary"]["tils"] += 1
            else:
                if "metastatic" in file:
                    nuclei_count["metastatic"]["other"] += 1
                else:
                    nuclei_count["primary"]["other"] += 1

In [None]:
# Data preparation
categories = ['tumor', 'tils', 'other']
titles = ['Tumor', 'TILs', 'Other']
primary_counts = [nuclei_count["primary"][category] for category in categories]
metastatic_counts = [nuclei_count["metastatic"][category] for category in categories]

# Data preparation for pie chart
primary_sizes = [nuclei_count["primary"][category] for category in categories]
metastatic_sizes = [nuclei_count["metastatic"][category] for category in categories]

# Define consistent colors for both pie charts
colors = ['teal', 'lightblue', 'lightgreen']

# Pie chart for Primary
fig, ax = plt.subplots(1, 2, figsize=(10, 5))

# Primary pie chart
ax[0].pie(primary_sizes, autopct='%1.1f%%', startangle=90, colors=colors)
ax[0].set_title('Primary Nuclei Distribution', fontsize=16, y=0.92)

# Metastatic pie chart
ax[1].pie(metastatic_sizes, autopct='%1.1f%%', startangle=90, colors=colors)
ax[1].set_title('Metastatic Nuclei Distribution', fontsize=16, y=0.92)

# Adding a single legend at the bottom center
fig.legend(titles, loc="lower center", fontsize=16, ncol=3, bbox_to_anchor=(0.5, 0))

# Display the plot
fig.tight_layout()

fig.savefig("../figures/dataset_distribution.pdf", bbox_inches='tight')

In [None]:
total_primary_nuclei = sum(nuclei_count['primary'].values())
total_metastatic_nuclei = sum(nuclei_count['metastatic'].values())
print(f"Total primary nuclei: {total_primary_nuclei}")
print(f"Total metastatic nuclei: {total_metastatic_nuclei}")
print(f"Total: {total_primary_nuclei + total_metastatic_nuclei}")

total_tumor = nuclei_count["primary"]["tumor"] + nuclei_count["metastatic"]["tumor"]
total_tils = nuclei_count["primary"]["tils"] + nuclei_count["metastatic"]["tils"]
total_other = nuclei_count["primary"]["other"] + nuclei_count["metastatic"]["other"]

print(f"Total tumor: {total_tumor}")
print(f"Total TILs: {total_tils}")
print(f"Total other: {total_other}")

In [None]:
# Packages
import numpy as np
from scipy.ndimage import measurements

def get_bounding_box(img):
    """Get bounding box coordinate information."""
    rows = np.any(img, axis=1)
    cols = np.any(img, axis=0)
    rmin, rmax = np.where(rows)[0][[0, -1]]
    cmin, cmax = np.where(cols)[0][[0, -1]]
    # due to python indexing, need to add 1 to max
    # else accessing will be 1px in the box, not out
    rmax += 1
    cmax += 1
    return [rmin, rmax, cmin, cmax]


def fix_mirror_padding(ann):
    """
    Deal with duplicated instances due to mirroring in interpolation
    during shape augmentation (scale, rotation etc.).
    """
    current_max_id = np.amax(ann)
    inst_list = list(np.unique(ann))
    inst_list.remove(0)  # 0 is background
    for inst_id in inst_list:
        inst_map = np.array(ann == inst_id, np.uint8)
        remapped_ids = measurements.label(inst_map)[0]
        remapped_ids[remapped_ids > 1] += current_max_id
        ann[remapped_ids > 1] = remapped_ids[remapped_ids > 1]
        current_max_id = np.amax(ann)
    return ann


def gen_instance_hv_map(ann):
    """
    Generate the HoVer maps for each nuclear instance.

    Args:
        - ann: instance ID map
    Returns:
        - hv_map: horizontal and vertical distance maps
    """
    fixed_ann = fix_mirror_padding(ann)

    x_map = np.zeros(ann.shape[:2], dtype=np.float32)
    y_map = np.zeros(ann.shape[:2], dtype=np.float32)

    inst_list = list(np.unique(ann))
    inst_list.remove(0)  # 0 is background
    for inst_id in inst_list:
        inst_map = np.array(fixed_ann == inst_id, np.uint8)
        inst_box = get_bounding_box(inst_map)

        # expand the box by 2px
        # Because we first pad the ann at line 207, the bboxes
        # will remain valid after expansion
        inst_box[0] -= 2
        inst_box[2] -= 2
        inst_box[1] += 2
        inst_box[3] += 2

        inst_map = inst_map[inst_box[0]:inst_box[1], inst_box[2]:inst_box[3]]

        if inst_map.shape[0] < 2 or inst_map.shape[1] < 2:
            continue

        # instance center of mass, rounded to nearest pixel
        inst_com = list(measurements.center_of_mass(inst_map))

        inst_com[0] = int(inst_com[0] + 0.5)
        inst_com[1] = int(inst_com[1] + 0.5)

        inst_x_range = np.arange(1, inst_map.shape[1] + 1)
        inst_y_range = np.arange(1, inst_map.shape[0] + 1)
        # shifting center of pixels grid to instance center of mass
        inst_x_range -= inst_com[1]
        inst_y_range -= inst_com[0]

        inst_x, inst_y = np.meshgrid(inst_x_range, inst_y_range)

        # remove coord outside of instance
        inst_x[inst_map == 0] = 0
        inst_y[inst_map == 0] = 0
        inst_x = inst_x.astype("float32")
        inst_y = inst_y.astype("float32")

        # normalize min into -1 scale
        if np.min(inst_x) < 0:
            inst_x[inst_x < 0] /= -np.amin(inst_x[inst_x < 0])
        if np.min(inst_y) < 0:
            inst_y[inst_y < 0] /= -np.amin(inst_y[inst_y < 0])
        # normalize max into +1 scale
        if np.max(inst_x) > 0:
            inst_x[inst_x > 0] /= np.amax(inst_x[inst_x > 0])
        if np.max(inst_y) > 0:
            inst_y[inst_y > 0] /= np.amax(inst_y[inst_y > 0])

        ####
        x_map_box = x_map[inst_box[0]:inst_box[1], inst_box[2]:inst_box[3]]
        x_map_box[inst_map > 0] = inst_x[inst_map > 0]

        y_map_box = y_map[inst_box[0]:inst_box[1], inst_box[2]:inst_box[3]]
        y_map_box[inst_map > 0] = inst_y[inst_map > 0]

    hv_map = np.dstack([x_map, y_map])
    return hv_map


def gen_targets(ann):
    """
    Generate the targets for the network.
    
    Args:
        - ann: instance ID map
    Returns:
        - target_dict: dictionary containing the hv_maps (X, Y, 2) and np_map (X, Y)
    """
    hv_map = gen_instance_hv_map(ann)
    np_map = ann.copy()
    np_map[np_map > 0] = 1

    target_dict = {
        "hv_map": hv_map,
        "np_map": np_map,
    }

    return target_dict

In [None]:
IMAGE_PATH      = '../data/01_training_dataset_tif_ROIs'
GEOJSON_PATH    = '../data/01_training_dataset_geojson_nuclei'
PATCH_SIZE      = 512
BATCH_SIZE      = 1

dataset = PumaDataset(
    image_path=IMAGE_PATH,
    geojson_path=GEOJSON_PATH,
    input_shape=(
        PATCH_SIZE,
        PATCH_SIZE
    ),
    mask_shape=(
        PATCH_SIZE,
        PATCH_SIZE
    ),
    run_mode="test",
    augment=True
)

# Get a batch of data
img, ann = PumaDataset.load_data(dataset, 7)
print("Pre-augmentation")
print(f"Image shape: {img.shape}")
print(f"Annotation shape: {ann.shape}")

aug = A.Compose([A.CenterCrop(height=PATCH_SIZE, width=PATCH_SIZE)])
augmented = aug(image=img, mask=ann)
img = augmented["image"]
ann = augmented["mask"]
print("Post-augmentation")
print(f"Image shape: {img.shape}")
print(f"Annotation shape: {ann.shape}")

targets = gen_targets(ann[..., 1])
print(f"Target keys: {targets.keys()}")
print(f"Horizontal and vertical map shape: {targets['hv_map'].shape}")
print(f"Nuclei pixel map shape: {targets['np_map'].shape}")

colors = ['white', 'teal', 'lightblue', 'lightgreen']  # Example colors for 0, 1, 2, 3
cmap = ListedColormap(colors)

fig, ax = plt.subplots(1, 5, figsize=(15, 5))
ax[0].imshow(img)
ax[0].set_title("Image")
ax[1].imshow(targets["hv_map"][..., 0])
ax[1].set_title("Horizontal map")
ax[2].imshow(targets["hv_map"][..., 1])
ax[2].set_title("Vertical map")
ax[3].imshow(targets["np_map"])
ax[3].set_title("Nuclei pixel map")
ax[4].imshow(ann[..., 0], cmap=cmap)
ax[4].set_title("Type Map")

for a in ax:
    a.tick_params(
        axis='both',  # Apply to both x and y axes
        which='both',  # Apply to major and minor ticks
        bottom=False,  # Remove ticks at the bottom
        left=False,    # Remove ticks at the left
        labelbottom=False,  # Remove tick labels at the bottom
        labelleft=False     # Remove tick labels at the left
    )

fig.tight_layout()
fig.savefig("../figures/targets_example.pdf", bbox_inches='tight')

In [None]:
IMAGE_PATH      = '../data/01_training_dataset_tif_ROIs'
GEOJSON_PATH    = '../data/01_training_dataset_geojson_nuclei'
PATCH_SIZE      = 512
BATCH_SIZE      = 1

dataset = PumaDataset(
    image_path=IMAGE_PATH,
    geojson_path=GEOJSON_PATH,
    input_shape=(
        PATCH_SIZE,
        PATCH_SIZE
    ),
    mask_shape=(
        PATCH_SIZE,
        PATCH_SIZE
    ),
    run_mode="test",
    augment=True
)

img, ann = PumaDataset.load_data(dataset, 7)
print("Pre-augmentation")
print(f"Image shape: {img.shape}")
print(f"Annotation shape: {ann.shape}")
aug = A.Compose([A.CenterCrop(height=PATCH_SIZE, width=PATCH_SIZE)])
augmented = aug(image=img, mask=ann)
img = augmented["image"]
ann = augmented["mask"]

# Define the colors for each class
colors = ['white', 'teal', 'lightblue', 'lightgreen']  # Example colors for 0, 1, 2, 3
cmap = ListedColormap(colors)
class_labels = ['Tumor', 'TILs', 'Other']  # Labels for each color

fig, ax = plt.subplots(1, 1, figsize=(5, 5))

ann_image = ann[:, :, 1]
contour_fill = ax.contourf(
    ann_image, 
    levels=[-0.5, 0.5, 1.5, 2.5, 3.5], 
    cmap=cmap, 
    alpha=0.8, 
    extent=(0, ann.shape[1], 0, ann.shape[0])
)
contour_lines = ax.contour(
    ann_image, 
    levels=[-0.5, 0.5, 1.5, 2.5, 3.5], 
    cmap=cmap, 
    linewidths=2,
    alpha=1, 
    extent=(0, ann.shape[1], 0, ann.shape[0])
)
ax.invert_yaxis()
ax.set_title("Type Map", fontsize=16)

ax.tick_params(
    axis='both',  # Apply to both x and y axes
    which='both',  # Apply to major and minor ticks
    bottom=False,  # Remove ticks at the bottom
    left=False,    # Remove ticks at the left
    labelbottom=False,  # Remove tick labels at the bottom
    labelleft=False     # Remove tick labels at the left
)

ax.set_aspect('equal', 'box')
fig.tight_layout()

fig.savefig("../figures/data_tp.png", bbox_inches='tight', dpi=300)

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(5, 5))
ax.imshow(img)

ax.set_title("Image", fontsize=16)

ax.tick_params(
    axis='both',  # Apply to both x and y axes
    which='both',  # Apply to major and minor ticks
    bottom=False,  # Remove ticks at the bottom
    left=False,    # Remove ticks at the left
    labelbottom=False,  # Remove tick labels at the bottom
    labelleft=False     # Remove tick labels at the left
)

ax.set_aspect('equal', 'box')
fig.tight_layout()

fig.savefig("../figures/data_img.png", bbox_inches='tight', dpi=300)

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(5, 5))

ax.imshow(targets["hv_map"][..., 0], cmap='gist_rainbow')

ax.set_title("Horizontal Map", fontsize=16)

ax.tick_params(
    axis='both',  # Apply to both x and y axes
    which='both',  # Apply to major and minor ticks
    bottom=False,  # Remove ticks at the bottom
    left=False,    # Remove ticks at the left
    labelbottom=False,  # Remove tick labels at the bottom
    labelleft=False     # Remove tick labels at the left
)

ax.set_aspect('equal', 'box')
fig.tight_layout()

fig.savefig("../figures/data_h.png", bbox_inches='tight', dpi=300)

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(5, 5))

ax.imshow(targets["hv_map"][..., 1], cmap='gist_rainbow')

ax.set_title("Vertical Map", fontsize=16)

ax.tick_params(
    axis='both',  # Apply to both x and y axes
    which='both',  # Apply to major and minor ticks
    bottom=False,  # Remove ticks at the bottom
    left=False,    # Remove ticks at the left
    labelbottom=False,  # Remove tick labels at the bottom
    labelleft=False     # Remove tick labels at the left
)

ax.set_aspect('equal', 'box')
fig.tight_layout()

fig.savefig("../figures/data_v.png", bbox_inches='tight', dpi=300)

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(5, 5))

colors = ['white', 'red']  # Example colors for 0, 1, 2, 3
cmap = ListedColormap(colors)

ax.imshow(targets["np_map"], cmap=cmap)

ax.set_title("Nuclear Pixel Map", fontsize=16)

ax.tick_params(
    axis='both',  # Apply to both x and y axes
    which='both',  # Apply to major and minor ticks
    bottom=False,  # Remove ticks at the bottom
    left=False,    # Remove ticks at the left
    labelbottom=False,  # Remove tick labels at the bottom
    labelleft=False     # Remove tick labels at the left
)

ax.set_aspect('equal', 'box')
fig.tight_layout()

fig.savefig("../figures/data_np.png", bbox_inches='tight', dpi=300)

In [None]:
import pandas as pd
from scipy.interpolate import interp1d
from scipy.ndimage import gaussian_filter1d

default_path = '../data/validation_dice_default.csv'
optimized_path = '../data/validation_dice_optim.csv'

default = pd.read_csv(default_path)
optimized = pd.read_csv(optimized_path)

default_spline = interp1d(default['step'], default['val'], kind="cubic")
optimized_spline = interp1d(optimized['step'], optimized['val'], kind="cubic")

default_x = np.linspace(default['step'].min(), default['step'].max(), 300)
optimized_x = np.linspace(optimized['step'].min(), optimized['step'].max(), 300)
default_y = default_spline(default_x)
optimized_y = optimized_spline(optimized_x)

default_y = gaussian_filter1d(default_y, sigma=10)
optimized_y = gaussian_filter1d(optimized_y, sigma=10)

fig, ax = plt.subplots(1, 1, figsize=(10, 5))
ax.plot(default_x, default_y, label='Default')
ax.plot(optimized_x, optimized_y, label='Optimized')
ax.set_xlabel('Epoch', fontsize=16)
ax.set_ylabel('Dice Coefficient', fontsize=16)
ax.set_title('Validation Dice Coefficient', fontsize=16)
ax.legend(fontsize=16)

fig.tight_layout()
fig.savefig("../figures/validation_dice.png", bbox_inches='tight', dpi=300)

In [None]:
def __proc_np_hv(pred):
    """
    Process Nuclei Prediction with XY Coordinate Map.

    Args:
        - pred: np.array(H, W, C)
            C=0: nuclear pixel map,
            C=1: horizontal map,
            C=2: vertical map
    Returns:
        - proced_pred: np.array(H, W)
            Numbered map of all nuclear instances.
    
    Source: https://github.com/vqdang/hover_net
    """
    pred = np.array(pred, dtype=np.float32)

    blb_raw = pred[..., 0]      # Probability map
    h_dir_raw = pred[..., 1]    # x-map
    v_dir_raw = pred[..., 2]    # y-map

    # Processing
    blb = np.array(blb_raw >= 0.5, dtype=np.int32)
    blb = measurements.label(blb)[0]
    blb[blb > 0] = 1

    # Normalize direction maps
    h_dir = cv2.normalize(
        h_dir_raw,
        None,
        alpha=0,
        beta=1,
        norm_type=cv2.NORM_MINMAX,
        dtype=cv2.CV_32F
    )
    v_dir = cv2.normalize(
        v_dir_raw,
        None,
        alpha=0,
        beta=1,
        norm_type=cv2.NORM_MINMAX,
        dtype=cv2.CV_32F
    )

    # Sobel calculates the derivaties of the image
    # The derivatives will be high when there is a high change in intensity
    # i.e. when going from one nuclei to another
    # https://docs.opencv.org/4.x/d2/d2c/tutorial_sobel_derivatives.html
    sobelh = cv2.Sobel(h_dir, cv2.CV_64F, 1, 0, ksize=21)
    sobelv = cv2.Sobel(v_dir, cv2.CV_64F, 0, 1, ksize=21)

    # Normalize the sobel maps
    sobelh = 1 - (
        cv2.normalize(
            sobelh,
            None,
            alpha=0,
            beta=1,
            norm_type=cv2.NORM_MINMAX,
            dtype=cv2.CV_32F
        )
    )
    sobelv = 1 - (
        cv2.normalize(
            sobelv,
            None,
            alpha=0,
            beta=1,
            norm_type=cv2.NORM_MINMAX,
            dtype=cv2.CV_32F
        )
    )

    overall = np.maximum(sobelh, sobelv)
    overall = overall - (1 - blb)
    overall[overall < 0] = 0

    dist = (1.0 - overall) * blb
    # Nuclei values form mountains so inverse to get basins
    dist = -cv2.GaussianBlur(dist, (3, 3), 0)

    overall = np.array(overall >= 0.4, dtype=np.int32)

    marker = blb - overall
    marker[marker < 0] = 0
    marker = binary_fill_holes(marker).astype("uint8")
    kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
    marker = cv2.morphologyEx(marker, cv2.MORPH_OPEN, kernel)
    marker = measurements.label(marker)[0]

    proced_pred = watershed(dist, markers=marker, mask=blb)

    return proced_pred

def process(pred_map):
    """
    Post processing of the output of the HoVer-Net model.

    Args:
        - pred_map: np.array(H, W, C)
            Combined output from all three branches of the HoVer-Net model.
            C=0: type map,
            C=1: nuclear pixel map,
            C=2: horizontal map,
            C=3: vertical map
        - nr_types (int): number of types considered at output of nc branch

    Returns:
        - pred_inst:     pixel-wise nuclear instance prediction
        - pred_type_out: dictionary containing instance information
            bbox: bounding box of the instance
            centroid: centroid of the instance
            contour: contour of the instance
            type_prob: probability of the instance belonging to a type
            type: type of the instance
    
    Based on: https://github.com/vqdang/hover_net
    """
    # Extract type and instance maps
    # pred_type: np.array(H, W, 1)
    # pred_inst: np.array(H, W, 3) => np, horizontal, vertical
    pred_type = pred_map[..., :1]
    pred_inst = pred_map[..., 1:]
    pred_type = pred_type.astype(np.int32)

    pred_inst = np.squeeze(pred_inst)
    pred_inst = __proc_np_hv(pred_inst)

    inst_info_dict = None
    # Get unique instance ids w/o background
    inst_id_list = np.unique(pred_inst)[1:]
    inst_info_dict = {}

    # Loop over each instance id
    for inst_id in inst_id_list:
        # Create map with only the current instance
        inst_map = pred_inst == inst_id

        rmin, rmax, cmin, cmax = get_bounding_box(inst_map)
        inst_bbox = np.array([[rmin, cmin], [rmax, cmax]])
        inst_map = inst_map[
            inst_bbox[0][0]:inst_bbox[1][0],
            inst_bbox[0][1]:inst_bbox[1][1]
        ]

        inst_map = inst_map.astype(np.uint8)

        # Get the countour of the instance
        inst_contour = cv2.findContours(
            inst_map, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE
        )
        inst_contour = np.squeeze(inst_contour[0][0].astype("int32"))

        # Skip a countour if it has less than 3 points
        # Likely an artifact
        if inst_contour.shape[0] < 3:
            continue
        if len(inst_contour.shape) != 2:
            continue

        # Get the moment of the nuclei instance
        # Moment is the "center of mass" of the instance
        # https://docs.opencv.org/2.4/modules/imgproc/doc/structural_analysis_and_shape_descriptors.html
        inst_moment = cv2.moments(inst_map)

        # Create centroid of the instance from the moment
        inst_centroid = [
            (inst_moment["m10"] / inst_moment["m00"]),
            (inst_moment["m01"] / inst_moment["m00"]),
        ]

        inst_centroid = np.array(inst_centroid)
        inst_contour[:, 0] += inst_bbox[0][1]   # X
        inst_contour[:, 1] += inst_bbox[0][0]   # Y
        inst_centroid[0] += inst_bbox[0][1]     # X
        inst_centroid[1] += inst_bbox[0][0]     # Y
        inst_info_dict[inst_id] = {             # inst_id should start at 1
            "bbox": inst_bbox,
            "centroid": inst_centroid,
            "contour": inst_contour,
            "type_prob": None,
            "type": None,
        }

    # * Get class of each instance id, stored at index id-1
    for inst_id in list(inst_info_dict.keys()):
        rmin, cmin, rmax, cmax = (
            inst_info_dict[inst_id]["bbox"]
        ).flatten()
        inst_map_crop = pred_inst[rmin:rmax, cmin:cmax]
        inst_type_crop = pred_type[rmin:rmax, cmin:cmax]
        inst_map_crop = (
            inst_map_crop == inst_id
        )  # TODO: duplicated operation, may be expensive
        inst_type = inst_type_crop[inst_map_crop]
        type_list, type_pixels = np.unique(inst_type, return_counts=True)
        type_list = list(zip(type_list, type_pixels))
        type_list = sorted(type_list, key=lambda x: x[1], reverse=True)
        inst_type = type_list[0][0]
        if inst_type == 0:  # ! pick the 2nd most dominant if exist
            if len(type_list) > 1:
                inst_type = type_list[1][0]
        type_dict = {v[0]: v[1] for v in type_list}
        type_prob = type_dict[inst_type] / (np.sum(inst_map_crop) + 1.0e-6)
        inst_info_dict[inst_id]["type"] = int(inst_type)
        inst_info_dict[inst_id]["type_prob"] = float(type_prob)

    return pred_inst, inst_info_dict

import torch.nn.functional as F
from collections import OrderedDict
def infer_step(batch_imgs, model, device="cuda"):
    """
    Infer a batch of images using the HoVer-net model.

    Args:
        - batch_imgs: torch.Tensor(B, H, W, C)
        - model: torch.nn.Module
    
    Returns:
        - pred_output: np.array(B, H, W, C)
            C=0: type map,
            C=1: nuclear pixel map,
            C=2: horizontal map,
            C=3: vertical map
    
    Based on: https://github.com/Kaminyou/HoVer-Net-PyTorch
    """
    # Move images to gpu and permute to (B, C, H, W)
    patch_imgs_gpu = batch_imgs.to(device).type(torch.float32)
    patch_imgs_gpu = patch_imgs_gpu.permute(0, 3, 1, 2).contiguous()

    # Put model in eval mode
    model.eval()

    # ... And DONT compute gradients
    with torch.no_grad():
        # Forward pass
        # pred_dict: ordered dict with tp, np, hv keys
        pred_dict = model(patch_imgs_gpu)

        # Post-process the output
        # Permute back to (B, H, W, C)
        pred_list = []
        for k, v in pred_dict.items():
            pred_list.append([k, v.permute(0, 2, 3, 1).contiguous()])
        pred_dict = OrderedDict(pred_list)

        # Softmax the nuclear pixel map
        pred_dict["np"] = F.softmax(pred_dict["np"], dim=-1)[..., 1:]

        # Softmax the type map
        type_map = F.softmax(pred_dict["tp"], dim=-1)
        type_map = torch.argmax(type_map, dim=-1, keepdim=True)
        type_map = type_map.type(torch.float32)
        pred_dict["tp"] = type_map
        pred_output = torch.cat(list(pred_dict.values()), -1)

    return pred_output.cpu().numpy()

In [None]:
model = HoVerNetExt(
    backbone_name="resnext",
    pretrained_backbone=True,
    num_types=4
)
model.load_state_dict(torch.load('../pretrained/latest.pth', weights_only=True))
model.to('cuda')
model.eval()

IMAGE_PATH      = '../data/01_training_dataset_tif_ROIs'
GEOJSON_PATH    = '../data/01_training_dataset_geojson_nuclei'
PATCH_SIZE      = 512
BATCH_SIZE      = 1

dataset = PumaDataset(
    image_path=IMAGE_PATH,
    geojson_path=GEOJSON_PATH,
    input_shape=(
        PATCH_SIZE,
        PATCH_SIZE
    ),
    mask_shape=(
        PATCH_SIZE,
        PATCH_SIZE
    ),
    run_mode="test",
    augment=True
)

img, ann = PumaDataset.load_data(dataset, 7)
aug = A.Compose([A.CenterCrop(height=PATCH_SIZE, width=PATCH_SIZE)])
augmented = aug(image=img, mask=ann)
img = augmented["image"]
img = torch.tensor(img).unsqueeze(0)

# print(f"Dataloader returns: {batch.keys()}")

pred = infer_step(
            batch_imgs=img, model=model, device="cuda"
        )
print(f"Pred shape: {pred.shape}")

fig, ax = plt.subplots(1, 5, figsize=(25, 5))
ax[0].imshow(img[0])
ax[4].imshow(pred[0, :, :, 0])
ax[3].imshow(pred[0, :, :, 1])
ax[1].imshow(pred[0, :, :, 2])
ax[2].imshow(pred[0, :, :, 3])


ax[0].set_title("Image", fontsize=16)
ax[1].set_title("Horizontal map", fontsize=16)
ax[2].set_title("Vertical map", fontsize=16)
ax[3].set_title("Nuclei pixel map", fontsize=16)
ax[4].set_title("Type Map", fontsize=16)

for a in ax:
    a.tick_params(
        axis='both',  # Apply to both x and y axes
        which='both',  # Apply to major and minor ticks
        bottom=False,  # Remove ticks at the bottom
        left=False,    # Remove ticks at the left
        labelbottom=False,  # Remove tick labels at the bottom
        labelleft=False     # Remove tick labels at the left
    )

fig.tight_layout()
fig.savefig("../figures/preds_example.pdf", bbox_inches='tight')

In [None]:
import matplotlib.pyplot as plt
import matplotlib.patches as patches

colors = ['white', 'teal', 'lightblue', 'lightgreen']
np_color = ['white', 'red']
np_cmap = ListedColormap(np_color)
cmap = ListedColormap(colors)

pred_inst, inst_info_dict = process(pred[0])

print(f"Instances: {inst_info_dict.keys()}")
print(f"Instance info: {inst_info_dict[1].keys()}")

pred_inst = pred_inst > 0
fig, ax = plt.subplots(1, 1, figsize=(5, 5))

# Display the background image first
# ax.imshow(pred_inst, cmap=np_cmap)  # You can specify a colormap if desired
# ax.imshow(img[0])

# Loop through each instance and fill the contour
for instance in inst_info_dict:
    contour = inst_info_dict[instance]["contour"]
    color_type = inst_info_dict[instance]["type"]
    
    # Ensure the type index is within the colors list
    if color_type < len(colors):
        color = colors[color_type]
    else:
        color = 'gray'  # Default color if type index is out of range
    
    # Fill the contour
    ax.fill(contour[:, 0], contour[:, 1], color=color, alpha=1.0)

# Optionally, adjust plot aesthetics
plt.tight_layout()
plt.show()


In [None]:
IMAGE_PATH      = '../data/01_training_dataset_tif_ROIs'
GEOJSON_PATH    = '../data/01_training_dataset_geojson_nuclei'
PATCH_SIZE      = 512
BATCH_SIZE      = 1

dataset = PumaDataset(
    image_path=IMAGE_PATH,
    geojson_path=GEOJSON_PATH,
    input_shape=(
        PATCH_SIZE,
        PATCH_SIZE
    ),
    mask_shape=(
        PATCH_SIZE,
        PATCH_SIZE
    ),
    run_mode="test",
    augment=True
)

img, ann = PumaDataset.load_data(dataset, 7)
print("Pre-augmentation")
print(f"Image shape: {img.shape}")
print(f"Annotation shape: {ann.shape}")
aug = A.Compose([A.CenterCrop(height=PATCH_SIZE, width=PATCH_SIZE)])
augmented = aug(image=img, mask=ann)
img = augmented["image"]
ann = augmented["mask"]

# Define the colors for each class
colors = ['white', 'teal', 'lightblue', 'lightgreen']  # Example colors for 0, 1, 2, 3
class_labels = ['Tumor', 'TILs', 'Other']  # Labels for each color
cmap = ListedColormap(colors)

fig, ax = plt.subplots(1, 2, figsize=(10, 5))

background = np.zeros_like(ann[:, :, 1])
ax[0].imshow(background, cmap=cmap)
for instance in inst_info_dict:
    contour = inst_info_dict[instance]["contour"]
    color_type = inst_info_dict[instance]["type"]
    
    # Ensure the type index is within the colors list
    if color_type < len(colors):
        color = colors[color_type]
    else:
        color = 'gray'  # Default color if type index is out of range
    
    # Fill the contour
    ax[0].fill(contour[:, 0], contour[:, 1], color=color, alpha=1.0)

ann_image = ann[:, :, 1]
#ax[1].imshow(ann_image, cmap=cmap)
contour_fill = ax[1].contourf(
    ann_image, 
    levels=[-0.5, 0.5, 1.5, 2.5, 3.5], 
    cmap=cmap, 
    alpha=0.8, 
    extent=(0, ann.shape[1], 0, ann.shape[0])
)
contour_lines = ax[1].contour(
    ann_image, 
    levels=[-0.5, 0.5, 1.5, 2.5, 3.5], 
    cmap=cmap, 
    linewidths=0,
    alpha=1, 
    extent=(0, ann.shape[1], 0, ann.shape[0])
)
ax[1].invert_yaxis()

ax[0].set_title("Prediction", fontsize=16)
ax[1].set_title("Ground Truth", fontsize=16)

ax[0].tick_params(
    axis='both',  # Apply to both x and y axes
    which='both',  # Apply to major and minor ticks
    bottom=False,  # Remove ticks at the bottom
    left=False,    # Remove ticks at the left
    labelbottom=False,  # Remove tick labels at the bottom
    labelleft=False     # Remove tick labels at the left
)

ax[1].tick_params(
    axis='both',  # Apply to both x and y axes
    which='both',  # Apply to major and minor ticks
    bottom=False,  # Remove ticks at the bottom
    left=False,    # Remove ticks at the left
    labelbottom=False,  # Remove tick labels at the bottom
    labelleft=False     # Remove tick labels at the left
)

ax[1].set_aspect('equal', 'box')

legend_elements = [Patch(facecolor=colors[i+1], label=class_labels[i]) for i in range(len(colors)-1)]
fig.legend(
    handles=legend_elements, 
    loc='lower center', 
    ncol=3,
    fontsize=16,
    bbox_to_anchor=(0.5, -0.1),  # Adjust this value to move the legend up or down
    frameon=True
)

# plt.subplots_adjust(top=0.3)

fig.tight_layout()

fig.savefig("../figures/prediction.pdf", bbox_inches='tight')

In [None]:
import pandas as pd
from scipy.interpolate import interp1d
from scipy.ndimage import gaussian_filter1d

DEFAULT_DIR = '../data/resnext_optim_hyperparam'
OPTIM_DIR = '../data/resnext_optim_lossweight'

files = os.listdir(DEFAULT_DIR)
titles = ["NP Dice", "Tumor Dice", "TILs Dice", "Other Dice"]
colors = ['blue', 'red', 'green', 'purple']

fig, ax = plt.subplots(1, 1, figsize=(10, 5))

for i, file in enumerate(files):
    path = os.path.join(DEFAULT_DIR, file)
    df = pd.read_csv(path, header=None)
    spline = interp1d(df[0], df[2], kind="cubic")
    x = np.linspace(df[0].min(), df[0].max(), 300)
    y = spline(x)
    y = gaussian_filter1d(y, sigma=6)
    ax.plot(x, y, label=titles[i], linewidth=3, color=colors[i])

for i, file in enumerate(files):
    path = os.path.join(OPTIM_DIR, file)
    df = pd.read_csv(path, header=None)
    spline = interp1d(df[0], df[2], kind="cubic")
    x = np.linspace(df[0].min(), df[0].max(), 300)
    y = spline(x)
    y = gaussian_filter1d(y, sigma=6)
    ax.plot(x, y, label='Opt ' + titles[i], linewidth=3, color=colors[i], linestyle='dashed')

ax.set_xlabel('Epoch', fontsize=16)
ax.set_xlim(0, 50)
ax.set_ylim(0, 1)
ax.legend(loc='upper left', bbox_to_anchor=(1.05, 1), fontsize=12, borderaxespad=0)
ax.set_title('Before and After Optimization', fontsize=16)

fig.tight_layout()
fig.savefig("../figures/optimization_comparison.pdf", bbox_inches='tight')

In [None]:
import pandas as pd
from scipy.interpolate import interp1d
from scipy.ndimage import gaussian_filter1d

DEFAULT_DIR = '../data/resnext'
OPTIM_DIR = '../data/resnet'

files = os.listdir(DEFAULT_DIR)
titles = ["NP Dice", "Tumor Dice", "TILs Dice", "Other Dice"]
colors = ['blue', 'red', 'green', 'purple']

fig, ax = plt.subplots(1, 1, figsize=(10, 5))

for i, file in enumerate(files):
    path = os.path.join(DEFAULT_DIR, file)
    df = pd.read_csv(path, header=None)
    spline = interp1d(df[0], df[2], kind="cubic")
    x = np.linspace(df[0].min(), df[0].max(), 300)
    y = spline(x)
    y = gaussian_filter1d(y, sigma=6)
    ax.plot(x, y, label='ResNeXt50 ' + titles[i], linewidth=3, color=colors[i])

for i, file in enumerate(files):
    path = os.path.join(OPTIM_DIR, file)
    df = pd.read_csv(path, header=None)
    spline = interp1d(df[0], df[2], kind="cubic")
    x = np.linspace(df[0].min(), df[0].max(), 300)
    y = spline(x)
    y = gaussian_filter1d(y, sigma=6)
    ax.plot(x, y, label='ResNet50 ' + titles[i], linewidth=3, color=colors[i], linestyle='dashed')

ax.set_xlabel('Epoch', fontsize=16)
ax.set_xlim(0, 50)
ax.set_ylim(0, 1)
ax.legend(loc='upper left', bbox_to_anchor=(1.05, 1), fontsize=12, borderaxespad=0)
ax.set_title('Before and After Optimization', fontsize=16)

fig.tight_layout()
fig.savefig("../figures/resnet_vs_resnext.pdf", bbox_inches='tight')

In [None]:
import pandas as pd
from scipy.interpolate import interp1d
from scipy.ndimage import gaussian_filter1d

DEFAULT_DIR = '../data/logcosh'
OPTIM_DIR = '../data/resnext_optim_lossweight'

files = os.listdir(DEFAULT_DIR)
titles = ["NP Dice", "Tumor Dice", "TILs Dice", "Other Dice"]
colors = ['blue', 'red', 'green', 'purple']

fig, ax = plt.subplots(1, 1, figsize=(10, 5))

for i, file in enumerate(files):
    path = os.path.join(DEFAULT_DIR, file)
    df = pd.read_csv(path, header=None)
    spline = interp1d(df[0], df[2], kind="cubic")
    x = np.linspace(df[0].min(), df[0].max(), 300)
    y = spline(x)
    y = gaussian_filter1d(y, sigma=6)
    ax.plot(x, y, label=titles[i], linewidth=3, color=colors[i])

for i, file in enumerate(files):
    path = os.path.join(OPTIM_DIR, file)
    df = pd.read_csv(path, header=None)
    spline = interp1d(df[0], df[2], kind="cubic")
    x = np.linspace(df[0].min(), df[0].max(), 300)
    y = spline(x)
    y = gaussian_filter1d(y, sigma=6)
    ax.plot(x, y, label='Opt ' + titles[i], linewidth=3, color=colors[i], linestyle='dashed')

ax.set_xlabel('Epoch', fontsize=16)
ax.set_xlim(0, 50)
ax.set_ylim(0, 1)
ax.legend(loc='upper left', bbox_to_anchor=(1.05, 1), fontsize=12, borderaxespad=0)
ax.set_title('Before and After Optimization', fontsize=16)

fig.tight_layout()
#fig.savefig("../figures/optimization_comparison.pdf", bbox_inches='tight')