# BBBC039 Segmentation Training Pipeline


In [None]:
import wandb

api = wandb.Api()
# run = api.run(f"joshbercich/BBBC039-Segmentation/jrhkyfjn")

In [None]:
# run = api.run("joshbercich/BBBC039-Segmentation/qjikuwvn")
# run.config["rand_augment"] = False
# run.config["rand_augment_n"] = None
# run.config["rand_augment_m"] = None
# run.update()

In [None]:
import itertools
from datetime import datetime

import lightning.pytorch as pl
import torch
import torchvision
from lightning.pytorch import LightningModule, Trainer, seed_everything
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from lightning.pytorch.loggers.wandb import WandbLogger
from torch.utils.data import DataLoader
from torchmetrics.classification import (
    Dice,
    MulticlassAccuracy,
    MulticlassF1Score,
    MultilabelCoverageError,
    MulticlassJaccardIndex,
)
from torchvision.transforms import InterpolationMode, Resize

import wandb
from dataset import BBBC039Segmentation

api = wandb.Api()


class UNet(LightningModule):
    """
    TODO
    """

    def __init__(
        self,
        learning_rate: float = 1e-4,
        in_channels: int = 3,
        out_channels: int = 3,
        init_features: int = 32,
        pretrained: bool = False,
        rand_augment: bool = False,
        rand_augment_n: int = None,
        rand_augment_m: float = None,
    ):
        super().__init__()
        self.rand_augment = rand_augment
        self.rand_augment_n = rand_augment_n
        self.rand_augment_m = rand_augment_m
        if self.rand_augment:
            assert self.rand_augment_n is not None and self.rand_augment_m is not None

        # Set hyperparameters
        self.learning_rate = learning_rate

        # Construct the model with the given parameters
        torch.hub._validate_not_a_forked_repo = lambda a, b, c: True
        self.net = torch.hub.load(
            "mateuszbuda/brain-segmentation-pytorch",
            "unet",
            in_channels=in_channels,
            out_channels=out_channels,
            init_features=init_features,
            pretrained=pretrained,
        )

        # Set the model loss function
        self.loss = torch.nn.functional.cross_entropy

        # Calculate the model number of hyperparameters
        self.num_parameters = sum(layer.numel() for layer in self.net.parameters())
        self.log("num_parameters", self.num_parameters, logger=True)

        # Define modular metrics
        self.train_f1 = MulticlassF1Score(num_classes=3)
        self.train_coverage = MultilabelCoverageError(num_labels=3)
        self.train_accuracy = MulticlassAccuracy(num_classes=3)
        self.train_dice = Dice(num_classes=3)
        self.val_f1 = MulticlassF1Score(num_classes=3)
        self.val_coverage = MultilabelCoverageError(num_labels=3)
        self.val_accuracy = MulticlassAccuracy(num_classes=3)
        self.val_dice = Dice(num_classes=3)
        self.test_f1 = MulticlassF1Score(num_classes=3)
        self.test_f1_binary = MulticlassF1Score(num_classes=3, ignore_index=2)
        self.test_coverage = MultilabelCoverageError(num_labels=3)
        self.test_coverage_binary = MultilabelCoverageError(
            num_labels=3, ignore_index=2
        )
        self.test_accuracy = MulticlassAccuracy(num_classes=3)
        self.test_accuracy_binary = MulticlassAccuracy(num_classes=3, ignore_index=2)
        self.test_dice = Dice(num_classes=3)
        self.test_dice_binary = Dice(num_classes=3, ignore_index=2)
        self.test_jid = MulticlassJaccardIndex(num_classes=3)
        self.test_jid_binary = MulticlassJaccardIndex(num_classes=3, ignore_index=2)

        # Shorthand for logger options
        self.train_log_opts = {"on_step": True, "on_epoch": True, "logger": True}
        self.val_log_opts = {"on_step": False, "on_epoch": True, "logger": True}
        self.test_log_opts = {"on_step": False, "on_epoch": True, "logger": True}

        # Save the hyperparameters for this model
        self.save_hyperparameters()

    def training_step(self, batch: tuple, batch_idx: int):
        # Load and inference
        images, labels = batch
        output = self.net(images)
        pred, labs = output.argmax(1), labels.int().argmax(1)
        loss = self.loss(output, labels)

        # Update metrics
        self.train_f1(pred, labs)
        self.train_coverage(output, labels)
        self.train_accuracy(pred, labs)
        self.train_dice(pred, labs)
        self.log("train_loss", loss, prog_bar=True, **self.train_log_opts)
        self.log("train_f1", self.train_f1, prog_bar=True, **self.train_log_opts)
        self.log("train_coverage", self.train_coverage, **self.train_log_opts)
        self.log("train_accuracy", self.train_accuracy, **self.train_log_opts)
        self.log("train_dice", self.train_dice, **self.train_log_opts)

        return loss

    def validation_step(self, batch: tuple, batch_idx: int):
        # Load and inference
        images, labels = batch
        output = self.net(images)
        pred, labs = output.argmax(1), labels.int().argmax(1)
        loss = self.loss(output, labels)

        # Update metrics
        self.val_f1(pred, labs)
        self.val_coverage(output, labels)
        self.val_accuracy(pred, labs)
        self.val_dice(pred, labs)
        self.log("val_loss", loss, **self.val_log_opts)
        self.log("val_f1", self.val_f1, prog_bar=True, **self.val_log_opts)
        self.log("val_coverage", self.val_coverage, **self.val_log_opts)
        self.log("val_accuracy", self.val_accuracy, **self.val_log_opts)
        self.log("val_dice", self.val_dice, **self.val_log_opts)
        return (output, labels)

    def on_validation_batch_end(self, output, batch: tuple, batch_idx: int):
        if batch_idx == 0:
            pred, labs = output
            self.logger.log_image("prediction", [wandb.Image(pred[0])])
            self.logger.log_image("true_label", [wandb.Image(labs[0])])

    def test_step(self, batch: tuple, batch_idx: int):
        # Load and inference
        images, labels = batch
        output = self.net(images)
        pred, labs = output.argmax(1), labels.int().argmax(1)
        loss = self.loss(output, labels)

        # Update metrics
        self.test_f1(pred, labs)
        self.test_coverage(output, labels)
        self.test_accuracy(pred, labs)
        self.test_dice(pred, labs)
        self.test_f1_binary(pred, labs)
        self.test_coverage_binary(output, labels)
        self.test_accuracy_binary(pred, labs)
        self.test_dice_binary(pred, labs)
        self.test_jid(pred, labs)
        self.test_jid_binary(pred, labs)
        self.log("test_loss", loss, **self.test_log_opts)
        self.log("test_f1", self.test_f1, **self.test_log_opts)
        self.log("test_f1_binary", self.test_f1_binary, **self.test_log_opts)
        self.log("test_coverage", self.test_coverage, **self.test_log_opts)
        self.log(
            "test_coverage_binary", self.test_coverage_binary, **self.test_log_opts
        )
        self.log("test_accuracy", self.test_accuracy, **self.test_log_opts)
        self.log(
            "test_accuracy_binary", self.test_accuracy_binary, **self.test_log_opts
        )
        self.log("test_dice", self.test_dice, **self.test_log_opts)
        self.log("test_dice_binary", self.test_dice_binary, **self.test_log_opts)
        self.log("test_jid", self.test_jid, **self.test_log_opts)
        self.log("test_jid_binary", self.test_jid_binary, **self.test_log_opts)

    def predict_step(self, batch: tuple, batch_idx: int):
        return self.net(batch[0]).argmax(1)

    def forward(self, x):
        return self.net(x)

    def configure_optimizers(self):
        optimiser = torch.optim.Adam(self.net.parameters(), lr=self.learning_rate)
        return optimiser

