# 04 - Baseline Model Test (ResNet-50)

**Author:** Tan Ming Kai (24PMR12003)  
**Date:** 2025-11-11  
**Purpose:** Get ONE baseline model working to verify training pipeline

**Project:** Multi-Scale Vision Transformer (CrossViT) for COVID-19 Chest X-ray Classification  
**Academic Year:** 2025/26

---

## Objectives
1. ‚úÖ Create PyTorch Dataset for loading CLAHE-enhanced images
2. ‚úÖ Implement DataLoader with memory-safe settings
3. ‚úÖ Load ResNet-50 baseline model
4. ‚úÖ Train on small subset (1000 images) first
5. ‚úÖ Verify GPU memory usage (<7GB)
6. ‚úÖ Train on full dataset
7. ‚úÖ Achieve >70% accuracy (confirm pipeline works)
8. ‚úÖ Log results to MLflow
9. ‚úÖ Save best model checkpoint

---

## Phase 1: Exploration - Final Step

This notebook completes Phase 1 by verifying you can successfully train a deep learning model on the preprocessed dataset.

## 1. Reproducibility Setup & Imports

**CRITICAL:** Load reproducibility seeds and required libraries.

In [None]:
"""
Baseline Model Test Notebook for CrossViT COVID-19 FYP
Author: Tan Ming Kai (24PMR12003)
Purpose: Verify training pipeline works with ResNet-50 baseline
"""

# ============================================================================
# 1. REPRODUCIBILITY SETUP (ALWAYS FIRST!)
# ============================================================================
import random
import numpy as np
import torch

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

print("‚úÖ Random seeds set to 42 for reproducibility")

# ============================================================================
# 2. STANDARD LIBRARY IMPORTS
# ============================================================================
import os
import sys
from pathlib import Path
import warnings
import time
from datetime import datetime
warnings.filterwarnings('ignore')

# ============================================================================
# 3. DATA SCIENCE LIBRARIES
# ============================================================================
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm

# Configure display
pd.set_option('display.max_columns', None)
plt.style.use('seaborn-v0_8-darkgrid')
sns.set_palette('husl')

# ============================================================================
# 4. PYTORCH & DEEP LEARNING
# ============================================================================
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, Subset
from torchvision import transforms, models

# ============================================================================
# 5. COMPUTER VISION
# ============================================================================
import cv2
from PIL import Image

# ============================================================================
# 6. MLFLOW (Experiment Tracking)
# ============================================================================
try:
    import mlflow
    import mlflow.pytorch
    MLFLOW_AVAILABLE = True
    print("‚úÖ MLflow available for experiment tracking")
except ImportError:
    MLFLOW_AVAILABLE = False
    print("‚ö†Ô∏è  MLflow not installed. Install with: pip install mlflow")
    print("   Continuing without experiment tracking...")

# ============================================================================
# 7. SKLEARN (Metrics)
# ============================================================================
from sklearn.metrics import (
    accuracy_score, precision_recall_fscore_support,
    confusion_matrix, classification_report
)

print("\n‚úÖ All imports successful!")
print(f"PyTorch version: {torch.__version__}")
print(f"NumPy version: {np.__version__}")
print(f"Pandas version: {pd.__version__}")

## 2. Hardware Verification

Verify GPU is available and check VRAM.

In [None]:
print("=" * 70)
print("HARDWARE VERIFICATION")
print("=" * 70)

# Check CUDA
cuda_available = torch.cuda.is_available()
device = torch.device('cuda' if cuda_available else 'cpu')

print(f"\n‚úì CUDA Available: {cuda_available}")
print(f"‚úì Using Device: {device}")

if cuda_available:
    gpu_name = torch.cuda.get_device_name(0)
    total_memory = torch.cuda.get_device_properties(0).total_memory / 1e9
    
    print(f"‚úì GPU: {gpu_name}")
    print(f"‚úì Total VRAM: {total_memory:.2f} GB")
    print(f"‚úì CUDA Version: {torch.version.cuda}")
    
    # Memory monitoring function
    def print_gpu_memory(prefix=""):
        allocated = torch.cuda.memory_allocated(0) / 1e9
        reserved = torch.cuda.memory_reserved(0) / 1e9
        free = total_memory - reserved
        print(f"{prefix}GPU Memory: Allocated={allocated:.3f}GB | Reserved={reserved:.3f}GB | Free={free:.3f}GB")
    
    print_gpu_memory("\n  ")
    
    if "4060" in gpu_name and 7.0 <= total_memory <= 9.0:
        print("\n‚úÖ CONFIRMED: RTX 4060 8GB detected - Ready for training!")
    else:
        print(f"\n‚ö†Ô∏è  Different GPU detected: {gpu_name}")
        print("   Adjust batch size if needed based on VRAM.")
