In [3]:
# %% [markdown]
# # Notebook 25: Masked RAovSeg + Attention U-Net + Focal Tversky
#
# New experimental setup:
# - Use **ovary-side masked** T2FS volumes (non-ovary half blacked out).
# - RAovSeg-style preprocessing and augmentation (from `UterusDatasetWithPreprocessing`).
# - Attention U-Net segmentation model.
# - Focal Tversky loss (alpha=0.7, beta=0.3, gamma=4/3).
# - Train/val/test splits are patient-based, from
#   `d2_manifest_t2fs_ov_final_with_split_and_masked.csv`.
#
# This notebook's goal: get a fully working training/eval loop on the new masked
# dataset. Transfer learning (ResNet34 encoder) will be added once this pipeline
# is stable.


In [4]:
# %%
import os
import sys
from pathlib import Path

import numpy as np
import pandas as pd
import torch
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

# Make sure we can import from src/
project_root = Path("..").resolve()
if (project_root / "src").exists() and str(project_root) not in sys.path:
    sys.path.append(str(project_root))

from src.data_loader import UterusDatasetWithPreprocessing
from src.models import AttentionUNet
from src.losses import FocalTverskyLoss

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Project root:", project_root)
print("Device:", device)


Project root: C:\Users\lytten\programming\dlvr-project
Device: cuda


In [5]:
# %% [markdown]
# ## Build RAovSeg manifests from the final masked manifest

# %%
data_dir = project_root / "data"
final_masked_manifest_path = data_dir / "d2_manifest_t2fs_ov_final_with_split_and_masked.csv"

final_df = pd.read_csv(final_masked_manifest_path)
print("Loaded final masked manifest:", final_masked_manifest_path)
print("Columns:", final_df.columns.tolist())
print("Num patients:", len(final_df))
print(final_df["split"].value_counts())

# We want RAovSeg-style manifests with columns: patient_id, mri_path, mask_path
def make_raovseg_manifest(df_split, out_path):
    df_local = df_split.copy()
    # Use masked T2FS as the MRI image
    df_local["mri_path"] = df_local["t2fs_masked_path"]
    df_local["mask_path"] = df_local["ov_mask_path"]

    # Keep only needed columns for the dataset class
    out_df = df_local[["patient_id", "mri_path", "mask_path"]].copy()

    # Ensure paths are relative to project_root (if they aren't already)
    def normalize_path(p):
        p = Path(p)
        if p.is_absolute():
            try:
                return str(p.relative_to(project_root))
            except ValueError:
                return str(p)  # fallback
        else:
            return str(p)

    out_df["mri_path"] = out_df["mri_path"].apply(normalize_path)
    out_df["mask_path"] = out_df["mask_path"].apply(normalize_path)

    out_df.to_csv(out_path, index=False)
    print(f"Saved RAovSeg manifest: {out_path} (n={len(out_df)})")

# Split by patient
train_df = final_df[final_df["split"] == "train"]
val_df   = final_df[final_df["split"] == "val"]
test_df  = final_df[final_df["split"] == "test"]

print("Train patients:", len(train_df))
print("Val patients:",   len(val_df))
print("Test patients:",  len(test_df))

# Output paths
train_manifest_raovseg = data_dir / "d2_manifest_masked_raovseg_train.csv"
val_manifest_raovseg   = data_dir / "d2_manifest_masked_raovseg_val.csv"
test_manifest_raovseg  = data_dir / "d2_manifest_masked_raovseg_test.csv"

make_raovseg_manifest(train_df, train_manifest_raovseg)
make_raovseg_manifest(val_df,   val_manifest_raovseg)
make_raovseg_manifest(test_df,  test_manifest_raovseg)