In [None]:
import os


class ModelEvaluation:
    @staticmethod
    def collect_checkpoint_files() -> list[str]:
        checkpoint_files = []
        for root, dirs, files in os.walk("."):
            for file in files:
                if file.endswith(".ckpt"):
                    checkpoint_files.append(os.path.join(root, file))
        return checkpoint_files


model = UNet.load_from_checkpoint(ModelEvaluation.collect_checkpoint_files()[0])

In [None]:
# transform = Resize((256, 256), InterpolationMode.NEAREST)
# test_dataset = BBBC039Segmentation("../datasets/", subset="test", transform=transform)
# test_loader = DataLoader(test_dataset, shuffle=False, batch_size=8)
# trainer = Trainer()

# for i, checkpoint in enumerate(ModelEvaluation.collect_checkpoint_files()):
#     if i < 5:
#         continue
#     run_title = checkpoint.split(os.sep)[-3]
#     run = api.run(f"joshbercich/BBBC039-Segmentation/{run_title}")
#     model = UNet.load_from_checkpoint(checkpoint)
#     result = trainer.test(model, test_loader)[0]
#     print(run.name, result)
#     for k, v in result.items():
#         run.summary[k] = v
#     run.update()

In [None]:
from ultralytics import SAM

# Load the model
model = SAM("mobile_sam.pt")