else:
    print("\n‚ùå WARNING: No GPU detected! Training will be VERY slow.")
    print("   Please ensure CUDA drivers and PyTorch with CUDA are installed.")

print("\n" + "=" * 70)

## 3. Configuration

Define all training parameters and paths.

In [None]:
# Paths
CSV_DIR = Path("../data/processed")
PROCESSED_IMG_DIR = Path("../data/processed/clahe_enhanced")
MODELS_DIR = Path("../models")
RESULTS_DIR = Path("../results")

# Create directories
MODELS_DIR.mkdir(parents=True, exist_ok=True)
RESULTS_DIR.mkdir(parents=True, exist_ok=True)

# Training configuration (Memory-safe for RTX 4060 8GB)
CONFIG = {
    # Reproducibility
    'seed': 42,
    'device': device,
    
    # Model
    'model_name': 'ResNet-50',
    'num_classes': 4,
    'pretrained': True,
    
    # Data
    'image_size': 240,
    'class_names': ['COVID', 'Normal', 'Lung_Opacity', 'Viral Pneumonia'],
    'class_weights': [1.47, 0.52, 0.88, 3.95],  # From EDA
    
    # Training hyperparameters
    'batch_size': 16,  # Safe for 8GB VRAM
    'num_workers': 0,  # WINDOWS FIX: Use 0 for Windows (4 for Linux/Mac)
    'pin_memory': True,
    'persistent_workers': False,  # WINDOWS FIX: Must be False when num_workers=0
    
    # Optimizer
    'learning_rate': 1e-4,  # ResNet-50 default
    'weight_decay': 1e-4,
    'max_epochs': 30,  # Start with fewer epochs for testing
    'early_stopping_patience': 10,
    
    # ImageNet normalization
    'mean': [0.485, 0.456, 0.406],
    'std': [0.229, 0.224, 0.225],
    
    # Memory management
    'mixed_precision': True,
    'memory_check_interval': 50,  # Check GPU memory every N batches
    
    # Testing
    'test_on_subset': True,  # Start with small subset
    'subset_size': 1000,  # Images per class for quick test
}

print("=" * 70)
print("CONFIGURATION")
print("=" * 70)
print(f"\n‚úì Model: {CONFIG['model_name']}")
print(f"‚úì Device: {CONFIG['device']}")
print(f"‚úì Batch Size: {CONFIG['batch_size']}")
print(f"‚úì Learning Rate: {CONFIG['learning_rate']}")
print(f"‚úì Max Epochs: {CONFIG['max_epochs']}")
print(f"‚úì Image Size: {CONFIG['image_size']}√ó{CONFIG['image_size']}")
print(f"‚úì Mixed Precision: {CONFIG['mixed_precision']}")
print(f"\n‚úì Test on Subset: {CONFIG['test_on_subset']}")
if CONFIG['test_on_subset']:
    print(f"  ‚Üí Subset Size: {CONFIG['subset_size']} images per split")
print(f"\n‚ö†Ô∏è  Windows Mode: num_workers=0 (single-threaded loading)")
print(f"   This is slower but stable on Windows")
print("\n" + "=" * 70)

## 4. MLflow Setup

Initialize experiment tracking.

In [None]:
print("=" * 70)
print("MLFLOW EXPERIMENT TRACKING SETUP")
print("=" * 70)

if MLFLOW_AVAILABLE:
    # Set experiment name
    mlflow.set_experiment("crossvit-covid19-classification")
    
    # Set tracking URI (local directory)
    mlflow.set_tracking_uri("file:./mlruns")
    
    print("\n‚úÖ MLflow configured:")
    print(f"   - Experiment: crossvit-covid19-classification")
    print(f"   - Tracking URI: {mlflow.get_tracking_uri()}")
    print(f"\nüí° View results: Run 'mlflow ui' in terminal, then open http://localhost:5000")
else:
    print("\n‚ö†Ô∏è  MLflow not available. Results will not be logged.")
    print("   Install with: pip install mlflow")

print("\n" + "=" * 70)

## 5. Load Data Splits

Load CSV files with paths to preprocessed images.

In [None]:
print("=" * 70)
print("LOADING DATA SPLITS")
print("=" * 70)

# Load processed CSV files
train_df = pd.read_csv(CSV_DIR / "train_processed.csv")
val_df = pd.read_csv(CSV_DIR / "val_processed.csv")
test_df = pd.read_csv(CSV_DIR / "test_processed.csv")

