# 02 — Train & Explain

In [7]:
import os, json
from pathlib import Path
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from PIL import Image

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.models import resnet18, ResNet18_Weights

from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score

import shap  # KernelExplainer later (robust for small tabular)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# paths
ART = Path("../artifacts")
ASSETS = Path("../assets")
MODELS = Path("../models")
ASSETS.joinpath("gradcam").mkdir(parents=True, exist_ok=True)
ASSETS.joinpath("shap").mkdir(parents=True, exist_ok=True)
MODELS.mkdir(parents=True, exist_ok=True)

# Load splits from Notebook 1
train_df = pd.read_parquet(ART / "train.parquet")
val_df   = pd.read_parquet(ART / "val.parquet")
test_df  = pd.read_parquet(ART / "test.parquet")
len(train_df), len(val_df), len(test_df)


Using device: cpu


(1000, 235, 234)

In [8]:
transform = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485,0.456,0.406], [0.229,0.224,0.225]),
])

class OvarianDataset(Dataset):
    def __init__(self, df, transform=None):
        self.df = df.reset_index(drop=True)
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def _resolve_path(self, p):
        p = Path(p)
        # If relative like "data/images/..", prepend project root ("..") when running from notebooks/
        if not p.is_absolute() and not str(p).startswith(".."):
            p = Path("..") / p
        return p

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img_path = self._resolve_path(row["image_path"])
        image = Image.open(img_path).convert("RGB")
        if self.transform:
            image = self.transform(image)
        tabular = torch.tensor([row["age"], row["ca125"], row["brca"]], dtype=torch.float32)
        label = torch.tensor(row["label"], dtype=torch.float32)
        return image, tabular, label

BATCH = 32
train_loader = DataLoader(OvarianDataset(train_df, transform), batch_size=BATCH, shuffle=True, num_workers=2 if torch.cuda.is_available() else 0)
val_loader   = DataLoader(OvarianDataset(val_df, transform), batch_size=BATCH, shuffle=False, num_workers=2 if torch.cuda.is_available() else 0)
test_loader  = DataLoader(OvarianDataset(test_df, transform), batch_size=BATCH, shuffle=False, num_workers=2 if torch.cuda.is_available() else 0)


In [9]:
class TabularMLP(nn.Module):
    def __init__(self, input_dim=3, hidden_dim=16):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(hidden_dim, 1)
        )
    def forward(self, x):
        return self.net(x)

def get_img_encoder():
    base = resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)  # patched
    for p in base.parameters():
        p.requires_grad = False
    in_feat = base.fc.in_features
    base.fc = nn.Identity()
    return base, in_feat

class ImageOnly(nn.Module):
    def __init__(self):
        super().__init__()
        self.backbone, feat_dim = get_img_encoder()
        self.fc = nn.Linear(feat_dim, 1)
    def forward(self, img, tab=None):
        feats = self.backbone(img)
        return self.fc(feats)

class FusedModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.backbone, feat_dim = get_img_encoder()
        self.tab_mlp = TabularMLP()
        self.fc = nn.Linear(feat_dim + 1, 1)  # tab_mlp returns 1‑dim logit
    def forward(self, img, tab):
        img_feat = self.backbone(img)
        tab_feat = self.tab_mlp(tab)
        x = torch.cat([img_feat, tab_feat], dim=1)
        return self.fc(x)


In [10]:
def forward_any(model, imgs, tabs):
    try:
        return model(imgs, tabs)
    except TypeError:
        return model(tabs)

In [11]:
def train_model(model, train_loader, val_loader, num_epochs=8, lr=1e-3, name="model"):
    model = model.to(device)
    crit = nn.BCEWithLogitsLoss()
    opt  = optim.Adam(model.parameters(), lr=lr)

    best_auc, best_state = 0.0, None
    for epoch in range(1, num_epochs+1):
        model.train()
        for imgs, tabs, labels in train_loader:
            imgs, tabs = imgs.to(device), tabs.to(device)
            labels = labels.to(device).unsqueeze(1)
            opt.zero_grad()
            logits = forward_any(model, imgs, tabs)
            loss = crit(logits, labels)
            loss.backward()
            opt.step()

        # validation (AUC for early stopping)
        model.eval()
        y_true, y_prob = [], []
        with torch.no_grad():
            for imgs, tabs, labels in val_loader:
                imgs, tabs = imgs.to(device), tabs.to(device)
                logits = forward_any(model, imgs, tabs)
                probs = torch.sigmoid(logits).cpu().numpy().ravel()
                y_prob.extend(probs)
                y_true.extend(labels.numpy().ravel())
        auc = roc_auc_score(y_true, y_prob)
        print(f"Epoch {epoch}/{num_epochs}  |  Val AUC: {auc:.4f}")
        if auc > best_auc:
            best_auc = auc
            best_state = model.state_dict()

    if best_state:
        model.load_state_dict(best_state)
    torch.save(model.state_dict(), MODELS / f"{name}.pth")
    print(f"Saved best '{name}' (AUC={best_auc:.4f}) -> {MODELS/f'{name}.pth'}")
    return model

