# Baseline Model Development

This notebook develops and trains baseline models for cancer detection from chest X-rays:
- Custom CNN architecture
- Transfer learning with pre-trained models
- Model comparison and analysis
- Performance evaluation

**Authors:** Sneh Gupta and Arpit Bhardwaj  
**Course:** CSET211 - Statistical Machine Learning

In [None]:
import os
import sys
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score
import warnings
from tqdm.notebook import tqdm

# Add src to path
sys.path.append('../src')

from data_loader import DataManager, get_transforms
from models import get_model, CustomCNN, ResNetModel
from utils import seed_everything, print_device_info, AverageMeter

# Configuration
warnings.filterwarnings('ignore')
plt.style.use('seaborn-v0_8')
seed_everything(42)

# Device setup
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
print_device_info()

## 1. Data Preparation

In [None]:
# Configuration for data loading
config = {
    'data': {
        'dataset_path': '../data/raw/images',
        'labels_file': '../data/raw/Data_Entry_2017_v2020.csv',
        'image_size': 224,
        'batch_size': 16,  # Smaller for notebook
        'num_workers': 0,   # For Windows compatibility
        'train_split': 0.7,
        'val_split': 0.15,
        'test_split': 0.15
    }
}

# Initialize data manager
print("Initializing data manager...")
try:
    data_manager = DataManager(config)
    
    # Get data loaders
    train_loader, val_loader, test_loader = data_manager.get_data_loaders()
    
    print(f"Training samples: {len(train_loader.dataset)}")
    print(f"Validation samples: {len(val_loader.dataset)}")
    print(f"Test samples: {len(test_loader.dataset)}")
    
    # Calculate positive weight for class imbalance
    pos_weight = data_manager.calculate_pos_weight()
    print(f"Positive weight for loss function: {pos_weight:.2f}")
    
except Exception as e:
    print(f"Error loading data: {e}")
    print("Please ensure the dataset is available in the correct directory.")
    # Create dummy data for demonstration
    print("Creating dummy data for demonstration...")
    
    from torch.utils.data import TensorDataset
    
    # Create dummy data
    dummy_images = torch.randn(100, 3, 224, 224)
    dummy_labels = torch.randint(0, 2, (100,)).float()
    
    dataset = TensorDataset(dummy_images, dummy_labels)
    train_loader = DataLoader(dataset[:60], batch_size=16, shuffle=True)
    val_loader = DataLoader(dataset[60:80], batch_size=16, shuffle=False)
    test_loader = DataLoader(dataset[80:], batch_size=16, shuffle=False)
    
    pos_weight = 2.0
    print("Using dummy data for demonstration")

In [None]:
# Visualize a batch of training data
def visualize_batch(data_loader, title="Training Batch"):
    # Get one batch
    images, labels = next(iter(data_loader))
    
    # Create subplot grid
    fig, axes = plt.subplots(2, 4, figsize=(12, 6))
    fig.suptitle(title, fontsize=14)
    
    for idx in range(min(8, len(images))):
        row = idx // 4
        col = idx % 4
        
        # Denormalize image for display
        image = images[idx]
        
        # Handle different image formats
        if image.shape[0] == 3:  # RGB
            # Denormalize using ImageNet stats
            mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
            std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
            image = image * std + mean
            image = torch.clamp(image, 0, 1)
            image = image.permute(1, 2, 0)
        
        axes[row, col].imshow(image, cmap='gray' if image.shape[-1] == 1 else None)
        axes[row, col].set_title(f'Label: {"Cancer" if labels[idx] == 1 else "Normal"}')
        axes[row, col].axis('off')
    
    plt.tight_layout()
    plt.show()

# Visualize training batch
if 'train_loader' in locals():
    visualize_batch(train_loader, "Training Data Sample")

## 2. Model Definition and Architecture Comparison

In [None]:
# Define model configurations
model_configs = {
    'custom_cnn': {
        'model': {
            'architecture': 'custom_cnn',
            'num_classes': 1,
            'dropout': 0.5,
            'pretrained': False
        }
    },
    'resnet18': {
        'model': {
            'architecture': 'resnet18',
            'num_classes': 1,
            'dropout': 0.5,
            'pretrained': True
        }
    },
    'resnet50': {
        'model': {
            'architecture': 'resnet50',
            'num_classes': 1,
            'dropout': 0.5,
            'pretrained': True
        }
    }
}

