# 06 - CrossViT Training (Phase 2)

**Author:** Tan Ming Kai (24PMR12003)  
**Date:** 2025-11-12  
**Purpose:** Train CrossViT-Tiny with 5 random seeds for statistical validation

**Project:** Multi-Scale Vision Transformer (CrossViT) for COVID-19 Chest X-ray Classification  
**Academic Year:** 2025/26

---

## Objectives
1. ‚úÖ Train CrossViT-Tiny with 5 different random seeds (42, 123, 456, 789, 101112)
2. ‚úÖ Log all runs to MLflow for experiment tracking
3. ‚úÖ Save model checkpoints and confusion matrices
4. ‚úÖ Calculate mean ¬± std accuracy across seeds
5. ‚úÖ Generate results table for thesis Chapter 5

---

## Phase 2: Systematic Experimentation

This notebook is part of Phase 2, where we train ALL 6 models with 5 seeds each (30 total runs).

## 1. Reproducibility Setup & Imports

In [None]:
"""
CrossViT Training Notebook for COVID-19 FYP
Author: Tan Ming Kai (24PMR12003)
Purpose: Train CrossViT-Tiny with multiple random seeds for statistical validation
"""

# ============================================================================
# 1. STANDARD LIBRARY IMPORTS
# ============================================================================
import os
import sys
from pathlib import Path
import warnings
import time
from datetime import datetime
import random
warnings.filterwarnings('ignore')

# ============================================================================
# 2. DATA SCIENCE LIBRARIES
# ============================================================================
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm

# Configure display
pd.set_option('display.max_columns', None)
plt.style.use('seaborn-v0_8-darkgrid')
sns.set_palette('husl')

# ============================================================================
# 3. PYTORCH & DEEP LEARNING
# ============================================================================
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

# ============================================================================
# 4. TIMM (PyTorch Image Models) for CrossViT
# ============================================================================
import timm

# ============================================================================
# 5. COMPUTER VISION
# ============================================================================
import cv2
from PIL import Image

# ============================================================================
# 6. MLFLOW (Experiment Tracking)
# ============================================================================
try:
    import mlflow
    import mlflow.pytorch
    MLFLOW_AVAILABLE = True
    print("‚úÖ MLflow available for experiment tracking")
except ImportError:
    MLFLOW_AVAILABLE = False
    print("‚ö†Ô∏è  MLflow not installed. Install with: pip install mlflow")
    print("   Continuing without experiment tracking...")

# ============================================================================
# 7. SKLEARN (Metrics)
# ============================================================================
from sklearn.metrics import (
    accuracy_score, precision_recall_fscore_support,
    confusion_matrix, classification_report
)

print("\n‚úÖ All imports successful!")
print(f"PyTorch version: {torch.__version__}")
print(f"Timm version: {timm.__version__}")
print(f"NumPy version: {np.__version__}")
print(f"Pandas version: {pd.__version__}")

## 2. Hardware Verification

In [None]:
print("=" * 70)
print("HARDWARE VERIFICATION")
print("=" * 70)

# Check CUDA
cuda_available = torch.cuda.is_available()
device = torch.device('cuda' if cuda_available else 'cpu')

print(f"\n‚úì CUDA Available: {cuda_available}")
print(f"‚úì Using Device: {device}")

if cuda_available:
    gpu_name = torch.cuda.get_device_name(0)
    total_memory = torch.cuda.get_device_properties(0).total_memory / 1e9
    
    print(f"‚úì GPU: {gpu_name}")
    print(f"‚úì Total VRAM: {total_memory:.2f} GB")
    print(f"‚úì CUDA Version: {torch.version.cuda}")
    
    # Memory monitoring function
    def print_gpu_memory(prefix=""):
        allocated = torch.cuda.memory_allocated(0) / 1e9
        reserved = torch.cuda.memory_reserved(0) / 1e9
        free = total_memory - reserved
        print(f"{prefix}GPU Memory: Allocated={allocated:.3f}GB | Reserved={reserved:.3f}GB | Free={free:.3f}GB")
    
    print_gpu_memory("\n  ")
    
    if "4060" in gpu_name and 7.0 <= total_memory <= 9.0:
        print("\n‚úÖ CONFIRMED: RTX 4060 8GB detected - Ready for CrossViT training!")
    else:
        print(f"\n‚ö†Ô∏è  Different GPU detected: {gpu_name}")
        print("   Adjust batch size if needed based on VRAM.")
