# üî¨ Breast Cancer Histopathology Classification
## Complete Training Pipeline with ResNet50 + GradCAM

This notebook provides:
1. **Data Loading** - BreakHis dataset (benign vs malignant)
2. **Proper Data Split** - Train/Val/Test with stratification
3. **Data Augmentation** - Prevent overfitting
4. **Transfer Learning** - Fine-tune ResNet50
5. **Training Loop** - With early stopping
6. **Evaluation** - Accuracy, AUC, Confusion Matrix, Classification Report
7. **GradCAM Visualization** - Explainability
8. **Model Export** - Save for production use

## 1. Setup & Imports

In [None]:
# Install requirements (run once)
!pip install torch torchvision scikit-learn matplotlib seaborn pillow tqdm --quiet

In [None]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from PIL import Image
from pathlib import Path
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')

# PyTorch
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms, models
import torch.nn.functional as F

# Sklearn
from sklearn.model_selection import train_test_split
from sklearn.metrics import (
    accuracy_score, roc_auc_score, confusion_matrix, 
    classification_report, roc_curve, precision_recall_curve
)

# Check GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"üñ•Ô∏è Using device: {device}")
if torch.cuda.is_available():
    print(f"   GPU: {torch.cuda.get_device_name(0)}")
    print(f"   Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

## 2. Configuration

In [None]:
# ============================================================
# CONFIGURATION - Modify these paths for your setup
# ============================================================

CONFIG = {
    # Data paths - UPDATE THESE!
    'data_dir': './data/Breakhis-400x',  # Contains 'benign' and 'malignant' folders
    'model_save_path': './models/imaging_model_trained.pth',
    
    # Training parameters
    'batch_size': 32,
    'num_epochs': 25,
    'learning_rate': 1e-4,
    'weight_decay': 1e-5,
    
    # Data split
    'val_split': 0.15,
    'test_split': 0.15,
    
    # Image settings
    'image_size': 224,
    
    # Early stopping
    'patience': 5,
    
    # Random seed
    'seed': 42
}

# Set seed for reproducibility
torch.manual_seed(CONFIG['seed'])
np.random.seed(CONFIG['seed'])

print("üìã Configuration:")
for k, v in CONFIG.items():
    print(f"   {k}: {v}")

## 3. Data Loading & Exploration

In [None]:
# ============================================================
# OPTION A: Load from local BreakHis folder
# ============================================================

def load_breakhis_data(data_dir):
    """Load image paths and labels from BreakHis directory structure."""
    data_dir = Path(data_dir)
    
    image_paths = []
    labels = []
    
    # Load benign images
    benign_dir = data_dir / 'benign'
    if benign_dir.exists():
        for img_path in benign_dir.glob('*.png'):
            image_paths.append(str(img_path))
            labels.append(0)  # 0 = Benign
        for img_path in benign_dir.glob('*.jpg'):
            image_paths.append(str(img_path))
            labels.append(0)
    
    # Load malignant images
    malignant_dir = data_dir / 'malignant'
    if malignant_dir.exists():
        for img_path in malignant_dir.glob('*.png'):
            image_paths.append(str(img_path))
            labels.append(1)  # 1 = Malignant
        for img_path in malignant_dir.glob('*.jpg'):
            image_paths.append(str(img_path))
            labels.append(1)
    
    return image_paths, labels

# Try to load data
try:
    image_paths, labels = load_breakhis_data(CONFIG['data_dir'])
    print(f"‚úÖ Loaded {len(image_paths)} images from {CONFIG['data_dir']}")
    print(f"   Benign: {labels.count(0)}")
    print(f"   Malignant: {labels.count(1)}")
except Exception as e:
    print(f"‚ùå Error loading data: {e}")
    print("\nüìå Please ensure your data directory has this structure:")
    print("   data_dir/")
    print("   ‚îú‚îÄ‚îÄ benign/")
    print("   ‚îÇ   ‚îú‚îÄ‚îÄ image1.png")
    print("   ‚îÇ   ‚îî‚îÄ‚îÄ ...")
    print("   ‚îî‚îÄ‚îÄ malignant/")
    print("       ‚îú‚îÄ‚îÄ image1.png")
    print("       ‚îî‚îÄ‚îÄ ...")

In [None]:
# ============================================================
# OPTION B: Download BreakHis from Kaggle (if not available locally)
# ============================================================

# Uncomment and run if you need to download the dataset
'''
# Install kaggle API
!pip install kaggle --quiet

# Upload your kaggle.json API key first, then:
!mkdir -p ~/.kaggle
!cp kaggle.json ~/.kaggle/
!chmod 600 ~/.kaggle/kaggle.json

# Download BreakHis dataset
!kaggle datasets download -d ambarish/breakhis
!unzip -q breakhis.zip -d ./data/
'''

In [None]:
# Visualize sample images
def show_sample_images(image_paths, labels, n_samples=8):
    """Display sample images from each class."""
    fig, axes = plt.subplots(2, n_samples//2, figsize=(16, 8))
    
    # Get indices for each class
    benign_idx = [i for i, l in enumerate(labels) if l == 0][:n_samples//2]
    malignant_idx = [i for i, l in enumerate(labels) if l == 1][:n_samples//2]
    
    # Plot benign
    for i, idx in enumerate(benign_idx):
        img = Image.open(image_paths[idx])
        axes[0, i].imshow(img)
        axes[0, i].set_title('Benign', color='green', fontsize=12)
        axes[0, i].axis('off')
    
    # Plot malignant
    for i, idx in enumerate(malignant_idx):
        img = Image.open(image_paths[idx])
        axes[1, i].imshow(img)
        axes[1, i].set_title('Malignant', color='red', fontsize=12)
        axes[1, i].axis('off')
    
    plt.suptitle('Sample Histopathology Images', fontsize=14, fontweight='bold')
    plt.tight_layout()
    plt.show()

if len(image_paths) > 0:
    show_sample_images(image_paths, labels)

In [None]:
# Class distribution
def plot_class_distribution(labels):
    """Visualize class distribution."""
    fig, ax = plt.subplots(figsize=(8, 5))
    
    classes = ['Benign', 'Malignant']
    counts = [labels.count(0), labels.count(1)]
    colors = ['#4CAF50', '#f44336']
    
    bars = ax.bar(classes, counts, color=colors, edgecolor='black', linewidth=1.5)
    
    # Add value labels
    for bar, count in zip(bars, counts):
        ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 20,
                f'{count}', ha='center', fontsize=14, fontweight='bold')
    
    ax.set_ylabel('Number of Images', fontsize=12)
    ax.set_title('Class Distribution in Dataset', fontsize=14, fontweight='bold')
    ax.set_ylim(0, max(counts) * 1.15)
    
    # Add percentage
    total = sum(counts)
    for i, (bar, count) in enumerate(zip(bars, counts)):
        pct = count / total * 100
        ax.text(bar.get_x() + bar.get_width()/2, bar.get_height()/2,
                f'{pct:.1f}%', ha='center', va='center', fontsize=12, color='white', fontweight='bold')
    
    plt.tight_layout()
    plt.show()
    
    print(f"\nüìä Dataset Statistics:")
    print(f"   Total: {total} images")
    print(f"   Benign: {counts[0]} ({counts[0]/total*100:.1f}%)")
    print(f"   Malignant: {counts[1]} ({counts[1]/total*100:.1f}%)")

if len(labels) > 0:
    plot_class_distribution(labels)

## 4. Data Preparation (Train/Val/Test Split)

In [None]:
# ============================================================
# CRITICAL: Proper data split to avoid data leakage
# ============================================================

# First split: separate test set
X_trainval, X_test, y_trainval, y_test = train_test_split(
    image_paths, labels,
    test_size=CONFIG['test_split'],
    random_state=CONFIG['seed'],
    stratify=labels  # Maintain class balance
)

# Second split: separate validation from training
val_ratio = CONFIG['val_split'] / (1 - CONFIG['test_split'])
X_train, X_val, y_train, y_val = train_test_split(
    X_trainval, y_trainval,
    test_size=val_ratio,
    random_state=CONFIG['seed'],
    stratify=y_trainval
)

print("üìä Data Split (Stratified):")
print(f"   Training:   {len(X_train)} images ({len(X_train)/len(image_paths)*100:.1f}%)")
print(f"   Validation: {len(X_val)} images ({len(X_val)/len(image_paths)*100:.1f}%)")
print(f"   Test:       {len(X_test)} images ({len(X_test)/len(image_paths)*100:.1f}%)")
print()
print("   Class balance in splits:")
print(f"   Train - Benign: {y_train.count(0)}, Malignant: {y_train.count(1)}")
print(f"   Val   - Benign: {y_val.count(0)}, Malignant: {y_val.count(1)}")
print(f"   Test  - Benign: {y_test.count(0)}, Malignant: {y_test.count(1)}")

## 5. Dataset & DataLoaders with Augmentation

In [None]:
# ============================================================
# Custom Dataset Class
# ============================================================

class BreastCancerDataset(Dataset):
    """Custom dataset for breast cancer histopathology images."""
    
    def __init__(self, image_paths, labels, transform=None):
        self.image_paths = image_paths
        self.labels = labels
        self.transform = transform
    
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        # Load image
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert('RGB')
        
        # Apply transforms
        if self.transform:
            image = self.transform(image)
        
        # Get label
        label = self.labels[idx]
        
        return image, label

In [None]:
# ============================================================
# Data Augmentation Transforms
# ============================================================

# ImageNet normalization (required for pretrained ResNet)
IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD = [0.229, 0.224, 0.225]

# Training transforms (with augmentation)
train_transform = transforms.Compose([
    transforms.Resize((CONFIG['image_size'], CONFIG['image_size'])),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomVerticalFlip(p=0.5),
    transforms.RandomRotation(degrees=15),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.1),
    transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)),
    transforms.ToTensor(),
    transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD)
])