# Predict a segment based on a point prompt
# model.predict("ultralytics/assets/zidane.jpg", points=[900, 370], labels=[1])

In [None]:
dataset = BBBC039Segmentation("../datasets/")
x, y = dataset[0]

In [None]:
from ultralytics import YOLO

# Load a COCO-pretrained YOLOv8n model
model = YOLO("yolov8n-seg.pt")
# results = model.train(data='coco8.yaml', epochs=100, imgsz=640)

z = model(y.unsqueeze(0))

In [None]:
from torchvision.transforms import *
import cv2
import numpy as np
import matplotlib.pyplot as plt

In [None]:
label = ToPILImage()(y)
label = np.array(label)
label[:, :, 0] = 0
plt.imshow(label.sum(axis=2))

In [None]:
from torchvision.models.detection import maskrcnn_resnet50_fpn_v2

In [None]:
y.shape

In [None]:
maskrcnn_resnet50_fpn_v2(x.unsqueeze(0))

In [None]:
import os
import torch
import numpy as np
import pandas as pd
import matplotlib as mpl
import matplotlib.pyplot as plt
import seaborn as sns
import torchmetrics.classification as metrics

import warnings

warnings.filterwarnings("ignore")

# Setup default plotting environment
sns.set_style("whitegrid")
sns.set_context("paper")
mpl.rcParams["legend.frameon"] = True
mpl.rcParams["axes.grid"] = True
mpl.rcParams["grid.alpha"] = 0.5

In [None]:
res = pd.read_csv("results-loss.csv")
res = res.melt(id_vars=["epoch"], var_name="Model", value_name="Loss")
res = res[~res["Model"].str.contains("__")]
res["Model"] = res["Model"].apply(lambda x: x.split("_unet")[0])
res["Ablation"] = res["Model"].apply(
    lambda x: "RandAugment" if "rand" in x else "Baseline"
)
res["Learning Rate"] = res["Model"].apply(lambda x: 0.001 if "fast" in x else 0.0001)
res

In [None]:
res = pd.read_csv("test-summary.csv")
res["Name"] = res["Name"].apply(lambda x: x.split("_")[0])
res["F1"] = (
    (res["test_f1"] * 100).round(2).astype(str)
    + "/"
    + (res["test_f1_binary"] * 100).round(2).astype(str)
)
res["JID"] = (
    (res["test_jid"] * 100).round(2).astype(str)
    + "/"
    + (res["test_jid_binary"] * 100).round(2).astype(str)
)
res["Dice"] = (
    (res["test_dice"] * 100).round(2).astype(str)
    + "/"
    + (res["test_dice_binary"] * 100).round(2).astype(str)
)
res["Accuracy"] = (
    (res["test_accuracy"] * 100).round(2).astype(str)
    + "/"
    + (res["test_accuracy_binary"] * 100).round(2).astype(str)
)
# res["CovErr"] = (
#     (res["test_coverage"]).round(2).astype(str)
#     + "/"
#     + (res["test_coverage_binary"]).round(2).astype(str)
# )
res["rand_augment_n"] = res["rand_augment_n"].astype(str)
res = res[
    [
        "Name",
        "init_features",
        "learning_rate",
        "rand_augment_n",
        "rand_augment_m",
        # "CovErr",
        "Accuracy",
        "Dice",
        "JID",
        "F1",
    ]
]

print(
    (
        res.sort_values(
            [
                "rand_augment_n",
                "rand_augment_m",
                "init_features",
                "learning_rate",
            ]
        )
    ).to_latex(index=False, na_rep="-", float_format="{:0.2f}".format)
)