# 18 - CrossViT Enhancement (Phase 2)

**Author:** Tan Ming Kai (24PMR12003)  
**Date:** 2025-11-26  
**Purpose:** Train enhanced CrossViT models with 24GB VRAM

**Models:**
- CrossViT-Base (105M parameters) - HIGHEST PRIORITY
- CrossViT-Small (26M parameters) - Backup

**Expected Results:**
- CrossViT-Tiny (current): 94.96%
- CrossViT-Base: 96-97% (target)

---

## Objectives
1. Train CrossViT-Base with 5 random seeds
2. Train CrossViT-Small with 5 random seeds (if time permits)
3. Implement ensemble prediction (5 models averaged)
4. Implement Test-Time Augmentation (TTA)
5. Combine Ensemble + TTA for maximum accuracy

---

In [None]:
# Standard imports
import os, sys, random, time, warnings
from pathlib import Path
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

import timm
import cv2
from PIL import Image

try:
    import mlflow
    import mlflow.pytorch
    MLFLOW_AVAILABLE = True
except ImportError:
    MLFLOW_AVAILABLE = False

from sklearn.metrics import accuracy_score, precision_recall_fscore_support, confusion_matrix, classification_report

warnings.filterwarnings('ignore')
plt.style.use('seaborn-v0_8-darkgrid')

print(f"PyTorch {torch.__version__} | CUDA: {torch.cuda.is_available()}")
print(f"timm {timm.__version__}")

In [None]:
# Hardware verification
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    vram_gb = torch.cuda.get_device_properties(0).total_memory / 1e9
    print(f"VRAM: {vram_gb:.2f} GB")
    
    # Determine batch size based on VRAM
    if vram_gb >= 24:
        BASE_BATCH_SIZE = 32
        SMALL_BATCH_SIZE = 64
        print(f"[OK] 24GB+ VRAM detected - Using optimal batch sizes")
    elif vram_gb >= 16:
        BASE_BATCH_SIZE = 16
        SMALL_BATCH_SIZE = 32
        print(f"[OK] 16GB+ VRAM detected - Using reduced batch sizes")
    else:
        BASE_BATCH_SIZE = 8
        SMALL_BATCH_SIZE = 16
        print(f"[WARN] <16GB VRAM - Using minimal batch sizes")

In [None]:
# Configuration
CSV_DIR = Path("../data/processed")
MODELS_DIR = Path("../models")
RESULTS_DIR = Path("../results")
MODELS_DIR.mkdir(exist_ok=True)
RESULTS_DIR.mkdir(exist_ok=True)

# CrossViT-Base Configuration (24GB VRAM optimized)
CONFIG_BASE = {
    'device': device,
    'model_name': 'CrossViT-Base',
    'timm_model': 'crossvit_base_240',
    'num_classes': 4,
    'image_size': 240,
    'class_names': ['COVID', 'Normal', 'Lung_Opacity', 'Viral Pneumonia'],
    'class_weights': [1.47, 0.52, 0.88, 3.95],
    'batch_size': BASE_BATCH_SIZE,
    'num_workers': 0,  # Windows compatibility
    'learning_rate': 5e-5,
    'weight_decay': 0.05,
    'max_epochs': 50,
    'early_stopping_patience': 15,
    'mean': [0.485, 0.456, 0.406],
    'std': [0.229, 0.224, 0.225],
    'mixed_precision': True,
    'seeds': [42, 123, 456, 789, 101112],
}

# CrossViT-Small Configuration
CONFIG_SMALL = CONFIG_BASE.copy()
CONFIG_SMALL.update({
    'model_name': 'CrossViT-Small',
    'timm_model': 'crossvit_small_240',
    'batch_size': SMALL_BATCH_SIZE,
})

print(f"CrossViT-Base: batch_size={CONFIG_BASE['batch_size']}")
print(f"CrossViT-Small: batch_size={CONFIG_SMALL['batch_size']}")
print(f"Seeds: {CONFIG_BASE['seeds']}")