# Validation/Test transforms (no augmentation)
val_transform = transforms.Compose([
    transforms.Resize((CONFIG['image_size'], CONFIG['image_size'])),
    transforms.ToTensor(),
    transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD)
])

print("‚úÖ Transforms defined")
print("   Training: Resize ‚Üí Flip ‚Üí Rotate ‚Üí ColorJitter ‚Üí Normalize")
print("   Val/Test: Resize ‚Üí Normalize")

In [None]:
# Create datasets
train_dataset = BreastCancerDataset(X_train, y_train, transform=train_transform)
val_dataset = BreastCancerDataset(X_val, y_val, transform=val_transform)
test_dataset = BreastCancerDataset(X_test, y_test, transform=val_transform)

# Create dataloaders
train_loader = DataLoader(
    train_dataset, 
    batch_size=CONFIG['batch_size'], 
    shuffle=True,
    num_workers=2,
    pin_memory=True if device.type == 'cuda' else False
)

val_loader = DataLoader(
    val_dataset, 
    batch_size=CONFIG['batch_size'], 
    shuffle=False,
    num_workers=2
)

test_loader = DataLoader(
    test_dataset, 
    batch_size=CONFIG['batch_size'], 
    shuffle=False,
    num_workers=2
)

print(f"‚úÖ DataLoaders created")
print(f"   Train batches: {len(train_loader)}")
print(f"   Val batches: {len(val_loader)}")
print(f"   Test batches: {len(test_loader)}")