print(f"\n‚úÖ CSV files loaded:")
print(f"   - Train: {len(train_df):,} images")
print(f"   - Val:   {len(val_df):,} images")
print(f"   - Test:  {len(test_df):,} images")

# Verify processed_path column exists
if 'processed_path' in train_df.columns:
    print(f"\n‚úÖ Using preprocessed images from: processed_path column")
    
    # Test loading one image
    test_path = train_df.iloc[0]['processed_path']
    if Path(test_path).exists():
        test_img = cv2.imread(test_path)
        print(f"   ‚úì Sample image loaded successfully: {test_img.shape}")
    else:
        print(f"   ‚ùå ERROR: Sample image not found at {test_path}")
        print(f"   Please verify processed images exist.")
else:
    print(f"\n‚ùå ERROR: 'processed_path' column not found in CSV")
    print(f"   Please run 02_data_cleaning.ipynb first.")

print("\nüìä Class Distribution in Training Set:")
class_counts = train_df['class_name'].value_counts()
for class_name, count in class_counts.items():
    pct = count / len(train_df) * 100
    print(f"   {class_name:20s}: {count:5d} ({pct:5.2f}%)")

print("\n" + "=" * 70)

## 6. Create PyTorch Dataset

Custom Dataset class for loading CLAHE-enhanced images.

In [None]:
class COVID19Dataset(Dataset):
    """
    PyTorch Dataset for COVID-19 chest X-ray classification.
    
    Loads CLAHE-enhanced images (240√ó240√ó3 RGB) from preprocessed directory.
    """
    
    def __init__(self, dataframe, transform=None):
        """
        Args:
            dataframe (pd.DataFrame): DataFrame with 'processed_path' and 'label' columns
            transform (callable, optional): Transformations to apply to images
        """
        self.dataframe = dataframe.reset_index(drop=True)
        self.transform = transform
        
        # Extract paths and labels
        self.image_paths = self.dataframe['processed_path'].values
        self.labels = self.dataframe['label'].values
        
    def __len__(self):
        return len(self.dataframe)
    
    def __getitem__(self, idx):
        """
        Load and return image and label at index idx.
        """
        # Load image (BGR format from cv2)
        img_path = self.image_paths[idx]
        image = cv2.imread(img_path)
        
        if image is None:
            raise FileNotFoundError(f"Image not found: {img_path}")
        
        # Convert BGR to RGB
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        
        # Convert to PIL Image for torchvision transforms
        image = Image.fromarray(image)
        
        # Apply transformations
        if self.transform:
            image = self.transform(image)
        
        # Get label
        label = torch.tensor(self.labels[idx], dtype=torch.long)
        
        return image, label


print("‚úÖ COVID19Dataset class defined")
print("\nüìù Features:")
print("   - Loads CLAHE-enhanced images (240√ó240√ó3)")
print("   - Converts BGR ‚Üí RGB")
print("   - Applies torchvision transforms")
print("   - Returns (image, label) tensors")

## 7. Define Data Transforms

Create transforms for training and validation.

**Important:** No aggressive augmentation yet - just normalization for baseline test.

In [None]:
# Training transforms (minimal augmentation for now)
train_transform = transforms.Compose([
    transforms.Resize((CONFIG['image_size'], CONFIG['image_size'])),
    transforms.RandomHorizontalFlip(p=0.5),  # Only horizontal flip
    transforms.ToTensor(),
    transforms.Normalize(mean=CONFIG['mean'], std=CONFIG['std'])
])

# Validation/Test transforms (no augmentation)
val_transform = transforms.Compose([
    transforms.Resize((CONFIG['image_size'], CONFIG['image_size'])),
    transforms.ToTensor(),
    transforms.Normalize(mean=CONFIG['mean'], std=CONFIG['std'])
])

print("=" * 70)
print("DATA TRANSFORMS")
print("=" * 70)
print("\n‚úÖ Training transforms:")
print("   1. Resize to 240√ó240")
print("   2. Random horizontal flip (50%)")
print("   3. ToTensor()")
print("   4. Normalize (ImageNet stats)")

print("\n‚úÖ Validation/Test transforms:")
print("   1. Resize to 240√ó240")
print("   2. ToTensor()")
print("   3. Normalize (ImageNet stats)")

print("\nüí° Note: Using minimal augmentation for baseline test.")
print("   More augmentation will be tested in 05_augmentation_test.ipynb")
print("\n" + "=" * 70)

## 8. Create Datasets and DataLoaders