Loaded final masked manifest: C:\Users\lytten\programming\dlvr-project\data\d2_manifest_t2fs_ov_final_with_split_and_masked.csv
Columns: ['patient_id', 't2fs_path', 'ov_mask_path', 'split', 't2fs_masked_path']
Num patients: 37
split
train    26
val       6
test      5
Name: count, dtype: int64
Train patients: 26
Val patients: 6
Test patients: 5
Saved RAovSeg manifest: C:\Users\lytten\programming\dlvr-project\data\d2_manifest_masked_raovseg_train.csv (n=26)
Saved RAovSeg manifest: C:\Users\lytten\programming\dlvr-project\data\d2_manifest_masked_raovseg_val.csv (n=6)
Saved RAovSeg manifest: C:\Users\lytten\programming\dlvr-project\data\d2_manifest_masked_raovseg_test.csv (n=5)


In [6]:
# %% [markdown]
# ## Datasets and DataLoaders (masked + RAovSeg preprocessing)

# %%
image_size = 256
batch_size = 1   # 1070 GPU constraint
num_workers = 0  # keep it simple / Windows-friendly

print("--- Loading masked RAovSeg datasets ---")
train_dataset = UterusDatasetWithPreprocessing(
    manifest_path=str(train_manifest_raovseg),
    image_size=image_size,
    augment=True,
)

val_dataset = UterusDatasetWithPreprocessing(
    manifest_path=str(val_manifest_raovseg),
    image_size=image_size,
    augment=False,
)

test_dataset = UterusDatasetWithPreprocessing(
    manifest_path=str(test_manifest_raovseg),
    image_size=image_size,
    augment=False,
)

print("\nDataset sizes (ovary-positive slices):")
print("Train slices:", len(train_dataset))
print("Val slices:  ", len(val_dataset))
print("Test slices: ", len(test_dataset))

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True,  num_workers=num_workers)
val_loader   = DataLoader(val_dataset,   batch_size=batch_size, shuffle=False, num_workers=num_workers)
test_loader  = DataLoader(test_dataset,  batch_size=batch_size, shuffle=False, num_workers=num_workers)


--- Loading masked RAovSeg datasets ---
Loading manifest from C:\Users\lytten\programming\dlvr-project\data\d2_manifest_masked_raovseg_train.csv and creating slice map...
Slice map created. Found 0 slices containing the ovary.
Loading manifest from C:\Users\lytten\programming\dlvr-project\data\d2_manifest_masked_raovseg_val.csv and creating slice map...
Slice map created. Found 0 slices containing the ovary.
Loading manifest from C:\Users\lytten\programming\dlvr-project\data\d2_manifest_masked_raovseg_test.csv and creating slice map...
Slice map created. Found 0 slices containing the ovary.

Dataset sizes (ovary-positive slices):
Train slices: 0
Val slices:   0
Test slices:  0


ValueError: num_samples should be a positive integer value, but got num_samples=0

In [None]:
from pathlib import Path
import pandas as pd
import nibabel as nib
import numpy as np

project_root = Path("..").resolve()
data_dir = project_root / "data"

train_manifest_raovseg = data_dir / "d2_manifest_masked_raovseg_train.csv"
print("Train RAovSeg manifest:", train_manifest_raovseg)

df_train = pd.read_csv(train_manifest_raovseg)
print("Columns:", df_train.columns.tolist())
print("Num rows:", len(df_train))
print(df_train.head())


Train RAovSeg manifest: C:\Users\lytten\programming\dlvr-project\data\d2_manifest_masked_raovseg_train.csv
Columns: ['patient_id', 'mri_path', 'mask_path']
Num rows: 26
  patient_id                                           mri_path  \
0     D2-051  data\UT-EndoMRI\D2_TCPW_masked\D2-051\D2-051_T...   
1     D2-027  data\UT-EndoMRI\D2_TCPW_masked\D2-027\D2-027_T...   
2     D2-048  data\UT-EndoMRI\D2_TCPW_masked\D2-048\D2-048_T...   
3     D2-014  data\UT-EndoMRI\D2_TCPW_masked\D2-014\D2-014_T...   
4     D2-012  data\UT-EndoMRI\D2_TCPW_masked\D2-012\D2-012_T...   

                                         mask_path  