In [None]:
# Visualize augmentations
def show_augmentations(original_path):
    """Show original image and augmented versions."""
    fig, axes = plt.subplots(2, 5, figsize=(15, 6))
    
    # Load original
    original = Image.open(original_path).convert('RGB')
    
    # Show original
    axes[0, 0].imshow(original)
    axes[0, 0].set_title('Original', fontsize=10)
    axes[0, 0].axis('off')
    
    # Show augmented versions
    for i in range(1, 10):
        row = i // 5
        col = i % 5
        
        # Apply augmentation
        aug_transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomVerticalFlip(p=0.5),
            transforms.RandomRotation(degrees=15),
            transforms.ColorJitter(brightness=0.2, contrast=0.2),
        ])
        
        augmented = aug_transform(original)
        axes[row, col].imshow(augmented)
        axes[row, col].set_title(f'Augmented {i}', fontsize=10)
        axes[row, col].axis('off')
    
    plt.suptitle('Data Augmentation Examples', fontsize=14, fontweight='bold')
    plt.tight_layout()
    plt.show()

# Show augmentation on a sample image
if len(X_train) > 0:
    show_augmentations(X_train[0])

## 6. Model Architecture (ResNet50 Transfer Learning)

In [None]:
# ============================================================
# ResNet50 with Custom Classifier
# ============================================================