def evaluate(model, loader):
    model.eval()
    y_true, y_prob = [], []
    with torch.no_grad():
        for imgs, tabs, labels in loader:
            imgs, tabs = imgs.to(device), tabs.to(device)
            logits = forward_any(model, imgs, tabs)
            probs = torch.sigmoid(logits).cpu().numpy().ravel()
            y_prob.extend(probs)
            y_true.extend(labels.numpy().ravel())
    y_pred = (np.array(y_prob) > 0.5).astype(int)
    return {
        "accuracy":  float(accuracy_score(y_true, y_pred)),
        "precision": float(precision_score(y_true, y_pred)),
        "recall":    float(recall_score(y_true, y_pred)),
        "f1":        float(f1_score(y_true, y_pred)),
        "roc_auc":   float(roc_auc_score(y_true, y_prob)),
    }

In [12]:
img_model   = train_model(ImageOnly(),   train_loader, val_loader, num_epochs=8, name="image_only")
tab_model   = train_model(TabularMLP(),  train_loader, val_loader, num_epochs=8, name="tabular_only")
fused_model = train_model(FusedModel(),  train_loader, val_loader, num_epochs=8, name="fused")

Epoch 1/8  |  Val AUC: 0.4994
Epoch 2/8  |  Val AUC: 0.4670
Epoch 3/8  |  Val AUC: 0.4637
Epoch 4/8  |  Val AUC: 0.5033
Epoch 5/8  |  Val AUC: 0.5044
Epoch 6/8  |  Val AUC: 0.5281
Epoch 7/8  |  Val AUC: 0.5743
Epoch 8/8  |  Val AUC: 0.5997
Saved best 'image_only' (AUC=0.5997) -> ..\models\image_only.pth
Epoch 1/8  |  Val AUC: 0.2318
Epoch 2/8  |  Val AUC: 0.4020
Epoch 3/8  |  Val AUC: 0.6057
Epoch 4/8  |  Val AUC: 0.8431
Epoch 5/8  |  Val AUC: 0.9681
Epoch 6/8  |  Val AUC: 0.9857
Epoch 7/8  |  Val AUC: 0.9950
Epoch 8/8  |  Val AUC: 0.9994
Saved best 'tabular_only' (AUC=0.9994) -> ..\models\tabular_only.pth
Epoch 1/8  |  Val AUC: 0.4642
Epoch 2/8  |  Val AUC: 0.4466
Epoch 3/8  |  Val AUC: 0.4719
Epoch 4/8  |  Val AUC: 0.5165
Epoch 5/8  |  Val AUC: 0.6085
Epoch 6/8  |  Val AUC: 0.7197
Epoch 7/8  |  Val AUC: 0.8458
Epoch 8/8  |  Val AUC: 0.9202
Saved best 'fused' (AUC=0.9202) -> ..\models\fused.pth


In [13]:
results = {
    "image_only":  evaluate(img_model,   test_loader),
    "tabular_only":evaluate(tab_model,   test_loader),
    "fused":       evaluate(fused_model, test_loader),
}
print(json.dumps(results, indent=2))

with open(ART/"metrics.json","w") as f:
    json.dump(results, f, indent=2)
print("Wrote metrics ->", ART/"metrics.json")

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


{
  "image_only": {
    "accuracy": 0.9700854700854701,
    "precision": 0.0,
    "recall": 0.0,
    "f1": 0.0,
    "roc_auc": 0.8023914411579609
  },
  "tabular_only": {
    "accuracy": 0.9700854700854701,
    "precision": 0.0,
    "recall": 0.0,
    "f1": 0.0,
    "roc_auc": 1.0
  },
  "fused": {
    "accuracy": 0.9700854700854701,
    "precision": 0.0,
    "recall": 0.0,
    "f1": 0.0,
    "roc_auc": 0.9559471365638766
  }
}
Wrote metrics -> ..\artifacts\metrics.json


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


In [19]:
import cv2

