In [None]:
import numpy as np
import pandas as pd
import random
from torch.utils.data import Subset
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split

import torch
import torch.nn as nn
from torchvision import models

import matplotlib.pyplot as plt
from sklearn.metrics import brier_score_loss, accuracy_score, precision_score, recall_score, f1_score, roc_auc_score, confusion_matrix, roc_curve, auc
from sklearn.calibration import calibration_curve
import seaborn as sns

from sklearn.linear_model import LogisticRegression
from sklearn.calibration import CalibratedClassifierCV

import cv2
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from pytorch_grad_cam.utils.image import show_cam_on_image

In [None]:
# Params
SEED = 42

DATA_DIR = r'D:\datasets_in_D\Data_Science_for_Digital_Health\BUSI_denoised\processed'

# below are the paths for cross-dataset training and testing
# DATA_DIR_TRAIN_A = r'D:\datasets_in_D\Data_Science_for_Digital_Health\BUS-NoCLAHE\processed'
# DATA_DIR_TRAIN_B = r'D:\datasets_in_D\Data_Science_for_Digital_Health\QAMEBI_NoCLAHE\processed'
# DATA_DIR_TEST = r'D:\datasets_in_D\Data_Science_for_Digital_Health\BUSI_pred_classify_ours\processed'

TRAIN_RATIO = 0.7
VAL_RATIO = 0.15
BATCH_SIZE = 32

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

LR = 1e-4
EPOCH = 10

### Fix random seeds for reproducibility

In [None]:
np.random.seed(SEED)
random.seed(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed(SEED)

### Load and transform data

In [None]:
# Define transformations
train_tfm = transforms.Compose([
    transforms.Grayscale(num_output_channels=3),  # Convert grayscale to 3-channel
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])  # Normalize to [-1, 1]
])
test_tfm = transforms.Compose([
    transforms.Grayscale(3),
    transforms.Resize((224,224)),
    transforms.ToTensor(),
    transforms.Normalize([0.5],[0.5]),
])

# Load dataset (single dataset)
full_ds = datasets.ImageFolder(DATA_DIR, transform=None)

# === Uncomment below lines if you want to use cross-dataset training and testing ===
# ds_a_raw = datasets.ImageFolder(DATA_DIR_TRAIN_A, transform=None)
# ds_b_raw = datasets.ImageFolder(DATA_DIR_TRAIN_B, transform=None)
# ds_test_raw = datasets.ImageFolder(DATA_DIR_TEST, transform=test_tfm)

# full_train_val = torch.utils.data.ConcatDataset([ds_a_raw, ds_b_raw])
# ===================================================================================

# Single dataset
# Split dataset
g = torch.Generator().manual_seed(SEED)         # reproducible

n = len(full_ds)
n_train = int(TRAIN_RATIO * n)
n_val   = int(VAL_RATIO * n)
n_test  = n - n_train - n_val
train_ds, val_ds, test_ds = random_split(full_ds, [n_train, n_val, n_test], generator=g)
train_ds.dataset.transform = train_tfm
val_ds.dataset.transform   = test_tfm
test_ds.dataset.transform  = test_tfm

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
val_loader   = DataLoader(val_ds,   batch_size=BATCH_SIZE)
test_loader  = DataLoader(test_ds,  batch_size=BATCH_SIZE)

# === Uncomment below lines if you want to use cross-dataset training and testing ===
# n = len(full_train_val)
# n_train = int(TRAIN_RATIO * n)
# n_val   = n - n_train
# train_ds, val_ds = random_split(full_train_val, [n_train, n_val], generator=g)

# class SubsetWithTransform(torch.utils.data.Dataset):
#     def __init__(self, subset, transform):
#         self.subset = subset
#         self.transform = transform
#     def __len__(self):
#         return len(self.subset)
#     def __getitem__(self, idx):
#         img, target = self.subset[idx]   # img is still PIL
#         if self.transform is not None:
#             img = self.transform(img)
#         return img, target