def create_model(pretrained=True, num_classes=2):
    """ResNet50 with custom classifier for binary classification."""
    
    def __init__(self, pretrained=True, num_classes=2):
        super(BreastCancerResNet, self).__init__()
        
        # Load pretrained ResNet50
        self.resnet = models.resnet50(weights='IMAGENET1K_V2' if pretrained else None)
        
        # Freeze early layers (optional - can unfreeze for fine-tuning)
        for param in list(self.resnet.parameters())[:-20]:  # Freeze all but last few layers
            param.requires_grad = False
        
        # Replace classifier
        num_features = self.resnet.fc.in_features
        self.resnet.fc = nn.Sequential(
            nn.Dropout(p=0.5),
            nn.Linear(num_features, 512),
            nn.ReLU(),
            nn.BatchNorm1d(512),
            nn.Dropout(p=0.3),
            nn.Linear(512, num_classes)
        )
    
    def forward(self, x):
        return self.resnet(x)
    
    def unfreeze_all(self):
        """Unfreeze all layers for full fine-tuning."""
        for param in self.resnet.parameters():
            param.requires_grad = True

# Create model
model = BreastCancerResNet(pretrained=True, num_classes=2)
model = model.to(device)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"‚úÖ Model created and moved to {device}")
print(f"   Total parameters: {total_params:,}")
print(f"   Trainable parameters: {trainable_params:,}")
print(f"   Frozen parameters: {total_params - trainable_params:,}")

In [None]:
# Loss function and optimizer

# Class weights for imbalanced data
class_counts = [y_train.count(0), y_train.count(1)]
class_weights = torch.tensor([1.0 / c for c in class_counts], dtype=torch.float32)
class_weights = class_weights / class_weights.sum()  # Normalize
class_weights = class_weights.to(device)

print(f"Class weights: Benign={class_weights[0]:.3f}, Malignant={class_weights[1]:.3f}")

# Loss function
criterion = nn.CrossEntropyLoss(weight=class_weights)

# Optimizer (only trainable parameters)
optimizer = optim.AdamW(
    filter(lambda p: p.requires_grad, model.parameters()),
    lr=CONFIG['learning_rate'],
    weight_decay=CONFIG['weight_decay']
)

# Learning rate scheduler
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='min', factor=0.5, patience=3, verbose=True
)

print("\n‚úÖ Optimizer and scheduler configured")

## 7. Training Loop

In [None]:
# ============================================================
# Training Functions
# ============================================================

def train_epoch(model, loader, criterion, optimizer, device):
    """Train for one epoch."""
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    pbar = tqdm(loader, desc='Training', leave=False)
    for images, labels in pbar:
        images = images.to(device)
        labels = labels.to(device)
        
        # Zero gradients
        optimizer.zero_grad()
        
        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, labels)
        
        # Backward pass
        loss.backward()
        optimizer.step()
        
        # Statistics
        running_loss += loss.item() * images.size(0)
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        
        pbar.set_postfix({'loss': f'{loss.item():.4f}'})
    
    epoch_loss = running_loss / total
    epoch_acc = correct / total
    
    return epoch_loss, epoch_acc


def validate_epoch(model, loader, criterion, device):
    """Validate for one epoch."""
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    all_probs = []
    all_labels = []
    
    with torch.no_grad():
        for images, labels in tqdm(loader, desc='Validating', leave=False):
            images = images.to(device)
            labels = labels.to(device)
            
            # Forward pass
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            # Statistics
            running_loss += loss.item() * images.size(0)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            
            # Store for AUC
            probs = F.softmax(outputs, dim=1)[:, 1]  # Probability of malignant
            all_probs.extend(probs.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    
    epoch_loss = running_loss / total
    epoch_acc = correct / total
    epoch_auc = roc_auc_score(all_labels, all_probs)
    
    return epoch_loss, epoch_acc, epoch_auc

In [None]:
# ============================================================
# TRAINING
# ============================================================

print("üöÄ Starting Training...")
print("="*60)

history = {
    'train_loss': [], 'train_acc': [],
    'val_loss': [], 'val_acc': [], 'val_auc': []
}

best_val_auc = 0.0
best_val_acc = 0.0
patience_counter = 0

for epoch in range(CONFIG['num_epochs']):
    print(f"\nEpoch {epoch+1}/{CONFIG['num_epochs']}")
    print("-" * 40)
    
    # Train
    train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device)
    
    # Validate
    val_loss, val_acc, val_auc = validate_epoch(model, val_loader, criterion, device)
    
    # Update scheduler
    scheduler.step(val_loss)
    
    # Save history
    history['train_loss'].append(train_loss)
    history['train_acc'].append(train_acc)
    history['val_loss'].append(val_loss)
    history['val_acc'].append(val_acc)
    history['val_auc'].append(val_auc)
    
    # Print metrics
    print(f"   Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f}")
    print(f"   Val Loss:   {val_loss:.4f} | Val Acc:   {val_acc:.4f} | Val AUC: {val_auc:.4f}")
    
    # Save best model
    if val_auc > best_val_auc:
        best_val_auc = val_auc
        best_val_acc = val_acc
        patience_counter = 0
        
        # Save model
        os.makedirs(os.path.dirname(CONFIG['model_save_path']), exist_ok=True)
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_auc': val_auc,
            'val_acc': val_acc,
            'config': CONFIG
        }, CONFIG['model_save_path'])
        print(f"   ‚úÖ Best model saved! (AUC: {val_auc:.4f})")
    else:
        patience_counter += 1
        print(f"   ‚è≥ No improvement ({patience_counter}/{CONFIG['patience']})")
    
    # Early stopping
    if patience_counter >= CONFIG['patience']:
        print(f"\n‚ö†Ô∏è Early stopping triggered at epoch {epoch+1}")
        break

