# Odelia Breast MRI Challenge - ResNet50

## <a id='imports'></a>1. Imports & Setup
_Standard libraries, Torch, Torchvision, utilities._

In [None]:
# Cell 1: Import Libraries
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.cuda.amp import autocast, GradScaler
from pathlib import Path
from tqdm.auto import tqdm
from collections import Counter
from sklearn.metrics import roc_auc_score, roc_curve
import warnings
warnings.filterwarnings('ignore')

# MONAI imports
from monai.data import CacheDataset, list_data_collate
from monai.transforms import (
    Compose, LoadImaged, EnsureChannelFirstd, Orientationd,
    Spacingd, ScaleIntensityRanged, ResizeWithPadOrCropd,
    EnsureTyped, ConcatItemsd,
    RandRotate90d, RandFlipd, RandZoomd, RandGaussianNoised
)
from monai.networks.nets import resnet50


print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

## <a id='config'></a>2. Configuration
_Hyperparameters, paths, and class mapping._

- `NUM_CLASSES = 3` (0: no lesion, 1: benign, 2: malignant)
- `TARGET_CLASS_FOR_OPS = 2` (malignant for operating-point metrics)

In [None]:
# Cell 2: Configuration
from pathlib import Path

# Paths (matching working notebook)
ROOT_DIR = r"C:\Users\NK\PycharmProjects\DeepLearningProject\dataset_downloaded"
CSV_ANNOT = r"C:\Users\NK\PycharmProjects\DeepLearningProject\dataset_downloaded\UMCU\metadata_unilateral\annotation.csv"
CSV_SPLIT = r"C:\Users\NK\PycharmProjects\DeepLearningProject\dataset_downloaded\UMCU\metadata_unilateral\split.csv"
IMAGES_ROOT = r"C:\Users\NK\PycharmProjects\DeepLearningProject\dataset_downloaded\UMCU\data_unilateral"

# Training config
BATCH_SIZE = 1
EPOCHS = 20
LR = 5e-5
WEIGHT_DECAY = 1e-4
NUM_WORKERS = 0
PIN_MEMORY = True
...
CHANNEL_KEYS = ["T2", "Pre", "Post_1", "Post_2", "Post_3"]

# 3-class setup
# 0 = no lesion, 1 = benign, 2 = malignant
NUM_CLASSES = 3
# For operating point metrics (Spec@90%Sens, Sens@90%Spec), target class = malignant
TARGET_CLASS_FOR_OPS = 2

TARGET_SPACING = (0.7, 0.7, 3.0)
TARGET_SHAPE = (96, 224, 224)

# Verify paths
assert Path(CSV_ANNOT).exists(), f"Missing: {CSV_ANNOT}"
assert Path(CSV_SPLIT).exists(), f"Missing: {CSV_SPLIT}"
assert Path(IMAGES_ROOT).exists(), f"Missing: {IMAGES_ROOT}"

# Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
print(f"All paths verified")

## <a id='data'></a>3. Data Loading & Preprocessing
_Read CSVs, build datasets/dataloaders, apply transforms._

In [None]:
# Cell 3: Load CSV Files
import pandas as pd

# Load annotation and split data
df_annot = pd.read_csv(CSV_ANNOT)
df_split = pd.read_csv(CSV_SPLIT)

# Check column names
print("Annotation columns:", df_annot.columns.tolist())
print("Split columns:", df_split.columns.tolist())

# Merge using UID (common column in both CSVs)
df = pd.merge(df_annot, df_split, on="UID")

df = df.rename(columns={"Split": "split"})

print(f"\nLoaded {len(df)} samples")
print(f"Merged columns: {df.columns.tolist()}")
print(f"\nLesion distribution:")
print(df["Lesion"].value_counts())
print(f"\nSplit distribution:")
print(df["split"].value_counts())

