Notebook to evaluate multihead efficientunetb0

In [None]:
import os
from pprint import pprint
import numpy as np
import torch
import sys
import matplotlib.pyplot as plt

sys.path.append("../")

from monkey.config import TrainingIOConfig
from monkey.data.dataset import get_detection_dataloaders
from monkey.model.efficientunetb0.architecture import (
    get_multihead_efficientunet,
)
from monkey.model.loss_functions import get_loss_function
from monkey.model.utils import get_activation_function
from tqdm.autonotebook import tqdm
from monkey.model.utils import (
    get_multiclass_patch_F1_score_batch,
    get_patch_F1_score_batch,
)
from monkey.model.loss_functions import dice_coeff
from monkey.data.data_utils import imagenet_denormalise

run_config = {
    "project_name": "Monkey_Multiclass_Detection",
    "model_name": "multihead_unet",
    "out_channels": [2, 1, 1],
    "val_fold": 1,  # [1-5]
    "batch_size": 16,
    "optimizer": "AdamW",
    "learning_rate": 0.0004,
    "weight_decay": 0.01,
    "epochs": 50,
    "loss_function": {
        "head_1": "BCE_Dice",
        "head_2": "BCE_Dice",
        "head_3": "BCE_Dice",
    },
    "do_augmentation": False,
    "activation_function": {
        "head_1": "sigmoid",
        "head_2": "sigmoid",
        "head_3": "sigmoid",
    },
    "use_nuclick_masks": True,
}

IOconfig = TrainingIOConfig(
    dataset_dir="/home/u1910100/Documents/Monkey/patches_256",
)

IOconfig.set_mask_dir(
    "/home/u1910100/Documents/Monkey/patches_256/annotations/nuclick_masks_processed"
)


# Create model
model = get_multihead_efficientunet(
    out_channels=run_config["out_channels"], pretrained=False
)
checkpoint_path = f"/home/u1910100/Documents/Monkey/runs/cell_multiclass_det/multihead_unet_experiment/fold_{run_config['val_fold']}/epoch_50.pth"
checkpoint = torch.load(checkpoint_path)
model.load_state_dict(checkpoint["model"])
model.eval()
model.to("cuda")


train_loader, val_loader = get_detection_dataloaders(
    IOconfig,
    val_fold=run_config["val_fold"],
    dataset_name="multitask",
    batch_size=run_config["batch_size"],
    do_augmentation=run_config["do_augmentation"],
    use_nuclick_masks=run_config["use_nuclick_masks"],
)


activation_dict = {
    "head_1": get_activation_function(
        run_config["activation_function"]["head_1"]
    ),
    "head_2": get_activation_function(
        run_config["activation_function"]["head_2"]
    ),
    "head_3": get_activation_function(
        run_config["activation_function"]["head_3"]
    ),
}

In [None]:
def multihead_unet_post_process(
    logits_pred: torch.Tensor,
    activation_dict: dict[str, torch.nn.Module],
    thresholds: list = [0.3, 0.3, 0.3, 0.5],
) -> dict[str, np.ndarray]:
    """
    Args:
        Thresholds: [overall, lymph, mono, contour]
    """
    head_1_logits = logits_pred["head_1"]
    head_2_logits = logits_pred["head_2"]
    head_3_logits = logits_pred["head_3"]
    pred_probs_1 = activation_dict["head_1"](head_1_logits)
    pred_probs_2 = activation_dict["head_2"](head_2_logits)
    pred_probs_3 = activation_dict["head_3"](head_3_logits)

    contour_pred_binary = (
        (pred_probs_1[:, 1:2, :, :] > thresholds[3])
        .float()
        .numpy(force=True)
    )

    overall_pred_binary = (
        (pred_probs_1[:, 0:1, :, :] > thresholds[0])
        .float()
        .numpy(force=True)
    )
    lymph_pred_binary = (
        (pred_probs_2 > thresholds[1]).float().numpy(force=True)
    )
    mono_pred_binary = (
        (pred_probs_3 > thresholds[2]).float().numpy(force=True)
    )

    overall_pred_binary[contour_pred_binary == 1] = 0
    lymph_pred_binary[contour_pred_binary == 1] = 0
    mono_pred_binary[contour_pred_binary == 1] = 0

    processed_masks = {
        "inflamm_mask": overall_pred_binary[:, 0, :, :],
        "contour_mask": contour_pred_binary[:, 0, :, :],
        "lymph_mask": lymph_pred_binary[:, 0, :, :],
        "mono_mask": mono_pred_binary[:, 0, :, :],
        "inflamm_prob": pred_probs_1[:, 0, :, :].numpy(force=True),
        "contour_prob": pred_probs_1[:, 1, :, :].numpy(force=True),
        "lymph_prob": pred_probs_2[:, 0, :, :].numpy(force=True),
        "mono_prob": pred_probs_3[:, 0, :, :].numpy(force=True),
    }
    return processed_masks

