# Hover-Net Example

## Imports

In [None]:
# System
import sys
import os

# Misc
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt

# ML
import torch
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import StepLR

# Augmentations
import albumentations as A
from albumentations.pytorch import ToTensorV2
import random

# Custom
sys.path.append(os.path.abspath(".."))
from hover_net.dataloader.dataset import get_dataloader
from hover_net.tools.utils import (dump_yaml, read_yaml)
from hover_net.datasets.puma_dataset import PumaDataset
from hover_net.dataloader.preprocessing import cropping_center, gen_targets

In [None]:
config = read_yaml('../configs/config.yaml')
config['DATA']['IMAGE_PATH'] = '../data/01_training_dataset_tif_ROIs'
config['DATA']['GEOJSON_PATH'] = '../data/01_training_dataset_geojson_nuclei'
config['DATA']['PATCH_SIZE'] = 1024
config['TRAIN']['BATCH_SIZE'] = 1

val_dataloader = get_dataloader(
    dataset_type="puma",
    image_path=config["DATA"]["IMAGE_PATH"],
    geojson_path=config["DATA"]["GEOJSON_PATH"],
    with_type=True,
    input_shape=(
        config["DATA"]["PATCH_SIZE"],
        config["DATA"]["PATCH_SIZE"]
    ),
    mask_shape=(
        config["DATA"]["PATCH_SIZE"],
        config["DATA"]["PATCH_SIZE"]
    ),
    batch_size=config["TRAIN"]["BATCH_SIZE"],
    run_mode="test",
    augment=False
)

# Get a batch of validation data
feed_dict = next(iter(val_dataloader))
print(f"Dataloader returns: {feed_dict.keys()}")
print(f"Image shape: {feed_dict['img'].shape}")
print(f"NP Map shape: {feed_dict['np_map'].shape}")

img = feed_dict["img"]
np_map = feed_dict["np_map"]

fig, axs = plt.subplots(1, 2, figsize=(10, 5))
axs[0].imshow(img[0])
axs[1].imshow(np_map[0])

In [None]:
dataset = PumaDataset(
    image_path=config["DATA"]["IMAGE_PATH"],
    geojson_path=config["DATA"]["GEOJSON_PATH"],
    with_type=True,
    input_shape=(
        config["DATA"]["PATCH_SIZE"],
        config["DATA"]["PATCH_SIZE"]
    ),
    mask_shape=(
        config["DATA"]["PATCH_SIZE"],
        config["DATA"]["PATCH_SIZE"]
    ),
    run_mode="test",
    augment=False,
)

img, ann = PumaDataset.load_data(dataset, 0)
print(f"Image shape: {img.shape}")
print(f"Annotation shape: {ann.shape}")

fig, axs = plt.subplots(1, 2, figsize=(10, 5))
axs[0].imshow(img)
axs[1].imshow(ann[..., 0])

inst_map = ann[..., 0]
target_dict = gen_targets(inst_map, (1024, 1024))
print(f"Target dict: {target_dict.keys()}")
print(f"hv_map shape: {target_dict['hv_map'].shape}")
print(f"np_map shape: {target_dict['np_map'].shape}")

In [None]:
from collections import OrderedDict

import torch
import torch.nn.functional as F

from hover_net.models.loss import dice_loss, mse_loss, msge_loss, xentropy_loss

def valid_step(
    batch_data,
    model,
    loss_opts,
    device="cuda",
):
    """
    Validate the hover-net with Neptune logging.
    """
    # Put model in evaluation
    model.eval()

    loss_func_dict = {
        "bce": xentropy_loss,
        "dice": dice_loss,
        "mse": mse_loss,
        "msge": msge_loss,
    }

    result_dict = {"EMA": {}}

    def track_value(name, value):
        result_dict["EMA"].update({name: value})

    imgs = batch_data["img"]
    true_np = batch_data["np_map"]
    true_hv = batch_data["hv_map"]

    imgs = imgs.to(device).type(torch.float32).permute(0, 3, 1, 2).contiguous()

    true_np = true_np.to(device).type(torch.int64)
    true_hv = true_hv.to(device).type(torch.float32)

    true_np_onehot = F.one_hot(true_np, num_classes=2).type(torch.float32)
    true_dict = {
        "np": true_np_onehot, 
        "hv": true_hv,
        }

    if model.num_types is not None:
        true_tp = batch_data["tp_map"].to(device).type(torch.int64)
        true_tp_onehot = F.one_hot(true_tp, num_classes=model.num_types).type(torch.float32)
        true_dict["tp"] = true_tp_onehot

    # --------------------------------------------------------------
    with torch.no_grad():  # dont compute gradient
        pred_dict = model(imgs)
        pred_dict = OrderedDict(
            [[k, v.permute(0, 2, 3, 1).contiguous()] for k, v in pred_dict.items()]
        )
        pred_dict["np"] = F.softmax(pred_dict["np"], dim=-1)
        if model.num_types is not None:
            pred_dict["tp"] = F.softmax(pred_dict["tp"], dim=-1)

        loss = 0
        for branch_name in pred_dict.keys():
            for loss_name, loss_weight in loss_opts[branch_name].items():
                loss_func = loss_func_dict[loss_name]
                loss_args = [true_dict[branch_name], pred_dict[branch_name]]
                if loss_name == "msge":
                    loss_args.extend([true_np_onehot[..., 1], device])
                term_loss = loss_func(*loss_args)
                track_value(f"loss_{branch_name}_{loss_name}", term_loss.cpu().item())
                loss += loss_weight * term_loss

        track_value("overall_loss", loss.cpu().item())

    return result_dict

In [None]:
from hover_net.models import HoVerNetExt

loss_opts = {
    "np": {"bce": 1, "dice": 1},
    "hv": {"mse": 1, "msge": 1},
    "tp": {"bce": 1, "dice": 1},
}

config = read_yaml('../configs/config.yaml')
config['DATA']['IMAGE_PATH'] = '../data/01_training_dataset_tif_ROIs'
config['DATA']['GEOJSON_PATH'] = '../data/01_training_dataset_geojson_nuclei'
config['DATA']['PATCH_SIZE'] = 256
config['TRAIN']['BATCH_SIZE'] = 2

model = HoVerNetExt(
    backbone_name=config["MODEL"]["BACKBONE"],
    pretrained_backbone=config["MODEL"]["PRETRAINED"],
    num_types=config["MODEL"]["NUM_TYPES"]
)
model.load_state_dict(torch.load('../pretrained/latest.pth', weights_only=True))
model.to(config["TRAIN"]["DEVICE"])
model.eval()

val_dataloader = get_dataloader(
    dataset_type="puma",
    image_path=config["DATA"]["IMAGE_PATH"],
    geojson_path=config["DATA"]["GEOJSON_PATH"],
    with_type=True,
    input_shape=(
        config["DATA"]["PATCH_SIZE"],
        config["DATA"]["PATCH_SIZE"]
    ),
    mask_shape=(
        config["DATA"]["PATCH_SIZE"],
        config["DATA"]["PATCH_SIZE"]
    ),
    batch_size=config["TRAIN"]["BATCH_SIZE"],
    run_mode="test",
    augment=True
)

# Get a batch of validation data
feed_dict = next(iter(val_dataloader))
print(f"Dataloader returns: {feed_dict.keys()}")

for i, batch_data in enumerate(val_dataloader):
    result_dict = valid_step(batch_data, model, loss_opts=loss_opts)
    print(f"Batch {i} - Result dict: {result_dict}")