In [None]:
# Cell 4: Build Sample Dictionaries
def build_samples(df, images_root):
    samples = []
    skipped = 0

    images_root_path = Path(images_root)
    has_laterality = 'Laterality' in df.columns

    # Check raw Lesion values before processing
    print("Checking Lesion column values:")
    print(f"Unique values: {df['Lesion'].unique()}")
    print(f"Value counts:\n{df['Lesion'].value_counts()}")
    print(f"\nFirst 10 raw Lesion values:")
    for i, val in enumerate(df['Lesion'].head(10)):
        print(f"  {i}: '{val}' (type: {type(val).__name__})")

    label_debug = {"0_no_lesion": 0, "1_benign": 0, "2_malignant": 0, "invalid": 0}

    for _, row in df.iterrows():
        uid = row["UID"]

        if has_laterality:
            lat = str(row["Laterality"]).strip()
            patient_id = f"{uid}_{lat}"
        else:
            patient_id = str(uid)

        subject_dir = images_root_path / str(uid)
        if has_laterality:
            subject_dir = subject_dir / lat

        if not subject_dir.exists():
            skipped += 1
            continue

        
        lesion_raw = row["Lesion"]
        try:
            lesion_int = int(lesion_raw)
        except Exception:
            lesion_int = None

        if lesion_int not in (0, 1, 2):
            label_debug["invalid"] += 1
            
            skipped += 1
            continue

        if lesion_int == 0:
            label_debug["0_no_lesion"] += 1
        elif lesion_int == 1:
            label_debug["1_benign"] += 1
        else:
            label_debug["2_malignant"] += 1

        split_value = str(row["split"]).strip()

        sample = {
            "label": lesion_int,                 # 0/1/2 (no lesion / benign / malignant)
            "subject_id": str(patient_id),
            "split": split_value
        }

        
        for key in CHANNEL_KEYS:
            nii_path = subject_dir / f"{key}.nii.gz"
            if nii_path.exists():
                sample[key] = str(nii_path)

        samples.append(sample)

    print("\nLabel parse summary:", label_debug)
    print(f"Skipped {skipped} samples (missing dir or invalid label).")
    return samples


all_samples = build_samples(df, IMAGES_ROOT)
print(f"\nBuilt {len(all_samples)} samples (allowing missing modalities)")

# Check label distribution
if all_samples:
    import collections
    built_labels = collections.Counter(s['label'] for s in all_samples)
    print(f"\n✓ Final label distribution: {dict(built_labels)}")
    
    by_split_label = {}
    for s in all_samples:
        k = (s['split'], s['label'])
        by_split_label[k] = by_split_label.get(k, 0) + 1
    print(f"Per-split label counts: {by_split_label}")

In [None]:
# Cell 5a: Split Data Using Provided Splits

def normalize_split(value) -> str:
    s = str(value).strip().lower()
    if s in {"0", "train"}:
        return "train"
    if s in {"1", "val", "valid", "validation"}:
        return "val"
    if s in {"2", "test", "testing"}:
        return "test"
    return s 

# Quick overview of raw values
raw_splits = sorted({str(s.get("split")).strip() for s in all_samples})
print(f"Raw split values found: {raw_splits}")

# Normalize all splits and bucket in one pass
buckets = {"train": [], "val": [], "test": []}
for s in all_samples:
    s["split"] = normalize_split(s.get("split", ""))
    if s["split"] in buckets:
        buckets[s["split"]].append(s)

train_list, val_list, test_list = buckets["train"], buckets["val"], buckets["test"]

def label_counts(samples):
    return dict(Counter(s["label"] for s in samples))

print(f"\nTrain: {len(train_list)} | Val: {len(val_list)} | Test: {len(test_list)}")
if train_list:
    print("Train distribution:", label_counts(train_list))
if val_list:
    print("Val distribution:  ", label_counts(val_list))
if test_list:
    print("Test distribution: ", label_counts(test_list))
else:
    print("No test samples")

In [None]:
# Cell 5b: Balanced Sampling Setup
from collections import Counter
from torch.utils.data import WeightedRandomSampler

# Calculate sample weights for balanced sampling
train_labels = [s["label"] for s in train_list]
class_counts = Counter(train_labels)

# Weight each sample inversely to its class frequency
sample_weights = [1.0 / class_counts[label] for label in train_labels]
sampler = WeightedRandomSampler(
    weights=sample_weights,
    num_samples=len(train_list),
    replacement=True
)

print(f"Original class counts: {dict(class_counts)}")

In [None]:
# Cell 6: Missing Modality Handler
class FillMissingModalitiesd:
    def __init__(self, keys, ref_key="Pre"): 
        self.keys = keys
        self.ref_key = ref_key

    def __call__(self, data):
        ref_shape = data[self.ref_key].shape
        for key in self.keys:
            if key not in data or data[key] is None:
                data[key] = torch.zeros(ref_shape, dtype=torch.float32)
        return data

