In [None]:
import os

import matplotlib.pyplot as plt
import pandas as pd
import torch

from data.config import (
    MODELS_DIR, 
    DATA_ADAPTATION_DIR, 
    TRANSFORM_MODE, 
    TRAIN_VALID_SPLIT, 
    SEED, 
    IMG_GRADIENT, 
    IMG_MODE
)
from dataset import CoarseMaskDataset
from u_net import Coarse2FineTiny
from u_net_res_attention import Coarse2FineUNetAttention
from u_net_residual import Coarse2FineTinyRes
from torch.utils.data import random_split

In [None]:
models = os.listdir(MODELS_DIR)

model_dirs = ['unetres16-32_bce-dice_18052025', 'unetres16-32_bce-dice_18052025', 
               'u_net_16_128_bce_dice', 'u_net_res_16_128_bce_dice',
               'unet16-128_bce-dice-bound_19052025', 'unetres16-128_bce-dice-bound_19052025',
               'unetres32-256_bce-dice-bound_21052025']

model_classes = [Coarse2FineTiny, Coarse2FineTinyRes,
                 Coarse2FineTiny, Coarse2FineTinyRes,
                 Coarse2FineTiny, Coarse2FineTinyRes,
                 Coarse2FineTinyRes]

model_names = ['U-Net 16-128 BCE-Dice',
               'U-Net Res 16-128 BCE-Dice',
               'U-Net 16-128 BCE-Dice',
               'U-Net Res 16-128 BCE-Dice',
               'U-Net 16-128 BCE-Dice+Bound',
               'U-Net Res 16-128 BCE-Dice+Bound',
               'U-Net Res 32-256 BCE-Dice+Bound']

losses = [["bce", "dice"],
         ["bce", "dice"],
         ["bce", "dice"],
         ["bce", "dice"],
         ["bce", "dice", "boundary"],
         ["bce", "dice", "boundary"],
         ["bce", "dice", "boundary"]]

features = [[16, 32], 
            [16, 32],
            [16, 32, 64, 128],
            [16, 32, 64, 128],
            [16, 32, 64, 128],
            [16, 32, 64, 128],
            [32, 64, 128, 256]] 

In [None]:
def plot_result(dataset, model, num_samples=5, threshold=0.5, common_title="Model Predictions", column_titles=None, idx_img=None):
    """
    Plots a grid of model predictions with a common title and column-specific titles.

    Args:
        dataset: PyTorch dataset returning (input, target) tuples.
        model: Trained PyTorch model.
        num_samples (int): Number of samples (rows) to show.
        threshold (float): Threshold for binary mask prediction.
        common_title (str): Title for the whole figure.
        column_titles (list[str]): Titles for each column (4 expected).
    """
    if idx_img is None:
        idx_img = 89

    n_rows = min(num_samples, len(dataset))
    n_cols = 4  # input image, input mask, target mask, predicted mask

    if column_titles is None:
        column_titles = [
            "Input image",
            "Input coarse mask",
            "Ground truth mask",
            "Predicted mask",
        ]

    fig, axes = plt.subplots(n_rows, n_cols, figsize=(n_cols * 2, n_rows * 3))
    fig.suptitle(common_title, fontsize=16)

    device = next(model.parameters()).device  # get model's device

    for i in range(n_rows):
        item, target = dataset[i + idx_img]  # FIXME: index offset still manual
        item = item.to(device)

        with torch.no_grad():
            predicted = model(item.unsqueeze(0))
            predicted_probs = torch.sigmoid(predicted)
            predicted_mask = (predicted_probs > threshold).float()

        images = [
            item[0].cpu(),          # input image
            item[1].cpu(),          # input coarse mask
            target[0].cpu(),        # ground truth mask
            predicted_mask[0][0].cpu(),  # predicted mask
        ]

        for j in range(n_cols):
            ax = axes[i, j] if n_rows > 1 else axes[j]
            ax.imshow(images[j], cmap="gray")
            ax.axis("off")
            if i == 0:
                ax.set_title(column_titles[j], fontsize=12)

    plt.tight_layout(rect=[0, 0, 1, 1])  # leave space for suptitle
    plt.show()


def evaluate_checkpoint(ckpt_dir: str) -> None:
    metrics_csv: str = os.path.join(ckpt_dir, "metrics.csv")
    df: pd.DataFrame = pd.read_csv(metrics_csv)

    train_df = df[df["train_loss"].notnull()]
    val_df = df[df["val_loss"].notnull()]

    # plot train and valid loss
    plt.figure(figsize=(10, 6))
    plt.plot(train_df["step"], train_df["train_loss"], label="Train Loss", marker="o")
    plt.plot(val_df["step"], val_df["val_loss"], label="Validation Loss", marker="s")
    plt.xlabel("Step")
    plt.ylabel("Loss")
    # plt.yscale("log")
    plt.title("Train and Validation Loss over Steps")
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    plt.show()

    # plot valid iou
    plt.plot(val_df["step"], val_df["val_iou"])
    plt.xlabel("Step")
    plt.ylabel("Validation Intersection over Union")
    plt.show()

    # TODO plot other metrics if present

In [None]:
full_dataset = CoarseMaskDataset(
        DATA_ADAPTATION_DIR,
        transform_type=TRANSFORM_MODE,
        image_gradient=IMG_GRADIENT,
        mode=IMG_MODE,
    )

total_len = len(full_dataset)
val_len = int(total_len * TRAIN_VALID_SPLIT)
train_len = total_len - val_len
train_dataset, val_dataset = random_split(
    full_dataset,
    [train_len, val_len],
    generator=torch.Generator().manual_seed(SEED),
)

In [None]:
for model_dir, model_class, model_name, loss, feat in zip(model_dirs, model_classes, model_names, losses, features):

    model_path = MODELS_DIR / model_dir / "checkpoints" / "best-checkpoint.ckpt"

    model = model_class.load_from_checkpoint(
        model_path,
        losses=loss,
        features=feat,
        loss_weights=[.4, .4, .2] if len(loss) == 3 else [.5, .5]
    )

    plot_result(val_dataset, model, num_samples=2, common_title=f"{model_name} Predictions", idx_img=8, threshold=0.4) #idx = 5 Ã¨ carino
    #evaluate_checkpoint(os.path.join(models_dir, model_dir))