def gradcam_single(model, img_tensor, target_layer):
    acts, grads = {}, {}

    def fwd(m, i, o):
        acts["v"] = o
        # attach a tensor-level hook to capture dL/d(acts)
        o.register_hook(lambda g: grads.setdefault("v", g))

    h = target_layer.register_forward_hook(fwd)

    model.eval()
    # tab dummy for image-only

    img_tensor = img_tensor.to(device).requires_grad_(True)
    
    out = model(img_tensor.to(device), torch.zeros((1,3), device=device))
    logit = out.mean()
    model.zero_grad()
    logit.backward()

    h.remove()

    A = acts["v"].detach()                    # [B, C, H, W]
    G = grads["v"].detach()                   # [B, C, H, W]
    weights = G.mean(dim=(2,3), keepdim=True) # [B, C, 1, 1]
    cam = F.relu((weights * A).sum(dim=1)).squeeze().cpu().numpy()
    cam = (cam - cam.min()) / (cam.max() + 1e-8)
    return cam

def overlay_cam(rgb_img_224, cam):
    # cam: HxW (e.g., 7x7); upsample to image size
    img_np = np.asarray(rgb_img_224)              # (224,224,3)
    H, W = img_np.shape[:2]
    cam_up = cv2.resize(cam, (W, H), interpolation=cv2.INTER_CUBIC)
    heatmap = cv2.applyColorMap((cam_up*255).astype(np.uint8), cv2.COLORMAP_JET)[:, :, ::-1] / 255.0
    overlay = 0.4*heatmap + 0.6*(img_np/255.0)
    overlay = np.clip(overlay, 0, 1)
    return (overlay*255).astype(np.uint8)

# pick 2 benign + 2 malignant from test_df
sel_benign = test_df[test_df["label"]==0].sample(min(2, (test_df["label"]==0).sum()), random_state=0)
sel_malign = test_df[test_df["label"]==1].sample(min(2, (test_df["label"]==1).sum()), random_state=0)
sel = pd.concat([sel_benign, sel_malign], ignore_index=True)

target_layer = img_model.backbone.layer4[-1]

for _, row in sel.iterrows():
    p = row["image_path"]
    p = Path(p) if str(p).startswith("..") or Path(p).is_absolute() else Path("..")/p
    img = Image.open(p).convert("RGB").resize((224,224))
    tens = transform(img).unsqueeze(0)

    cam = gradcam_single(img_model, tens, target_layer)
    overlay = overlay_cam(img, cam)
    out_name = ASSETS/"gradcam"/f"{Path(p).stem}_gradcam.png"
    Image.fromarray(overlay).save(out_name)
    print("Saved:", out_name)

Saved: ..\assets\gradcam\1378_gradcam.png
Saved: ..\assets\gradcam\930_gradcam.png
Saved: ..\assets\gradcam\1432_gradcam.png
Saved: ..\assets\gradcam\1430_gradcam.png


In [20]:
tab_model.eval()

def tab_predict(x_np):
    x = torch.tensor(x_np, dtype=torch.float32)
    with torch.no_grad():
        logits = tab_model(x).numpy().ravel()
    return 1.0 / (1.0 + np.exp(-logits))

X_train = train_df[["age","ca125","brca"]].values
X_test  = test_df[["age","ca125","brca"]].values

bg = X_train[np.random.RandomState(0).choice(len(X_train), size=min(100, len(X_train)), replace=False)]
explainer = shap.KernelExplainer(tab_predict, bg)

X_sample = X_test[: min(200, len(X_test))]
shap_values = explainer.shap_values(X_sample, nsamples=100)

# Global
shap.summary_plot(shap_values, X_sample, feature_names=["age","ca125","brca"], show=False)
plt.tight_layout()
plt.savefig(ASSETS/"shap/global.png", dpi=150, bbox_inches="tight")
plt.close()
print("Saved:", ASSETS/"shap/global.png")

# Two locals
for i in range(min(2, len(X_sample))):
    shap.force_plot(explainer.expected_value, shap_values[i], X_sample[i],
                    feature_names=["age","ca125","brca"], matplotlib=True, show=False)
    plt.tight_layout()
    outp = ASSETS/f"shap/local_{i}.png"
    plt.savefig(outp, dpi=150, bbox_inches="tight")
    plt.close()
    print("Saved:", outp)

100%|██████████| 200/200 [00:01<00:00, 129.54it/s]


Saved: ..\assets\shap\global.png
Saved: ..\assets\shap\local_0.png
Saved: ..\assets\shap\local_1.png