In [None]:
# MLflow setup
if MLFLOW_AVAILABLE:
    mlflow.set_experiment("crossvit-covid19-classification")
    mlflow.set_tracking_uri("file:./mlruns")
    print("[OK] MLflow configured")

In [None]:
# Load data
train_df = pd.read_csv(CSV_DIR / "train.csv")
val_df = pd.read_csv(CSV_DIR / "val.csv")
test_df = pd.read_csv(CSV_DIR / "test.csv")
print(f"Train: {len(train_df):,} | Val: {len(val_df):,} | Test: {len(test_df):,}")

# Verify paths exist
sample_path = train_df['image_path'].iloc[0]
if os.path.exists(sample_path):
    print(f"[OK] Path verification passed")
else:
    print(f"[ERROR] Path not found: {sample_path}")

In [None]:
# Dataset class with on-the-fly CLAHE
class COVID19Dataset(Dataset):
    def __init__(self, dataframe, transform=None):
        self.dataframe = dataframe.reset_index(drop=True)
        self.transform = transform
        self.image_paths = dataframe['image_path'].values
        self.labels = dataframe['label'].values
    
    def __len__(self):
        return len(self.dataframe)
    
    def __getitem__(self, idx):
        # Load image
        img_path = self.image_paths[idx]
        image = cv2.imread(img_path)
        if image is None:
            raise FileNotFoundError(f"Could not load: {img_path}")
        
        # Convert to grayscale if needed
        if len(image.shape) == 3:
            gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
        else:
            gray = image
        
        # Apply CLAHE
        clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
        enhanced = clahe.apply(gray)
        
        # Convert to RGB
        rgb_image = cv2.cvtColor(enhanced, cv2.COLOR_GRAY2RGB)
        image = Image.fromarray(rgb_image)
        
        if self.transform:
            image = self.transform(image)
        
        label = torch.tensor(self.labels[idx], dtype=torch.long)
        return image, label

print("[OK] Dataset class defined")

In [None]:
# Transforms for 240x240 input (CrossViT requirement)
def get_transforms(config, is_train=True):
    if is_train:
        return transforms.Compose([
            transforms.Resize((config['image_size'], config['image_size'])),
            transforms.RandomRotation(10),
            transforms.RandomHorizontalFlip(0.5),
            transforms.ColorJitter(brightness=0.1, contrast=0.1),
            transforms.ToTensor(),
            transforms.Normalize(mean=config['mean'], std=config['std'])
        ])
    else:
        return transforms.Compose([
            transforms.Resize((config['image_size'], config['image_size'])),
            transforms.ToTensor(),
            transforms.Normalize(mean=config['mean'], std=config['std'])
        ])

print("[OK] Transforms defined")

In [None]:
# Training functions
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

def train_one_epoch(model, loader, criterion, optimizer, device, scaler=None, epoch=0):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    progress_bar = tqdm(loader, desc=f"Epoch {epoch+1} [Train]")
    
    for batch_idx, (images, labels) in enumerate(progress_bar):
        images = images.to(device, non_blocking=True)
        labels = labels.to(device, non_blocking=True)
        
        optimizer.zero_grad(set_to_none=True)
        
        if scaler is not None:
            with torch.cuda.amp.autocast():
                outputs = model(images)
                loss = criterion(outputs, labels)
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        else:
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
        
        running_loss += loss.item()
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()
        
        progress_bar.set_postfix({'loss': running_loss / (batch_idx + 1), 'acc': 100. * correct / total})
    
    return running_loss / len(loader), 100. * correct / total

def validate(model, loader, criterion, device, desc="Val"):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        for images, labels in tqdm(loader, desc=f"[{desc}]"):
            images = images.to(device, non_blocking=True)
            labels = labels.to(device, non_blocking=True)
            
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
            
            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    
    return running_loss / len(loader), 100. * correct / total, np.array(all_preds), np.array(all_labels)

print("[OK] Training functions defined")