print("\n" + "="*60)
print("üéâ Training Complete!")
print(f"   Best Val Accuracy: {best_val_acc:.4f}")
print(f"   Best Val AUC: {best_val_auc:.4f}")

In [None]:
# Plot training history
def plot_training_history(history):
    """Visualize training progress."""
    fig, axes = plt.subplots(1, 3, figsize=(15, 4))
    
    epochs = range(1, len(history['train_loss']) + 1)
    
    # Loss
    axes[0].plot(epochs, history['train_loss'], 'b-', label='Train', linewidth=2)
    axes[0].plot(epochs, history['val_loss'], 'r-', label='Validation', linewidth=2)
    axes[0].set_xlabel('Epoch')
    axes[0].set_ylabel('Loss')
    axes[0].set_title('Loss over Epochs')
    axes[0].legend()
    axes[0].grid(True, alpha=0.3)
    
    # Accuracy
    axes[1].plot(epochs, history['train_acc'], 'b-', label='Train', linewidth=2)
    axes[1].plot(epochs, history['val_acc'], 'r-', label='Validation', linewidth=2)
    axes[1].set_xlabel('Epoch')
    axes[1].set_ylabel('Accuracy')
    axes[1].set_title('Accuracy over Epochs')
    axes[1].legend()
    axes[1].grid(True, alpha=0.3)
    
    # AUC
    axes[2].plot(epochs, history['val_auc'], 'g-', label='Val AUC', linewidth=2)
    axes[2].set_xlabel('Epoch')
    axes[2].set_ylabel('AUC-ROC')
    axes[2].set_title('Validation AUC over Epochs')
    axes[2].legend()
    axes[2].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()

plot_training_history(history)

## 8. Model Evaluation on Test Set