In [None]:
# Cell 7: Data Transforms
def make_transforms(is_train: bool) -> Compose:
    common = [
        LoadImaged(keys=CHANNEL_KEYS, image_only=False, ensure_channel_first=False, allow_missing_keys=True),
        EnsureChannelFirstd(keys=CHANNEL_KEYS, allow_missing_keys=True),
        Orientationd(keys=CHANNEL_KEYS, axcodes="RAS", allow_missing_keys=True),
        Spacingd(keys=CHANNEL_KEYS, pixdim=TARGET_SPACING, mode=("bilinear",) * len(CHANNEL_KEYS),
                 allow_missing_keys=True),
        ScaleIntensityRanged(
            keys=CHANNEL_KEYS, a_min=-100, a_max=1000, b_min=0.0, b_max=1.0, clip=True, allow_missing_keys=True
        ),
    ]

    # Add augmentation for training only
    if is_train:
        augmentation = [
            RandRotate90d(keys=CHANNEL_KEYS, prob=0.7, spatial_axes=(1, 2)),
            RandFlipd(keys=CHANNEL_KEYS, prob=0.5, spatial_axis=1),
            RandFlipd(keys=CHANNEL_KEYS, prob=0.5, spatial_axis=2),
            RandZoomd(keys=CHANNEL_KEYS, prob=0.3, min_zoom=0.9, max_zoom=1.1),
            RandGaussianNoised(keys=CHANNEL_KEYS, prob=0.2, mean=0.0, std=0.1),
        ]
        common.extend(augmentation)

    common.extend([
        ResizeWithPadOrCropd(keys=CHANNEL_KEYS, spatial_size=TARGET_SHAPE, allow_missing_keys=True),
        EnsureTyped(keys=CHANNEL_KEYS + ["label"]),
        FillMissingModalitiesd(keys=CHANNEL_KEYS, ref_key="Pre"), 
        ConcatItemsd(keys=CHANNEL_KEYS, name="image", dim=0),
    ])

    return Compose(common)


print("Transforms created (with augmentation for training)")

In [None]:
# Cell 8: Build Cached DataLoaders
CACHE_RATE = 0.5 

print("Building cached datasets")

train_ds = CacheDataset(
    train_list,
    transform=make_transforms(is_train=True),
    cache_rate=CACHE_RATE,
    num_workers=4
)

val_ds = CacheDataset(
    val_list,
    transform=make_transforms(is_train=False),
    cache_rate=1.0,
    num_workers=4
)

train_loader = DataLoader(
    train_ds,
    batch_size=BATCH_SIZE,
    sampler=sampler,
    shuffle=True,
    num_workers=NUM_WORKERS,
    pin_memory=PIN_MEMORY,
    collate_fn=list_data_collate
)

val_loader = DataLoader(
    val_ds,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=PIN_MEMORY,
    collate_fn=list_data_collate
)

print(f" DataLoaders ready:")
print(f"  Train: {len(train_loader)} batches")
print(f"  Val:   {len(val_loader)} batches")

In [None]:
# Cell 9: Test DataLoader
print("Testing data loading...")
batch = next(iter(train_loader))
print(f"Batch image shape: {batch['image'].shape}")  # (B, C, D, H, W)
print(f"Batch label shape: {batch['label'].shape}")  # (B,)
print(f"Label values: {batch['label'].tolist()}")

In [None]:
# Cell 10: Adjust class weights for extreme imbalance
from collections import Counter

train_labels = [s["label"] for s in train_list]
class_counts = Counter(train_labels)
total = len(train_labels)

# Inverse-frequency weights
weights = []
for i in range(NUM_CLASSES):
    count = class_counts.get(i, 0)
    if count == 0:
        w = 1.0
    else:
        w = total / (NUM_CLASSES * count)
    weights.append(w)

if NUM_CLASSES >= 3:
    weights[2] *= 1.5

class_weights_tensor = torch.tensor(weights, dtype=torch.float32).to(device)
print(f"Class counts: {dict(class_counts)}")
print(f"Adjusted class weights: {weights}")