# Create models and analyze their properties
models = {}
model_info = []

for name, config in model_configs.items():
    try:
        model = get_model(config)
        models[name] = model
        
        # Count parameters
        total_params = sum(p.numel() for p in model.parameters())
        trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
        
        model_info.append({
            'Model': name,
            'Total Parameters': f"{total_params:,}",
            'Trainable Parameters': f"{trainable_params:,}",
            'Memory (MB)': f"{total_params * 4 / 1024 / 1024:.1f}"  # Rough estimate
        })
        
        print(f"✓ {name} loaded successfully")
    except Exception as e:
        print(f"✗ Error loading {name}: {e}")

# Display model comparison
if model_info:
    import pandas as pd
    model_df = pd.DataFrame(model_info)
    print("\nModel Architecture Comparison:")
    print(model_df.to_string(index=False))

In [None]:
# Test forward pass for all models
print("Testing forward pass for all models:")
print("=" * 40)

dummy_input = torch.randn(1, 3, 224, 224).to(device)

for name, model in models.items():
    try:
        model.to(device)
        model.eval()
        
        with torch.no_grad():
            output = model(dummy_input)
        
        print(f"✓ {name}: Input {list(dummy_input.shape)} → Output {list(output.shape)}")
        
    except Exception as e:
        print(f"✗ {name}: Error in forward pass - {e}")

## 3. Training Functions

In [None]:
def train_epoch(model, train_loader, criterion, optimizer, device):
    """Train model for one epoch"""
    model.train()
    losses = AverageMeter()
    
    all_preds = []
    all_targets = []
    
    progress_bar = tqdm(train_loader, desc='Training', leave=False)
    
    for batch_idx, (data, target) in enumerate(progress_bar):
        data, target = data.to(device), target.to(device)
        
        optimizer.zero_grad()
        output = model(data).squeeze()
        loss = criterion(output, target)
        
        loss.backward()
        optimizer.step()
        
        losses.update(loss.item(), data.size(0))
        
        # Store predictions for metrics
        with torch.no_grad():
            preds = torch.sigmoid(output).cpu().numpy()
            all_preds.extend(preds)
            all_targets.extend(target.cpu().numpy())
        
        progress_bar.set_postfix({'Loss': f'{losses.avg:.4f}'})
    
    # Calculate metrics
    all_preds = np.array(all_preds)
    all_targets = np.array(all_targets)
    
    train_acc = accuracy_score(all_targets, all_preds > 0.5)
    train_auc = roc_auc_score(all_targets, all_preds) if len(np.unique(all_targets)) > 1 else 0.0
    
    return losses.avg, train_acc, train_auc

def validate_epoch(model, val_loader, criterion, device):
    """Validate model for one epoch"""
    model.eval()
    losses = AverageMeter()
    
    all_preds = []
    all_targets = []
    
    with torch.no_grad():
        progress_bar = tqdm(val_loader, desc='Validation', leave=False)
        
        for data, target in progress_bar:
            data, target = data.to(device), target.to(device)
            output = model(data).squeeze()
            loss = criterion(output, target)
            
            losses.update(loss.item(), data.size(0))
            
            preds = torch.sigmoid(output).cpu().numpy()
            all_preds.extend(preds)
            all_targets.extend(target.cpu().numpy())
            
            progress_bar.set_postfix({'Loss': f'{losses.avg:.4f}'})
    
    # Calculate metrics
    all_preds = np.array(all_preds)
    all_targets = np.array(all_targets)
    
    val_acc = accuracy_score(all_targets, all_preds > 0.5)
    val_precision = precision_score(all_targets, all_preds > 0.5, zero_division=0)
    val_recall = recall_score(all_targets, all_preds > 0.5, zero_division=0)
    val_f1 = f1_score(all_targets, all_preds > 0.5, zero_division=0)
    val_auc = roc_auc_score(all_targets, all_preds) if len(np.unique(all_targets)) > 1 else 0.0
    
    return losses.avg, val_acc, val_precision, val_recall, val_f1, val_auc

