In [None]:
import numpy as np
import matplotlib.pyplot as plt
from monkey.config import TrainingIOConfig
from monkey.data.dataset import get_classification_dataloaders
from monkey.model.classification_model.efficientnet_b0 import (
    EfficientNet_B0,
)
from sklearn import metrics
from pprint import pprint
from monkey.model.utils import get_classification_metrics
import torch
from tqdm.autonotebook import tqdm
from monkey.data.data_utils import imagenet_denormalise

In [None]:
model = EfficientNet_B0(
    input_channels=3, num_classes=1, pretrained=False
)
# model = smp.Unet(
#     encoder_name="mit_b5",
#     encoder_weights=None,
#     decoder_attention_type="scse",
#     in_channels=3,
#     classes=1,
# )

val_fold = 4

checkpoint_path = f"/home/u1910100/Documents/Monkey/runs/cls/efficientnetb0/fold_{val_fold}/epoch_75.pth"
checkpoint = torch.load(checkpoint_path)
model.load_state_dict(checkpoint["model"])
model.eval()
model.to("cuda")

IOconfig = TrainingIOConfig(
    dataset_dir="/home/u1910100/Documents/Monkey/classification",
    save_dir=f"./",
)
IOconfig.set_image_dir(
    "/home/u1910100/Documents/Monkey/classification/patches"
)
IOconfig.set_mask_dir(
    "/home/u1910100/Documents/Monkey/classification/patches"
)

# Get dataloaders for task
train_loader, val_loader = get_classification_dataloaders(
    IOconfig,
    val_fold=val_fold,
    batch_size=32,
    do_augmentation=False,
    stack_mask=False,
)

In [None]:
visualization = False
pred_probs_list = []
true_labels_list = []

for data in tqdm(val_loader):
    file_ids = data["id"]

    images, true_labels = (
        data["image"].cuda().float(),
        data["label"].cpu().tolist(),
    )

    true_labels_list.extend(true_labels)

    image_np = images[0]
    image_np = image_np.cpu().numpy()

    image_np = np.moveaxis(image_np, 0, 2)
    image_np = imagenet_denormalise(image_np)

    with torch.no_grad():
        logits_pred = model(images)
        pred_probs = torch.sigmoid(logits_pred)

    pred_probs = torch.squeeze(pred_probs)
    pred_probs = pred_probs.cpu().tolist()
    pred_probs_list.extend(pred_probs)


pred_probs_list = np.array(pred_probs_list)
true_labels_list = np.array(true_labels_list)
fpr, tpr, thresholds = metrics.roc_curve(
    true_labels_list, pred_probs_list
)
roc_auc = metrics.auc(fpr, tpr)
display = metrics.RocCurveDisplay(
    fpr=fpr,
    tpr=tpr,
    roc_auc=roc_auc,
    estimator_name="cell classifier",
)
display.plot()
plt.show()

In [None]:
thresh = 0.5
pred_labels_list = np.where(pred_probs_list > thresh, 1, 0)
scores = get_classification_metrics(
    true_labels_list, pred_labels_list
)
pprint(scores)
metrics.ConfusionMatrixDisplay.from_predictions(
    true_labels_list,
    pred_labels_list,
    display_labels=["lymphocyte", "monocyte"],
)
plt.show()