Instantiate Dataset objects and DataLoaders.

In [None]:
print("=" * 70)
print("CREATING DATASETS AND DATALOADERS")
print("=" * 70)

# Create datasets
train_dataset = COVID19Dataset(train_df, transform=train_transform)
val_dataset = COVID19Dataset(val_df, transform=val_transform)
test_dataset = COVID19Dataset(test_df, transform=val_transform)

print(f"\n‚úÖ Datasets created:")
print(f"   - Train: {len(train_dataset):,} images")
print(f"   - Val:   {len(val_dataset):,} images")
print(f"   - Test:  {len(test_dataset):,} images")

# Create subset for quick testing (if enabled)
if CONFIG['test_on_subset']:
    subset_size = CONFIG['subset_size']
    
    # Random subset
    train_indices = np.random.choice(len(train_dataset), min(subset_size, len(train_dataset)), replace=False)
    val_indices = np.random.choice(len(val_dataset), min(subset_size//5, len(val_dataset)), replace=False)
    test_indices = np.random.choice(len(test_dataset), min(subset_size//5, len(test_dataset)), replace=False)
    
    train_dataset_use = Subset(train_dataset, train_indices)
    val_dataset_use = Subset(val_dataset, val_indices)
    test_dataset_use = Subset(test_dataset, test_indices)
    
    print(f"\n‚ö†Ô∏è  Using SUBSET for quick testing:")
    print(f"   - Train: {len(train_dataset_use):,} images (sampled)")
    print(f"   - Val:   {len(val_dataset_use):,} images (sampled)")
    print(f"   - Test:  {len(test_dataset_use):,} images (sampled)")
    print(f"\n   üí° To use FULL dataset, set CONFIG['test_on_subset'] = False")
else:
    train_dataset_use = train_dataset
    val_dataset_use = val_dataset
    test_dataset_use = test_dataset
    print(f"\n‚úÖ Using FULL dataset for training")

# Create DataLoaders
train_loader = DataLoader(
    train_dataset_use,
    batch_size=CONFIG['batch_size'],
    shuffle=True,
    num_workers=CONFIG['num_workers'],
    pin_memory=CONFIG['pin_memory'],
    persistent_workers=CONFIG['persistent_workers'],
    drop_last=True  # Drop incomplete batch
)

val_loader = DataLoader(
    val_dataset_use,
    batch_size=CONFIG['batch_size'],
    shuffle=False,
    num_workers=CONFIG['num_workers'],
    pin_memory=CONFIG['pin_memory'],
    persistent_workers=CONFIG['persistent_workers']
)

test_loader = DataLoader(
    test_dataset_use,
    batch_size=CONFIG['batch_size'],
    shuffle=False,
    num_workers=CONFIG['num_workers'],
    pin_memory=CONFIG['pin_memory'],
    persistent_workers=CONFIG['persistent_workers']
)

print(f"\n‚úÖ DataLoaders created:")
print(f"   - Train batches: {len(train_loader)}")
print(f"   - Val batches:   {len(val_loader)}")
print(f"   - Test batches:  {len(test_loader)}")
print(f"   - Batch size:    {CONFIG['batch_size']}")
print(f"   - Num workers:   {CONFIG['num_workers']}")

# Test loading one batch
print(f"\nüß™ Testing DataLoader...")
sample_images, sample_labels = next(iter(train_loader))
print(f"   ‚úì Sample batch shape: {sample_images.shape}")
print(f"   ‚úì Sample labels shape: {sample_labels.shape}")
print(f"   ‚úì Image value range: [{sample_images.min():.3f}, {sample_images.max():.3f}]")
print(f"   ‚úì Unique labels in batch: {sample_labels.unique().tolist()}")

print("\n" + "=" * 70)

## 9. Load ResNet-50 Model

Load pretrained ResNet-50 and modify for 4-class classification.

In [None]:
print("=" * 70)
print("LOADING RESNET-50 MODEL")
print("=" * 70)

# Load pretrained ResNet-50
model = models.resnet50(pretrained=CONFIG['pretrained'])

# Modify final layer for 4 classes
num_features = model.fc.in_features
model.fc = nn.Linear(num_features, CONFIG['num_classes'])

# Move to device
model = model.to(device)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"\n‚úÖ ResNet-50 loaded successfully")
print(f"   - Pretrained: {CONFIG['pretrained']}")
print(f"   - Total parameters: {total_params:,}")
print(f"   - Trainable parameters: {trainable_params:,}")
print(f"   - Model size: ~{total_params * 4 / 1e6:.2f} MB (FP32)")
print(f"   - Device: {device}")

# Test forward pass
print(f"\nüß™ Testing forward pass...")
model.eval()
with torch.no_grad():
    test_input = torch.randn(1, 3, CONFIG['image_size'], CONFIG['image_size']).to(device)
    test_output = model(test_input)

print(f"   ‚úì Input shape: {test_input.shape}")
print(f"   ‚úì Output shape: {test_output.shape}")
print(f"   ‚úì Expected: torch.Size([1, {CONFIG['num_classes']}])")

if test_output.shape == torch.Size([1, CONFIG['num_classes']]):
    print(f"\n‚úÖ Model configuration CORRECT for 4-class classification!")
else:
    print(f"\n‚ùå ERROR: Output shape mismatch!")

# Check GPU memory after loading model
if cuda_available:
    print_gpu_memory("\n  ")

print("\n" + "=" * 70)

## 10. Define Loss Function and Optimizer

Use weighted CrossEntropyLoss for class imbalance.

In [None]:
# Class weights for imbalanced dataset
class_weights = torch.tensor(CONFIG['class_weights'], dtype=torch.float32).to(device)

# Loss function with class weights
criterion = nn.CrossEntropyLoss(weight=class_weights)

# Optimizer
optimizer = optim.Adam(
    model.parameters(),
    lr=CONFIG['learning_rate'],
    weight_decay=CONFIG['weight_decay']
)

# Learning rate scheduler
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer,
    mode='min',
    factor=0.5,
    patience=5,
    verbose=True
)

# Mixed precision scaler (if enabled)
if CONFIG['mixed_precision'] and cuda_available:
    scaler = torch.cuda.amp.GradScaler()
    print("‚úÖ Using mixed precision training (FP16)")
else:
    scaler = None
    print("‚úÖ Using standard precision training (FP32)")

print("\n=" * 70)
print("TRAINING SETUP")
print("=" * 70)
print(f"\n‚úì Loss Function: CrossEntropyLoss (weighted)")
print(f"   - Class weights: {CONFIG['class_weights']}")
print(f"\n‚úì Optimizer: Adam")
print(f"   - Learning rate: {CONFIG['learning_rate']}")
print(f"   - Weight decay: {CONFIG['weight_decay']}")
print(f"\n‚úì Scheduler: ReduceLROnPlateau")
print(f"   - Factor: 0.5")
print(f"   - Patience: 5 epochs")
print(f"\n‚úì Mixed Precision: {CONFIG['mixed_precision']}")
print("\n" + "=" * 70)

## 11. Training and Validation Functions

Define training and validation loops with memory monitoring.

In [None]:
def train_one_epoch(model, loader, criterion, optimizer, device, scaler=None, epoch=0):
    """
    Train for one epoch.
    """
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    progress_bar = tqdm(loader, desc=f"Epoch {epoch+1} [Train]")
    
    for batch_idx, (images, labels) in enumerate(progress_bar):
        # Move to device
        images = images.to(device, non_blocking=True)
        labels = labels.to(device, non_blocking=True)
        
        # Zero gradients
        optimizer.zero_grad(set_to_none=True)
        
        # Forward pass with mixed precision
        if scaler is not None:
            with torch.cuda.amp.autocast():
                outputs = model(images)
                loss = criterion(outputs, labels)
            
            # Backward pass
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        else:
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
        
        # Statistics
        running_loss += loss.item()
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()
        
        # Update progress bar
        progress_bar.set_postfix({
            'loss': running_loss / (batch_idx + 1),
            'acc': 100. * correct / total
        })
        
        # Memory monitoring
        if cuda_available and batch_idx % CONFIG['memory_check_interval'] == 0:
            if batch_idx == 0:
                print_gpu_memory("\n  ")
        
        # Clear cache periodically
        if cuda_available and batch_idx % 10 == 0:
            torch.cuda.empty_cache()
    
    epoch_loss = running_loss / len(loader)
    epoch_acc = 100. * correct / total
    
    return epoch_loss, epoch_acc


def validate(model, loader, criterion, device, desc="Val"):
    """
    Validate model on validation/test set.
    """
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    
    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        progress_bar = tqdm(loader, desc=f"[{desc}]")
        
        for images, labels in progress_bar:
            images = images.to(device, non_blocking=True)
            labels = labels.to(device, non_blocking=True)
            
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
            
            # Store for metrics
            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
            
            # Update progress bar
            progress_bar.set_postfix({
                'loss': running_loss / (progress_bar.n + 1),
                'acc': 100. * correct / total
            })
    
    epoch_loss = running_loss / len(loader)
    epoch_acc = 100. * correct / total
    
    return epoch_loss, epoch_acc, np.array(all_preds), np.array(all_labels)


print("‚úÖ Training and validation functions defined")
print("\nüìù Features:")
print("   - Mixed precision support (FP16)")
print("   - GPU memory monitoring")
print("   - Progress bars (tqdm)")
print("   - Automatic cache clearing")
print("   - Returns predictions for metrics")

## 12. Training Loop

Train the model with early stopping.

In [None]:
print("=" * 70)
print("STARTING TRAINING")
print("=" * 70)

# Training history
history = {
    'train_loss': [],
    'train_acc': [],
    'val_loss': [],
    'val_acc': []
}

# Early stopping
best_val_loss = float('inf')
patience_counter = 0
best_model_path = MODELS_DIR / f"resnet50_best_seed{CONFIG['seed']}.pth"

# Start MLflow run
if MLFLOW_AVAILABLE:
    run_name = f"resnet50-baseline-seed-{CONFIG['seed']}"
    if CONFIG['test_on_subset']:
        run_name += "-SUBSET"
    
    mlflow.start_run(run_name=run_name)
    
    # Log parameters
    mlflow.log_param("model", CONFIG['model_name'])
    mlflow.log_param("random_seed", CONFIG['seed'])
    mlflow.log_param("batch_size", CONFIG['batch_size'])
    mlflow.log_param("learning_rate", CONFIG['learning_rate'])
    mlflow.log_param("weight_decay", CONFIG['weight_decay'])
    mlflow.log_param("max_epochs", CONFIG['max_epochs'])
    mlflow.log_param("pretrained", CONFIG['pretrained'])
    mlflow.log_param("mixed_precision", CONFIG['mixed_precision'])
    mlflow.log_param("test_on_subset", CONFIG['test_on_subset'])
    if CONFIG['test_on_subset']:
        mlflow.log_param("subset_size", CONFIG['subset_size'])
    mlflow.log_param("image_size", CONFIG['image_size'])
    mlflow.log_param("num_classes", CONFIG['num_classes'])
    mlflow.set_tag("phase", "Phase 1 - Baseline Test")
    mlflow.set_tag("status", "training")

print(f"\nüìä Training Configuration:")
print(f"   - Model: {CONFIG['model_name']}")
print(f"   - Max Epochs: {CONFIG['max_epochs']}")
print(f"   - Early Stopping Patience: {CONFIG['early_stopping_patience']}")
print(f"   - Device: {device}")
print(f"\nüöÄ Starting training...\n")

start_time = time.time()

try:
    for epoch in range(CONFIG['max_epochs']):
        print(f"\n{'='*70}")
        print(f"Epoch {epoch+1}/{CONFIG['max_epochs']}")
        print(f"{'='*70}")
        
        # Train
        train_loss, train_acc = train_one_epoch(
            model, train_loader, criterion, optimizer, device, scaler, epoch
        )
        
        # Validate
        val_loss, val_acc, _, _ = validate(
            model, val_loader, criterion, device, desc="Val"
        )
        
        # Update scheduler
        scheduler.step(val_loss)
        
        # Save history
        history['train_loss'].append(train_loss)
        history['train_acc'].append(train_acc)
        history['val_loss'].append(val_loss)
        history['val_acc'].append(val_acc)
        
        # Log to MLflow
        if MLFLOW_AVAILABLE:
            mlflow.log_metric("train_loss", train_loss, step=epoch)
            mlflow.log_metric("train_acc", train_acc, step=epoch)
            mlflow.log_metric("val_loss", val_loss, step=epoch)
            mlflow.log_metric("val_acc", val_acc, step=epoch)
            mlflow.log_metric("learning_rate", optimizer.param_groups[0]['lr'], step=epoch)
        
        # Print epoch summary
        print(f"\nüìä Epoch {epoch+1} Summary:")
        print(f"   Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}%")
        print(f"   Val Loss:   {val_loss:.4f} | Val Acc:   {val_acc:.2f}%")
        print(f"   LR: {optimizer.param_groups[0]['lr']:.2e}")
        
        # Early stopping check
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            patience_counter = 0
            
            # Save best model
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'val_loss': val_loss,
                'val_acc': val_acc,
                'config': CONFIG
            }, best_model_path)
            
            print(f"   ‚úÖ New best model saved! (Val Loss: {val_loss:.4f})")
        else:
            patience_counter += 1
            print(f"   ‚è≥ Patience: {patience_counter}/{CONFIG['early_stopping_patience']}")
            
            if patience_counter >= CONFIG['early_stopping_patience']:
                print(f"\n‚èπÔ∏è  Early stopping triggered!")
                break