In [22]:
from collections import OrderedDict

def strip_all_hooks(model: torch.nn.Module):
    # wipe all hooks on every submodule
    for m in model.modules():
        if hasattr(m, "_forward_hooks") and isinstance(m._forward_hooks, dict):
            m._forward_hooks = OrderedDict()
        if hasattr(m, "_forward_pre_hooks") and isinstance(m._forward_pre_hooks, dict):
            m._forward_pre_hooks = OrderedDict()
        if hasattr(m, "_backward_hooks") and isinstance(m._backward_hooks, dict):
            m._backward_hooks = OrderedDict()

# Clear hooks and re-eval
strip_all_hooks(img_model)
strip_all_hooks(tab_model)
strip_all_hooks(fused_model)

# (optional but clean) reload weights to guarantee a fresh graph
img_model.load_state_dict(torch.load(MODELS/"image_only.pth", map_location=device))
tab_model.load_state_dict(torch.load(MODELS/"tabular_only.pth", map_location=device))
fused_model.load_state_dict(torch.load(MODELS/"fused.pth", map_location=device))

img_model.eval(); tab_model.eval(); fused_model.eval()


FusedModel(
  (backbone): ResNet(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track

In [23]:
from sklearn.metrics import roc_curve, confusion_matrix
import seaborn as sns

def collect_preds(model, loader):
    model.eval()
    y_true, y_prob = [], []
    with torch.no_grad():
        for imgs, tabs, labels in loader:
            imgs, tabs = imgs.to(device), tabs.to(device)
            logits = forward_any(model, imgs, tabs)
            y_prob.extend(torch.sigmoid(logits).cpu().numpy().ravel())
            y_true.extend(labels.numpy().ravel())
    y_true = np.array(y_true); y_prob = np.array(y_prob)
    y_pred = (y_prob > 0.5).astype(int)
    return y_true, y_prob, y_pred

def dump_model_results(name, model):
    y_true, y_prob, y_pred = collect_preds(model, test_loader)
    res = {
        "accuracy":  float(accuracy_score(y_true, y_pred)),
        "precision": float(precision_score(y_true, y_pred)),
        "recall":    float(recall_score(y_true, y_pred)),
        "f1":        float(f1_score(y_true, y_pred)),
        "roc_auc":   float(roc_auc_score(y_true, y_prob)),
    }
    # ROC curve
    fpr, tpr, _ = roc_curve(y_true, y_prob)
    plt.figure(figsize=(4,4))
    plt.plot(fpr, tpr, label=f"{name} (AUC={res['roc_auc']:.3f})")
    plt.plot([0,1],[0,1],"--",alpha=0.5)
    plt.xlabel("FPR"); plt.ylabel("TPR"); plt.title(f"ROC - {name}"); plt.legend()
    (ASSETS/"metrics").mkdir(parents=True, exist_ok=True)
    plt.savefig(ASSETS/f"metrics/roc_{name}.png", dpi=150, bbox_inches="tight"); plt.close()

    # Confusion matrix
    cm = confusion_matrix(y_true, y_pred)
    plt.figure(figsize=(3.2,3))
    sns.heatmap(cm, annot=True, fmt="d", cbar=False)
    plt.xlabel("Predicted"); plt.ylabel("True"); plt.title(f"CM - {name}")
    plt.savefig(ASSETS/f"metrics/cm_{name}.png", dpi=150, bbox_inches="tight"); plt.close()

    return res

results = {
    "image_only":  dump_model_results("image_only", img_model),
    "tabular_only":dump_model_results("tabular_only", tab_model),
    "fused":       dump_model_results("fused", fused_model),
}
with open(ART/"metrics.json","w") as f:
    json.dump(results, f, indent=2)
print(json.dumps(results, indent=2))
print("Saved metrics.json and ROC/CM plots in assets/metrics/")


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


{
  "image_only": {
    "accuracy": 0.9700854700854701,
    "precision": 0.0,
    "recall": 0.0,
    "f1": 0.0,
    "roc_auc": 0.8023914411579609
  },
  "tabular_only": {
    "accuracy": 0.9700854700854701,
    "precision": 0.0,
    "recall": 0.0,
    "f1": 0.0,
    "roc_auc": 1.0
  },
  "fused": {
    "accuracy": 0.9700854700854701,
    "precision": 0.0,
    "recall": 0.0,
    "f1": 0.0,
    "roc_auc": 0.9559471365638766
  }
}
Saved metrics.json and ROC/CM plots in assets/metrics/