In [None]:
# Single seed training function
def train_crossvit_single_seed(seed, config):
    print(f"\n{'='*70}")
    print(f"TRAINING {config['model_name']} WITH SEED {seed}")
    print(f"{'='*70}")
    
    set_seed(seed)
    
    # Create dataloaders
    train_transform = get_transforms(config, is_train=True)
    val_transform = get_transforms(config, is_train=False)
    
    train_dataset = COVID19Dataset(train_df, transform=train_transform)
    val_dataset = COVID19Dataset(val_df, transform=val_transform)
    test_dataset = COVID19Dataset(test_df, transform=val_transform)
    
    train_loader = DataLoader(train_dataset, batch_size=config['batch_size'], 
                              shuffle=True, num_workers=config['num_workers'], 
                              pin_memory=True, drop_last=True)
    val_loader = DataLoader(val_dataset, batch_size=config['batch_size'], 
                            shuffle=False, num_workers=config['num_workers'],
                            pin_memory=True)
    test_loader = DataLoader(test_dataset, batch_size=config['batch_size'], 
                             shuffle=False, num_workers=config['num_workers'],
                             pin_memory=True)
    
    # Load model
    model = timm.create_model(config['timm_model'], pretrained=True, num_classes=config['num_classes'])
    model = model.to(device)
    
    num_params = sum(p.numel() for p in model.parameters())
    print(f"[OK] {config['model_name']} loaded: {num_params:,} parameters")
    
    # Loss, optimizer, scheduler
    class_weights = torch.tensor(config['class_weights'], dtype=torch.float32).to(device)
    criterion = nn.CrossEntropyLoss(weight=class_weights)
    optimizer = optim.AdamW(model.parameters(), lr=config['learning_rate'], weight_decay=config['weight_decay'])
    scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2)
    scaler = torch.cuda.amp.GradScaler() if config['mixed_precision'] else None
    
    # MLflow
    if MLFLOW_AVAILABLE:
        try:
            mlflow.end_run()
        except:
            pass
        mlflow.start_run(run_name=f"{config['timm_model']}-seed-{seed}")
        mlflow.log_param("model", config['model_name'])
        mlflow.log_param("timm_model", config['timm_model'])
        mlflow.log_param("random_seed", seed)
        mlflow.log_param("batch_size", config['batch_size'])
        mlflow.log_param("learning_rate", config['learning_rate'])
        mlflow.log_param("num_params", num_params)
        mlflow.set_tag("phase", "Phase 2 - CrossViT Enhancement")
    
    # Training loop
    best_val_loss = float('inf')
    patience_counter = 0
    model_filename = f"{config['timm_model'].replace('-', '_')}_best_seed{seed}.pth"
    best_model_path = MODELS_DIR / model_filename
    
    start_time = time.time()
    
    for epoch in range(config['max_epochs']):
        train_loss, train_acc = train_one_epoch(model, train_loader, criterion, optimizer, device, scaler, epoch)
        val_loss, val_acc, _, _ = validate(model, val_loader, criterion, device)
        scheduler.step()
        
        if MLFLOW_AVAILABLE:
            mlflow.log_metric("train_loss", train_loss, step=epoch)
            mlflow.log_metric("val_loss", val_loss, step=epoch)
            mlflow.log_metric("val_acc", val_acc, step=epoch)
        
        print(f"Epoch {epoch+1}: Train Loss={train_loss:.4f} Acc={train_acc:.2f}% | Val Loss={val_loss:.4f} Acc={val_acc:.2f}%")
        
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            patience_counter = 0
            torch.save(model.state_dict(), best_model_path)
            print("[OK] Best model saved!")
        else:
            patience_counter += 1
            if patience_counter >= config['early_stopping_patience']:
                print(f"[STOP] Early stopping at epoch {epoch+1}")
                break
    
    training_time = time.time() - start_time
    
    # Test evaluation
    model.load_state_dict(torch.load(best_model_path))
    test_loss, test_acc, test_preds, test_labels = validate(model, test_loader, criterion, device, desc="Test")
    
    # Confusion matrix
    cm = confusion_matrix(test_labels, test_preds)
    plt.figure(figsize=(10, 8))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
                xticklabels=config['class_names'], yticklabels=config['class_names'])
    plt.ylabel('True'); plt.xlabel('Predicted')
    plt.title(f"{config['model_name']} Confusion Matrix (Seed {seed})")
    cm_path = RESULTS_DIR / f"{config['timm_model'].replace('-', '_')}_cm_seed{seed}.png"
    plt.savefig(cm_path, dpi=300, bbox_inches='tight')
    plt.close()
    
    if MLFLOW_AVAILABLE:
        mlflow.log_metric("test_loss", test_loss)
        mlflow.log_metric("test_accuracy", test_acc)
        mlflow.log_metric("training_time_minutes", training_time / 60)
        mlflow.log_artifact(str(cm_path))
        mlflow.end_run()
    
    print(f"[OK] Seed {seed} complete: Test Acc = {test_acc:.2f}%")
    
    # Clear GPU memory
    del model
    torch.cuda.empty_cache()
    
    return {
        'seed': seed,
        'test_acc': test_acc,
        'test_loss': test_loss,
        'training_time': training_time,
        'model_path': str(best_model_path)
    }