In [None]:
# Load best model for evaluation
checkpoint = torch.load(CONFIG['model_save_path'], map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
print(f"‚úÖ Loaded best model from epoch {checkpoint['epoch']+1}")
print(f"   Validation AUC: {checkpoint['val_auc']:.4f}")

In [None]:
# ============================================================
# COMPREHENSIVE TEST SET EVALUATION
# ============================================================

def evaluate_model(model, loader, device):
    """Comprehensive model evaluation."""
    model.eval()
    
    all_preds = []
    all_probs = []
    all_labels = []
    
    with torch.no_grad():
        for images, labels in tqdm(loader, desc='Evaluating'):
            images = images.to(device)
            
            outputs = model(images)
            probs = F.softmax(outputs, dim=1)
            _, preds = torch.max(outputs, 1)
            
            all_preds.extend(preds.cpu().numpy())
            all_probs.extend(probs[:, 1].cpu().numpy())  # Prob of malignant
            all_labels.extend(labels.numpy())
    
    return np.array(all_preds), np.array(all_probs), np.array(all_labels)

test_preds, test_probs, test_labels = evaluate_model(model, test_loader, device)

In [None]:
# ============================================================
# METRICS
# ============================================================

test_acc = accuracy_score(test_labels, test_preds)
test_auc = roc_auc_score(test_labels, test_probs)

print("="*60)
print("üìä TEST SET RESULTS")
print("="*60)
print(f"\nüéØ Accuracy: {test_acc:.4f} ({test_acc*100:.2f}%)")
print(f"üìà AUC-ROC:  {test_auc:.4f}")
print()

# Classification Report
print("üìã Classification Report:")
print("-"*40)
print(classification_report(test_labels, test_preds, target_names=['Benign', 'Malignant']))

In [None]:
# Confusion Matrix
def plot_confusion_matrix(y_true, y_pred):
    """Plot confusion matrix."""
    cm = confusion_matrix(y_true, y_pred)
    
    fig, ax = plt.subplots(figsize=(8, 6))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
                xticklabels=['Benign', 'Malignant'],
                yticklabels=['Benign', 'Malignant'],
                annot_kws={'size': 16})
    
    ax.set_xlabel('Predicted', fontsize=12)
    ax.set_ylabel('Actual', fontsize=12)
    ax.set_title('Confusion Matrix', fontsize=14, fontweight='bold')
    
    # Add percentages
    total = cm.sum()
    for i in range(2):
        for j in range(2):
            pct = cm[i, j] / total * 100
            ax.text(j + 0.5, i + 0.7, f'({pct:.1f}%)', 
                   ha='center', va='center', fontsize=10, color='gray')
    
    plt.tight_layout()
    plt.show()
    
    # Print metrics from confusion matrix
    tn, fp, fn, tp = cm.ravel()
    sensitivity = tp / (tp + fn)
    specificity = tn / (tn + fp)
    precision = tp / (tp + fp)
    
    print(f"\nüìä Detailed Metrics:")
    print(f"   True Negatives (Benign correct):  {tn}")
    print(f"   True Positives (Malignant correct): {tp}")
    print(f"   False Negatives (Missed cancers): {fn}")
    print(f"   False Positives (False alarms):   {fp}")
    print(f"\n   Sensitivity (Recall): {sensitivity:.4f} - Ability to detect cancer")
    print(f"   Specificity: {specificity:.4f} - Ability to rule out cancer")
    print(f"   Precision: {precision:.4f} - Positive predictive value")

plot_confusion_matrix(test_labels, test_preds)

In [None]:
# ROC Curve and Precision-Recall Curve
def plot_curves(y_true, y_probs):
    """Plot ROC and Precision-Recall curves."""
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))
    
    # ROC Curve
    fpr, tpr, thresholds = roc_curve(y_true, y_probs)
    auc = roc_auc_score(y_true, y_probs)
    
    axes[0].plot(fpr, tpr, 'b-', linewidth=2, label=f'ROC Curve (AUC = {auc:.4f})')
    axes[0].plot([0, 1], [0, 1], 'k--', linewidth=1, label='Random Classifier')
    axes[0].fill_between(fpr, tpr, alpha=0.3)
    axes[0].set_xlabel('False Positive Rate', fontsize=12)
    axes[0].set_ylabel('True Positive Rate', fontsize=12)
    axes[0].set_title('ROC Curve', fontsize=14, fontweight='bold')
    axes[0].legend(loc='lower right')
    axes[0].grid(True, alpha=0.3)
    
    # Find optimal threshold
    optimal_idx = np.argmax(tpr - fpr)
    optimal_threshold = thresholds[optimal_idx]
    axes[0].scatter(fpr[optimal_idx], tpr[optimal_idx], color='red', s=100, 
                   label=f'Optimal (threshold={optimal_threshold:.3f})')
    axes[0].legend(loc='lower right')
    
    # Precision-Recall Curve
    precision, recall, _ = precision_recall_curve(y_true, y_probs)
    
    axes[1].plot(recall, precision, 'g-', linewidth=2, label='PR Curve')
    axes[1].fill_between(recall, precision, alpha=0.3, color='green')
    axes[1].set_xlabel('Recall', fontsize=12)
    axes[1].set_ylabel('Precision', fontsize=12)
    axes[1].set_title('Precision-Recall Curve', fontsize=14, fontweight='bold')
    axes[1].legend(loc='lower left')
    axes[1].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()
    
    print(f"\nüìå Optimal Classification Threshold: {optimal_threshold:.4f}")

plot_curves(test_labels, test_probs)