## <a id='model'></a>4. Model (ResNet50)
_Load ResNet50 and adapt the final layer to 3 classes._

In [None]:
# Cell 11: Model Definition (3-Class Classification)
class ResNet3DClassifier(nn.Module):
    def __init__(self, in_channels=5, num_classes=3):  # 5 channels, 3 classes
        super().__init__()
        self.backbone = resnet50(
            pretrained=False,
            spatial_dims=3,
            n_input_channels=in_channels,
            num_classes=512
        )

        self.classifier = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Dropout(0.4),
            nn.Linear(256, num_classes)
        )

    def forward(self, x):
        features = self.backbone(x)
        return self.classifier(features)


model = ResNet3DClassifier(in_channels=len(CHANNEL_KEYS), num_classes=NUM_CLASSES).to(device)
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"Model created:")
print(f"  Total parameters: {total_params:,}")
print(f"  Trainable parameters: {trainable_params:,}")

## <a id='loss'></a>5. Loss & Optimizer
_Define loss, optimizer, scheduler._

In [None]:
# Cell 12: Optimizer and Scheduler
from torch.optim.lr_scheduler import SequentialLR, LinearLR, CosineAnnealingLR

optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=LR,  # Changed from INITIAL_LR
    weight_decay=WEIGHT_DECAY,
    betas=(0.9, 0.999)
)

# Warmup (3 epochs) + Cosine decay
warmup = LinearLR(optimizer, start_factor=0.1, total_iters=3)
cosine = CosineAnnealingLR(optimizer, T_max=EPOCHS-3, eta_min=1e-6)
scheduler = SequentialLR(optimizer, schedulers=[warmup, cosine], milestones=[3])

criterion = nn.CrossEntropyLoss(weight=class_weights_tensor)
scaler = GradScaler()

print("Optimizer, scheduler, and loss function ready")

## <a id='eval'></a>7. Evaluation & Metrics (ROC/AUC)
_Compute ROC, AUC-macro, accuracy, and operating-point metrics._

In [None]:
# Cell 13: Evaluation Function
import numpy as np
from sklearn.metrics import roc_auc_score, roc_curve

@torch.no_grad()
def evaluate(loader, epoch, tag="Val"):
    model.eval()
    running = 0.0
    all_logits = []
    all_labels = []
    all_preds = []
    n = 0

    pbar = tqdm(loader, desc=f"{tag} {epoch}", leave=False)
    for batch in pbar:
        x = batch["image"].to(device, non_blocking=True)
        y = batch["label"].to(device, non_blocking=True).long()

        logits = model(x)
        loss = criterion(logits, y)

        running += loss.item() * x.size(0)
        n += x.size(0)

        preds = logits.argmax(dim=1)
        all_preds.append(preds.detach().cpu().numpy())
        all_labels.append(y.detach().cpu().numpy())
        all_logits.append(logits.detach().cpu().numpy())

        pbar.set_postfix({'loss': running / max(n, 1)})

    all_labels = np.concatenate(all_labels) if len(all_labels) else np.array([])
    all_logits = np.concatenate(all_logits) if len(all_logits) else np.zeros((0, NUM_CLASSES))
    all_preds = np.concatenate(all_preds) if len(all_preds) else np.array([])

    # Softmax probabilities for multi-class
    probs = torch.softmax(torch.from_numpy(all_logits), dim=1).numpy()

    # Accuracy
    accuracy = float((all_preds == all_labels).mean()) if len(all_labels) else 0.0

    # Macro AUC (OvR) across 3 classes
    try:
        auc_macro = float(roc_auc_score(all_labels, probs, multi_class="ovr", average="macro"))
    except Exception as e:
        print(f"AUC-macro calculation issue: {e}")
        auc_macro = float("nan")

    # Operating points for malignant vs rest
    spec_at_90sens = float("nan")
    sens_at_90spec = float("nan")
    if len(all_labels) and NUM_CLASSES >= 3:
        y_true_bin = (all_labels == TARGET_CLASS_FOR_OPS).astype(np.uint8)   # malignant vs rest
        y_score_bin = probs[:, TARGET_CLASS_FOR_OPS]
        if len(np.unique(y_true_bin)) == 2:
            fpr, tpr, thr = roc_curve(y_true_bin, y_score_bin)
            # Spec@90% Sens
            mask_sens = tpr >= 0.90
            if np.any(mask_sens):
                spec_at_90sens = float(np.max(1.0 - fpr[mask_sens]))
            # Sens@90% Spec
            mask_spec = (1.0 - fpr) >= 0.90
            if np.any(mask_spec):
                sens_at_90spec = float(np.max(tpr[mask_spec]))

    print(f"{tag}: loss={running/max(n,1):.4f}, acc={accuracy:.4f}, "
          f"AUC-macro={auc_macro if not np.isnan(auc_macro) else 'nan'}, "
          f"Spec@90%Sens={spec_at_90sens if not np.isnan(spec_at_90sens) else 'nan'}, "
          f"Sens@90%Spec={sens_at_90spec if not np.isnan(sens_at_90spec) else 'nan'}")

    return {
        "loss": running / max(n, 1),
        "auc_macro": auc_macro,
        "accuracy": accuracy,
        "probs": probs,
        "labels": all_labels,
        "preds": all_preds,
        "spec_at_90sens": spec_at_90sens,
        "sens_at_90spec": sens_at_90spec,
    }

