<a href="https://colab.research.google.com/github/Mehedi16009/BreastCancer-ViTRegNet-XAI/blob/main/12th_August_Breast_Cancer_ViT%2BRegNet_50.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# Cell 1: Install dependencies
!pip install -q timm grad-cam albumentations opencv-python-headless scikit-learn torchmetrics

# (torch & torchvision are already present in Colab typically)


[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.8/7.8 MB[0m [31m63.2 MB/s[0m eta [36m0:00:00[0m
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m983.0/983.0 kB[0m [31m61.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m4.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.8/13.8 MB[0m [31m99.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m24.6/24.6 MB[0m [31m30.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m883.7/883.7 kB[0m [31m58.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m664.8/664.8 MB[0m [31m2.2 MB/s[0m eta 

In [2]:
# Cell 2: imports and seeds
import os, random, time, math
from glob import glob
import numpy as np
import pandas as pd
from pathlib import Path
from tqdm import tqdm

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T
from PIL import Image

from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score, accuracy_score

import timm
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from pytorch_grad_cam.utils.image import show_cam_on_image

# reproducibility
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

set_seed(42)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)


Device: cuda


In [3]:
# Cell 3: Mount Google Drive and read CSVs
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

CSV_DIR = '/content/drive/MyDrive/archive/csv'
BASE_DIR = '/content/drive/MyDrive/archive/jpeg'   # where the jpeg/<patient_id>/*jpg files live

mass_csv = os.path.join(CSV_DIR, 'mass_case_description_train_set.csv')
calc_csv = os.path.join(CSV_DIR, 'calc_case_description_train_set.csv')
meta_csv = os.path.join(CSV_DIR, 'meta.csv')

print("CSV paths exist:",
      os.path.exists(mass_csv),
      os.path.exists(calc_csv),
      os.path.exists(meta_csv))

mass_df = pd.read_csv(mass_csv)
calc_df = pd.read_csv(calc_csv)
meta_df = pd.read_csv(meta_csv)

print("mass rows:", len(mass_df), "calc rows:", len(calc_df), "meta rows:", len(meta_df))
print("mass columns:", mass_df.columns.tolist())
print("calc columns:", calc_df.columns.tolist())
print(mass_df.head(3))


Mounted at /content/drive
CSV paths exist: True True True
mass rows: 1318 calc rows: 1546 meta rows: 6775
mass columns: ['patient_id', 'breast_density', 'left or right breast', 'image view', 'abnormality id', 'abnormality type', 'mass shape', 'mass margins', 'assessment', 'pathology', 'subtlety', 'image file path', 'cropped image file path', 'ROI mask file path']
calc columns: ['patient_id', 'breast density', 'left or right breast', 'image view', 'abnormality id', 'abnormality type', 'calc type', 'calc distribution', 'assessment', 'pathology', 'subtlety', 'image file path', 'cropped image file path', 'ROI mask file path']
  patient_id  breast_density left or right breast image view  abnormality id  \
0    P_00001               3                 LEFT         CC               1   
1    P_00001               3                 LEFT        MLO               1   
2    P_00004               3                 LEFT         CC               1   

  abnormality type                          mass 

In [4]:
# Build mapping from UID folder -> label
uid_to_label = {}

def add_uid_mapping(df, label_col='pathology', path_col='image file path'):
    for _, row in df.iterrows():
        path_str = str(row[path_col]).strip()
        if not path_str or path_str.lower() == 'nan':
            continue
        parts = path_str.split('/')
        if len(parts) < 3:
            continue
        # second-to-last folder in CSV path is the JPEG folder name
        uid_folder = parts[-2]
        label = str(row[label_col]).strip().lower()
        uid_to_label[uid_folder] = 1 if label == 'malignant' else 0

add_uid_mapping(mass_df)
add_uid_mapping(calc_df)

print(f"Total UID mappings: {len(uid_to_label)}")

# Match JPEG files
image_records = []
missed = 0

for patient_uid in os.listdir(BASE_DIR):
    patient_dir = os.path.join(BASE_DIR, patient_uid)
    if not os.path.isdir(patient_dir):
        continue
    if patient_uid in uid_to_label:
        label_int = uid_to_label[patient_uid]
        for fn in os.listdir(patient_dir):
            if fn.lower().endswith('.jpg'):
                image_records.append((os.path.join(patient_dir, fn), label_int))
    else:
        missed += 1

print(f"Matched JPEG images: {len(image_records)}")
print(f"Unmatched UID folders: {missed}")
print("Sample matches:", image_records[:5])


Total UID mappings: 2458
Matched JPEG images: 2458
Unmatched UID folders: 4316
Sample matches: [('/content/drive/MyDrive/archive/jpeg/1.3.6.1.4.1.9590.100.1.2.71928301212219609314072252602966948877/1-081.jpg', 0), ('/content/drive/MyDrive/archive/jpeg/1.3.6.1.4.1.9590.100.1.2.182177787312682499236296341930939113234/1-060.jpg', 0), ('/content/drive/MyDrive/archive/jpeg/1.3.6.1.4.1.9590.100.1.2.412935186912567939930947831442475559499/1-093.jpg', 1), ('/content/drive/MyDrive/archive/jpeg/1.3.6.1.4.1.9590.100.1.2.75074886911317428336210351841116845645/1-123.jpg', 1), ('/content/drive/MyDrive/archive/jpeg/1.3.6.1.4.1.9590.100.1.2.398436206112075190335811098902545958874/1-243.jpg', 0)]


In [5]:
import os
import pandas as pd

# === Paths ===
jpeg_root = "/content/drive/MyDrive/archive/jpeg"
mass_csv_path = "/content/drive/MyDrive/archive/csv/mass_case_description_train_set.csv"
calc_csv_path = "/content/drive/MyDrive/archive/csv/calc_case_description_train_set.csv"

# === Load CSVs ===
mass_df = pd.read_csv(mass_csv_path)
calc_df = pd.read_csv(calc_csv_path)

# Combine CSVs
all_csv_paths = pd.concat([mass_df['image file path'], calc_df['image file path']]).tolist()

# === Build UID mappings from CSVs (CORRECT LEVEL) ===
uid_mappings = {}
for path in all_csv_paths:
    parts = path.split('/')
    if len(parts) >= 2:
        uid = parts[-2]  # folder before the DICOM file
        label = 0 if 'calc' in path.lower() else 1
        uid_mappings[uid] = label

print("Total UID mappings:", len(uid_mappings))

# === Compare with JPEG folder ===
jpeg_uids = set(os.listdir(jpeg_root))
matched_uids = set(uid_mappings.keys()) & jpeg_uids
unmatched_uids = jpeg_uids - matched_uids

print("Matched JPEG images:", len(matched_uids))
print("Unmatched UID folders:", len(unmatched_uids))
print("Sample matched:", list(matched_uids)[:5])
print("Sample unmatched:", list(unmatched_uids)[:5])


Total UID mappings: 2458
Matched JPEG images: 2458
Unmatched UID folders: 4317
Sample matched: ['1.3.6.1.4.1.9590.100.1.2.106917101213622946607766232832878006143', '1.3.6.1.4.1.9590.100.1.2.247139233113856247217565715372804656493', '1.3.6.1.4.1.9590.100.1.2.214159660811754034822231515062693367987', '1.3.6.1.4.1.9590.100.1.2.93262706812708085503412661002234752608', '1.3.6.1.4.1.9590.100.1.2.117311526211769399912345133122658800626']
Sample unmatched: ['1.3.6.1.4.1.9590.100.1.2.304902787311311619535166611670640395210', '1.3.6.1.4.1.9590.100.1.2.318836487510466400123723183690561506321', '1.3.6.1.4.1.9590.100.1.2.323441279313490933607793228002205134456', '1.3.6.1.4.1.9590.100.1.2.136770106912670810407098991750285575490', '1.3.6.1.4.1.9590.100.1.2.142970723512367328629685306071766834363']


In [6]:
import os
import pandas as pd

# === Paths ===
jpeg_root = "/content/drive/MyDrive/archive/jpeg"
csv_dir = "/content/drive/MyDrive/archive/csv"

# === Load ALL CSVs ===
mass_train = pd.read_csv(os.path.join(csv_dir, "mass_case_description_train_set.csv"))
mass_test  = pd.read_csv(os.path.join(csv_dir, "mass_case_description_test_set.csv"))
calc_train = pd.read_csv(os.path.join(csv_dir, "calc_case_description_train_set.csv"))
calc_test  = pd.read_csv(os.path.join(csv_dir, "calc_case_description_test_set.csv"))

# Combine them all
all_df = pd.concat([mass_train, mass_test, calc_train, calc_test], ignore_index=True)

# === Build UID mappings from CSVs ===
uid_mappings = {}
for path in all_df['image file path']:
    parts = path.split('/')
    if len(parts) >= 2:
        uid = parts[-2]  # final UID folder before the DICOM filename
        label_str = path.lower()
        # label 1 for malignant, 0 for benign
        if 'malignant' in label_str:
            label = 1
        else:
            label = 0
        uid_mappings[uid] = label

print("Total UID mappings:", len(uid_mappings))

# === Compare with JPEG folder ===
jpeg_uids = set(os.listdir(jpeg_root))
matched_uids = set(uid_mappings.keys()) & jpeg_uids
unmatched_uids = jpeg_uids - matched_uids

print("Matched JPEG images:", len(matched_uids))
print("Unmatched UID folders:", len(unmatched_uids))
print("Sample matched:", list(matched_uids)[:5])
print("Sample unmatched:", list(unmatched_uids)[:5])


Total UID mappings: 3103
Matched JPEG images: 3103
Unmatched UID folders: 3672
Sample matched: ['1.3.6.1.4.1.9590.100.1.2.318836487510466400123723183690561506321', '1.3.6.1.4.1.9590.100.1.2.106917101213622946607766232832878006143', '1.3.6.1.4.1.9590.100.1.2.247139233113856247217565715372804656493', '1.3.6.1.4.1.9590.100.1.2.272917492411709393015036949944104292812', '1.3.6.1.4.1.9590.100.1.2.214159660811754034822231515062693367987']
Sample unmatched: ['1.3.6.1.4.1.9590.100.1.2.304902787311311619535166611670640395210', '1.3.6.1.4.1.9590.100.1.2.323441279313490933607793228002205134456', '1.3.6.1.4.1.9590.100.1.2.136770106912670810407098991750285575490', '1.3.6.1.4.1.9590.100.1.2.142970723512367328629685306071766834363', '1.3.6.1.4.1.9590.100.1.2.195924545611755277629845871882119525289']


In [7]:
import shutil

# Directory for filtered dataset
filtered_dir = "/content/drive/MyDrive/archive/filtered_jpeg"
os.makedirs(filtered_dir, exist_ok=True)

for uid in matched_uids:
    src_dir = os.path.join(jpeg_root, uid)
    dst_dir = os.path.join(filtered_dir, uid)
    if not os.path.exists(dst_dir):
        shutil.copytree(src_dir, dst_dir)

print("Filtered dataset ready at:", filtered_dir)


Filtered dataset ready at: /content/drive/MyDrive/archive/filtered_jpeg


In [8]:
import os
import shutil

jpeg_root = "/content/drive/MyDrive/archive/jpeg"
matched_uids = set(uid_mappings.keys())  # from your matching script

removed_count = 0

for folder in os.listdir(jpeg_root):
    folder_path = os.path.join(jpeg_root, folder)
    if os.path.isdir(folder_path) and folder not in matched_uids:
        shutil.rmtree(folder_path)
        removed_count += 1

print(f"Removed {removed_count} unmatched UID folders.")


Removed 3671 unmatched UID folders.


In [9]:
# Cell 6: Create stratified train/val/test split
paths = [r[0] for r in image_records]
labels = [r[1] for r in image_records]

if len(paths) < 10:
    raise RuntimeError("Too few images matched. Check CSV mapping and BASE_DIR paths.")

train_paths, temp_paths, train_labels, temp_labels = train_test_split(
    paths, labels, test_size=0.2, random_state=42, stratify=labels)

val_paths, test_paths, val_labels, test_labels = train_test_split(
    temp_paths, temp_labels, test_size=0.5, random_state=42, stratify=temp_labels)

train_records = list(zip(train_paths, train_labels))
val_records = list(zip(val_paths, val_labels))
test_records = list(zip(test_paths, test_labels))

print("Splits -> train:", len(train_records), "val:", len(val_records), "test:", len(test_records))
print("Train positive ratio:", np.mean([l for _,l in train_records]))


Splits -> train: 1966 val: 246 test: 246
Train positive ratio: 0.4496439471007121


In [10]:
# --- Cell 6: Create stratified train/val/test split ---

import numpy as np
from sklearn.model_selection import train_test_split

# Ensure we have data
if not image_records or len(image_records) < 10:
    raise RuntimeError(f"Too few images matched ({len(image_records)}). "
                       "Check CSV mapping and filtered dataset path.")

# Separate paths and labels
paths  = np.array([rec[0] for rec in image_records])
labels = np.array([rec[1] for rec in image_records])

# Stratified Train/Test split (80% train, 20% temp)
train_paths, temp_paths, train_labels, temp_labels = train_test_split(
    paths, labels, test_size=0.2, random_state=42, stratify=labels
)

# Stratified Temp -> Val/Test split (50% val, 50% test from temp)
val_paths, test_paths, val_labels, test_labels = train_test_split(
    temp_paths, temp_labels, test_size=0.5, random_state=42, stratify=temp_labels
)

# Combine into records
train_records = list(zip(train_paths, train_labels))
val_records   = list(zip(val_paths, val_labels))
test_records  = list(zip(test_paths, test_labels))

# Print summary
print(f"Splits -> train: {len(train_records)}, val: {len(val_records)}, test: {len(test_records)}")
print(f"Train positive ratio: {np.mean([lbl for _, lbl in train_records]):.3f}")
print(f"Val   positive ratio: {np.mean([lbl for _, lbl in val_records]):.3f}")
print(f"Test  positive ratio: {np.mean([lbl for _, lbl in test_records]):.3f}")


Splits -> train: 1966, val: 246, test: 246
Train positive ratio: 0.450
Val   positive ratio: 0.451
Test  positive ratio: 0.447


In [11]:
# ---------------- Cell 7 ----------------
# Dataset class, transforms & DataLoaders (improved)
import os, random
import torch
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import torchvision.transforms as T
import numpy as np

# Output folder for figures
OUT_DIR = "/content/drive/MyDrive/cbis_ddsm_results"
os.makedirs(OUT_DIR, exist_ok=True)

class MammogramDataset(Dataset):
    def __init__(self, records, transform=None):
        """
        records: list of (path, label_int)
        transform: torchvision transforms to apply
        """
        self.records = list(records)
        self.transform = transform

    def __len__(self):
        return len(self.records)

    def __getitem__(self, idx):
        path, label = self.records[idx]
        try:
            img = Image.open(path).convert('RGB')
        except Exception as e:
            # fallback: black image
            print(f"Warning: couldn't open {path!r}: {e}")
            img = Image.new('RGB', (IMG_SIZE, IMG_SIZE), (0,0,0))
        if self.transform:
            img = self.transform(img)
        return img, torch.tensor(label, dtype=torch.float32)

# Image size (keep 224 for ViT-base patch16_224)
IMG_SIZE = 224

# stronger augmentation for training
train_transform = T.Compose([
    T.Resize((IMG_SIZE, IMG_SIZE)),
    T.RandomHorizontalFlip(0.5),
    T.RandomVerticalFlip(0.1),
    T.RandomRotation(10),
    T.ColorJitter(brightness=0.12, contrast=0.12, saturation=0.06),
    T.ToTensor(),
    T.Normalize([0.485,0.456,0.406], [0.229,0.224,0.225])
])

# minimal transforms for val/test
eval_transform = T.Compose([
    T.Resize((IMG_SIZE, IMG_SIZE)),
    T.ToTensor(),
    T.Normalize([0.485,0.456,0.406], [0.229,0.224,0.225])
])

# Create dataset objects (expects train_records / val_records / test_records defined)
train_ds = MammogramDataset(train_records, transform=train_transform)
val_ds   = MammogramDataset(val_records, transform=eval_transform)
test_ds  = MammogramDataset(test_records, transform=eval_transform)

BATCH_SIZE = 16   # reduce if OOM
NUM_WORKERS = 2

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=True)
val_loader   = DataLoader(val_ds,   batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)
test_loader  = DataLoader(test_ds,  batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)

print("Dataset sizes: train", len(train_ds), "val", len(val_ds), "test", len(test_ds))


Dataset sizes: train 1966 val 246 test 246


In [12]:
# ---------------- Cell 8 ----------------
import torch.nn as nn
import torch.nn.functional as F
from timm import create_model

class FusionModel(nn.Module):
    def __init__(self, vit_name='vit_base_patch16_224', regnet_name='regnety_032', pretrained=True, dropout=0.3):
        super().__init__()
        self.vit = create_model(vit_name, pretrained=pretrained, num_classes=0, global_pool='avg')
        self.regnet = create_model(regnet_name, pretrained=pretrained, num_classes=0, global_pool='avg')

        # try to detect num_features
        self.vit_feat = getattr(self.vit, 'num_features', None) or getattr(self.vit, 'embed_dim', None) or 768
        self.regnet_feat = getattr(self.regnet, 'num_features', None) or getattr(self.regnet, 'head', None) and getattr(self.regnet.head, 'in_features', None) or 1024

        fused_dim = int(self.vit_feat + self.regnet_feat)
        self.classifier = nn.Sequential(
            nn.Linear(fused_dim, 512),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout),
            nn.Linear(512, 1)  # binary logit
        )

    def forward(self, x):
        v = self.vit.forward_features(x)
        r = self.regnet.forward_features(x)

        # ensure global pooled vectors
        if v.dim() > 2:
            if v.dim() == 3:  # [B, N, C] -> avg over tokens
                v = v.mean(dim=1)
            else:
                v = F.adaptive_avg_pool2d(v, 1).reshape(v.size(0), -1)
        if r.dim() > 2:
            if r.dim() == 3:
                r = r.mean(dim=1)
            else:
                r = F.adaptive_avg_pool2d(r, 1).reshape(r.size(0), -1)

        x = torch.cat([v, r], dim=1)
        logits = self.classifier(x).squeeze(1)
        return logits

# build and move to device
model = FusionModel(vit_name='vit_base_patch16_224', regnet_name='regnety_032', pretrained=True, dropout=0.3)
model = model.to(device)
print("Model built. vit_feat, regnet_feat:", model.vit_feat, model.regnet_feat)


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


model.safetensors:   0%|          | 0.00/346M [00:00<?, ?B/s]



model.safetensors:   0%|          | 0.00/78.1M [00:00<?, ?B/s]

Model built. vit_feat, regnet_feat: 768 1512


In [13]:
# ---------------- Cell 9 ----------------
from torch.optim import AdamW
from sklearn.metrics import roc_auc_score, accuracy_score, precision_score, recall_score, f1_score, confusion_matrix, precision_recall_curve, auc

# compute pos_weight (safe)
train_pos = sum([lbl for _,lbl in train_records])
train_neg = len(train_records) - train_pos
if train_pos == 0:
    raise RuntimeError("No positive samples in training set!")
pos_weight = torch.tensor([(train_neg / train_pos) if train_pos > 0 else 1.0], dtype=torch.float32).to(device)

criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
optimizer = AdamW(model.parameters(), lr=2e-4, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=2, verbose=True)
scaler = torch.cuda.amp.GradScaler(enabled=torch.cuda.is_available())

# training bookkeeping
history = {
    'train_loss': [], 'train_auc': [], 'train_acc': [],
    'val_loss': [], 'val_auc': [], 'val_acc': [],
    'val_precision': [], 'val_recall': [], 'val_f1': []
}

def sigmoid_np(x):
    return 1.0/(1.0+np.exp(-x))

def train_one_epoch(model, loader, optimizer, criterion, device, scaler):
    model.train()
    running_loss = 0.0
    all_preds = []
    all_labels = []
    for imgs, labels in loader:
        imgs = imgs.to(device, non_blocking=True)
        labels = labels.to(device, non_blocking=True)
        optimizer.zero_grad()
        with torch.cuda.amp.autocast(enabled=torch.cuda.is_available()):
            logits = model(imgs)
            loss = criterion(logits, labels)
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        running_loss += loss.item() * imgs.size(0)
        probs = torch.sigmoid(logits).detach().cpu().numpy()
        all_preds.extend(probs.tolist())
        all_labels.extend(labels.detach().cpu().numpy().tolist())
    epoch_loss = running_loss / len(loader.dataset)
    try:
        epoch_auc = roc_auc_score(all_labels, all_preds)
    except Exception:
        epoch_auc = float('nan')
    preds_bin = [1 if p>0.5 else 0 for p in all_preds]
    epoch_acc = accuracy_score(all_labels, preds_bin)
    return epoch_loss, epoch_auc, epoch_acc

def validate(model, loader, criterion, device):
    model.eval()
    running_loss = 0.0
    all_preds = []
    all_labels = []
    with torch.no_grad():
        for imgs, labels in loader:
            imgs = imgs.to(device, non_blocking=True)
            labels = labels.to(device, non_blocking=True)
            logits = model(imgs)
            loss = criterion(logits, labels)
            running_loss += loss.item() * imgs.size(0)
            probs = torch.sigmoid(logits).detach().cpu().numpy()
            all_preds.extend(probs.tolist())
            all_labels.extend(labels.detach().cpu().numpy().tolist())
    val_loss = running_loss / len(loader.dataset)
    try:
        val_auc = roc_auc_score(all_labels, all_preds)
    except Exception:
        val_auc = float('nan')
    preds_bin = [1 if p>0.5 else 0 for p in all_preds]
    try:
        val_acc = accuracy_score(all_labels, preds_bin)
        val_precision = precision_score(all_labels, preds_bin, zero_division=0)
        val_recall = recall_score(all_labels, preds_bin, zero_division=0)
        val_f1 = f1_score(all_labels, preds_bin, zero_division=0)
    except Exception:
        val_acc = val_precision = val_recall = val_f1 = float('nan')
    return val_loss, val_auc, val_acc, val_precision, val_recall, val_f1, all_labels, all_preds


  scaler = torch.cuda.amp.GradScaler(enabled=torch.cuda.is_available())


In [14]:
# ---------------- Cell 10 ----------------
import time, math
EPOCHS = 10
best_val_auc = -1.0
save_path = os.path.join(OUT_DIR, 'cbis_ddsm_fusion_best.pth')

for epoch in range(1, EPOCHS+1):
    t0 = time.time()
    train_loss, train_auc, train_acc = train_one_epoch(model, train_loader, optimizer, criterion, device, scaler)
    val_loss, val_auc, val_acc, val_prec, val_rec, val_f1, _, _ = validate(model, val_loader, criterion, device)
    scheduler.step(val_auc if not math.isnan(val_auc) else val_loss)

    # record to history
    history['train_loss'].append(train_loss); history['train_auc'].append(train_auc); history['train_acc'].append(train_acc)
    history['val_loss'].append(val_loss); history['val_auc'].append(val_auc); history['val_acc'].append(val_acc)
    history['val_precision'].append(val_prec); history['val_recall'].append(val_rec); history['val_f1'].append(val_f1)

    t1 = time.time()
    print(f"Epoch {epoch}/{EPOCHS} time={t1-t0:.1f}s train_loss={train_loss:.4f} train_auc={train_auc:.4f} train_acc={train_acc:.4f} val_loss={val_loss:.4f} val_auc={val_auc:.4f} val_acc={val_acc:.4f} val_f1={val_f1:.4f}")

    # Save best
    if not math.isnan(val_auc) and val_auc > best_val_auc:
        best_val_auc = val_auc
        torch.save({'epoch': epoch, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'val_auc': val_auc, 'history': history}, save_path)
        print("Saved best model to:", save_path)


  with torch.cuda.amp.autocast(enabled=torch.cuda.is_available()):


Epoch 1/10 time=863.0s train_loss=0.7831 train_auc=0.5959 train_acc=0.5605 val_loss=0.8178 val_auc=0.6289 val_acc=0.5203 val_f1=0.6467
Saved best model to: /content/drive/MyDrive/cbis_ddsm_results/cbis_ddsm_fusion_best.pth


  with torch.cuda.amp.autocast(enabled=torch.cuda.is_available()):


Epoch 2/10 time=381.8s train_loss=0.6966 train_auc=0.6945 train_acc=0.6236 val_loss=0.6940 val_auc=0.7311 val_acc=0.6341 val_f1=0.6831
Saved best model to: /content/drive/MyDrive/cbis_ddsm_results/cbis_ddsm_fusion_best.pth


  with torch.cuda.amp.autocast(enabled=torch.cuda.is_available()):


Epoch 3/10 time=408.2s train_loss=0.6422 train_auc=0.7569 train_acc=0.6780 val_loss=0.6312 val_auc=0.7734 val_acc=0.6951 val_f1=0.6964
Saved best model to: /content/drive/MyDrive/cbis_ddsm_results/cbis_ddsm_fusion_best.pth


  with torch.cuda.amp.autocast(enabled=torch.cuda.is_available()):


Epoch 4/10 time=411.4s train_loss=0.6073 train_auc=0.7830 train_acc=0.6882 val_loss=0.6412 val_auc=0.7888 val_acc=0.6870 val_f1=0.6351
Saved best model to: /content/drive/MyDrive/cbis_ddsm_results/cbis_ddsm_fusion_best.pth


  with torch.cuda.amp.autocast(enabled=torch.cuda.is_available()):


Epoch 5/10 time=406.8s train_loss=0.5658 train_auc=0.8228 train_acc=0.7309 val_loss=0.8374 val_auc=0.7622 val_acc=0.6870 val_f1=0.5882


  with torch.cuda.amp.autocast(enabled=torch.cuda.is_available()):


Epoch 6/10 time=375.6s train_loss=0.5526 train_auc=0.8339 train_acc=0.7467 val_loss=0.6305 val_auc=0.7841 val_acc=0.6789 val_f1=0.6973


  with torch.cuda.amp.autocast(enabled=torch.cuda.is_available()):


Epoch 7/10 time=376.7s train_loss=0.5198 train_auc=0.8554 train_acc=0.7675 val_loss=0.6485 val_auc=0.7767 val_acc=0.7033 val_f1=0.7068


  with torch.cuda.amp.autocast(enabled=torch.cuda.is_available()):


Epoch 8/10 time=374.2s train_loss=0.4414 train_auc=0.8989 train_acc=0.8133 val_loss=0.6953 val_auc=0.7915 val_acc=0.6911 val_f1=0.6162
Saved best model to: /content/drive/MyDrive/cbis_ddsm_results/cbis_ddsm_fusion_best.pth


  with torch.cuda.amp.autocast(enabled=torch.cuda.is_available()):


Epoch 9/10 time=404.4s train_loss=0.3966 train_auc=0.9188 train_acc=0.8276 val_loss=0.7795 val_auc=0.7893 val_acc=0.7073 val_f1=0.6842


  with torch.cuda.amp.autocast(enabled=torch.cuda.is_available()):


Epoch 10/10 time=379.2s train_loss=0.3655 train_auc=0.9315 train_acc=0.8505 val_loss=0.7548 val_auc=0.7760 val_acc=0.6748 val_f1=0.6154


In [15]:
# ---------------- Cell 11 ----------------
import torch
from sklearn.metrics import precision_recall_curve, confusion_matrix
import seaborn as sns
import matplotlib.pyplot as plt

ckpt = torch.load(save_path, map_location=device, weights_only=False)
model.load_state_dict(ckpt['model_state_dict'])
print("Loaded checkpoint epoch", ckpt.get('epoch'), "val_auc", ckpt.get('val_auc'))

# Evaluate on test set and collect preds
model.eval()
test_loss = 0.0
all_preds = []
all_labels = []
with torch.no_grad():
    for imgs, labels in test_loader:
        imgs = imgs.to(device)
        labels = labels.to(device)
        logits = model(imgs)
        loss = criterion(logits, labels)
        test_loss += loss.item() * imgs.size(0)
        probs = torch.sigmoid(logits).cpu().numpy().tolist()
        all_preds.extend(probs)
        all_labels.extend(labels.cpu().numpy().tolist())

avg_test_loss = test_loss / len(test_loader.dataset)
test_auc = roc_auc_score(all_labels, all_preds)
preds_bin = [1 if p>0.5 else 0 for p in all_preds]
test_acc = accuracy_score(all_labels, preds_bin)
test_prec = precision_score(all_labels, preds_bin, zero_division=0)
test_rec = recall_score(all_labels, preds_bin, zero_division=0)
test_f1 = f1_score(all_labels, preds_bin, zero_division=0)

print(f"Test loss: {avg_test_loss:.4f}  AUC: {test_auc:.4f}  Acc: {test_acc:.4f}  Prec: {test_prec:.4f}  Rec: {test_rec:.4f}  F1: {test_f1:.4f}")

# --- Precision-Recall curve (HD) ---
precision, recall, thresholds = precision_recall_curve(all_labels, all_preds)
pr_auc = auc(recall, precision)

plt.figure(figsize=(8,6), dpi=300)
sns.set(style="whitegrid")
plt.plot(recall, precision, color='#1f77b4', linewidth=2)
plt.fill_between(recall, precision, alpha=0.15, color='#1f77b4')
plt.xlabel("Recall", fontsize=14)
plt.ylabel("Precision", fontsize=14)
plt.title(f"Precision-Recall curve (AUC={pr_auc:.4f})", fontsize=16)
plt.xlim(0,1); plt.ylim(0,1)
plt.tight_layout()
pr_path = os.path.join(OUT_DIR, "precision_recall_curve.png")
plt.savefig(pr_path, dpi=300, bbox_inches='tight')
plt.close()
print("Saved PR curve:", pr_path)

# --- Confusion Matrix (HD) ---
cm = confusion_matrix(all_labels, preds_bin)
plt.figure(figsize=(6,5), dpi=300)
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', cbar=False, annot_kws={"size":14})
plt.xlabel("Predicted", fontsize=12); plt.ylabel("Actual", fontsize=12)
plt.title(f"Confusion Matrix (Acc={test_acc:.3f}, F1={test_f1:.3f})", fontsize=14)
cm_path = os.path.join(OUT_DIR, "confusion_matrix.png")
plt.savefig(cm_path, dpi=300, bbox_inches='tight')
plt.close()
print("Saved Confusion Matrix:", cm_path)


Loaded checkpoint epoch 8 val_auc 0.791524858191525
Test loss: 0.6548  AUC: 0.8074  Acc: 0.6951  Prec: 0.7011  Rec: 0.5545  F1: 0.6193
Saved PR curve: /content/drive/MyDrive/cbis_ddsm_results/precision_recall_curve.png
Saved Confusion Matrix: /content/drive/MyDrive/cbis_ddsm_results/confusion_matrix.png


In [16]:
# ---------------- Cell 12 ----------------
import matplotlib.pyplot as plt
import seaborn as sns
sns.set(style="whitegrid")

epochs = np.arange(1, len(history['train_loss'])+1)

# Loss curve
plt.figure(figsize=(10,6), dpi=300)
plt.plot(epochs, history['train_loss'], label='Train Loss', color='#d62728', linewidth=2)
plt.plot(epochs, history['val_loss'], label='Val Loss', color='#2ca02c', linewidth=2)
plt.scatter(epochs, history['train_loss'], color='#d62728', s=20)
plt.scatter(epochs, history['val_loss'], color='#2ca02c', s=20)
plt.xlabel("Epoch", fontsize=14); plt.ylabel("Loss", fontsize=14)
plt.title("Train & Validation Loss", fontsize=16)
plt.legend(fontsize=12)
plt.tight_layout()
loss_path = os.path.join(OUT_DIR, "loss_curve.png")
plt.savefig(loss_path, dpi=300, bbox_inches='tight')
plt.close()
print("Saved Loss curve:", loss_path)

# Accuracy curve
plt.figure(figsize=(10,6), dpi=300)
plt.plot(epochs, history['train_acc'], label='Train Accuracy', color='#9467bd', linewidth=2)
plt.plot(epochs, history['val_acc'], label='Val Accuracy', color='#8c564b', linewidth=2)
plt.xlabel("Epoch", fontsize=14); plt.ylabel("Accuracy", fontsize=14)
plt.title("Train & Validation Accuracy", fontsize=16)
plt.legend(fontsize=12)
plt.tight_layout()
acc_path = os.path.join(OUT_DIR, "accuracy_curve.png")
plt.savefig(acc_path, dpi=300, bbox_inches='tight')
plt.close()
print("Saved Accuracy curve:", acc_path)

# F1 curve (val)
plt.figure(figsize=(10,6), dpi=300)
plt.plot(epochs, history['val_f1'], label='Val F1', color='#17becf', linewidth=2)
plt.scatter(epochs, history['val_f1'], color='#17becf', s=20)
plt.xlabel("Epoch", fontsize=14); plt.ylabel("F1 Score", fontsize=14)
plt.title("Validation F1 Score", fontsize=16)
plt.tight_layout()
f1_path = os.path.join(OUT_DIR, "f1_curve.png")
plt.savefig(f1_path, dpi=300, bbox_inches='tight')
plt.close()
print("Saved F1 curve:", f1_path)


Saved Loss curve: /content/drive/MyDrive/cbis_ddsm_results/loss_curve.png
Saved Accuracy curve: /content/drive/MyDrive/cbis_ddsm_results/accuracy_curve.png
Saved F1 curve: /content/drive/MyDrive/cbis_ddsm_results/f1_curve.png


In [1]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

Mounted at /content/drive


In [4]:
# ---------------- Cell 13 (REPLACEMENT) ----------------
# Manual Grad-CAM hooks -> robust across pytorch-grad-cam versions
import torch
import torch.nn as nn
import torch.nn.functional as F
import cv2
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import os


# Output file
OUT_DIR = "/content/drive/MyDrive/cbis_ddsm_results"
os.makedirs(OUT_DIR, exist_ok=True)
# Define image size before Grad-CAM cell
IMG_SIZE = 224
# Helpers
def preprocess_pil_for_model(pil_img, size=IMG_SIZE):
    pil_resized = pil_img.resize((size, size))
    # return both normalized tensor (for model) and original RGB array (for overlay)
    t = eval_transform(pil_resized)  # normalized tensor (C,H,W)
    rgb = np.asarray(pil_resized).astype(np.float32) / 255.0  # HWC 0..1
    return t.unsqueeze(0).to(device), rgb

def compute_gradcam_single_layer(model, input_tensor, target_layer, target_label=1):
    """
    Compute Grad-CAM heatmap (H x W numpy, values 0..1) for a single Conv2d target_layer.
    target_label: 1 -> positive logit, 0 -> complementary (we take -logit)
    """
    model.eval()
    activations = []
    gradients = []

    # forward hook
    def forward_hook(module, inp, out):
        activations.append(out)
    # backward hook
    def backward_hook(module, grad_in, grad_out):
        # grad_out is a tuple; grad_out[0] is gradient wrt module output
        gradients.append(grad_out[0])

    h_fwd = target_layer.register_forward_hook(forward_hook)
    # register full backward hook if available, otherwise fallback
    if hasattr(target_layer, "register_full_backward_hook"):
        h_bwd = target_layer.register_full_backward_hook(lambda m, gi, go: gradients.append(go[0]))
    else:
        h_bwd = target_layer.register_backward_hook(lambda m, gi, go: gradients.append(go[0]))

    # Forward
    model.zero_grad()
    out = model(input_tensor)  # expected shape [1] (single logit)
    if out is None:
        h_fwd.remove(); h_bwd.remove()
        raise RuntimeError("Model returned None. Check forward pass.")
    # out shape might be [1] or [1,] or [1,1]; get scalar logit
    logit = out.reshape(-1)[0]

    # For binary single-logit model:
    # - to get gradients for class 1 (malignant) differentiate logit
    # - to get gradients for class 0 differentiate -logit
    if int(target_label) == 1:
        score = logit
    else:
        score = -logit

    score.backward(retain_graph=True)

    # Ensure we captured something
    if len(activations) == 0 or len(gradients) == 0:
        h_fwd.remove(); h_bwd.remove()
        raise RuntimeError("Failed to capture activations or gradients. Hook did not fire.")

    act = activations[-1].detach()   # tensor [1, C, H, W]
    grad = gradients[-1].detach()    # tensor [1, C, H, W]

    # Global average pooling of gradients -> weights
    weights = torch.mean(grad, dim=(2,3), keepdim=True)   # [1,C,1,1]
    # Weighted combination
    cam = torch.sum(weights * act, dim=1).squeeze(0)  # [H, W]
    cam = F.relu(cam)

    # Normalize to 0..1
    cam -= cam.min()
    if cam.max() > 0:
        cam = cam / cam.max()
    cam_np = cam.cpu().numpy()

    # remove hooks
    h_fwd.remove(); h_bwd.remove()
    return cam_np

def upsample_cam(cam, size=(IMG_SIZE, IMG_SIZE)):
    # cam is HxW float 0..1 -> resize to size
    cam_resized = cv2.resize(cam, (size[1], size[0]), interpolation=cv2.INTER_LINEAR)
    cam_resized = np.clip(cam_resized, 0, 1)
    return cam_resized

def apply_colormap_on_image(img_rgb, cam, colormap=cv2.COLORMAP_JET, alpha=0.5):
    """
    img_rgb: HxWx3 float 0..1
    cam: HxW float 0..1
    returns overlay_rgb float 0..1
    """
    heatmap_255 = np.uint8(255 * cam)
    heatmap_color = cv2.applyColorMap(heatmap_255, colormap)  # BGR uint8
    heatmap_color = cv2.cvtColor(heatmap_color, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0
    overlay = (1.0 - alpha) * img_rgb + alpha * heatmap_color
    overlay = np.clip(overlay, 0, 1)
    return heatmap_color, overlay

# Determine layers (we already printed them earlier)
print("vit_last_conv:", vit_last_conv)
print("regnet_last_conv:", regnet_last_conv)
if vit_last_conv is None and regnet_last_conv is None:
    raise RuntimeError("No conv layers found for either backbone; cannot compute Grad-CAM.")

# Pick some examples from test set
N = 6
samples = test_records[:N]

# Create figure large + high DPI
fig = plt.figure(figsize=(16, 3 * N), dpi=300)

for i, (path, true_label) in enumerate(samples):
    pil = Image.open(path).convert('RGB')
    input_t, rgb_img = preprocess_pil_for_model(pil, size=IMG_SIZE)  # input_t on device, rgb_img HWC 0..1

    # predict
    model.eval()
    with torch.no_grad():
        logit = model(input_t)
        prob = torch.sigmoid(logit).cpu().item()
        pred_label = 1 if prob > 0.5 else 0

    # Compute cams for available layers and fuse them by averaging
    cams = []
    if vit_last_conv is not None:
        try:
            cam_v = compute_gradcam_single_layer(model, input_t, vit_last_conv, target_label=pred_label)
            cams.append(upsample_cam(cam_v, size=(IMG_SIZE, IMG_SIZE)))
        except Exception as e:
            print("WARNING: vit cam failed for", path, "->", e)

    if regnet_last_conv is not None:
        try:
            cam_r = compute_gradcam_single_layer(model, input_t, regnet_last_conv, target_label=pred_label)
            cams.append(upsample_cam(cam_r, size=(IMG_SIZE, IMG_SIZE)))
        except Exception as e:
            print("WARNING: regnet cam failed for", path, "->", e)

    if len(cams) == 0:
        raise RuntimeError("No CAMs computed for image; both layer CAMs failed.")

    fused_cam = np.mean(np.stack(cams, axis=0), axis=0)  # HxW float 0..1

    # Create overlay
    heatmap_color, overlay = apply_colormap_on_image(rgb_img, fused_cam, alpha=0.5)

    # Plot original and overlay
    ax1 = fig.add_subplot(N, 2, 2*i+1)
    ax1.imshow(rgb_img); ax1.axis('off')
    ax1.set_title(f"Original (true={true_label})", fontsize=10)

    ax2 = fig.add_subplot(N, 2, 2*i+2)
    ax2.imshow(overlay); ax2.axis('off')
    ax2.set_title(f"Pred={pred_label} prob={prob:.3f}", fontsize=10)

plt.tight_layout()
out_path = os.path.join(OUT_DIR, "gradcam_fused_examples.png")
fig.savefig(out_path, dpi=300, bbox_inches='tight')
plt.close(fig)
print("Saved Grad-CAM fused examples to:", out_path)


NameError: name 'vit_last_conv' is not defined

In [None]:
# Cell 7: Dataset class and transforms (torchvision)
class MammogramDataset(Dataset):
    def __init__(self, records, transform=None):
        self.records = records
        self.transform = transform

    def __len__(self):
        return len(self.records)

    def __getitem__(self, idx):
        path, label = self.records[idx]
        try:
            img = Image.open(path).convert('RGB')
        except Exception as e:
            # in case of read error, return a black image (should rarely happen)
            print("Warning: couldn't open", path, "->", e)
            img = Image.new('RGB', (224,224), (0,0,0))
        if self.transform:
            img = self.transform(img)
        return img, torch.tensor(label, dtype=torch.float32)

# transforms
IMG_SIZE = 224
train_transform = T.Compose([
    T.Resize((IMG_SIZE, IMG_SIZE)),
    T.RandomHorizontalFlip(0.5),
    T.RandomRotation(10),
    T.ColorJitter(brightness=0.06, contrast=0.06),
    T.ToTensor(),
    T.Normalize([0.485,0.456,0.406], [0.229,0.224,0.225])
])

eval_transform = T.Compose([
    T.Resize((IMG_SIZE, IMG_SIZE)),
    T.ToTensor(),
    T.Normalize([0.485,0.456,0.406], [0.229,0.224,0.225])
])

# create datasets
train_ds = MammogramDataset(train_records, transform=train_transform)
val_ds = MammogramDataset(val_records, transform=eval_transform)
test_ds = MammogramDataset(test_records, transform=eval_transform)

BATCH_SIZE = 16   # adjust depending on GPU memory
num_workers = 2
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=num_workers, pin_memory=True)
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=num_workers, pin_memory=True)
test_loader = DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=num_workers, pin_memory=True)


In [None]:
# Cell 8: FusionModel - ViT + RegNet features concatenation
import torch.nn.functional as F
from timm import create_model

class FusionModel(nn.Module):
    def __init__(self, vit_name='vit_base_patch16_224', regnet_name='regnety_032', pretrained=True, dropout=0.3):
        super().__init__()
        # create backbones with no classification head (num_classes=0)
        self.vit = create_model(vit_name, pretrained=pretrained, num_classes=0, global_pool='avg')
        self.regnet = create_model(regnet_name, pretrained=pretrained, num_classes=0, global_pool='avg')
        self.vit_feat = getattr(self.vit, 'num_features', None)
        self.regnet_feat = getattr(self.regnet, 'num_features', None)
        if self.vit_feat is None or self.regnet_feat is None:
            # fallback to attribute names sometimes used
            self.vit_feat = getattr(self.vit, 'head', None) and getattr(self.vit.head, 'in_features', None) or 768
            self.regnet_feat = getattr(self.regnet, 'head', None) and getattr(self.regnet.head, 'in_features', None) or 1024

        fused_dim = int(self.vit_feat + self.regnet_feat)
        self.classifier = nn.Sequential(
            nn.Linear(fused_dim, 512),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout),
            nn.Linear(512, 1)   # binary logit
        )

    def forward(self, x):
        # timm models generally have forward_features
        v = self.vit.forward_features(x)
        r = self.regnet.forward_features(x)

        # ensure vectors (global pooling might still be present)
        if v.dim() > 2:
            # e.g. [B, N, C] or [B, C, H, W]
            v = v.view(v.size(0), v.size(1), -1) if v.dim()==3 else v
            # final reduce
            v = v.mean(dim=1) if v.dim()==3 else F.adaptive_avg_pool2d(v, 1).reshape(v.size(0), -1)

        if r.dim() > 2:
            r = r.view(r.size(0), r.size(1), -1) if r.dim()==3 else r
            r = r.mean(dim=1) if r.dim()==3 else F.adaptive_avg_pool2d(r, 1).reshape(r.size(0), -1)

        x = torch.cat([v, r], dim=1)
        logits = self.classifier(x).squeeze(1)
        return logits

# build model and send to device
model = FusionModel(vit_name='vit_base_patch16_224', regnet_name='regnety_032', pretrained=True, dropout=0.3)
model = model.to(device)
print("Model built. vit_feat, regnet_feat:", model.vit_feat, model.regnet_feat)


In [None]:
# Cell 9: Loss, optimizer, scheduler
from torch.optim import AdamW

# compute pos_weight for BCEWithLogitsLoss to address imbalance
train_pos = sum([l for _,l in train_records])
train_neg = len(train_records) - train_pos
if train_pos == 0:
    raise RuntimeError("No positive samples in training set!")
pos_weight = torch.tensor([(train_neg / train_pos) if train_pos > 0 else 1.0], dtype=torch.float32).to(device)
print("Train pos/neg:", train_pos, train_neg, "pos_weight:", pos_weight.item())

criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
optimizer = AdamW(model.parameters(), lr=2e-4, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=2, verbose=True)

# mixed precision scaler
scaler = torch.cuda.amp.GradScaler(enabled=torch.cuda.is_available())

# helper train/val loops
def train_one_epoch(model, loader, optimizer, criterion, device, scaler):
    model.train()
    running_loss = 0.0
    all_preds = []
    all_labels = []
    for imgs, labels in tqdm(loader, desc="Train batches", leave=False):
        imgs = imgs.to(device, non_blocking=True)
        labels = labels.to(device, non_blocking=True)
        optimizer.zero_grad()
        with torch.cuda.amp.autocast(enabled=torch.cuda.is_available()):
            logits = model(imgs)
            loss = criterion(logits, labels)
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        running_loss += loss.item() * imgs.size(0)
        probs = torch.sigmoid(logits).detach().cpu().numpy()
        all_preds.extend(probs.tolist())
        all_labels.extend(labels.detach().cpu().numpy().tolist())
    epoch_loss = running_loss / len(loader.dataset)
    try:
        epoch_auc = roc_auc_score(all_labels, all_preds)
    except Exception:
        epoch_auc = float('nan')
    return epoch_loss, epoch_auc

def validate(model, loader, criterion, device):
    model.eval()
    running_loss = 0.0
    all_preds = []
    all_labels = []
    with torch.no_grad():
        for imgs, labels in tqdm(loader, desc="Val batches", leave=False):
            imgs = imgs.to(device, non_blocking=True)
            labels = labels.to(device, non_blocking=True)
            logits = model(imgs)
            loss = criterion(logits, labels)
            running_loss += loss.item() * imgs.size(0)
            probs = torch.sigmoid(logits).detach().cpu().numpy()
            all_preds.extend(probs.tolist())
            all_labels.extend(labels.detach().cpu().numpy().tolist())
    val_loss = running_loss / len(loader.dataset)
    try:
        val_auc = roc_auc_score(all_labels, all_preds)
    except Exception:
        val_auc = float('nan')
    # binary accuracy at 0.5
    preds_bin = [1 if p>0.5 else 0 for p in all_preds]
    try:
        val_acc = accuracy_score(all_labels, preds_bin)
    except Exception:
        val_acc = float('nan')
    return val_loss, val_auc, val_acc, all_labels, all_preds


In [None]:
# Cell 10: Train the model (adjust EPOCHS)
EPOCHS = 10
best_val_auc = -1.0
save_path = '/content/drive/MyDrive/cbis_ddsm_fusion_best.pth'

for epoch in range(1, EPOCHS+1):
    t0 = time.time()
    train_loss, train_auc = train_one_epoch(model, train_loader, optimizer, criterion, device, scaler)
    val_loss, val_auc, val_acc, _, _ = validate(model, val_loader, criterion, device)
    scheduler.step(val_auc if not math.isnan(val_auc) else val_loss)
    t1 = time.time()
    print(f"Epoch {epoch}/{EPOCHS}  time={t1-t0:.1f}s  train_loss={train_loss:.4f} train_auc={train_auc:.4f}  val_loss={val_loss:.4f} val_auc={val_auc:.4f} val_acc={val_acc:.4f}")
    # Save best
    if not math.isnan(val_auc) and val_auc > best_val_auc:
        best_val_auc = val_auc
        torch.save({'epoch': epoch, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'val_auc': val_auc},
                   save_path)
        print("Saved best model to:", save_path)


In [None]:
# === Cell 3: Load best model checkpoint and evaluate ===

import torch

# Ensure device is set (match your training setup)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Path where you saved your best checkpoint during training
save_path = "/content/drive/MyDrive/cbis_ddsm_fusion_best.pth"  # <-- replace with your actual path if different

# Load checkpoint (fix for PyTorch 2.6+ UnpicklingError)
ckpt = torch.load(save_path, map_location=device, weights_only=False)

# Restore model weights
model.load_state_dict(ckpt['model_state_dict'])

# (Optional) Restore optimizer if needed
if 'optimizer_state_dict' in ckpt:
    optimizer.load_state_dict(ckpt['optimizer_state_dict'])

# Print checkpoint info
print("Loaded checkpoint:")
print(f"  Epoch: {ckpt.get('epoch', 'N/A')}")
print(f"  Val AUC: {ckpt.get('val_auc', 'N/A')}")

# Evaluate on test set
model.eval()
test_loss = 0
all_preds = []
all_labels = []

with torch.no_grad():
    for images, labels in test_loader:  # assumes test_loader is already defined
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        loss = criterion(outputs, labels)
        test_loss += loss.item() * images.size(0)

        probs = torch.sigmoid(outputs)
        all_preds.extend(probs.detach().cpu().numpy().tolist())
        all_labels.extend(labels.detach().cpu().numpy().tolist())


avg_loss = test_loss / len(test_loader.dataset)

# Calculate metrics
test_auc = roc_auc_score(all_labels, all_preds)
preds_bin = [1 if p > 0.5 else 0 for p in all_preds]
test_acc = accuracy_score(all_labels, preds_bin)


print(f"Test Loss: {avg_loss:.4f}")
print(f"Test AUC: {test_auc:.4f}")
print(f"Test Accuracy: {test_acc:.4f}")

In [None]:
# Cell 12: Helper to find last Conv2d in a given submodule (robust for timm models)
import torch.nn as nn

def find_last_conv(module):
    last_conv = None
    for m in module.modules():
        if isinstance(m, nn.Conv2d):
            last_conv = m
    return last_conv

# find convs in both backbones inside the fusion model
vit_last_conv = find_last_conv(model.vit)
regnet_last_conv = find_last_conv(model.regnet)
print("vit last conv:", vit_last_conv)
print("regnet last conv:", regnet_last_conv)
if vit_last_conv is None or regnet_last_conv is None:
    print("Warning: could not find last conv in one of the backbones. Grad-CAM may fail.")


In [16]:
# Cell 13: Run Grad-CAM (fused from both backbones) on example images from test set
# We'll compute a cam from the ViT target conv and the RegNet conv and average them.

from matplotlib import pyplot as plt
# Import the new target for binary classification
from pytorch_grad_cam.utils.model_targets import BinaryClassifierOutputTarget


def tensor_to_numpy_img(tensor):
    # expects tensor normalized with ImageNet mean/std
    t = tensor.detach().cpu().numpy()
    t = np.transpose(t, (1,2,0))  # HWC
    mean = np.array([0.485,0.456,0.406])
    std = np.array([0.229,0.224,0.225])
    img = (t * std + mean)
    img = np.clip(img, 0, 1)
    return img

# pick a few examples
n_examples = 4
sample_records = test_records[:n_examples]

# use_cuda = torch.cuda.is_available() # Removed: no longer needed

# instantiate CAM objects separately for the two target layers
cam_vit = None
cam_reg = None
if vit_last_conv is not None:
    cam_vit = GradCAM(model=model, target_layers=[vit_last_conv])
if regnet_last_conv is not None:
    cam_reg = GradCAM(model=model, target_layers=[regnet_last_conv])

fig = plt.figure(figsize=(12, 3*n_examples))
for i,(path,label) in enumerate(sample_records):
    img_pil = Image.open(path).convert('RGB').resize((IMG_SIZE, IMG_SIZE))
    img_t = eval_transform(img_pil).unsqueeze(0).to(device) # Ensure input is on device
    # forward to get predicted class
    model.eval()
    with torch.no_grad():
        logits = model(img_t)
        prob = torch.sigmoid(logits).cpu().item()
        pred_label = 1 if prob>0.5 else 0
    # compute cams
    cams = []
    # Use BinaryClassifierOutputTarget for binary output
    targets = [BinaryClassifierOutputTarget(pred_label)]
    if cam_vit is not None:
        grayscale_cam_v = cam_vit(input_tensor=img_t, targets=targets)[0]
        cams.append(grayscale_cam_v)
    if cam_reg is not None:
        grayscale_cam_r = cam_reg(input_tensor=img_t, targets=targets)[0]
        cams.append(grayscale_cam_r)
    if len(cams) == 0:
        raise RuntimeError("No CAM layers available.")
    fused_cam = np.mean(cams, axis=0)  # average cams
    # overlay
    rgb_img = tensor_to_numpy_img(eval_transform(img_pil))  # this gives normalized->0..1 image
    visualization = show_cam_on_image(rgb_img, fused_cam, use_rgb=True)
    ax = fig.add_subplot(n_examples, 2, 2*i+1)
    ax.imshow(rgb_img)
    ax.axis('off')
    ax.set_title(f"Original: label={label} ")
    ax2 = fig.add_subplot(n_examples, 2, 2*i+2)
    ax2.imshow(visualization)
    ax2.axis('off')
    ax2.set_title(f"Pred={pred_label} prob={prob:.3f}")
plt.tight_layout()
plt.show()

# optionally save figure to Drive
out_fig = '/content/drive/MyDrive/cbis_ddsm_gradcam_examples.png'
fig.savefig(out_fig)
print("Saved Grad-CAM image to:", out_fig)

NameError: name 'test_records' is not defined

In [15]:
# Helper to find last Conv2d in a given submodule (robust for timm models)
import torch.nn as nn

def find_last_conv(module):
    last_conv = None
    for m in module.modules():
        if isinstance(m, nn.Conv2d):
            last_conv = m
    return last_conv

# find convs in both backbones inside the fusion model
vit_last_conv = find_last_conv(model.vit)
regnet_last_conv = find_last_conv(model.regnet)
print("vit last conv:", vit_last_conv)
print("regnet last conv:", regnet_last_conv)
if vit_last_conv is None or regnet_last_conv is None:
    print("Warning: could not find last conv in one of the backbones. Grad-CAM may fail.")

vit last conv: Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
regnet last conv: Conv2d(576, 1512, kernel_size=(1, 1), stride=(2, 2), bias=False)


In [14]:
# Cell 8: FusionModel - ViT + RegNet features concatenation
import torch.nn as nn
import torch.nn.functional as F
from timm import create_model

class FusionModel(nn.Module):
    def __init__(self, vit_name='vit_base_patch16_224', regnet_name='regnety_032', pretrained=True, dropout=0.3):
        super().__init__()
        # create backbones with no classification head (num_classes=0)
        self.vit = create_model(vit_name, pretrained=pretrained, num_classes=0, global_pool='avg')
        self.regnet = create_model(regnet_name, pretrained=pretrained, num_classes=0, global_pool='avg')
        self.vit_feat = getattr(self.vit, 'num_features', None)
        self.regnet_feat = getattr(self.regnet, 'num_features', None)
        if self.vit_feat is None or self.regnet_feat is None:
            # fallback to attribute names sometimes used
            self.vit_feat = getattr(self.vit, 'head', None) and getattr(self.vit.head, 'in_features', None) or 768
            self.regnet_feat = getattr(self.regnet, 'head', None) and getattr(self.regnet.head, 'in_features', None) or 1024

        fused_dim = int(self.vit_feat + self.regnet_feat)
        self.classifier = nn.Sequential(
            nn.Linear(fused_dim, 512),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout),
            nn.Linear(512, 1)   # binary logit
        )

    def forward(self, x):
        # timm models generally have forward_features
        v = self.vit.forward_features(x)
        r = self.regnet.forward_features(x)

        # ensure vectors (global pooling might still be present)
        if v.dim() > 2:
            # e.g. [B, N, C] or [B, C, H, W]
            v = v.view(v.size(0), v.size(1), -1) if v.dim()==3 else v
            # final reduce
            v = v.mean(dim=1) if v.dim()==3 else F.adaptive_avg_pool2d(v, 1).reshape(v.size(0), -1)

        if r.dim() > 2:
            r = r.view(r.size(0), r.size(1), -1) if r.dim()==3 else r
            r = r.mean(dim=1) if r.dim()==3 else F.adaptive_avg_pool2d(r, 1).reshape(r.size(0), -1)

        x = torch.cat([v, r], dim=1)
        logits = self.classifier(x).squeeze(1)
        return logits

# build model and send to device
model = FusionModel(vit_name='vit_base_patch16_224', regnet_name='regnety_032', pretrained=True, dropout=0.3)
model = model.to(device)
print("Model built. vit_feat, regnet_feat:", model.vit_feat, model.regnet_feat)



Model built. vit_feat, regnet_feat: 768 1512


In [13]:
# Cell 2: imports and seeds
import os, random, time, math
from glob import glob
import numpy as np
import pandas as pd
from pathlib import Path
from tqdm import tqdm

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T
from PIL import Image

from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score, accuracy_score

import timm
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from pytorch_grad_cam.utils.image import show_cam_on_image

# reproducibility
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

set_seed(42)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

Device: cpu


In [11]:
!pip install -q pytorch-grad-cam

[31mERROR: Could not find a version that satisfies the requirement pytorch-grad-cam (from versions: none)[0m[31m
[0m[31mERROR: No matching distribution found for pytorch-grad-cam[0m[31m
[0m

In [12]:
!pip install -q grad-cam

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.8/7.8 MB[0m [31m38.6 MB/s[0m eta [36m0:00:00[0m
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m1.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.8/13.8 MB[0m [31m100.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m24.6/24.6 MB[0m [31m95.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m883.7/883.7 kB[0m [31m49.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m664.8/664.8 MB[0m [31m845.8 kB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m211.5/211.5 MB[0m [31m7.2 MB/s[0m eta 