# Notebook 22 – Unified Model Evaluation (original vs masked validation)

In this notebook we:
- Load all trained ovary segmentation models (baseline, attention, RAovSeg, Focal Tversky, TL-5, masked model).
- Evaluate each model on the **same** validation set (positive ovary slices only).
- Evaluate each model again on **ovary-side masked** validation images.
- Summarise results in a table: Val Dice (orig), Val Dice (masked), and ΔDice = masked − orig.

This is a sanity check and will give us the final numbers for Table 1 in the report.


In [19]:
import sys
from pathlib import Path

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import numpy as np
import pandas as pd

# Make project root importable (assuming this notebook lives in notebooks/)
project_root = Path.cwd().parent
if str(project_root) not in sys.path:
    sys.path.append(str(project_root))

print("Project root:", project_root)

from src.data_loader import UterusDataset, UterusDatasetWithPreprocessing
from src.models import UNet, AttentionUNet, DoubleConv, AttentionGate


Project root: c:\Users\lytten\programming\dlvr-project


In [20]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

def dice_score(preds, targets, epsilon=1e-6):
    """
    Compute Dice score between binary predictions and binary targets.
    preds, targets: (B, 1, H, W) tensors with {0,1} values.
    """
    preds_flat = preds.view(-1)
    targets_flat = targets.view(-1)
    intersection = (preds_flat * targets_flat).sum()
    return (2.0 * intersection + epsilon) / (preds_flat.sum() + targets_flat.sum() + epsilon)

def evaluate_model(model, dataloader):
    model.eval()
    dice_sum = 0.0
    n = 0
    with torch.no_grad():
        for batch in dataloader:
            # Dataset returns (image, mask)
            images, masks = batch
            images = images.to(device)
            masks = masks.to(device)

            logits = model(images)
            probs = torch.sigmoid(logits)
            preds = (probs > 0.5).float()

            dice = dice_score(preds, masks)
            dice_sum += dice.item() * images.size(0)
            n += images.size(0)

    return dice_sum / max(n, 1)


Using device: cuda


In [21]:
from torch.utils.data import Dataset