print("[OK] Single seed training function defined")

---

## Part 1: Train CrossViT-Base (5 seeds)

CrossViT-Base is the largest CrossViT variant with ~105M parameters.
Expected accuracy: 96-97%

In [None]:
# Verify CrossViT-Base model
print("Verifying CrossViT-Base model...")
test_model = timm.create_model('crossvit_base_240', pretrained=True, num_classes=4)
print(f"Model: {CONFIG_BASE['timm_model']}")
print(f"Parameters: {sum(p.numel() for p in test_model.parameters()):,}")

# Test forward pass
test_model = test_model.to(device)
test_input = torch.randn(2, 3, 240, 240).to(device)
with torch.no_grad():
    test_output = test_model(test_input)
print(f"Input shape: {test_input.shape}")
print(f"Output shape: {test_output.shape}")
print(f"VRAM used: {torch.cuda.memory_allocated()/1e9:.2f} GB")

del test_model, test_input, test_output
torch.cuda.empty_cache()
print("[OK] CrossViT-Base verification passed")

In [None]:
# Train CrossViT-Base with all seeds
print(f"\n{'='*70}")
print(f"STARTING CROSSVIT-BASE MULTI-SEED TRAINING")
print(f"{'='*70}")
print(f"Model: {CONFIG_BASE['timm_model']}")
print(f"Batch size: {CONFIG_BASE['batch_size']}")
print(f"Seeds: {CONFIG_BASE['seeds']}")
print()

base_results = []
for seed in CONFIG_BASE['seeds']:
    try:
        result = train_crossvit_single_seed(seed, CONFIG_BASE)
        base_results.append(result)
    except Exception as e:
        print(f"[ERROR] Seed {seed}: {e}")
        continue

print(f"\n{'='*70}")
print(f"CROSSVIT-BASE TRAINING COMPLETED")
print(f"{'='*70}")

In [None]:
# CrossViT-Base statistical analysis
if base_results:
    accuracies = [r['test_acc'] for r in base_results]
    mean_acc = np.mean(accuracies)
    std_acc = np.std(accuracies, ddof=1)
    
    print(f"\n CrossViT-Base Results ({len(base_results)} seeds):")
    print(f"   Mean +/- Std: {mean_acc:.2f}% +/- {std_acc:.2f}%")
    print(f"   Range: [{np.min(accuracies):.2f}%, {np.max(accuracies):.2f}%]")
    
    # Save results
    results_df = pd.DataFrame(base_results)
    results_path = RESULTS_DIR / "crossvit_base_results.csv"
    results_df.to_csv(results_path, index=False)
    print(f"\n[OK] Results saved to {results_path}")
    print(results_df)

---

## Part 2: Train CrossViT-Small (5 seeds) - Optional

CrossViT-Small is a mid-size variant with ~26M parameters.
Expected accuracy: 95.5-96%

In [None]:
# Train CrossViT-Small with all seeds
print(f"\n{'='*70}")
print(f"STARTING CROSSVIT-SMALL MULTI-SEED TRAINING")
print(f"{'='*70}")
print(f"Model: {CONFIG_SMALL['timm_model']}")
print(f"Batch size: {CONFIG_SMALL['batch_size']}")
print(f"Seeds: {CONFIG_SMALL['seeds']}")
print()