## 9. GradCAM Visualization

In [None]:
# ============================================================
# GradCAM Implementation
# ============================================================

class GradCAM:
    """GradCAM for ResNet50."""
    
    def __init__(self, model):
        self.model = model
        self.gradients = None
        self.activations = None
        
        # Hook the last conv layer (layer4)
        target_layer = model.resnet.layer4[-1]
        target_layer.register_forward_hook(self._forward_hook)
        target_layer.register_full_backward_hook(self._backward_hook)
    
    def _forward_hook(self, module, input, output):
        self.activations = output.detach()
    
    def _backward_hook(self, module, grad_input, grad_output):
        self.gradients = grad_output[0].detach()
    
    def generate(self, image, target_class=None):
        """Generate GradCAM heatmap."""
        self.model.eval()
        
        # Forward pass
        output = self.model(image)
        
        if target_class is None:
            target_class = output.argmax(dim=1).item()
        
        # Backward pass
        self.model.zero_grad()
        one_hot = torch.zeros_like(output)
        one_hot[0, target_class] = 1
        output.backward(gradient=one_hot)
        
        # Compute GradCAM
        pooled_gradients = self.gradients.mean(dim=[2, 3], keepdim=True)
        cam = (self.activations * pooled_gradients).sum(dim=1, keepdim=True)
        cam = F.relu(cam)
        cam = F.interpolate(cam, size=(224, 224), mode='bilinear', align_corners=False)
        cam = cam.squeeze().cpu().numpy()
        
        # Normalize
        cam = (cam - cam.min()) / (cam.max() - cam.min() + 1e-8)
        
        return cam

# Create GradCAM
gradcam = GradCAM(model)
print("‚úÖ GradCAM initialized")

In [None]:
# Visualize GradCAM on sample images
def visualize_gradcam(model, gradcam, image_paths, labels, n_samples=4):
    """Visualize GradCAM attention maps."""
    
    # Get samples from each class
    benign_idx = [i for i, l in enumerate(labels) if l == 0][:n_samples//2]
    malig_idx = [i for i, l in enumerate(labels) if l == 1][:n_samples//2]
    indices = benign_idx + malig_idx
    
    fig, axes = plt.subplots(len(indices), 3, figsize=(12, 4*len(indices)))
    
    for row, idx in enumerate(indices):
        # Load image
        img_path = image_paths[idx]
        true_label = labels[idx]
        
        original_img = Image.open(img_path).convert('RGB')
        
        # Transform for model
        img_tensor = val_transform(original_img).unsqueeze(0).to(device)
        
        # Get prediction
        with torch.no_grad():
            output = model(img_tensor)
            probs = F.softmax(output, dim=1)
            pred_class = output.argmax(dim=1).item()
            confidence = probs[0, pred_class].item()
        
        # Generate GradCAM
        cam = gradcam.generate(img_tensor, target_class=pred_class)
        
        # Original image
        axes[row, 0].imshow(original_img)
        true_str = 'Benign' if true_label == 0 else 'Malignant'
        axes[row, 0].set_title(f'Original\nTrue: {true_str}', fontsize=10)
        axes[row, 0].axis('off')
        
        # GradCAM heatmap
        axes[row, 1].imshow(cam, cmap='jet')
        pred_str = 'Benign' if pred_class == 0 else 'Malignant'
        color = 'green' if pred_class == true_label else 'red'
        axes[row, 1].set_title(f'GradCAM\nPred: {pred_str} ({confidence:.1%})', 
                              fontsize=10, color=color)
        axes[row, 1].axis('off')
        
        # Overlay
        img_resized = original_img.resize((224, 224))
        overlay = np.array(img_resized) / 255.0
        heatmap = plt.cm.jet(cam)[:, :, :3]
        combined = overlay * 0.6 + heatmap * 0.4
        
        axes[row, 2].imshow(combined)
        axes[row, 2].set_title('Overlay', fontsize=10)
        axes[row, 2].axis('off')
    
    plt.suptitle('GradCAM Attention Visualization', fontsize=14, fontweight='bold', y=1.02)
    plt.tight_layout()
    plt.show()

# Visualize on test images
visualize_gradcam(model, gradcam, X_test, y_test, n_samples=6)

## 10. Export Model for Production

In [None]:
# ============================================================
# FINAL MODEL EXPORT
# ============================================================

# Save complete model info
final_model_path = CONFIG['model_save_path'].replace('.pth', '_final.pth')

torch.save({
    'model_state_dict': model.state_dict(),
    'config': CONFIG,
    'metrics': {
        'test_accuracy': test_acc,
        'test_auc': test_auc,
    },
    'class_names': ['Benign', 'Malignant'],
    'input_size': CONFIG['image_size'],
    'normalization': {
        'mean': IMAGENET_MEAN,
        'std': IMAGENET_STD
    }
}, final_model_path)

print(f"‚úÖ Model saved to: {final_model_path}")
print(f"\nüìä Final Test Metrics:")
print(f"   Accuracy: {test_acc:.4f} ({test_acc*100:.2f}%)")
print(f"   AUC-ROC: {test_auc:.4f}")

In [None]:
# ============================================================
# INFERENCE FUNCTION (for use in app)
# ============================================================

def predict_single_image(model, image_path_or_pil, device):
    """
    Predict on a single image.
    
    Args:
        model: Trained PyTorch model
        image_path_or_pil: Path to image or PIL Image
        device: torch device
    
    Returns:
        prediction: 'Benign' or 'Malignant'
        confidence: Probability of predicted class
        probabilities: [prob_benign, prob_malignant]
    """
    model.eval()
    
    # Load image
    if isinstance(image_path_or_pil, str):
        image = Image.open(image_path_or_pil).convert('RGB')
    else:
        image = image_path_or_pil.convert('RGB')
    
    # Transform
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD)
    ])
    
    img_tensor = transform(image).unsqueeze(0).to(device)
    
    # Predict
    with torch.no_grad():
        output = model(img_tensor)
        probs = F.softmax(output, dim=1)
        pred_class = output.argmax(dim=1).item()
    
    class_names = ['Benign', 'Malignant']
    prediction = class_names[pred_class]
    confidence = probs[0, pred_class].item()
    probabilities = probs[0].cpu().numpy()
    
    return prediction, confidence, probabilities