class OvarySideMaskedDataset(Dataset):
    """
    Wraps a base dataset (UterusDataset or UterusDatasetWithPreprocessing)
    and zeros out the half of the image opposite to the annotated ovary.

    This mimics the masking we used during training of Model 7.
    """
    def __init__(self, base_dataset):
        self.base = base_dataset

    def __len__(self):
        return len(self.base)

    def __getitem__(self, idx):
        image, mask = self.base[idx]  # image, mask: (1, H, W)

        # Work on a copy so we don't modify the underlying dataset
        image = image.clone()
        mask_np = mask[0].numpy()  # (H, W)

        if mask_np.sum() > 0:
            H, W = mask_np.shape
            # x-axis is the second index
            xs = np.where(mask_np > 0)[1]
            center_x = xs.mean()
            if center_x < W / 2:
                # ovary on left -> zero out right half
                image[:, :, W // 2:] = 0.0
            else:
                # ovary on right -> zero out left half
                image[:, :, :W // 2] = 0.0

        return image, mask


In [22]:
manifest_path = project_root / "data" / "d2_manifest_t2fs_ovary_eligible.csv"
print("Manifest path:", manifest_path)

# Plain preprocessing (min-max normalization)
val_dataset_plain = UterusDataset(
    manifest_path=str(manifest_path),
    image_size=256,
    augment=False
)

# RAovSeg-style preprocessing
val_dataset_raovseg = UterusDatasetWithPreprocessing(
    manifest_path=str(manifest_path),
    image_size=256,
    augment=False
)

batch_size = 1

val_loader_plain = DataLoader(val_dataset_plain, batch_size=batch_size, shuffle=False)
val_loader_plain_masked = DataLoader(OvarySideMaskedDataset(val_dataset_plain),
                                     batch_size=batch_size, shuffle=False)

val_loader_raovseg = DataLoader(val_dataset_raovseg, batch_size=batch_size, shuffle=False)
val_loader_raovseg_masked = DataLoader(OvarySideMaskedDataset(val_dataset_raovseg),
                                       batch_size=batch_size, shuffle=False)

print("Plain val slices:", len(val_dataset_plain))
print("RAovSeg val slices:", len(val_dataset_raovseg))


Manifest path: c:\Users\lytten\programming\dlvr-project\data\d2_manifest_t2fs_ovary_eligible.csv
Loading manifest from c:\Users\lytten\programming\dlvr-project\data\d2_manifest_t2fs_ovary_eligible.csv and creating slice map...
Slice map created. Found 278 slices containing the uterus.
Loading manifest from c:\Users\lytten\programming\dlvr-project\data\d2_manifest_t2fs_ovary_eligible.csv and creating slice map...
Slice map created. Found 278 slices containing the ovary.
Plain val slices: 278
RAovSeg val slices: 278


In [23]:
# --- Transfer-learning Attention U-Net with ResNet34 encoder ---

import torchvision
from torchvision.models import resnet34, ResNet34_Weights

class TLAttentionUNetResNet34(nn.Module):
    """
    Attention U-Net style decoder with a ResNet34 encoder.
    - Encoder: ResNet34 pretrained on ImageNet (RGB).
    - Decoder: 4 upsampling stages with attention gates and DoubleConv blocks.
    - Input: (B, 1, H, W); we repeat the channel to 3 for ResNet.
    - Output: logits (B, 1, H, W), resized back to input resolution if needed.
    """
    def __init__(self, n_classes=1, in_channels=1, use_pretrained=True):
        super().__init__()
        self.in_channels = in_channels

        # Try to load pretrained weights; fall back to random init if download fails.
        if use_pretrained:
            try:
                encoder = resnet34(weights=ResNet34_Weights.DEFAULT)
                print("Loaded ResNet34 with ImageNet pretrained weights.")
            except Exception as e:
                print(
                    "WARNING: Could not load pretrained ResNet34 weights. "
                    "Falling back to randomly initialized encoder.\nError:", e
                )
                encoder = resnet34(weights=None)
        else:
            encoder = resnet34(weights=None)
            print("Using ResNet34 without pretrained weights.")

        self.encoder = encoder

        # Encoder parts
        self.conv1 = encoder.conv1
        self.bn1 = encoder.bn1
        self.relu = encoder.relu
        self.maxpool = encoder.maxpool

        self.layer1 = encoder.layer1  # output: 64 ch, /4
        self.layer2 = encoder.layer2  # 128 ch, /8
        self.layer3 = encoder.layer3  # 256 ch, /16
        self.layer4 = encoder.layer4  # 512 ch, /32

        # Decoder with attention gates (mirror the last 4 scales)
        # shapes: x1 (64, H/4), x2 (128, H/8), x3 (256, H/16), x4 (512, H/32)

        self.up1 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.att1 = AttentionGate(F_g=256, F_l=256, F_int=128)
        self.conv1_up = DoubleConv(512, 256)

        self.up2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.att2 = AttentionGate(F_g=128, F_l=128, F_int=64)
        self.conv2_up = DoubleConv(256, 128)

        self.up3 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.att3 = AttentionGate(F_g=64, F_l=64, F_int=32)
        self.conv3_up = DoubleConv(128, 64)

        self.up4 = nn.ConvTranspose2d(64, 64, kernel_size=2, stride=2)
        self.att4 = AttentionGate(F_g=64, F_l=64, F_int=32)
        self.conv4_up = DoubleConv(128, 64)

        self.outc = nn.Conv2d(64, n_classes, kernel_size=1)

    def forward(self, x):
        # Remember input spatial size (e.g. 256x256)
        input_size = x.size()[2:]

        # Repeat channel to 3 for ResNet (if input is 1-channel)
        if self.in_channels == 1 and x.size(1) == 1:
            x = x.repeat(1, 3, 1, 1)

        # Encoder
        x0 = self.conv1(x)            # -> (64, H/2, W/2)
        x0 = self.bn1(x0)
        x0 = self.relu(x0)            # we'll use this as the highest-res skip connection

        x1 = self.maxpool(x0)         # -> (64, H/4, W/4)
        x1 = self.layer1(x1)          # -> (64, H/4, W/4)
        x2 = self.layer2(x1)          # -> (128, H/8, W/8)
        x3 = self.layer3(x2)          # -> (256, H/16, W/16)
        x4 = self.layer4(x3)          # -> (512, H/32, W/32)

        # Decoder with attention
        d4 = self.up1(x4)             # (256, H/16, W/16)
        x3_att = self.att1(g=d4, x=x3)
        d4 = self.conv1_up(torch.cat([d4, x3_att], dim=1))

        d3 = self.up2(d4)             # (128, H/8, W/8)
        x2_att = self.att2(g=d3, x=x2)
        d3 = self.conv2_up(torch.cat([d3, x2_att], dim=1))

        d2 = self.up3(d3)             # (64, H/4, W/4)
        x1_att = self.att3(g=d2, x=x1)
        d2 = self.conv3_up(torch.cat([d2, x1_att], dim=1))

        d1 = self.up4(d2)             # (64, H/2, W/2)
        x0_att = self.att4(g=d1, x=x0)
        d1 = self.conv4_up(torch.cat([d1, x0_att], dim=1))  # (64, H/2, W/2)

        out = self.outc(d1)

        # Upsample back to original input resolution (256x256)
        if out.shape[2:] != input_size:
            out = F.interpolate(out, size=input_size, mode="bilinear", align_corners=False)

        return out


def freeze_encoder(model: TLAttentionUNetResNet34):
    for p in model.encoder.parameters():
        p.requires_grad = False
    print("Encoder frozen (no gradient).")


def unfreeze_encoder(model: TLAttentionUNetResNet34):
    for p in model.encoder.parameters():
        p.requires_grad = True
    print("Encoder unfrozen (trainable).")


In [24]:
models_dir = project_root / "models"
print("Models dir:", models_dir)

include_tl5 = True  # set to False if you don't want to deal with the TL architecture here

MODEL_CONFIGS = [
    {
        "name": "Model 1: U-Net (baseline)",
        "short": "M1_baseline",
        "ckpt": "07_ovary_baseline_best.pth",
        "arch": "unet",
        "preproc": "plain",
    },
    {
        "name": "Model 2: Attention U-Net",
        "short": "M2_attn",
        "ckpt": "09_attention_unet_best.pth",
        "arch": "attn",
        "preproc": "plain",
    },
    {
        "name": "Model 3: Attn U-Net + RAovSeg (20 ep)",
        "short": "M3_attn_raovseg",
        "ckpt": "11_attention_unet_preprocessed_best.pth",
        "arch": "attn",
        "preproc": "raovseg",
    },
    {
        "name": "Model 4: Attn U-Net + RAovSeg (50 ep)",
        "short": "M4_attn_raovseg_long",
        "ckpt": "13_attn_unet_prep_long_best.pth",
        "arch": "attn",
        "preproc": "raovseg",
    },
    {
        "name": "Model 5: Attn U-Net + RAovSeg + FTL",
        "short": "M5_attn_raovseg_ftl",
        "ckpt": "15_attn_unet_focal_tversky_best.pth",
        "arch": "attn",
        "preproc": "raovseg",
    },
]

if include_tl5:
    MODEL_CONFIGS.append(
        {
            "name": "TL-5: ResNet34 Attn U-Net + RAovSeg + FTL",
            "short": "TL5_resnet34_attn",
            "ckpt": "20_tl_attn_unet_resnet34_ft_best.pth",
            "arch": "resnet34_attn",
            "preproc": "raovseg",
        }
    )

MODEL_CONFIGS.append(
    {
        "name": "Model 7: Attn U-Net + RAovSeg + FTL (masked train)",
        "short": "M7_attn_raovseg_ftl_masked",
        "ckpt": "21_attn_unet_prep_ft_masked_best.pth",
        "arch": "attn",
        "preproc": "raovseg",
    }
)

def create_model(cfg):
    arch = cfg["arch"]
    if arch == "unet":
        model = UNet(n_channels=1, n_classes=1)
    elif arch == "attn":
        model = AttentionUNet(n_channels=1, n_classes=1)
    elif arch == "resnet34_attn":
        # use_pretrained=False because the checkpoint already has the trained weights
        model = TLAttentionUNetResNet34(n_classes=1, in_channels=1, use_pretrained=False)
    else:
        raise ValueError(f"Unknown arch: {arch}")

    ckpt_path = models_dir / cfg["ckpt"]
    print(f"  Loading {cfg['name']} from {ckpt_path}")
    state_dict = torch.load(ckpt_path, map_location=device)
    model.load_state_dict(state_dict)
    model.to(device)
    return model



Models dir: c:\Users\lytten\programming\dlvr-project\models


In [25]:
results = []

for cfg in MODEL_CONFIGS:
    print("=" * 80)
    print(cfg["name"])
    model = create_model(cfg)

    if cfg["preproc"] == "plain":
        dl_orig = val_loader_plain
        dl_masked = val_loader_plain_masked
    elif cfg["preproc"] == "raovseg":
        dl_orig = val_loader_raovseg
        dl_masked = val_loader_raovseg_masked
    else:
        raise ValueError(f"Unknown preproc: {cfg['preproc']}")

    dice_orig = evaluate_model(model, dl_orig)
    dice_masked = evaluate_model(model, dl_masked)

    delta = dice_masked - dice_orig

    print(f"  Val Dice (orig):   {dice_orig:.4f}")
    print(f"  Val Dice (masked): {dice_masked:.4f}")
    print(f"  ΔDice (masked-orig): {delta:+.4f}")

    results.append(
        {
            "Model": cfg["name"],
            "Short": cfg["short"],
            "Preproc": cfg["preproc"],
            "Val Dice (orig)": dice_orig,
            "Val Dice (masked)": dice_masked,
            "ΔDice (masked - orig)": delta,
        }
    )

results_df = pd.DataFrame(results)
results_df


Model 1: U-Net (baseline)
  Loading Model 1: U-Net (baseline) from c:\Users\lytten\programming\dlvr-project\models\07_ovary_baseline_best.pth


  state_dict = torch.load(ckpt_path, map_location=device)


  Val Dice (orig):   0.2755
  Val Dice (masked): 0.3916
  ΔDice (masked-orig): +0.1162
Model 2: Attention U-Net
  Loading Model 2: Attention U-Net from c:\Users\lytten\programming\dlvr-project\models\09_attention_unet_best.pth
  Val Dice (orig):   0.2470
  Val Dice (masked): 0.3805
  ΔDice (masked-orig): +0.1335
Model 3: Attn U-Net + RAovSeg (20 ep)
  Loading Model 3: Attn U-Net + RAovSeg (20 ep) from c:\Users\lytten\programming\dlvr-project\models\11_attention_unet_preprocessed_best.pth
  Val Dice (orig):   0.3061
  Val Dice (masked): 0.3481
  ΔDice (masked-orig): +0.0420
Model 4: Attn U-Net + RAovSeg (50 ep)
  Loading Model 4: Attn U-Net + RAovSeg (50 ep) from c:\Users\lytten\programming\dlvr-project\models\13_attn_unet_prep_long_best.pth
  Val Dice (orig):   0.2644
  Val Dice (masked): 0.2954
  ΔDice (masked-orig): +0.0309
Model 5: Attn U-Net + RAovSeg + FTL
  Loading Model 5: Attn U-Net + RAovSeg + FTL from c:\Users\lytten\programming\dlvr-project\models\15_attn_unet_focal_tversky_

Unnamed: 0,Model,Short,Preproc,Val Dice (orig),Val Dice (masked),ΔDice (masked - orig)
0,Model 1: U-Net (baseline),M1_baseline,plain,0.275465,0.391648,0.116183
1,Model 2: Attention U-Net,M2_attn,plain,0.247044,0.380542,0.133497
2,Model 3: Attn U-Net + RAovSeg (20 ep),M3_attn_raovseg,raovseg,0.306095,0.3481,0.042005
3,Model 4: Attn U-Net + RAovSeg (50 ep),M4_attn_raovseg_long,raovseg,0.264445,0.295393,0.030947
4,Model 5: Attn U-Net + RAovSeg + FTL,M5_attn_raovseg_ftl,raovseg,0.31283,0.343682,0.030853
5,TL-5: ResNet34 Attn U-Net + RAovSeg + FTL,TL5_resnet34_attn,raovseg,0.364519,0.420374,0.055855
6,Model 7: Attn U-Net + RAovSeg + FTL (masked tr...,M7_attn_raovseg_ftl_masked,raovseg,0.26426,0.353516,0.089256


In [26]:
# Sort roughly by "pipeline complexity"
order = [
    "M1_baseline",
    "M2_attn",
    "M3_attn_raovseg",
    "M4_attn_raovseg_long",
    "M5_attn_raovseg_ftl",
    "TL5_resnet34_attn" if include_tl5 else None,
    "M7_attn_raovseg_ftl_masked",
]
order = [s for s in order if s is not None]

results_df_sorted = results_df.set_index("Short").loc[order].reset_index()

# Round for nice printing
display_cols = ["Model", "Val Dice (orig)", "Val Dice (masked)", "ΔDice (masked - orig)"]
print(results_df_sorted[display_cols].round(4))

# If you want LaTeX:
print("\nLaTeX table rows:")
for _, row in results_df_sorted.iterrows():
    model_name = row["Model"]
    d_orig = row["Val Dice (orig)"]
    d_mask = row["Val Dice (masked)"]
    delta = row["ΔDice (masked - orig)"]
    print(f"{model_name} & {d_orig:.4f} & {d_mask:.4f} & {delta:+.4f} \\\\")


                                               Model  Val Dice (orig)  \
0                          Model 1: U-Net (baseline)           0.2755   
1                           Model 2: Attention U-Net           0.2470   
2              Model 3: Attn U-Net + RAovSeg (20 ep)           0.3061   
3              Model 4: Attn U-Net + RAovSeg (50 ep)           0.2644   
4                Model 5: Attn U-Net + RAovSeg + FTL           0.3128   
5          TL-5: ResNet34 Attn U-Net + RAovSeg + FTL           0.3645   
6  Model 7: Attn U-Net + RAovSeg + FTL (masked tr...           0.2643   

   Val Dice (masked)  ΔDice (masked - orig)  
0             0.3916                 0.1162  
1             0.3805                 0.1335  
2             0.3481                 0.0420  
3             0.2954                 0.0309  
4             0.3437                 0.0309  
5             0.4204                 0.0559  
6             0.3535                 0.0893  

LaTeX table rows:
Model 1: U-Net (baseline) & 