except KeyboardInterrupt:
    print("\n‚ö†Ô∏è  Training interrupted by user")
except Exception as e:
    print(f"\n‚ùå ERROR during training: {e}")
    import traceback
    traceback.print_exc()

training_time = time.time() - start_time

print(f"\n{'='*70}")
print(f"TRAINING COMPLETED")
print(f"{'='*70}")
print(f"\n‚è±Ô∏è  Total Training Time: {training_time/60:.2f} minutes")
print(f"‚úÖ Best model saved to: {best_model_path}")
print(f"\nüìä Best Validation Loss: {best_val_loss:.4f}")

if MLFLOW_AVAILABLE:
    mlflow.log_metric("training_time_minutes", training_time/60)
    mlflow.log_metric("best_val_loss", best_val_loss)
    mlflow.set_tag("status", "completed")

## 13. Plot Training History

Visualize training and validation curves.

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(15, 5))

epochs_range = range(1, len(history['train_loss']) + 1)

# Loss plot
axes[0].plot(epochs_range, history['train_loss'], 'b-', label='Train Loss', linewidth=2)
axes[0].plot(epochs_range, history['val_loss'], 'r-', label='Val Loss', linewidth=2)
axes[0].set_xlabel('Epoch', fontweight='bold')
axes[0].set_ylabel('Loss', fontweight='bold')
axes[0].set_title('Training and Validation Loss', fontweight='bold', fontsize=14)
axes[0].legend()
axes[0].grid(alpha=0.3)