# Test inference function
if len(X_test) > 0:
    test_img = X_test[0]
    pred, conf, probs = predict_single_image(model, test_img, device)
    print(f"\nüß™ Test Inference:")
    print(f"   Image: {test_img}")
    print(f"   Prediction: {pred}")
    print(f"   Confidence: {conf:.4f}")
    print(f"   Benign prob: {probs[0]:.4f}")
    print(f"   Malignant prob: {probs[1]:.4f}")

## 11. Summary & Next Steps

In [None]:
# ============================================================
# TRAINING SUMMARY
# ============================================================

print("="*60)
print("üìã TRAINING SUMMARY")
print("="*60)
print()
print("üî¨ MODEL ARCHITECTURE")
print(f"   Base: ResNet50 (ImageNet pretrained)")
print(f"   Custom classifier with dropout")
print(f"   Total params: {total_params:,}")
print(f"   Trainable params: {trainable_params:,}")
print()
print("üìä DATASET")
print(f"   Total images: {len(image_paths)}")
print(f"   Train: {len(X_train)} | Val: {len(X_val)} | Test: {len(X_test)}")
print()
print("‚öôÔ∏è TRAINING CONFIG")
print(f"   Epochs: {len(history['train_loss'])}")
print(f"   Batch size: {CONFIG['batch_size']}")
print(f"   Learning rate: {CONFIG['learning_rate']}")
print(f"   Augmentation: Flip, Rotate, ColorJitter")
print()
print("üéØ FINAL RESULTS (Test Set)")
print(f"   Accuracy: {test_acc:.4f} ({test_acc*100:.2f}%)")
print(f"   AUC-ROC: {test_auc:.4f}")
print()
print("üíæ SAVED FILES")
print(f"   Model: {final_model_path}")
print()
print("‚úÖ Ready for deployment!")

---

## üìå Next Steps

1. **Copy the trained model** (`models/imaging_model_trained_final.pth`) to your breast_cancer_ai project

2. **Update `modules/imaging.py`** to load the trained model instead of generic pretrained ResNet50

3. **Important Notes:**
   - If accuracy is low, try:
     - More epochs
     - Unfreezing more layers (`model.unfreeze_all()`)
     - More data augmentation
     - Lower learning rate
   - For production, consider:
     - Cross-validation
     - Test-time augmentation (TTA)
     - Model ensembling