print("Evaluation function ready")

In [None]:
# Cell 14: Training Function
def train_one_epoch(epoch):
    model.train()
    running = 0.0
    correct = 0
    n = 0
    pbar = tqdm(train_loader, desc=f"Train {epoch}")
    
    for batch in pbar:
        x = batch["image"].to(device, non_blocking=True)
        y = batch["label"].to(device, non_blocking=True).long()
        
        optimizer.zero_grad(set_to_none=True)
        
        with autocast(enabled=(device.type == "cuda")):
            logits = model(x)
            loss = criterion(logits, y)
        
        scaler.scale(loss).backward()
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        scaler.step(optimizer)
        scaler.update()
        
        running += loss.item() * x.size(0)
        preds = logits.argmax(dim=1)
        correct += (preds == y).sum().item()
        n += x.size(0)
        
        pbar.set_postfix({
            'loss': running/max(n,1),
            'acc': correct/max(n,1),
            'lr': optimizer.param_groups[0]['lr']
        })
    
    scheduler.step()
    return running / max(n,1), correct / max(n,1)

print("Training function ready")

In [None]:
# Cell 15a: Initialize Training
best_macro_auc = -np.inf
best_ckpt = None
history = {
    'train_loss': [], 'train_acc': [],
    'val_loss': [], 'val_auc_macro': [], 'val_spec_at_90sens': [], 'val_sens_at_90spec': [],
    'lr': []
}

print(f"Starting training for {EPOCHS} epochs")

In [None]:
# Cell 15b: Sanity Check Model Outputs
print("Checking initial model behavior")
model.eval()
with torch.no_grad():
    sample_batch = next(iter(val_loader))
    x = sample_batch["image"].to(device)
    y = sample_batch["label"].to(device)
    
    logits = model(x)
    probs = torch.softmax(logits, dim=1)
    preds = logits.argmax(dim=1)
    
    print(f"Sample batch size: {x.shape[0]}")
    print(f"True labels: {y.cpu().numpy()}")
    print(f"Predictions: {preds.cpu().numpy()}")
    print(f"Class 0 probs: {probs[:, 0].cpu().numpy()}")
    print(f"Class 1 probs: {probs[:, 1].cpu().numpy()}")
    print(f"Logits range: [{logits.min().item():.2f}, {logits.max().item():.2f}]")

model.train()

## <a id='train'></a>6. Training Loop
_Epoch loop, logging, checkpointing best val AUC._