# Accuracy plot
axes[1].plot(epochs_range, history['train_acc'], 'b-', label='Train Acc', linewidth=2)
axes[1].plot(epochs_range, history['val_acc'], 'r-', label='Val Acc', linewidth=2)
axes[1].set_xlabel('Epoch', fontweight='bold')
axes[1].set_ylabel('Accuracy (%)', fontweight='bold')
axes[1].set_title('Training and Validation Accuracy', fontweight='bold', fontsize=14)
axes[1].legend()
axes[1].grid(alpha=0.3)

plt.tight_layout()
plt.savefig(RESULTS_DIR / 'resnet50_training_history.png', dpi=300, bbox_inches='tight')
plt.show()

print("‚úÖ Training history plot saved")

# Log to MLflow
if MLFLOW_AVAILABLE:
    mlflow.log_artifact(str(RESULTS_DIR / 'resnet50_training_history.png'))

## 14. Evaluate on Test Set

Load best model and evaluate on test set.

In [None]:
print("=" * 70)
print("EVALUATING ON TEST SET")
print("=" * 70)

# Load best model
checkpoint = torch.load(best_model_path)
model.load_state_dict(checkpoint['model_state_dict'])
print(f"\n‚úÖ Loaded best model from epoch {checkpoint['epoch']+1}")