0  data\UT-EndoMRI\D2_TCPW\D2-001\D2-001_ov.nii.gz  
1  data\UT-EndoMRI\D2_TCPW\D2-005\D2-005_ov.nii.gz  
2  data\UT-EndoMRI\D2_TCPW\D2-007\D2-007_ov.nii.gz  
3  data\UT-EndoMRI\D2_TCPW\D2-010\D2-010_ov.nii.gz  
4  data\UT-EndoMRI\D2_TCPW\D2-012\D2-012_ov.nii.gz  


In [None]:
row0 = df_train.iloc[0]
print("\nRow 0:", row0)

mask_path = Path(row0["mask_path"])
if not mask_path.is_absolute():
    mask_path = project_root / mask_path

print("Resolved mask path:", mask_path)
assert mask_path.exists(), "Mask file does not exist!"

msk = nib.load(str(mask_path)).get_fdata()
print("Mask shape:", msk.shape)
print("Mask min / max:", float(msk.min()), float(msk.max()))
print("Num voxels > 0:", int((msk > 0).sum()))



Row 0: patient_id                                               D2-051
mri_path      data\UT-EndoMRI\D2_TCPW_masked\D2-051\D2-051_T...
mask_path       data\UT-EndoMRI\D2_TCPW\D2-001\D2-001_ov.nii.gz
Name: 0, dtype: object
Resolved mask path: C:\Users\lytten\programming\dlvr-project\data\UT-EndoMRI\D2_TCPW\D2-001\D2-001_ov.nii.gz
Mask shape: (320, 320, 34)
Mask min / max: 0.0 1.0
Num voxels > 0: 1039


In [7]:
# %% [markdown]
# ## Datasets and DataLoaders (masked + RAovSeg preprocessing)

# %%
image_size = 256
batch_size = 1   # 1070 GPU constraint
num_workers = 0  # keep it simple / Windows-friendly

print("--- Loading masked RAovSeg datasets ---")
train_dataset = UterusDatasetWithPreprocessing(
    manifest_path=str(train_manifest_raovseg),
    image_size=image_size,
    augment=True,
)

val_dataset = UterusDatasetWithPreprocessing(
    manifest_path=str(val_manifest_raovseg),
    image_size=image_size,
    augment=False,
)

test_dataset = UterusDatasetWithPreprocessing(
    manifest_path=str(test_manifest_raovseg),
    image_size=image_size,
    augment=False,
)

print("\nDataset sizes (ovary-positive slices):")
print("Train slices:", len(train_dataset))
print("Val slices:  ", len(val_dataset))
print("Test slices: ", len(test_dataset))

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True,  num_workers=num_workers)
val_loader   = DataLoader(val_dataset,   batch_size=batch_size, shuffle=False, num_workers=num_workers)
test_loader  = DataLoader(test_dataset,  batch_size=batch_size, shuffle=False, num_workers=num_workers)


--- Loading masked RAovSeg datasets ---
Loading manifest from C:\Users\lytten\programming\dlvr-project\data\d2_manifest_masked_raovseg_train.csv and creating slice map...
Slice map created. Found 0 slices containing the ovary.
Loading manifest from C:\Users\lytten\programming\dlvr-project\data\d2_manifest_masked_raovseg_val.csv and creating slice map...
Slice map created. Found 0 slices containing the ovary.
Loading manifest from C:\Users\lytten\programming\dlvr-project\data\d2_manifest_masked_raovseg_test.csv and creating slice map...
Slice map created. Found 0 slices containing the ovary.

Dataset sizes (ovary-positive slices):
Train slices: 0
Val slices:   0
Test slices:  0


ValueError: num_samples should be a positive integer value, but got num_samples=0

In [None]:
# %% [markdown]
# ## Model, loss, optimizer

# %%
n_channels = 1  # single-channel T2FS
n_classes  = 1  # ovary vs background

model = AttentionUNet(n_channels=n_channels, n_classes=n_classes).to(device)

criterion = FocalTverskyLoss(alpha=0.7, beta=0.3, gamma=4/3)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

print(model)


In [None]:
# %% [markdown]
# ## Training and validation loops (Dice metric)