else:
    print("\n‚ùå WARNING: No GPU detected! Training will be VERY slow.")
    print("   Please ensure CUDA drivers and PyTorch with CUDA are installed.")

print("\n" + "=" * 70)

## 3. Configuration

**CRITICAL:** These hyperparameters are FIXED per CLAUDE.md specifications.

In [None]:
# Paths
CSV_DIR = Path("../data/processed")
PROCESSED_IMG_DIR = Path("../data/processed/clahe_enhanced")
MODELS_DIR = Path("../models")
RESULTS_DIR = Path("../results")

# Create directories
MODELS_DIR.mkdir(parents=True, exist_ok=True)
RESULTS_DIR.mkdir(parents=True, exist_ok=True)

# Training configuration (FIXED per CLAUDE.md)
BASE_CONFIG = {
    # Device
    'device': device,
    
    # Model
    'model_name': 'CrossViT-Tiny',
    'timm_model': 'crossvit_tiny_240',
    'num_classes': 4,
    'pretrained': True,
    
    # Data
    'image_size': 240,
    'class_names': ['COVID', 'Normal', 'Lung_Opacity', 'Viral Pneumonia'],
    'class_weights': [1.47, 0.52, 0.88, 3.95],  # From EDA
    
    # Training hyperparameters - FIXED per CLAUDE.md
    'batch_size': 8,  # Reduced for CrossViT (safer on 8GB VRAM)
    'gradient_accumulation_steps': 4,  # Effective batch size = 32
    'num_workers': 0,  # Must be 0 on Windows
    'pin_memory': False,  # Disabled on Windows
    'persistent_workers': False,
    
    # Optimizer - FIXED per CLAUDE.md
    'learning_rate': 5e-5,  # CrossViT-specific
    'weight_decay': 0.05,  # CrossViT-specific
    'max_epochs': 50,
    'early_stopping_patience': 15,
    
    # ImageNet normalization
    'mean': [0.485, 0.456, 0.406],
    'std': [0.229, 0.224, 0.225],
    
    # Memory management
    'mixed_precision': True,
    
    # Multi-seed experiment
    'seeds': [42, 123, 456, 789, 101112],  # 5 seeds for statistical validation
}

print("=" * 70)
print("CROSSVIT TRAINING CONFIGURATION (FIXED HYPERPARAMETERS)")
print("=" * 70)
print(f"\n‚úì Model: {BASE_CONFIG['model_name']}")
print(f"‚úì Timm Model: {BASE_CONFIG['timm_model']}")
print(f"‚úì Device: {BASE_CONFIG['device']}")
print(f"‚úì Batch Size: {BASE_CONFIG['batch_size']} (gradient accumulation: {BASE_CONFIG['gradient_accumulation_steps']})")
print(f"‚úì Effective Batch Size: {BASE_CONFIG['batch_size'] * BASE_CONFIG['gradient_accumulation_steps']}")
print(f"‚úì Learning Rate: {BASE_CONFIG['learning_rate']}")
print(f"‚úì Weight Decay: {BASE_CONFIG['weight_decay']}")
print(f"‚úì Max Epochs: {BASE_CONFIG['max_epochs']}")
print(f"‚úì Early Stopping Patience: {BASE_CONFIG['early_stopping_patience']}")
print(f"‚úì Image Size: {BASE_CONFIG['image_size']}√ó{BASE_CONFIG['image_size']}")
print(f"‚úì Mixed Precision: {BASE_CONFIG['mixed_precision']}")
print(f"\n‚úì Random Seeds: {BASE_CONFIG['seeds']}")
print(f"  ‚Üí Will train {len(BASE_CONFIG['seeds'])} times for statistical validation")
print("\n‚ö†Ô∏è  IMPORTANT: These hyperparameters are FIXED per CLAUDE.md")
print("   Do not modify unless explicitly required.")
print("\n" + "=" * 70)

## 4. MLflow Setup

In [None]:
print("=" * 70)
print("MLFLOW EXPERIMENT TRACKING SETUP")
print("=" * 70)

if MLFLOW_AVAILABLE:
    # Set experiment name
    mlflow.set_experiment("crossvit-covid19-classification")
    
    # Set tracking URI (local directory)
    mlflow.set_tracking_uri("file:./mlruns")
    
    print("\n‚úÖ MLflow configured:")
    print(f"   - Experiment: crossvit-covid19-classification")
    print(f"   - Tracking URI: {mlflow.get_tracking_uri()}")
    print(f"\nüí° View results: Run 'mlflow ui' in terminal, then open http://localhost:5000")