small_results = []
for seed in CONFIG_SMALL['seeds']:
    try:
        result = train_crossvit_single_seed(seed, CONFIG_SMALL)
        small_results.append(result)
    except Exception as e:
        print(f"[ERROR] Seed {seed}: {e}")
        continue

print(f"\n{'='*70}")
print(f"CROSSVIT-SMALL TRAINING COMPLETED")
print(f"{'='*70}")

In [None]:
# CrossViT-Small statistical analysis
if small_results:
    accuracies = [r['test_acc'] for r in small_results]
    mean_acc = np.mean(accuracies)
    std_acc = np.std(accuracies, ddof=1)
    
    print(f"\n CrossViT-Small Results ({len(small_results)} seeds):")
    print(f"   Mean +/- Std: {mean_acc:.2f}% +/- {std_acc:.2f}%")
    print(f"   Range: [{np.min(accuracies):.2f}%, {np.max(accuracies):.2f}%]")
    
    # Save results
    results_df = pd.DataFrame(small_results)
    results_path = RESULTS_DIR / "crossvit_small_results.csv"
    results_df.to_csv(results_path, index=False)
    print(f"\n[OK] Results saved to {results_path}")
    print(results_df)

---

## Part 3: Ensemble Prediction

Combine predictions from all 5 models (seeds) to improve accuracy.

In [None]:
# Ensemble prediction functions
def load_ensemble_models(model_paths, timm_model, num_classes=4, device='cuda'):
    """Load multiple models for ensemble prediction."""
    models = []
    for path in model_paths:
        model = timm.create_model(timm_model, pretrained=False, num_classes=num_classes)
        model.load_state_dict(torch.load(path, map_location=device))
        model.to(device)
        model.eval()
        models.append(model)
    return models

def ensemble_predict_batch(models, images, device='cuda'):
    """Average predictions from all models for a batch."""
    all_probs = []
    
    with torch.no_grad():
        for model in models:
            output = model(images.to(device))
            probs = torch.softmax(output, dim=1)
            all_probs.append(probs)
    
    avg_probs = torch.stack(all_probs).mean(dim=0)
    return avg_probs

def evaluate_ensemble(models, loader, device='cuda'):
    """Evaluate ensemble on test set."""
    all_preds = []
    all_labels = []
    
    for images, labels in tqdm(loader, desc="Ensemble Eval"):
        avg_probs = ensemble_predict_batch(models, images, device)
        preds = avg_probs.argmax(dim=1)
        
        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.numpy())
    
    accuracy = accuracy_score(all_labels, all_preds) * 100
    return accuracy, np.array(all_preds), np.array(all_labels)

print("[OK] Ensemble functions defined")

In [None]:
# Evaluate CrossViT-Base ensemble (if trained)
if base_results:
    print("\n" + "="*70)
    print("CROSSVIT-BASE ENSEMBLE EVALUATION")
    print("="*70)
    
    model_paths = [r['model_path'] for r in base_results]
    print(f"Loading {len(model_paths)} models...")
    
    ensemble_models = load_ensemble_models(model_paths, CONFIG_BASE['timm_model'], device=device)
    
    val_transform = get_transforms(CONFIG_BASE, is_train=False)
    test_dataset = COVID19Dataset(test_df, transform=val_transform)
    test_loader = DataLoader(test_dataset, batch_size=CONFIG_BASE['batch_size'], 
                             shuffle=False, num_workers=0)
    
    ensemble_acc, ensemble_preds, ensemble_labels = evaluate_ensemble(ensemble_models, test_loader, device)
    
    print(f"\n CrossViT-Base Ensemble Accuracy: {ensemble_acc:.2f}%")
    
    # Confusion matrix
    cm = confusion_matrix(ensemble_labels, ensemble_preds)
    plt.figure(figsize=(10, 8))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                xticklabels=CONFIG_BASE['class_names'], yticklabels=CONFIG_BASE['class_names'])
    plt.ylabel('True'); plt.xlabel('Predicted')
    plt.title(f"CrossViT-Base Ensemble Confusion Matrix (Acc: {ensemble_acc:.2f}%)")
    plt.savefig(RESULTS_DIR / "crossvit_base_ensemble_cm.png", dpi=300, bbox_inches='tight')
    plt.show()
    
    # Clean up
    del ensemble_models
    torch.cuda.empty_cache()