# Evaluate
test_loss, test_acc, test_preds, test_labels = validate(
    model, test_loader, criterion, device, desc="Test"
)

print(f"\nüìä Test Set Results:")
print(f"   - Test Loss: {test_loss:.4f}")
print(f"   - Test Accuracy: {test_acc:.2f}%")

# Detailed classification report
print(f"\nüìã Classification Report:\n")
print(classification_report(
    test_labels, 
    test_preds, 
    target_names=CONFIG['class_names'],
    digits=4
))

# Confusion matrix
cm = confusion_matrix(test_labels, test_preds)
print(f"\nüìä Confusion Matrix:\n")
print(cm)

# Log to MLflow
if MLFLOW_AVAILABLE:
    mlflow.log_metric("test_loss", test_loss)
    mlflow.log_metric("test_accuracy", test_acc)
    
    # Calculate per-class metrics
    precision, recall, f1, support = precision_recall_fscore_support(
        test_labels, test_preds, average=None
    )
    
    for i, class_name in enumerate(CONFIG['class_names']):
        mlflow.log_metric(f"test_precision_{class_name}", precision[i])
        mlflow.log_metric(f"test_recall_{class_name}", recall[i])
        mlflow.log_metric(f"test_f1_{class_name}", f1[i])

print("\n" + "=" * 70)

## 15. Visualize Confusion Matrix

Create publication-quality confusion matrix.