else:
    print("\n‚ö†Ô∏è  MLflow not available. Results will not be logged.")
    print("   Install with: pip install mlflow")

print("\n" + "=" * 70)

## 5. Load Data Splits

In [None]:
print("=" * 70)
print("LOADING DATA SPLITS")
print("=" * 70)

# Load processed CSV files
train_df = pd.read_csv(CSV_DIR / "train_processed.csv")
val_df = pd.read_csv(CSV_DIR / "val_processed.csv")
test_df = pd.read_csv(CSV_DIR / "test_processed.csv")

print(f"\n‚úÖ CSV files loaded:")
print(f"   - Train: {len(train_df):,} images")
print(f"   - Val:   {len(val_df):,} images")
print(f"   - Test:  {len(test_df):,} images")

print("\nüìä Class Distribution in Training Set:")
class_counts = train_df['class_name'].value_counts()
for class_name, count in class_counts.items():
    pct = count / len(train_df) * 100
    print(f"   {class_name:20s}: {count:5d} ({pct:5.2f}%)")

print("\n" + "=" * 70)

## 6. Create PyTorch Dataset

In [None]:
class COVID19Dataset(Dataset):
    """
    PyTorch Dataset for COVID-19 chest X-ray classification.
    
    Loads CLAHE-enhanced images (240√ó240√ó3 RGB) from preprocessed directory.
    """
    
    def __init__(self, dataframe, transform=None):
        """
        Args:
            dataframe (pd.DataFrame): DataFrame with 'processed_path' and 'label' columns
            transform (callable, optional): Transformations to apply to images
        """
        self.dataframe = dataframe.reset_index(drop=True)
        self.transform = transform
        
        # Extract paths and labels
        self.image_paths = self.dataframe['processed_path'].values
        self.labels = self.dataframe['label'].values
        
    def __len__(self):
        return len(self.dataframe)
    
    def __getitem__(self, idx):
        """
        Load and return image and label at index idx.
        """
        # Load image (BGR format from cv2)
        img_path = self.image_paths[idx]
        image = cv2.imread(img_path)
        
        if image is None:
            raise FileNotFoundError(f"Image not found: {img_path}")
        
        # Convert BGR to RGB
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        
        # Convert to PIL Image for torchvision transforms
        image = Image.fromarray(image)
        
        # Apply transformations
        if self.transform:
            image = self.transform(image)
        
        # Get label
        label = torch.tensor(self.labels[idx], dtype=torch.long)
        
        return image, label


print("‚úÖ COVID19Dataset class defined")

## 7. Define Data Transforms

Using **conservative augmentation** as per CLAUDE.md.

In [None]:
# Training transforms (conservative augmentation per CLAUDE.md)
train_transform = transforms.Compose([
    transforms.Resize((BASE_CONFIG['image_size'], BASE_CONFIG['image_size'])),
    transforms.RandomRotation(10),  # ¬±10¬∞ only
    transforms.RandomHorizontalFlip(0.5),  # NO vertical flip
    transforms.ColorJitter(brightness=0.1, contrast=0.1),
    transforms.ToTensor(),
    transforms.Normalize(mean=BASE_CONFIG['mean'], std=BASE_CONFIG['std'])
])

# Validation/Test transforms (no augmentation)
val_transform = transforms.Compose([
    transforms.Resize((BASE_CONFIG['image_size'], BASE_CONFIG['image_size'])),
    transforms.ToTensor(),
    transforms.Normalize(mean=BASE_CONFIG['mean'], std=BASE_CONFIG['std'])
])

print("‚úÖ Data transforms defined (Conservative augmentation per CLAUDE.md)")

## 8. Training and Validation Functions

In [None]:
def set_seed(seed):
    """
    Set all random seeds for reproducibility.
    """
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False