# train_set = SubsetWithTransform(train_ds, train_tfm)
# val_set   = SubsetWithTransform(val_ds,   test_tfm)   # deterministic
# test_ds.dataset.transform  = test_tfm

# Create data loaders
# train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True)
# val_loader   = DataLoader(val_set,   batch_size=BATCH_SIZE)
# test_loader  = DataLoader(ds_test_raw,  batch_size=BATCH_SIZE)

### Load pre-trained model

In [None]:
# Load pre-trained VGG16 model
model = models.vgg16(pretrained=True)

# Freeze feature extractor
for param in model.features.parameters():
    param.requires_grad = False

# Modify classifier
model.classifier[6] = nn.Sequential(
    nn.Linear(4096, 1),
    nn.Sigmoid()
)

# Move model to device
model = model.to(DEVICE)

# Define loss and optimizer
criterion = nn.BCELoss()
optimizer = torch.optim.Adam(model.classifier.parameters(), lr=LR)

### Model fine-tuning & validating

In [None]:
train_losses, val_losses = [], []
val_preds_all, val_labels_all = [], []

for epoch in range(EPOCH):
    model.train()
    running_train_loss = 0.0
    for images, labels in train_loader:
        images = images.to(DEVICE)
        labels = labels.to(DEVICE).float().unsqueeze(1)

        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, labels)
        running_train_loss += loss.item() * images.size(0)
        
        # Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
    epoch_train_loss = running_train_loss / len(train_loader.dataset)
    train_losses.append(epoch_train_loss)

    # Validation
    model.eval()
    running_val_loss = 0.0
    with torch.no_grad():
        val_loss = 0
        for images, labels in val_loader:
            images = images.to(DEVICE)
            labels = labels.to(DEVICE).float().unsqueeze(1)
            outputs = model(images)
            loss = criterion(outputs, labels).item()
            val_loss += loss
            running_val_loss += loss * images.size(0)

            # Collect for calibration
            val_preds_all.extend(outputs.cpu().numpy())
            val_labels_all.extend(labels.cpu().numpy())

    epoch_val_loss = running_val_loss / len(val_loader.dataset)
    val_losses.append(epoch_val_loss)
    print(f'Epoch [{epoch+1}/{EPOCH}], Validation Loss: {val_loss/len(val_loader):.4f}')

### Loss curves for train/val

In [None]:
# Plotting the loss curves
plt.figure(figsize=(10, 5))
plt.plot(range(EPOCH), train_losses, label='Training Loss')
plt.plot(range(EPOCH), val_losses, label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training and Validation Loss Curves')
plt.legend()
plt.show()

### Evaluate model on test set

In [None]:
model.eval()
all_preds = []
all_labels = []

with torch.no_grad():
    for images, labels in test_loader:
        images = images.to(DEVICE)
        labels = labels.to(DEVICE)
        outputs = model(images)
        preds = outputs.cpu().numpy()
        all_preds.extend(preds)
        all_labels.extend(labels.cpu().numpy())

# Convert predictions to binary labels
threshold = 0.5
binary_preds = [1 if pred >= threshold else 0 for pred in all_preds]

# Calculate metrics
accuracy = accuracy_score(all_labels, binary_preds)
precision = precision_score(all_labels, binary_preds)
recall = recall_score(all_labels, binary_preds)
f1 = f1_score(all_labels, binary_preds)
roc_auc = roc_auc_score(all_labels, all_preds)

# Calculate specificity
tn, fp, fn, tp = confusion_matrix(all_labels, binary_preds).ravel()
specificity = tn / (tn + fp)

print(f'Accuracy: {accuracy:.4f}')
print(f'Precision: {precision:.4f}')
print(f'Recall (Sensitivity): {recall:.4f}')
print(f'Specificity: {specificity:.4f}')
print(f'F1 Score: {f1:.4f}')
print(f'AUROC: {roc_auc:.4f}')

# Plot confusion matrix
cm = confusion_matrix(all_labels, binary_preds)
plt.figure(figsize=(6, 5))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=['Benign', 'Malignant'], yticklabels=['Benign', 'Malignant'])
plt.xlabel('Predicted Label')
plt.ylabel('True Label')
plt.title('Confusion Matrix')
plt.show()