In [None]:
# Cell 16: Training Loop
for epoch in range(1, EPOCHS + 1):
    # Train
    train_loss, train_acc = train_one_epoch(epoch)

    # Validate
    val = evaluate(val_loader, epoch, tag="Val")

    # Store history
    history['train_loss'].append(train_loss)
    history['train_acc'].append(train_acc)
    history['val_loss'].append(val['loss'])
    history['val_auc_macro'].append(val['auc_macro'])
    history['val_spec_at_90sens'].append(val['spec_at_90sens'])
    history['val_sens_at_90spec'].append(val['sens_at_90spec'])
    history['lr'].append(optimizer.param_groups[0]['lr'])

    # Print summary
    print(f"\nEpoch {epoch}/{EPOCHS}")
    print(f"  Train: loss={train_loss:.4f}, acc={train_acc:.4f}")
    print(f"  Val:   loss={val['loss']:.4f}, acc={val['accuracy']:.4f}, "
          f"AUC-macro={val['auc_macro'] if not np.isnan(val['auc_macro']) else 'nan'}, "
          f"Spec@90%Sens={val['spec_at_90sens'] if not np.isnan(val['spec_at_90sens']) else 'nan'}, "
          f"Sens@90%Spec={val['sens_at_90spec'] if not np.isnan(val['sens_at_90spec']) else 'nan'}")
    print(f"  Val predictions: {np.bincount(val['preds'], minlength=NUM_CLASSES)}")
    print(f"  LR: {optimizer.param_groups[0]['lr']:.2e}")

    # Save best model by macro AUC
    if (val["auc_macro"] == val["auc_macro"]) and (val["auc_macro"] > best_macro_auc):  # NaN-safe
        best_macro_auc = val["auc_macro"]
        best_ckpt = {
            "state_dict": {k: v.detach().cpu() for k, v in model.state_dict().items()},
            "epoch": epoch,
            "val": val,
        }
        print(f"  ✓ New best model (AUC-macro: {best_macro_auc:.4f})")

print("\n" + "="*60)
print("Training completed!")

In [None]:
# Cell 17: Save Best Model
if best_ckpt is not None:
    save_path = "resnet50_classifier_best.pth"
    torch.save(best_ckpt, save_path)
    print(f"Saved best model to {save_path}")
    print(f"  Best epoch: {best_ckpt['epoch']}")
    print(f"  Best AUC: {best_macro_auc:.4f}")
else:
    print("No valid checkpoint to save")

In [None]:
# Cell 18: Save Training History
history_df = pd.DataFrame(history)
history_df.to_csv("training_history_resnet50.csv", index=False)
print("Training history saved to training_history_resnet50.csv")
print("\nFinal 5 epochs:")
print(history_df.tail())

## <a id='plots'></a>8. Plots & Visualization
_Training curves and ROC plots._

In [None]:
# Cell 19: Plot Training Curves
fig, axes = plt.subplots(2, 2, figsize=(14, 10))

# Loss
axes[0, 0].plot(history['train_loss'], label='Train Loss', marker='o')
axes[0, 0].plot(history['val_loss'], label='Val Loss', marker='s')
axes[0, 0].set_title('Training and Validation Loss')
axes[0, 0].set_xlabel('Epoch')
axes[0, 0].set_ylabel('Loss')
axes[0, 0].legend()
axes[0, 0].grid(True)

# Accuracy
axes[0, 1].plot(history['train_acc'], label='Train Acc', marker='o', color='green')
axes[0, 1].set_title('Training Accuracy')
axes[0, 1].set_xlabel('Epoch')
axes[0, 1].set_ylabel('Accuracy')
axes[0, 1].legend()
axes[0, 1].grid(True)

# AUC-macro
axes[1, 0].plot(history['val_auc_macro'], label='Val AUC-macro', marker='s', color='purple')
axes[1, 0].set_title('Validation AUC-macro')
axes[1, 0].set_xlabel('Epoch')
axes[1, 0].set_ylabel('AUC')
axes[1, 0].legend()
axes[1, 0].grid(True)

# Learning Rate
axes[1, 1].plot(history['lr'], label='Learning Rate', marker='^', color='orange')
axes[1, 1].set_title('Learning Rate Schedule')
axes[1, 1].set_xlabel('Epoch')
axes[1, 1].set_ylabel('LR')
axes[1, 1].set_yscale('log')
axes[1, 1].legend()
axes[1, 1].grid(True)

plt.tight_layout()
plt.savefig('training_history_resnet50.png', dpi=150, bbox_inches='tight')
plt.show()

print(" Training curves saved to training_history_resnet50.png")

In [None]:
# Cell 20: Training Summary
print("Best validation macro AUC:", best_macro_auc if best_macro_auc == best_macro_auc else "nan")
if best_ckpt is not None:
    print("Best epoch:", best_ckpt["epoch"])
    print("Best Val metrics:")
    print({k: v for k, v in best_ckpt["val"].items() if k in ["loss","accuracy","auc_macro","spec_at_90sens","sens_at_90spec"]})