def train_one_epoch(model, loader, criterion, optimizer, device, scaler=None, 
                   gradient_accumulation_steps=1, epoch=0):
    """
    Train for one epoch with gradient accumulation.
    """
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    progress_bar = tqdm(loader, desc=f"Epoch {epoch+1} [Train]")
    
    optimizer.zero_grad(set_to_none=True)
    
    for batch_idx, (images, labels) in enumerate(progress_bar):
        # Move to device
        images = images.to(device, non_blocking=True)
        labels = labels.to(device, non_blocking=True)
        
        # Forward pass with mixed precision
        if scaler is not None:
            with torch.cuda.amp.autocast():
                outputs = model(images)
                loss = criterion(outputs, labels)
                loss = loss / gradient_accumulation_steps  # Scale loss
            
            # Backward pass
            scaler.scale(loss).backward()
            
            # Update weights every gradient_accumulation_steps
            if (batch_idx + 1) % gradient_accumulation_steps == 0:
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad(set_to_none=True)
        else:
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss = loss / gradient_accumulation_steps
            loss.backward()
            
            if (batch_idx + 1) % gradient_accumulation_steps == 0:
                optimizer.step()
                optimizer.zero_grad(set_to_none=True)
        
        # Statistics
        running_loss += loss.item() * gradient_accumulation_steps
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()
        
        # Update progress bar
        progress_bar.set_postfix({
            'loss': running_loss / (batch_idx + 1),
            'acc': 100. * correct / total
        })
    
    epoch_loss = running_loss / len(loader)
    epoch_acc = 100. * correct / total
    
    return epoch_loss, epoch_acc


def validate(model, loader, criterion, device, desc="Val"):
    """
    Validate model on validation/test set.
    """
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    
    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        progress_bar = tqdm(loader, desc=f"[{desc}]")
        
        for images, labels in progress_bar:
            images = images.to(device, non_blocking=True)
            labels = labels.to(device, non_blocking=True)
            
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
            
            # Store for metrics
            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
            
            # Update progress bar
            progress_bar.set_postfix({
                'loss': running_loss / (progress_bar.n + 1),
                'acc': 100. * correct / total
            })
    
    epoch_loss = running_loss / len(loader)
    epoch_acc = 100. * correct / total
    
    return epoch_loss, epoch_acc, np.array(all_preds), np.array(all_labels)


print("‚úÖ Training and validation functions defined")

## 9. Single Seed Training Function