---

## Part 4: Test-Time Augmentation (TTA)

Apply multiple augmentations at test time and average predictions.

In [None]:
# TTA functions
def predict_with_tta(model, image_tensor, device='cuda'):
    """5-fold TTA: original + 4 augmentations."""
    model.eval()
    
    augmentations = [
        lambda x: x,                                    # Original
        lambda x: torch.flip(x, dims=[-1]),            # Horizontal flip
        lambda x: torch.flip(x, dims=[-2]),            # Vertical flip
        lambda x: torch.rot90(x, k=1, dims=[-2, -1]),  # Rotate 90
        lambda x: torch.rot90(x, k=3, dims=[-2, -1]),  # Rotate -90
    ]
    
    all_probs = []
    with torch.no_grad():
        for aug in augmentations:
            aug_img = aug(image_tensor).to(device)
            output = model(aug_img)
            probs = torch.softmax(output, dim=1)
            all_probs.append(probs)
    
    return torch.stack(all_probs).mean(dim=0)

def evaluate_tta(model, loader, device='cuda'):
    """Evaluate single model with TTA."""
    all_preds = []
    all_labels = []
    
    for images, labels in tqdm(loader, desc="TTA Eval"):
        tta_probs = predict_with_tta(model, images, device)
        preds = tta_probs.argmax(dim=1)
        
        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.numpy())
    
    accuracy = accuracy_score(all_labels, all_preds) * 100
    return accuracy, np.array(all_preds), np.array(all_labels)

print("[OK] TTA functions defined")

In [None]:
# Evaluate CrossViT-Base best model with TTA
if base_results:
    print("\n" + "="*70)
    print("CROSSVIT-BASE TTA EVALUATION")
    print("="*70)
    
    # Load best model (highest accuracy)
    best_result = max(base_results, key=lambda x: x['test_acc'])
    print(f"Best seed: {best_result['seed']} (Acc: {best_result['test_acc']:.2f}%)")
    
    model = timm.create_model(CONFIG_BASE['timm_model'], pretrained=False, num_classes=4)
    model.load_state_dict(torch.load(best_result['model_path']))
    model.to(device)
    model.eval()
    
    val_transform = get_transforms(CONFIG_BASE, is_train=False)
    test_dataset = COVID19Dataset(test_df, transform=val_transform)
    test_loader = DataLoader(test_dataset, batch_size=CONFIG_BASE['batch_size'], 
                             shuffle=False, num_workers=0)
    
    tta_acc, tta_preds, tta_labels = evaluate_tta(model, test_loader, device)
    
    print(f"\n CrossViT-Base TTA Accuracy: {tta_acc:.2f}%")
    print(f"   Improvement over base: {tta_acc - best_result['test_acc']:.2f}%")
    
    del model
    torch.cuda.empty_cache()

---

## Part 5: Ensemble + TTA (Maximum Accuracy)

Combine all 5 models with TTA = 25 predictions averaged.

In [None]:
# Combined Ensemble + TTA
def ensemble_tta_predict_batch(models, images, device='cuda'):
    """Combine all models with TTA."""
    all_probs = []
    
    for model in models:
        tta_probs = predict_with_tta(model, images, device)
        all_probs.append(tta_probs)
    
    return torch.stack(all_probs).mean(dim=0)

def evaluate_ensemble_tta(models, loader, device='cuda'):
    """Evaluate ensemble with TTA."""
    all_preds = []
    all_labels = []
    
    for images, labels in tqdm(loader, desc="Ensemble+TTA Eval"):
        avg_probs = ensemble_tta_predict_batch(models, images, device)
        preds = avg_probs.argmax(dim=1)
        
        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.numpy())
    
    accuracy = accuracy_score(all_labels, all_preds) * 100
    return accuracy, np.array(all_preds), np.array(all_labels)