else:
    print("No checkpoint recorded.")

In [None]:
# Cell 21: ROC Curve Visualization (malignant vs rest)
from sklearn.metrics import roc_curve, auc
import matplotlib.pyplot as plt

print("Generating ROC curve (malignant vs rest) from best model")

# Load best model
if best_ckpt is not None:
    model.load_state_dict(best_ckpt['state_dict'])
    model.eval()
    
    # Get predictions on validation set
    all_probs = []
    all_labels = []
    
    with torch.no_grad():
        for batch in tqdm(val_loader, desc="ROC Val", leave=False):
            x = batch["image"].to(device, non_blocking=True)
            y = batch["label"].to(device, non_blocking=True).long()
            logits = model(x)
            probs = torch.softmax(logits, dim=1)
            all_probs.append(probs[:, TARGET_CLASS_FOR_OPS].detach().cpu().numpy())  # malignant prob
            all_labels.append((y == TARGET_CLASS_FOR_OPS).long().detach().cpu().numpy())  # malignant=1 else 0
    
    all_probs = np.concatenate(all_probs) if len(all_probs) else np.array([])
    all_labels = np.concatenate(all_labels) if len(all_labels) else np.array([])
    
    if len(np.unique(all_labels)) < 2:
        print("Not enough positive/negative samples in Val for ROC.")
    else:
        fpr, tpr, thresholds = roc_curve(all_labels, all_probs)
        roc_auc = auc(fpr, tpr)
        
        # Threshold by Youden's J statistic (tpr - fpr)
        youden = tpr - fpr
        optimal_idx = np.argmax(youden)
        optimal_threshold = thresholds[optimal_idx]
        
        plt.figure(figsize=(10, 8))
        plt.plot(fpr, tpr, lw=2, label=f'ROC (AUC = {roc_auc:.3f})')
        plt.plot([0, 1], [0, 1], lw=2, linestyle='--', label='Random')
        
        # Mark optimal threshold
        plt.scatter(fpr[optimal_idx], tpr[optimal_idx], marker='o', s=100, 
                    label=f'Threshold={optimal_threshold:.3f}\nSens={tpr[optimal_idx]:.3f}, Spec={1-fpr[optimal_idx]:.3f}')
        
        plt.xlim([0.0, 1.0])
        plt.ylim([0.0, 1.05])
        plt.xlabel('False Positive Rate (1 - Specificity)', fontsize=12)
        plt.ylabel('True Positive Rate (Sensitivity)', fontsize=12)
        plt.title('ROC — Malignant vs Rest (Validation)', fontsize=14)
        plt.legend(loc="lower right", fontsize=10)
        plt.grid(alpha=0.3)
        plt.tight_layout()
        
        # Save figure
        plt.savefig('roc_curve_malignant_vs_rest.png', dpi=150, bbox_inches='tight')
        plt.show()
        
        # Detailed metrics at optimal threshold
        optimal_preds = (all_probs >= optimal_threshold).astype(int)
        tn = np.sum((all_labels == 0) & (optimal_preds == 0))
        fp = np.sum((all_labels == 0) & (optimal_preds == 1))
        fn = np.sum((all_labels == 1) & (optimal_preds == 0))
        tp = np.sum((all_labels == 1) & (optimal_preds == 1))
        
        sensitivity = tp / (tp + fn) if (tp + fn) > 0 else 0
        specificity = tn / (tn + fp) if (tn + fp) > 0 else 0
        precision = tp / (tp + fp) if (tp + fp) > 0 else 0
        f1_score = (2 * precision * sensitivity) / (precision + sensitivity) if (precision + sensitivity) > 0 else 0
        
        print("\nDetailed metrics at selected threshold:")
        print(f"  Sensitivity (Recall): {sensitivity:.3f}")
        print(f"  Specificity:          {specificity:.3f}")
        print(f"  Precision:            {precision:.3f}")
        print(f"  F1-Score:             {f1_score:.3f}")
        print(f"\nConfusion Matrix at Selected Threshold:")
        print(f"  True Negatives:  {tn}")
        print(f"  False Positives: {fp}")
        print(f"  False Negatives: {fn}")
        print(f"  True Positives:  {tp}")
        print("="*60)
        
        print("\n ROC curve saved to roc_curve_malignant_vs_rest.png")