def train_model(model, train_loader, val_loader, epochs=10, learning_rate=0.001, pos_weight=1.0):
    """Complete training loop"""
    model = model.to(device)
    
    # Loss function with class weighting
    criterion = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([pos_weight]).to(device))
    optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-4)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=3, factor=0.5)
    
    # Training history
    history = {
        'train_loss': [], 'train_acc': [], 'train_auc': [],
        'val_loss': [], 'val_acc': [], 'val_precision': [], 
        'val_recall': [], 'val_f1': [], 'val_auc': []
    }
    
    best_val_auc = 0.0
    
    for epoch in range(epochs):
        print(f'\nEpoch {epoch+1}/{epochs}')
        print('-' * 50)
        
        # Train
        train_loss, train_acc, train_auc = train_epoch(model, train_loader, criterion, optimizer, device)
        
        # Validate
        val_loss, val_acc, val_precision, val_recall, val_f1, val_auc = validate_epoch(model, val_loader, criterion, device)
        
        # Update scheduler
        scheduler.step(val_loss)
        
        # Store history
        history['train_loss'].append(train_loss)
        history['train_acc'].append(train_acc)
        history['train_auc'].append(train_auc)
        history['val_loss'].append(val_loss)
        history['val_acc'].append(val_acc)
        history['val_precision'].append(val_precision)
        history['val_recall'].append(val_recall)
        history['val_f1'].append(val_f1)
        history['val_auc'].append(val_auc)
        
        # Print metrics
        print(f'Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}, Train AUC: {train_auc:.4f}')
        print(f'Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}, Val AUC: {val_auc:.4f}')
        print(f'Val Precision: {val_precision:.4f}, Val Recall: {val_recall:.4f}, Val F1: {val_f1:.4f}')
        
        # Save best model
        if val_auc > best_val_auc:
            best_val_auc = val_auc
            print(f'New best model! Val AUC: {val_auc:.4f}')
    
    return model, history, best_val_auc

print("Training functions defined successfully!")

## 4. Baseline Model Training

In [None]:
# Train Custom CNN as baseline
if 'custom_cnn' in models:
    print("Training Custom CNN (Baseline Model)")
    print("=" * 50)
    
    custom_model = models['custom_cnn']
    
    # Train for fewer epochs in notebook
    trained_custom, custom_history, custom_best_auc = train_model(
        custom_model, train_loader, val_loader, 
        epochs=5, learning_rate=0.001, pos_weight=pos_weight
    )
    
    print(f"\nCustom CNN Training Complete!")
    print(f"Best Validation AUC: {custom_best_auc:.4f}")
else:
    print("Custom CNN model not available")

In [None]:
# Train ResNet18 with transfer learning
if 'resnet18' in models:
    print("Training ResNet18 (Transfer Learning)")
    print("=" * 50)
    
    resnet18_model = models['resnet18']
    
    # Train with lower learning rate for transfer learning
    trained_resnet18, resnet18_history, resnet18_best_auc = train_model(
        resnet18_model, train_loader, val_loader,
        epochs=5, learning_rate=0.0001, pos_weight=pos_weight
    )
    
    print(f"\nResNet18 Training Complete!")
    print(f"Best Validation AUC: {resnet18_best_auc:.4f}")
else:
    print("ResNet18 model not available")

## 5. Results Analysis and Visualization

In [None]:
# Plot training history for all trained models
def plot_training_history(histories, model_names):
    fig, axes = plt.subplots(2, 3, figsize=(15, 10))
    
    metrics = ['loss', 'acc', 'auc', 'precision', 'recall', 'f1']
    metric_titles = ['Loss', 'Accuracy', 'AUC', 'Precision', 'Recall', 'F1-Score']
    
    colors = ['blue', 'red', 'green', 'orange', 'purple']
    
    for idx, (metric, title) in enumerate(zip(metrics, metric_titles)):
        row, col = idx // 3, idx % 3
        ax = axes[row, col]
        
        for i, (history, name) in enumerate(zip(histories, model_names)):
            color = colors[i % len(colors)]
            
            if metric == 'loss':
                ax.plot(history['train_loss'], f'{color}--', label=f'{name} Train', alpha=0.7)
                ax.plot(history['val_loss'], color, label=f'{name} Val')
            else:
                val_key = f'val_{metric}'
                if val_key in history:
                    ax.plot(history[val_key], color, label=f'{name}')
        
        ax.set_title(title)
        ax.set_xlabel('Epoch')
        ax.legend()
        ax.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()

# Collect histories from trained models
trained_histories = []
trained_names = []