In [None]:
def train_crossvit_single_seed(seed, base_config):
    """
    Train CrossViT with a single random seed.
    
    Returns:
        dict: Results containing test_acc, test_loss, confusion_matrix, etc.
    """
    print("\n" + "=" * 70)
    print(f"TRAINING CROSSVIT WITH SEED {seed}")
    print("=" * 70)
    
    # Set seed
    set_seed(seed)
    print(f"\n‚úÖ Random seed set to {seed}")
    
    # Create datasets and dataloaders
    train_dataset = COVID19Dataset(train_df, transform=train_transform)
    val_dataset = COVID19Dataset(val_df, transform=val_transform)
    test_dataset = COVID19Dataset(test_df, transform=val_transform)
    
    train_loader = DataLoader(
        train_dataset,
        batch_size = 88
        shuffle=True,
        num_workers=base_config['num_workers'],
        pin_memory=base_config['pin_memory'],
        drop_last=True
    )
    
    val_loader = DataLoader(
        val_dataset,
        batch_size = 88
        shuffle=False,
        num_workers=base_config['num_workers'],
        pin_memory=base_config['pin_memory']
    )
    
    test_loader = DataLoader(
        test_dataset,
        batch_size = 88
        shuffle=False,
        num_workers=base_config['num_workers'],
        pin_memory=base_config['pin_memory']
    )
    
    print(f"‚úÖ DataLoaders created: {len(train_loader)} train batches")
    
    # Load CrossViT model
    model = timm.create_model(
        base_config['timm_model'],
        pretrained=base_config['pretrained'],
        num_classes=base_config['num_classes']
    )
    model = model.to(device)
    
    total_params = sum(p.numel() for p in model.parameters())
    print(f"‚úÖ CrossViT loaded: {total_params:,} parameters")
    
    # Loss, optimizer, scheduler
    class_weights = torch.tensor(base_config['class_weights'], dtype=torch.float32).to(device)
    criterion = nn.CrossEntropyLoss(weight=class_weights)
    
    optimizer = optim.AdamW(
        model.parameters(),
        lr=base_config['learning_rate'],
        weight_decay=base_config['weight_decay']
    )
    
    scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(
        optimizer,
        T_0=10,
        T_mult=2
    )
    
    # Mixed precision scaler
    scaler = torch.cuda.amp.GradScaler() if base_config['mixed_precision'] else None
    
    # Start MLflow run
    if MLFLOW_AVAILABLE:
        run_name = f"crossvit-seed-{seed}"
        mlflow.start_run(run_name=run_name)
        
        # Log parameters
        mlflow.log_param("model", base_config['model_name'])
        mlflow.log_param("random_seed", seed)
        mlflow.log_param("batch_size", base_config['batch_size'])
        mlflow.log_param("gradient_accumulation_steps", base_config['gradient_accumulation_steps'])
        mlflow.log_param("effective_batch_size", base_config['batch_size'] * base_config['gradient_accumulation_steps'])
        mlflow.log_param("learning_rate", base_config['learning_rate'])
        mlflow.log_param("weight_decay", base_config['weight_decay'])
        mlflow.log_param("max_epochs", base_config['max_epochs'])
        mlflow.set_tag("phase", "Phase 2 - Systematic Experimentation")
    
    # Training loop
    best_val_loss = float('inf')
    patience_counter = 0
    best_model_path = MODELS_DIR / f"crossvit_best_seed{seed}.pth"
    
    history = {
        'train_loss': [],
        'train_acc': [],
        'val_loss': [],
        'val_acc': []
    }
    
    start_time = time.time()
    
    for epoch in range(base_config['max_epochs']):
        # Train
        train_loss, train_acc = train_one_epoch(
            model, train_loader, criterion, optimizer, device, scaler,
            base_config['gradient_accumulation_steps'], epoch
        )
        
        # Validate
        val_loss, val_acc, _, _ = validate(model, val_loader, criterion, device)
        
        # Update scheduler
        scheduler.step()
        
        # Save history
        history['train_loss'].append(train_loss)
        history['train_acc'].append(train_acc)
        history['val_loss'].append(val_loss)
        history['val_acc'].append(val_acc)
        
        # Log to MLflow
        if MLFLOW_AVAILABLE:
            mlflow.log_metric("train_loss", train_loss, step=epoch)
            mlflow.log_metric("train_acc", train_acc, step=epoch)
            mlflow.log_metric("val_loss", val_loss, step=epoch)
            mlflow.log_metric("val_acc", val_acc, step=epoch)
        
        print(f"\nEpoch {epoch+1}: Train Loss={train_loss:.4f} | Val Loss={val_loss:.4f} | Val Acc={val_acc:.2f}%")
        
        # Early stopping
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            patience_counter = 0
            torch.save(model.state_dict(), best_model_path)
            print(f"‚úÖ Best model saved!")
        else:
            patience_counter += 1
            if patience_counter >= base_config['early_stopping_patience']:
                print(f"\n‚èπÔ∏è  Early stopping at epoch {epoch+1}")
                break
    
    training_time = time.time() - start_time
    
    # Load best model and evaluate on test set
    model.load_state_dict(torch.load(best_model_path))
    test_loss, test_acc, test_preds, test_labels = validate(
        model, test_loader, criterion, device, desc="Test"
    )
    
    # Confusion matrix
    cm = confusion_matrix(test_labels, test_preds)
    
    # Save confusion matrix plot
    plt.figure(figsize=(10, 8))
    sns.heatmap(
        cm, annot=True, fmt='d', cmap='Blues',
        xticklabels=base_config['class_names'],
        yticklabels=base_config['class_names']
    )
    plt.ylabel('True Label', fontweight='bold')
    plt.xlabel('Predicted Label', fontweight='bold')
    plt.title(f"CrossViT Confusion Matrix (Seed {seed})", fontweight='bold')
    plt.tight_layout()
    cm_path = RESULTS_DIR / f"crossvit_cm_seed{seed}.png"
    plt.savefig(cm_path, dpi=300, bbox_inches='tight')
    plt.close()
    
    # Log final results
    if MLFLOW_AVAILABLE:
        mlflow.log_metric("test_loss", test_loss)
        mlflow.log_metric("test_accuracy", test_acc)
        mlflow.log_metric("training_time_minutes", training_time / 60)
        mlflow.log_artifact(str(cm_path))
        mlflow.end_run()
    
    print(f"\n‚úÖ Seed {seed} complete: Test Acc = {test_acc:.2f}%")
    
    return {
        'seed': seed,
        'test_acc': test_acc,
        'test_loss': test_loss,
        'confusion_matrix': cm,
        'training_time': training_time,
        'history': history
    }


print("‚úÖ Single seed training function defined")

## 10. Train CrossViT with All Seeds

**This will train 5 times (seeds: 42, 123, 456, 789, 101112)**