In [None]:
# Create confusion matrix visualization
plt.figure(figsize=(10, 8))
sns.heatmap(
    cm, 
    annot=True, 
    fmt='d', 
    cmap='Blues',
    xticklabels=CONFIG['class_names'],
    yticklabels=CONFIG['class_names'],
    cbar_kws={'label': 'Count'}
)
plt.ylabel('True Label', fontweight='bold', fontsize=12)
plt.xlabel('Predicted Label', fontweight='bold', fontsize=12)
plt.title(f"Confusion Matrix - {CONFIG['model_name']} (Test Set)", fontweight='bold', fontsize=14, pad=20)
plt.tight_layout()
plt.savefig(RESULTS_DIR / 'resnet50_confusion_matrix.png', dpi=300, bbox_inches='tight')
plt.show()

print("‚úÖ Confusion matrix saved")

# Log to MLflow
if MLFLOW_AVAILABLE:
    mlflow.log_artifact(str(RESULTS_DIR / 'resnet50_confusion_matrix.png'))

## 16. Summary Report

Final summary of baseline test results.

In [None]:
print("\n" + "=" * 70)
print("BASELINE MODEL TEST - SUMMARY REPORT")
print("=" * 70)

print("\n‚úÖ COMPLETED TASKS:")
print("   [‚úì] Created PyTorch Dataset and DataLoader")
print("   [‚úì] Loaded ResNet-50 baseline model")
print("   [‚úì] Trained on preprocessed CLAHE-enhanced images")
print("   [‚úì] Applied class weights for imbalance")
print("   [‚úì] Used mixed precision training (FP16)")
print("   [‚úì] Monitored GPU memory usage")
print("   [‚úì] Saved best model checkpoint")
print("   [‚úì] Logged results to MLflow")

print("\nüìä FINAL RESULTS:")
print(f"   - Model: {CONFIG['model_name']}")
print(f"   - Test Accuracy: {test_acc:.2f}%")
print(f"   - Test Loss: {test_loss:.4f}")
print(f"   - Training Time: {training_time/60:.2f} minutes")
print(f"   - Best Val Loss: {best_val_loss:.4f}")

if CONFIG['test_on_subset']:
    print(f"\n‚ö†Ô∏è  SUBSET MODE:")
    print(f"   - Trained on {len(train_dataset_use):,} images (subset)")
    print(f"   - To train on FULL dataset: Set CONFIG['test_on_subset'] = False")
else:
    print(f"\n‚úÖ FULL DATASET MODE:")
    print(f"   - Trained on {len(train_dataset_use):,} images (full dataset)")

print("\nüéØ PHASE 1 STATUS:")
if test_acc >= 70.0:
    print(f"   ‚úÖ SUCCESS: Achieved {test_acc:.2f}% accuracy (>70% target)")
    print(f"   ‚úÖ Training pipeline verified and working!")
    print(f"   ‚úÖ Ready to move to Phase 2 (Systematic Experimentation)")
else:
    print(f"   ‚ö†Ô∏è  WARNING: Achieved {test_acc:.2f}% accuracy (<70% target)")
    print(f"   üí° Suggestions:")
    print(f"      - Check if using subset mode (set test_on_subset=False)")
    print(f"      - Train for more epochs")
    print(f"      - Verify data preprocessing")

print("\nüìÅ OUTPUT FILES:")
print(f"   - Best model: {best_model_path}")
print(f"   - Training history: {RESULTS_DIR / 'resnet50_training_history.png'}")
print(f"   - Confusion matrix: {RESULTS_DIR / 'resnet50_confusion_matrix.png'}")

if MLFLOW_AVAILABLE:
    print(f"\nüìä MLFLOW:")
    print(f"   - Experiment: crossvit-covid19-classification")
    print(f"   - Run name: resnet50-baseline-seed-{CONFIG['seed']}")
    print(f"   - View results: mlflow ui ‚Üí http://localhost:5000")

print("\nüéØ NEXT STEPS:")
if test_acc >= 70.0:
    print("   1. Optional: Create 05_augmentation_test.ipynb to test augmentation strategies")
    print("   2. Or skip to Phase 2: Start systematic experiments (notebooks 06-11)")
    print("   3. Train all 6 models with 5 seeds each (30 total runs)")
    print("   4. Use MLflow to track all experiments")
else:
    print("   1. Re-run with CONFIG['test_on_subset'] = False (if using subset)")
    print("   2. Increase max_epochs to 50")
    print("   3. Debug any preprocessing issues")
    print("   4. Achieve >70% before moving to Phase 2")

print("\n‚úÖ Baseline model test complete! Phase 1 finished.")
print("=" * 70 + "\n")

# End MLflow run
if MLFLOW_AVAILABLE:
    mlflow.end_run()
    print("‚úÖ MLflow run ended successfully")