In [None]:
running_overall_score = 0.0
running_lymph_score = 0.0
running_mono_score = 0.0
running_contour_score = 0.0

for i, data in enumerate(
    tqdm(val_loader, desc="validation", leave=False)
):
    images = data["image"].cuda().float()
    inflamm_true_masks = data["binary_mask"][:, 0, :, :].numpy(
        force=True
    )
    contour_true_masks = (
        data["contour_mask"][:, 0, :, :].cpu().float()
    )
    lymph_true_masks = data["class_mask"][:, 0, :, :].numpy(
        force=True
    )
    mono_true_masks = data["class_mask"][:, 1, :, :].numpy(force=True)

    with torch.no_grad():
        logits_pred = model(images)
        processed_output = multihead_unet_post_process(
            logits_pred,
            activation_dict,
            thresholds=[0.3, 0.3, 0.3, 0.5],
        )

    # Compute detection F1 score
    overall_metrics = get_patch_F1_score_batch(
        processed_output["inflamm_mask"],
        inflamm_true_masks,
        processed_output["inflamm_prob"],
    )
    lymph_metrics = get_patch_F1_score_batch(
        processed_output["lymph_mask"],
        lymph_true_masks,
        processed_output["lymph_prob"],
    )
    mono_metrics = get_patch_F1_score_batch(
        processed_output["mono_mask"],
        mono_true_masks,
        processed_output["mono_prob"],
    )
    contour_dice = dice_coeff(
        torch.from_numpy(processed_output["contour_mask"]),
        contour_true_masks,
        reduce_batch_first=True,
    )

    running_overall_score += (overall_metrics["F1"]) * images.size(0)
    running_lymph_score += (lymph_metrics["F1"]) * images.size(0)
    running_mono_score += (mono_metrics["F1"]) * images.size(0)
    running_contour_score += contour_dice.item() * images.size(0)

results = {
    "overall_F1": running_overall_score / len(val_loader.sampler),
    "lymph_F1": running_lymph_score / len(val_loader.sampler),
    "mono_F1": running_mono_score / len(val_loader.sampler),
    "contour_dice": running_contour_score / len(val_loader.sampler),
}

pprint(results)

Visualization

In [None]:
train_loader, val_loader = get_detection_dataloaders(
    IOconfig,
    val_fold=run_config["val_fold"],
    dataset_name="multitask",
    batch_size=1,
    do_augmentation=run_config["do_augmentation"],
    use_nuclick_masks=run_config["use_nuclick_masks"],
)

for i, data in enumerate(
    tqdm(val_loader, desc="validation", leave=False)
):

    fig, axes = plt.subplots(3, 3, figsize=(10, 10))

    images = data["image"].cuda().float()
    inflamm_true_masks = data["binary_mask"][:, 0, :, :].numpy(
        force=True
    )
    contour_true_masks = (
        data["contour_mask"][:, 0, :, :].cpu().float()
    )
    lymph_true_masks = data["class_mask"][:, 0, :, :].numpy(
        force=True
    )
    mono_true_masks = data["class_mask"][:, 1, :, :].numpy(force=True)

    with torch.no_grad():
        logits_pred = model(images)
        processed_output = multihead_unet_post_process(
            logits_pred,
            activation_dict,
            thresholds=[0.3, 0.3, 0.3, 0.5],
        )

    images = images.numpy(force=True)[0]
    images = np.moveaxis(images, 0, 2)
    images = imagenet_denormalise(images)

    axes[0][0].imshow(images)
    axes[0][0].set_title("Image")

    axes[0][1].imshow(inflamm_true_masks[0], cmap="gray")
    axes[0][1].set_title("True Overall")

    axes[0][2].imshow(contour_true_masks[0], cmap="gray")
    axes[0][2].set_title("True Contour")

    axes[1][0].imshow(lymph_true_masks[0], cmap="gray")
    axes[1][0].set_title("True Lymph")

    axes[1][1].imshow(mono_true_masks[0], cmap="gray")
    axes[1][1].set_title("True Mono")

    axes[1][2].imshow(
        processed_output["contour_mask"][0], cmap="gray"
    )
    axes[1][2].set_title("Pred Contour")

    axes[2][0].imshow(processed_output["lymph_mask"][0], cmap="gray")
    axes[2][0].set_title("Pred Lymph")

    axes[2][1].imshow(processed_output["mono_mask"][0], cmap="gray")
    axes[2][1].set_title("Pred Mono")

    axes[2][2].imshow(
        processed_output["inflamm_mask"][0], cmap="gray"
    )
    axes[2][2].set_title("Pred Overall")

    for ax in axes.ravel():
        ax.axis("off")
    plt.show()

    if i > 5:
        break