if 'custom_history' in locals():
    trained_histories.append(custom_history)
    trained_names.append('Custom CNN')

if 'resnet18_history' in locals():
    trained_histories.append(resnet18_history)
    trained_names.append('ResNet18')

if trained_histories:
    plot_training_history(trained_histories, trained_names)
else:
    print("No training histories available to plot")

In [None]:
# Create model comparison table
comparison_data = []

if 'custom_history' in locals():
    custom_final = {
        'Model': 'Custom CNN',
        'Final Val Accuracy': f"{custom_history['val_acc'][-1]:.4f}",
        'Final Val AUC': f"{custom_history['val_auc'][-1]:.4f}",
        'Final Val Precision': f"{custom_history['val_precision'][-1]:.4f}",
        'Final Val Recall': f"{custom_history['val_recall'][-1]:.4f}",
        'Final Val F1': f"{custom_history['val_f1'][-1]:.4f}",
        'Best Val AUC': f"{custom_best_auc:.4f}"
    }
    comparison_data.append(custom_final)

if 'resnet18_history' in locals():
    resnet18_final = {
        'Model': 'ResNet18',
        'Final Val Accuracy': f"{resnet18_history['val_acc'][-1]:.4f}",
        'Final Val AUC': f"{resnet18_history['val_auc'][-1]:.4f}",
        'Final Val Precision': f"{resnet18_history['val_precision'][-1]:.4f}",
        'Final Val Recall': f"{resnet18_history['val_recall'][-1]:.4f}",
        'Final Val F1': f"{resnet18_history['val_f1'][-1]:.4f}",
        'Best Val AUC': f"{resnet18_best_auc:.4f}"
    }
    comparison_data.append(resnet18_final)

if comparison_data:
    import pandas as pd
    comparison_df = pd.DataFrame(comparison_data)
    print("Model Performance Comparison:")
    print("=" * 80)
    print(comparison_df.to_string(index=False))
    
    # Determine best model
    best_auc = max([float(row['Best Val AUC']) for row in comparison_data])
    best_model = [row['Model'] for row in comparison_data if float(row['Best Val AUC']) == best_auc][0]
    print(f"\nBest performing model: {best_model} (AUC: {best_auc:.4f})")
else:
    print("No models trained for comparison")

## 6. Test Set Evaluation

In [None]:
# Evaluate the best model on test set
def evaluate_on_test(model, test_loader, model_name):
    model.eval()
    all_preds = []
    all_targets = []
    
    with torch.no_grad():
        for data, target in tqdm(test_loader, desc=f'Testing {model_name}'):
            data, target = data.to(device), target.to(device)
            output = model(data).squeeze()
            
            preds = torch.sigmoid(output).cpu().numpy()
            all_preds.extend(preds)
            all_targets.extend(target.cpu().numpy())
    
    all_preds = np.array(all_preds)
    all_targets = np.array(all_targets)
    
    # Calculate metrics
    test_acc = accuracy_score(all_targets, all_preds > 0.5)
    test_precision = precision_score(all_targets, all_preds > 0.5, zero_division=0)
    test_recall = recall_score(all_targets, all_preds > 0.5, zero_division=0)
    test_f1 = f1_score(all_targets, all_preds > 0.5, zero_division=0)
    test_auc = roc_auc_score(all_targets, all_preds) if len(np.unique(all_targets)) > 1 else 0.0
    
    print(f"\n{model_name} Test Results:")
    print(f"Accuracy:  {test_acc:.4f}")
    print(f"Precision: {test_precision:.4f}")
    print(f"Recall:    {test_recall:.4f}")
    print(f"F1-Score:  {test_f1:.4f}")
    print(f"AUC:       {test_auc:.4f}")
    
    return {
        'predictions': all_preds,
        'targets': all_targets,
        'accuracy': test_acc,
        'precision': test_precision,
        'recall': test_recall,
        'f1': test_f1,
        'auc': test_auc
    }

# Test all trained models
test_results = {}

if 'trained_custom' in locals():
    test_results['Custom CNN'] = evaluate_on_test(trained_custom, test_loader, 'Custom CNN')

if 'trained_resnet18' in locals():
    test_results['ResNet18'] = evaluate_on_test(trained_resnet18, test_loader, 'ResNet18')

if test_results:
    print(f"\nTested {len(test_results)} models on test set")
else:
    print("No trained models available for testing")