### Bootstrapping

In [None]:
def bootstrap_metrics(y_true, y_probs, n_bootstraps=1000, seed=42):
    rng = np.random.RandomState(seed)
    metrics = {
        'Accuracy': [],
        'AUC': [],
        'Sensitivity': [],
        'Specificity': [],
        'F1 score': [],
        'Precision': []
    }

    y_true = np.array(y_true)
    y_probs = np.array(y_probs)

    for _ in range(n_bootstraps):
        indices = rng.choice(len(y_true), size=len(y_true), replace=True)
        y_true_sample = y_true[indices]
        y_probs_sample = y_probs[indices]
        y_pred_sample = (y_probs_sample >= 0.5).astype(int)

        tn, fp, fn, tp = confusion_matrix(y_true_sample, y_pred_sample).ravel()
        specificity = tn / (tn + fp) if (tn + fp) > 0 else 0
        sensitivity = tp / (tp + fn) if (tp + fn) > 0 else 0

        metrics['Accuracy'].append(accuracy_score(y_true_sample, y_pred_sample))
        metrics['AUC'].append(roc_auc_score(y_true_sample, y_probs_sample))
        metrics['Sensitivity'].append(sensitivity)
        metrics['Specificity'].append(specificity)
        metrics['F1 score'].append(f1_score(y_true_sample, y_pred_sample, zero_division=0))
        metrics['Precision'].append(precision_score(y_true_sample, y_pred_sample, zero_division=0))

    return metrics

In [None]:
# Run bootstrap
boot_results = bootstrap_metrics(all_labels, all_preds)

# Convert to DataFrame for plotting
df_bootstrap = pd.DataFrame({k: v for k, v in boot_results.items()})

# Convert to long-form DataFrame for seaborn boxplot
df_melted = df_bootstrap.melt(var_name="Metric", value_name="Values")

# Plot
plt.figure(figsize=(10, 6))
sns.boxplot(data=df_melted, x="Metric", y="Values", palette="colorblind")
plt.title("Bootstrap Performance Metrics")
plt.grid(True, axis='y')
plt.tight_layout()
plt.show()


### ROC and calibration curve

In [None]:
# Compute ROC curve and AUC
fpr, tpr, thresholds = roc_curve(all_labels, all_preds)
roc_auc = auc(fpr, tpr)

# Plot
plt.figure(figsize=(8, 6))
plt.plot(fpr, tpr, color='blue', lw=2, label=f'ROC curve (AUC = {roc_auc:.4f})')
plt.plot([0, 1], [0, 1], color='gray', lw=1, linestyle='--')
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate (Sensitivity)')
plt.title('Receiver Operating Characteristic (ROC) Curve')
plt.legend(loc='lower right')
plt.grid(True)
plt.tight_layout()
plt.show()

In [None]:
# Compute calibration curve (bin-wise average confidence vs actual positive rate)
prob_true, prob_pred = calibration_curve(all_labels, all_preds, n_bins=30)

# Plot
plt.figure(figsize=(8, 6))
plt.plot(prob_pred, prob_true, marker='o', linewidth=2, label='Model')
plt.plot([0, 1], [0, 1], linestyle='--', color='gray', label='Perfectly Calibrated')
plt.xlabel('Mean Predicted Probability')
plt.ylabel('Fraction of Positives')
plt.title('Calibration Curve (Reliability Diagram)')
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.show()

### Improve calibration

In [None]:
# Convert validation predictions to NumPy arrays
val_preds = np.array(val_preds_all).reshape(-1, 1)
val_labels = np.array(val_labels_all).reshape(-1)

# Logistic Regression (Platt scaling)
lr = LogisticRegression()
calibrator = CalibratedClassifierCV(estimator=lr, method='sigmoid', cv=5)
calibrator.fit(val_preds, val_labels)