# %%
def dice_score(preds, targets, epsilon=1e-6):
    """
    preds, targets: tensors of shape [B, 1, H, W], values in [0,1].
    """
    preds_flat = preds.view(-1)
    targets_flat = targets.view(-1)

    intersection = (preds_flat * targets_flat).sum()
    return (2. * intersection + epsilon) / (preds_flat.sum() + targets_flat.sum() + epsilon)


def train_one_epoch(model, loader, optimizer, criterion, device):
    model.train()
    running_loss = 0.0

    for images, masks in loader:
        images = images.to(device)  # [B, 1, H, W]
        masks  = masks.to(device)   # [B, 1, H, W]

        optimizer.zero_grad()
        logits = model(images)              # [B, 1, H, W]
        probs  = torch.sigmoid(logits)
        loss   = criterion(probs, masks)

        loss.backward()
        optimizer.step()

        running_loss += loss.item() * images.size(0)

    return running_loss / len(loader.dataset)


def evaluate(model, loader, criterion, device):
    model.eval()
    running_loss = 0.0
    running_dice = 0.0

    with torch.no_grad():
        for images, masks in loader:
            images = images.to(device)
            masks  = masks.to(device)

            logits = model(images)
            probs  = torch.sigmoid(logits)
            loss   = criterion(probs, masks)

            running_loss += loss.item() * images.size(0)
            running_dice += dice_score(probs, masks).item() * images.size(0)

    avg_loss = running_loss / len(loader.dataset)
    avg_dice = running_dice / len(loader.dataset)
    return avg_loss, avg_dice


In [None]:
# %% [markdown]
# ## Train the masked RAovSeg + Attention U-Net + Focal Tversky model

# %%
num_epochs = 10  # start with 10 for sanity; bump to 50 once you're happy

best_val_dice = -1.0
best_epoch = -1

model_dir = project_root / "models"
model_dir.mkdir(exist_ok=True)
best_model_path = model_dir / "25_masked_attn_unet_raovseg_ftl_best.pth"

train_loss_history = []
val_loss_history = []
val_dice_history = []

print(f"Starting training for {num_epochs} epochs...")
for epoch in range(num_epochs):
    train_loss = train_one_epoch(model, train_loader, optimizer, criterion, device)
    val_loss, val_dice = evaluate(model, val_loader, criterion, device)

    train_loss_history.append(train_loss)
    val_loss_history.append(val_loss)
    val_dice_history.append(val_dice)

    if val_dice > best_val_dice:
        best_val_dice = val_dice
        best_epoch = epoch
        torch.save(model.state_dict(), best_model_path)
        improved = "*"
    else:
        improved = " "

    print(
        f"Epoch {epoch+1:02d}/{num_epochs:02d} "
        f"TrainLoss={train_loss:.4f}  ValLoss={val_loss:.4f}  ValDice={val_dice:.4f}  {improved}"
    )

print(f"\nBest val Dice: {best_val_dice:.4f} at epoch {best_epoch+1}")
print("Saved best model to:", best_model_path)


In [None]:
# %% [markdown]
# ## Training curves

# %%
epochs = range(1, len(train_loss_history) + 1)

plt.figure(figsize=(10,4))
plt.subplot(1,2,1)
plt.plot(epochs, train_loss_history, label="Train loss")
plt.plot(epochs, val_loss_history, label="Val loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend()
plt.title("Focal Tversky loss")

plt.subplot(1,2,2)
plt.plot(epochs, val_dice_history, label="Val Dice")
plt.xlabel("Epoch")
plt.ylabel("Dice")
plt.legend()
plt.title("Validation Dice")

plt.tight_layout()
plt.show()


In [None]:
# %% [markdown]
# ## Final evaluation on the test set (masked, RAovSeg-preprocessed)

# %%
# Reload best model
best_model = AttentionUNet(n_channels=n_channels, n_classes=n_classes).to(device)
best_model.load_state_dict(torch.load(best_model_path, map_location=device))

test_loss, test_dice = evaluate(best_model, test_loader, criterion, device)
print(f"Test loss: {test_loss:.4f}")
print(f"Test Dice: {test_dice:.4f}")