In [None]:
print("="  * 70)
print("STARTING MULTI-SEED CROSSVIT TRAINING")
print("=" * 70)
print(f"\nüìä Will train CrossViT {len(BASE_CONFIG['seeds'])} times with different seeds")
print(f"   Seeds: {BASE_CONFIG['seeds']}")
print(f"\n‚è±Ô∏è  Estimated time: ~2-3 hours per seed (~10-15 hours total)")
print(f"\nüöÄ Starting training...\n")

# Train with all seeds
all_results = []

for seed in BASE_CONFIG['seeds']:
    try:
        result = train_crossvit_single_seed(seed, BASE_CONFIG)
        all_results.append(result)
    except Exception as e:
        print(f"\n‚ùå ERROR training with seed {seed}: {e}")
        import traceback
        traceback.print_exc()
        continue

print("\n" + "=" * 70)
print("ALL SEEDS TRAINING COMPLETED")
print("=" * 70)

## 11. Statistical Analysis

Calculate mean ¬± std accuracy across all seeds.

In [None]:
print("\n" + "=" * 70)
print("STATISTICAL ANALYSIS")
print("=" * 70)

# Extract accuracies
accuracies = [r['test_acc'] for r in all_results]
seeds = [r['seed'] for r in all_results]

# Calculate statistics
mean_acc = np.mean(accuracies)
std_acc = np.std(accuracies, ddof=1)  # Sample std
min_acc = np.min(accuracies)
max_acc = np.max(accuracies)

print(f"\nüìä CrossViT Test Accuracy (5 seeds):")
print(f"   Mean ¬± Std: {mean_acc:.2f}% ¬± {std_acc:.2f}%")
print(f"   Range: [{min_acc:.2f}%, {max_acc:.2f}%]")
print(f"\nüìã Individual Results:")
for seed, acc in zip(seeds, accuracies):
    print(f"   Seed {seed:6d}: {acc:.2f}%")

# Create results DataFrame
results_df = pd.DataFrame({
    'Model': ['CrossViT-Tiny'] * len(all_results),
    'Seed': seeds,
    'Test Accuracy (%)': accuracies,
    'Test Loss': [r['test_loss'] for r in all_results],
    'Training Time (min)': [r['training_time'] / 60 for r in all_results]
})

# Save to CSV
results_path = RESULTS_DIR / "crossvit_results.csv"
results_df.to_csv(results_path, index=False)
print(f"\n‚úÖ Results saved to: {results_path}")

# Display table
print(f"\n{results_df.to_string(index=False)}")

print("\n" + "=" * 70)

## 12. Summary Report

In [None]:
print("\n" + "=" * 70)
print("CROSSVIT TRAINING - SUMMARY REPORT")
print("=" * 70)

print("\n‚úÖ COMPLETED:")
print(f"   [‚úì] Trained CrossViT-Tiny with {len(BASE_CONFIG['seeds'])} random seeds")
print(f"   [‚úì] Logged all runs to MLflow")
print(f"   [‚úì] Saved {len(all_results)} model checkpoints")
print(f"   [‚úì] Generated {len(all_results)} confusion matrices")
print(f"   [‚úì] Calculated statistics (mean ¬± std)")

print(f"\nüìä FINAL STATISTICS:")
print(f"   CrossViT-Tiny: {mean_acc:.2f}% ¬± {std_acc:.2f}%")

print(f"\nüìÅ OUTPUT FILES:")
print(f"   - Results CSV: {results_path}")
print(f"   - Model checkpoints: {MODELS_DIR}/crossvit_best_seed*.pth")
print(f"   - Confusion matrices: {RESULTS_DIR}/crossvit_cm_seed*.png")

if MLFLOW_AVAILABLE:
    print(f"\nüìä MLFLOW:")
    print(f"   - View results: mlflow ui ‚Üí http://localhost:5000")
    print(f"   - Experiment: crossvit-covid19-classification")
    print(f"   - Total runs logged: {len(all_results)}")

print(f"\nüéØ NEXT STEPS:")
print(f"   1. Train baseline models (notebooks 07-11)")
print(f"   2. Complete all 30 runs (6 models √ó 5 seeds)")
print(f"   3. Move to Phase 3: Statistical validation")
print(f"   4. Compare CrossViT vs baselines with hypothesis testing")

print(f"\n‚úÖ CrossViT training complete! 1/6 models done.")
print("=" * 70 + "\n")