In [None]:
test_preds_np = np.array(all_preds).reshape(-1, 1)
calibrated_probs = calibrator.predict_proba(test_preds_np)[:, 1]  # Prob of class 1 (malignant)

# Threshold at 0.5
calibrated_binary = (calibrated_probs >= 0.5).astype(int)

# Print metrics
print("AUROC (calibrated):", roc_auc_score(all_labels, calibrated_probs))
print("Brier Score:", brier_score_loss(all_labels, calibrated_probs))  # Lower is better
print("Accuracy:", accuracy_score(all_labels, calibrated_binary))
print("Precision:", precision_score(all_labels, calibrated_binary))
print("Recall:", recall_score(all_labels, calibrated_binary))
print("F1 Score:", f1_score(all_labels, calibrated_binary))

# Confusion Matrix
cm = confusion_matrix(all_labels, calibrated_binary)
sns.heatmap(cm, annot=True, fmt="d", cmap="Blues")
plt.xlabel("Predicted")
plt.ylabel("Actual")
plt.title("Confusion Matrix (After Calibration)")
plt.show()

In [None]:
prob_true, prob_pred = calibration_curve(all_labels, calibrated_probs, n_bins=30)

plt.figure(figsize=(8, 6))
plt.plot(prob_pred, prob_true, marker='o', label='Calibrated')
plt.plot([0, 1], [0, 1], linestyle='--', color='gray', label='Perfect')
plt.xlabel('Mean Predicted Probability')
plt.ylabel('Fraction of Positives')
plt.title('Calibration Curve (After Platt Scaling or Isotonic)')
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.show()

### Explainability

In [None]:
# ------------------------------------------------------------------
# 7.  Grad-CAM
#     * Re-enable grads on the backbone
# ------------------------------------------------------------------
for p in model.features.parameters(): p.requires_grad_(True)

# target_layers = [
#     # model.features[16],     # conv3_3  (56×56)
#     # model.features[23],     # conv4_3  (28×28)
#     model.features[28]      # conv5_3  (14×14)
# ]
conv_layers = [m for m in model.features if isinstance(m, nn.Conv2d)]
cam = GradCAM(model=model, target_layers=conv_layers)

# take one test mini-batch (or loop, as you prefer)
model.eval()
imgs, labs = next(iter(test_loader))
imgs = imgs.to(DEVICE)

# forward once to get probabilities
probs = model(imgs).cpu().detach().numpy().squeeze()

# number of examples you want to display
N_SHOW = 4
THRESH = 0.5    # Set a threshold above 0.5 for displaying malignant cases
showed = 0

for i in range(len(probs)):
    if probs[i] < THRESH:
        continue
    # a single image tensor
    inp = imgs[i:i+1].requires_grad_(True)

    # run Grad-CAM → 2-D map in [0,1]
    grayscale_cam = cam(
        input_tensor=inp,
        targets=[ClassifierOutputTarget(0)])  # class-0 = malignant
    grayscale_cam = grayscale_cam[0]         # (H,W)

    # --- build a 'jet' heat-map -------------------------------------------
    heatmap = cv2.applyColorMap(
        np.uint8(255 * grayscale_cam), cv2.COLORMAP_JET)
    heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB) / 255.0

    # original image back to [0,1] RGB for display
    rgb = inp[0].cpu().detach().numpy().transpose(1, 2, 0)       # C,H,W → H,W,C
    rgb = (rgb - rgb.min()) / (rgb.max() - rgb.min())

    overlay = 0.4 * heatmap + 0.6 * rgb                  # blend
    overlay = np.clip(overlay, 0., 1.)

    # --- side-by-side plot -------------------------------------------------
    fig, ax = plt.subplots(1, 2, figsize=(6, 3))
    ax[0].imshow(rgb)
    ax[0].set_title("Original")
    ax[0].axis("off")

    ax[1].imshow(overlay)
    ax[1].set_title(f"CAM - jet  |  p={probs[i]:.2f}")
    ax[1].axis("off")

    plt.tight_layout(); plt.show()
    showed += 1
    if showed >= N_SHOW:
        break