print("[OK] Ensemble+TTA functions defined")

In [None]:
# Evaluate CrossViT-Base Ensemble + TTA
if base_results:
    print("\n" + "="*70)
    print("CROSSVIT-BASE ENSEMBLE + TTA EVALUATION")
    print("="*70)
    
    model_paths = [r['model_path'] for r in base_results]
    print(f"Loading {len(model_paths)} models for Ensemble+TTA (25 predictions per sample)...")
    
    ensemble_models = load_ensemble_models(model_paths, CONFIG_BASE['timm_model'], device=device)
    
    val_transform = get_transforms(CONFIG_BASE, is_train=False)
    test_dataset = COVID19Dataset(test_df, transform=val_transform)
    test_loader = DataLoader(test_dataset, batch_size=8, shuffle=False, num_workers=0)  # Smaller batch for memory
    
    ensemble_tta_acc, ensemble_tta_preds, ensemble_tta_labels = evaluate_ensemble_tta(
        ensemble_models, test_loader, device)
    
    print(f"\n CrossViT-Base Ensemble+TTA Accuracy: {ensemble_tta_acc:.2f}%")
    
    # Confusion matrix
    cm = confusion_matrix(ensemble_tta_labels, ensemble_tta_preds)
    plt.figure(figsize=(10, 8))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                xticklabels=CONFIG_BASE['class_names'], yticklabels=CONFIG_BASE['class_names'])
    plt.ylabel('True'); plt.xlabel('Predicted')
    plt.title(f"CrossViT-Base Ensemble+TTA Confusion Matrix (Acc: {ensemble_tta_acc:.2f}%)")
    plt.savefig(RESULTS_DIR / "crossvit_base_ensemble_tta_cm.png", dpi=300, bbox_inches='tight')
    plt.show()
    
    del ensemble_models
    torch.cuda.empty_cache()

---

## Summary

In [None]:
# Final summary
print("\n" + "="*70)
print("CROSSVIT ENHANCEMENT SUMMARY")
print("="*70)

summary_data = []

# CrossViT-Tiny baseline (from previous experiments)
summary_data.append({
    'Method': 'CrossViT-Tiny (Baseline)',
    'Accuracy': 94.96,
    'Notes': 'Previous result'
})

if base_results:
    base_accs = [r['test_acc'] for r in base_results]
    summary_data.append({
        'Method': 'CrossViT-Base (Single)',
        'Accuracy': np.mean(base_accs),
        'Notes': f'{len(base_results)} seeds, std={np.std(base_accs, ddof=1):.2f}'
    })

if 'ensemble_acc' in dir():
    summary_data.append({
        'Method': 'CrossViT-Base Ensemble',
        'Accuracy': ensemble_acc,
        'Notes': '5 models averaged'
    })

if 'tta_acc' in dir():
    summary_data.append({
        'Method': 'CrossViT-Base + TTA',
        'Accuracy': tta_acc,
        'Notes': 'Best seed + 5-fold TTA'
    })

if 'ensemble_tta_acc' in dir():
    summary_data.append({
        'Method': 'CrossViT-Base Ensemble+TTA',
        'Accuracy': ensemble_tta_acc,
        'Notes': '5 models x 5 TTA = 25 predictions'
    })

if small_results:
    small_accs = [r['test_acc'] for r in small_results]
    summary_data.append({
        'Method': 'CrossViT-Small (Single)',
        'Accuracy': np.mean(small_accs),
        'Notes': f'{len(small_results)} seeds'
    })

summary_df = pd.DataFrame(summary_data)
summary_df = summary_df.sort_values('Accuracy', ascending=False)
print(summary_df.to_string(index=False))

# Save summary
summary_df.to_csv(RESULTS_DIR / "crossvit_enhancement_summary.csv", index=False)
print(f"\n[OK] Summary saved to {RESULTS_DIR / 'crossvit_enhancement_summary.csv'}")

In [None]:
print("\n" + "="*70)
print("CROSSVIT ENHANCEMENT COMPLETE!")
print("="*70)