In [1]:
"""ViT Baseline Network (NO CNN-CBAM, NO QUANTUM)
ERBMAHE Dataset Classification
Binary Classification: Abnormal vs Normal

ABLATION STUDY: Only ViT + Focal Loss (baseline for comparison)
METHOD 2 K-FOLD WITH SEPARATE TEST SET:
- 10% holdout test set (never used during training)
- K-Fold cross-validation on remaining 90%
- Validation-based early stopping
- Best fold model evaluated on test set
"""

import os
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'

import sys
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from PIL import Image
from sklearn.model_selection import train_test_split, StratifiedKFold
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, confusion_matrix
import seaborn as sns
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')

# Import PyTorch & torchvision
try:
    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    from torch.utils.data import Dataset, DataLoader
    from torchvision import transforms
    print("‚úì PyTorch & torchvision imported successfully")
except Exception as e:
    print(f"‚ùå Error importing PyTorch/torchvision: {e}")
    sys.exit(1)

# Transformers (ViT)
try:
    from transformers import ViTImageProcessor, ViTModel
    print("‚úì Transformers imported successfully")
except ImportError:
    print("üì¶ Installing Transformers...")
    import subprocess
    subprocess.check_call([sys.executable, "-m", "pip", "install", "transformers"])
    from transformers import ViTImageProcessor, ViTModel
    print("‚úì Transformers installed successfully")

# Check GPU
if not torch.cuda.is_available():
    raise RuntimeError("CUDA is not available. Please run this script on a machine with CUDA-enabled GPU and proper drivers.")
device = torch.device('cuda')
print(f"üîß Using device: {device}")
print(f"üîß PyTorch version: {torch.__version__}")

# Config
MODEL_NAME = "google/vit-base-patch16-224"
CLASSES = ['Abnormal', 'Normal']
N_FOLDS = 5

# ============================================================================
# CNN-CBAM REMOVED (Ablation Study)
# QUANTUM LAYER REMOVED (Ablation Study)
# Pure ViT Baseline
# ============================================================================

# ============================================================================
# ViT-Only Model
# Pipeline: ViT -> Classifier (no quantum, no CNN-CBAM, no fusion)
# ============================================================================
class ViTBaseline(nn.Module):
    def __init__(self, model_name, num_classes=2):
        super(ViTBaseline, self).__init__()
        # ViT backbone
        self.vit = ViTModel.from_pretrained(model_name)
        vit_dim = self.vit.config.hidden_size  # 768

        # Direct classifier - no quantum layer, no CNN branch
        self.classifier = nn.Sequential(
            nn.Linear(vit_dim, vit_dim // 2),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(vit_dim // 2, num_classes)
        )

    def forward(self, pixel_values):
        # ViT features
        vit_out = self.vit(pixel_values=pixel_values)
        vit_features = vit_out.last_hidden_state[:, 0]  # (B, 768)

        # Direct to classifier - no quantum, no CNN fusion
        logits = self.classifier(vit_features)
        return logits

# ============================================================================
# Focal Loss
# ============================================================================
class FocalLoss(nn.Module):
    def __init__(self, alpha=None, gamma=2.0, reduction='mean'):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, inputs, targets):
        ce_loss = F.cross_entropy(inputs, targets, reduction='none')
        p_t = torch.exp(-ce_loss)
        focal_loss = (1 - p_t) ** self.gamma * ce_loss
        if self.alpha is not None:
            if isinstance(self.alpha, (float, int)):
                alpha_t = self.alpha
            else:
                alpha_t = self.alpha[targets]
            focal_loss = alpha_t * focal_loss
        if self.reduction == 'mean':
            return focal_loss.mean()
        elif self.reduction == 'sum':
            return focal_loss.sum()
        else:
            return focal_loss

# ============================================================================
# Early Stopping
# ============================================================================
class EarlyStopping:
    def __init__(self, patience=10, min_delta=0, mode='max', verbose=True):
        self.patience = patience
        self.min_delta = min_delta
        self.mode = mode
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.best_epoch = 0

    def __call__(self, current_score, epoch):
        if self.best_score is None:
            self.best_score = current_score
            self.best_epoch = epoch
            if self.verbose:
                print(f"  ‚úì Initial best score: {current_score:.4f}")
            return False
        if self.mode == 'max':
            improved = current_score > (self.best_score + self.min_delta)
        else:
            improved = current_score < (self.best_score - self.min_delta)
        if improved:
            self.best_score = current_score
            self.best_epoch = epoch
            self.counter = 0
            if self.verbose:
                print(f"  ‚úì New best score: {current_score:.4f}")
        else:
            self.counter += 1
            if self.verbose:
                print(f"  No improvement. Patience: {self.counter}/{self.patience}")
            if self.counter >= self.patience:
                self.early_stop = True
                if self.verbose:
                    print(f"\n‚ö†Ô∏è Early stopping triggered!")
                    print(f"   Best score: {self.best_score:.4f} at epoch {self.best_epoch}")
        return self.early_stop

# ============================================================================
# DATASET
# ============================================================================
class ERBMAHEDataset(Dataset):
    def __init__(self, dataframe, processor, augment=False):
        self.df = dataframe.reset_index(drop=True)
        self.processor = processor
        self.augment = augment
        self.aug_transform = transforms.Compose([
            transforms.RandomRotation(10),
            transforms.RandomAffine(degrees=0, translate=(0.05, 0.05)),
            transforms.ColorJitter(brightness=0.05, contrast=0.05, saturation=0.0),
            transforms.RandomHorizontalFlip(p=0.5),
        ])

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        img_path = self.df.loc[idx, 'image_path']
        label = self.df.loc[idx, 'class_id']
        image = Image.open(img_path).convert('RGB')
        if self.augment:
            image = self.aug_transform(image)
        inputs = self.processor(images=image, return_tensors="pt")
        pixel_values = inputs['pixel_values'].squeeze(0)
        return pixel_values, label

# ============================================================================
# TRAIN / VALIDATION FUNCTIONS
# ============================================================================

def train_epoch(model, dataloader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    all_preds = []
    all_labels = []
    pbar = tqdm(dataloader, desc='Training')
    for pixel_values, labels in pbar:
        pixel_values, labels = pixel_values.to(device), labels.to(device)
        optimizer.zero_grad()
        logits = model(pixel_values)
        loss = criterion(logits, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        _, predicted = logits.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()
        all_preds.extend(predicted.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())
        pbar.set_postfix({
            'loss': f'{running_loss/(pbar.n+1):.4f}',
            'acc': f'{100.*correct/total:.2f}%'
        })
    _, _, f1, _ = precision_recall_fscore_support(all_labels, all_preds, average='weighted', zero_division=0)
    return running_loss / len(dataloader), 100. * correct / total, f1

def validate(model, dataloader, criterion, device):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    all_preds = []
    all_labels = []
    with torch.no_grad():
        pbar = tqdm(dataloader, desc='Validation')
        for pixel_values, labels in pbar:
            pixel_values, labels = pixel_values.to(device), labels.to(device)
            logits = model(pixel_values)
            loss = criterion(logits, labels)
            running_loss += loss.item()
            _, predicted = logits.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
            pbar.set_postfix({
                'loss': f'{running_loss/(pbar.n+1):.4f}',
                'acc': f'{100.*correct/total:.2f}%'
            })
    _, _, f1, _ = precision_recall_fscore_support(all_labels, all_preds, average='weighted', zero_division=0)
    return running_loss / len(dataloader), 100. * correct / total, f1, all_preds, all_labels

def calculate_metrics(y_true, y_pred):
    """Calculate comprehensive metrics including sensitivity and specificity"""
    cm = confusion_matrix(y_true, y_pred)
    tn, fp, fn, tp = cm.ravel()

    sensitivity_class0 = tn / (tn + fn) if (tn + fn) > 0 else 0  # Normal
    sensitivity_class1 = tp / (tp + fp) if (tp + fp) > 0 else 0  # Abnormal
    specificity_class0 = tn / (tn + fp) if (tn + fp) > 0 else 0  # Normal
    specificity_class1 = tp / (tp + fn) if (tp + fn) > 0 else 0  # Abnormal
    avg_sensitivity = (sensitivity_class0 + sensitivity_class1) / 2
    avg_specificity = (specificity_class0 + specificity_class1) / 2

    precision, recall, f1, _ = precision_recall_fscore_support(y_true, y_pred, average='weighted')
    accuracy = accuracy_score(y_true, y_pred)

    return {
        'accuracy': accuracy * 100,
        'precision': precision,
        'recall': recall,
        'f1': f1,
        'sensitivity_class0': sensitivity_class0,
        'sensitivity_class1': sensitivity_class1,
        'specificity_class0': specificity_class0,
        'specificity_class1': specificity_class1,
        'avg_sensitivity': avg_sensitivity,
        'avg_specificity': avg_specificity,
        'confusion_matrix': cm
    }

# ============================================================================
# MAIN
# ============================================================================

def main():
    print("="*70)
    print("ViT Baseline (NO CNN-CBAM, NO QUANTUM)")
    print("ABLATION STUDY: Pure ViT + Focal Loss Baseline")
    print("ERBMAHE Dataset Classification")
    print("Binary Classification: Abnormal vs Normal")
    print(f"METHOD 2: K-FOLD ({N_FOLDS} folds) WITH SEPARATE TEST SET (10%)")
    print("="*70)

    data_path = 'D:/training/archive/ICMR_datasets_ERBMAHE'

    print(f"\nü§ñ Base Model: {MODEL_NAME}")
    print(f"üö´ CNN-CBAM Branch: REMOVED")
    print(f"üö´ Quantum Layer: REMOVED")
    print(f"‚úÖ Pure ViT Baseline")
    print(f"üìä Classes: {CLASSES}")
    print(f"üîÑ K-Fold: {N_FOLDS} folds on 90% data")
    print(f"üß™ Test Set: 10% holdout")

    print("\nüìÅ Loading dataset...")
    data_list = []
    for class_name in CLASSES:
        class_path = os.path.join(data_path, class_name)
        if os.path.exists(class_path):
            for img_file in os.listdir(class_path):
                if img_file.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.tiff')):
                    data_list.append({
                        'image_path': os.path.join(class_path, img_file),
                        'label': class_name,
                        'class_id': CLASSES.index(class_name)
                    })
        else:
            print(f"‚ö†Ô∏è Warning: {class_path} not found!")

    df = pd.DataFrame(data_list)
    print(f"üìä Total images: {len(df)}")
    print("\nClass distribution:")
    print(df['label'].value_counts())

    # Split: 90% for k-fold, 10% for final test
    kfold_df, test_df = train_test_split(df, test_size=0.10, stratify=df['class_id'], random_state=42)

    print(f"\nüìä Dataset split:")
    print(f"  K-Fold data: {len(kfold_df)} ({len(kfold_df)/len(df)*100:.1f}%)")
    print(f"  Test data:   {len(test_df)} ({len(test_df)/len(df)*100:.1f}%)")

    # Class weights for focal loss
    class_counts = kfold_df['label'].value_counts().sort_index().values
    total_samples = len(kfold_df)
    class_weights = total_samples / (len(CLASSES) * class_counts)
    class_weights = torch.FloatTensor(class_weights).to(device)
    print(f"\n‚öñÔ∏è Class weights for Focal Loss:")
    for i, class_name in enumerate(CLASSES):
        print(f"  {class_name}: {class_weights[i]:.4f}")

    print(f"\nüîß Loading ViT processor...")
    processor = ViTImageProcessor.from_pretrained(MODEL_NAME)

    skf = StratifiedKFold(n_splits=N_FOLDS, shuffle=True, random_state=42)
    fold_results = []
    all_fold_histories = []

    batch_size = 16
    num_epochs = 50

    print(f"\n‚öôÔ∏è Training configuration:")
    print(f"  Max epochs per fold: {num_epochs}")
    print(f"  Batch size: {batch_size}")
    print(f"  Early stopping patience: 10 epochs")
    print(f"  Learning rate (ViT): 1e-5")
    print(f"  Learning rate (Classifier): 1e-4")
    print(f"  CNN-CBAM: REMOVED")
    print(f"  Quantum Layer: REMOVED")
    print(f"  Scheduler: CosineAnnealingLR")
    print(f"  Loss: Focal Loss (gamma=2.0)")
    print(f"  Strategy: Method 2 with separate test set")

    # ============================================================================
    # K-FOLD CROSS-VALIDATION
    # ============================================================================

    for fold_idx, (train_idx, val_idx) in enumerate(skf.split(kfold_df, kfold_df['class_id'])):
        print("\n" + "="*70)
        print(f"FOLD {fold_idx + 1}/{N_FOLDS}")
        print("="*70)

        train_df = kfold_df.iloc[train_idx].reset_index(drop=True)
        val_df   = kfold_df.iloc[val_idx].reset_index(drop=True)

        print(f"\nüìä Fold {fold_idx + 1} split:")
        print(f"  Train:      {len(train_df)} ({len(train_df)/len(kfold_df)*100:.1f}% of k-fold data)")
        print(f"  Validation: {len(val_df)} ({len(val_df)/len(kfold_df)*100:.1f}% of k-fold data)")

        train_dataset = ERBMAHEDataset(train_df, processor=processor, augment=True)
        val_dataset   = ERBMAHEDataset(val_df,   processor=processor, augment=False)

        train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True,  num_workers=0, pin_memory=True)
        val_loader   = DataLoader(val_dataset,   batch_size=batch_size, shuffle=False, num_workers=0, pin_memory=True)

        print(f"\nü§ñ Creating ViT Baseline (no CNN-CBAM, no Quantum) for Fold {fold_idx + 1}...")
        model = ViTBaseline(model_name=MODEL_NAME, num_classes=len(CLASSES)).to(device)

        for p in model.parameters():
            p.requires_grad = True

        if fold_idx == 0:
            total_params     = sum(p.numel() for p in model.parameters())
            trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
            print(f"‚úì Total parameters:     {total_params:,}")
            print(f"‚úì Trainable parameters: {trainable_params:,}")
            print(f"‚úì CNN-CBAM parameters:  0 (REMOVED)")
            print(f"‚úì Quantum parameters:   0 (REMOVED)")
            print(f"‚úì Model size: {total_params * 4 / 1024 / 1024:.2f} MB")

        def build_optimizer(model):
            groups = []
            vit_params = [p for p in model.vit.parameters() if p.requires_grad]
            if vit_params:
                groups.append({'params': vit_params, 'lr': 1e-5})
            clf_params = [p for p in model.classifier.parameters() if p.requires_grad]
            if clf_params:
                groups.append({'params': clf_params, 'lr': 1e-4})
            if not groups:
                groups = [{'params': [p for p in model.parameters() if p.requires_grad], 'lr': 1e-4}]
            return torch.optim.AdamW(groups, weight_decay=0.01)

        criterion      = FocalLoss(alpha=class_weights, gamma=2.0, reduction='mean')
        optimizer      = build_optimizer(model)
        scheduler      = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs, eta_min=1e-6)
        early_stopping = EarlyStopping(patience=10, min_delta=0.001, mode='max', verbose=True)

        best_val_acc = 0.0
        fold_history = {
            'train_loss': [], 'train_acc': [], 'train_f1': [],
            'val_loss':   [], 'val_acc':   [], 'val_f1':   []
        }

        # Training loop
        for epoch in range(num_epochs):
            print(f"\nFold {fold_idx + 1} - Epoch {epoch+1}/{num_epochs}")
            print("-" * 70)

            train_loss, train_acc, train_f1 = train_epoch(model, train_loader, criterion, optimizer, device)
            val_loss,   val_acc,   val_f1, _, _ = validate(model, val_loader, criterion, device)

            scheduler.step()

            fold_history['train_loss'].append(train_loss)
            fold_history['train_acc'].append(train_acc)
            fold_history['train_f1'].append(train_f1)
            fold_history['val_loss'].append(val_loss)
            fold_history['val_acc'].append(val_acc)
            fold_history['val_f1'].append(val_f1)

            current_lr = optimizer.param_groups[0]['lr']
            print(f"\nResults:")
            print(f"  Train - Loss: {train_loss:.4f}, Acc: {train_acc:.2f}%, F1: {train_f1:.4f}")
            print(f"  Val   - Loss: {val_loss:.4f}, Acc: {val_acc:.2f}%, F1: {val_f1:.4f}")
            print(f"  Learning Rate (group0): {current_lr:.2e}")

            if val_acc > best_val_acc:
                best_val_acc = val_acc
                torch.save(model.state_dict(), f'fold{fold_idx+1}_vitonly_best.pth')
                print(f"  üíæ Best model saved! (Val Acc: {best_val_acc:.2f}%)")

            if early_stopping(val_acc, epoch):
                print(f"\n‚ö†Ô∏è Early stopping at epoch {epoch+1}")
                break

        # Final validation evaluation
        print(f"\nüìä EVALUATING FOLD {fold_idx + 1}")
        print("="*70)
        model.load_state_dict(torch.load(f'fold{fold_idx+1}_vitonly_best.pth'))
        val_loss, val_acc, val_f1, y_pred, y_true = validate(model, val_loader, criterion, device)

        metrics = calculate_metrics(y_true, y_pred)

        print(f"\nüìà Fold {fold_idx + 1} Validation Results:")
        print(f"  Accuracy:               {metrics['accuracy']:.2f}%")
        print(f"  Precision:              {metrics['precision']:.4f}")
        print(f"  Recall:                 {metrics['recall']:.4f}")
        print(f"  F1 Score:               {metrics['f1']:.4f}")
        print(f"  Sensitivity (Normal):   {metrics['sensitivity_class0']:.4f}")
        print(f"  Sensitivity (Abnormal): {metrics['sensitivity_class1']:.4f}")
        print(f"  Specificity (Normal):   {metrics['specificity_class0']:.4f}")
        print(f"  Specificity (Abnormal): {metrics['specificity_class1']:.4f}")
        print(f"  Avg Sensitivity:        {metrics['avg_sensitivity']:.4f}")
        print(f"  Avg Specificity:        {metrics['avg_specificity']:.4f}")

        fold_results.append({
            'fold': fold_idx + 1,
            'val_acc':         metrics['accuracy'],
            'val_precision':   metrics['precision'],
            'val_recall':      metrics['recall'],
            'val_f1':          metrics['f1'],
            'val_sensitivity': metrics['avg_sensitivity'],
            'val_specificity': metrics['avg_specificity'],
            'best_val_acc':    best_val_acc,
            'epochs_trained':  len(fold_history['train_loss']),
            'model_path':      f'fold{fold_idx+1}_vitonly_best.pth',
            'y_true': y_true,
            'y_pred': y_pred,
            'metrics': metrics
        })

        all_fold_histories.append(fold_history)

        # Per-fold confusion matrix
        cm = metrics['confusion_matrix']
        plt.figure(figsize=(8, 6))
        sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=CLASSES, yticklabels=CLASSES)
        plt.title(f'Fold {fold_idx + 1} - Confusion Matrix (ViT Only)\n(Val Acc: {metrics["accuracy"]:.2f}%)')
        plt.ylabel('True Label')
        plt.xlabel('Predicted Label')
        plt.tight_layout()
        plt.savefig(f'fold{fold_idx+1}_vitonly_confusion_matrix.png', dpi=150)
        plt.close()
        print(f"‚úì Confusion matrix saved")

        # Enhanced visualizations for Fold 1 only
        if fold_idx == 0:
            print(f"\nüìä Creating enhanced visualizations for Fold 1...")
            epochs_range = range(1, len(fold_history['train_loss']) + 1)

            # 1. Training / Validation curves
            fig, axes = plt.subplots(1, 3, figsize=(18, 5))

            axes[0].plot(epochs_range, fold_history['train_loss'], 'b-o', label='Train', linewidth=2, markersize=4)
            axes[0].plot(epochs_range, fold_history['val_loss'],   'r-s', label='Val',   linewidth=2, markersize=4)
            axes[0].set_xlabel('Epoch', fontsize=12)
            axes[0].set_ylabel('Loss', fontsize=12)
            axes[0].set_title('Fold 1 - Loss (ViT Only)', fontsize=14, fontweight='bold')
            axes[0].legend(fontsize=10)
            axes[0].grid(True, alpha=0.3)

            axes[1].plot(epochs_range, fold_history['train_acc'], 'b-o', label='Train', linewidth=2, markersize=4)
            axes[1].plot(epochs_range, fold_history['val_acc'],   'r-s', label='Val',   linewidth=2, markersize=4)
            axes[1].set_xlabel('Epoch', fontsize=12)
            axes[1].set_ylabel('Accuracy (%)', fontsize=12)
            axes[1].set_title('Fold 1 - Accuracy (ViT Only)', fontsize=14, fontweight='bold')
            axes[1].legend(fontsize=10)
            axes[1].grid(True, alpha=0.3)

            axes[2].plot(epochs_range, fold_history['train_f1'], 'b-o', label='Train', linewidth=2, markersize=4)
            axes[2].plot(epochs_range, fold_history['val_f1'],   'r-s', label='Val',   linewidth=2, markersize=4)
            axes[2].set_xlabel('Epoch', fontsize=12)
            axes[2].set_ylabel('F1 Score', fontsize=12)
            axes[2].set_title('Fold 1 - F1 Score (ViT Only)', fontsize=14, fontweight='bold')
            axes[2].legend(fontsize=10)
            axes[2].grid(True, alpha=0.3)

            plt.tight_layout()
            plt.savefig('fold1_vitonly_training_curves.png', dpi=150)
            plt.close()
            print("  ‚úì Training curves saved")

            # 2. Detailed confusion matrix
            fig, ax = plt.subplots(figsize=(10, 8))
            cm_percent = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] * 100
            annotations = np.empty_like(cm).astype(str)
            for i in range(cm.shape[0]):
                for j in range(cm.shape[1]):
                    annotations[i, j] = f'{cm[i, j]}\n({cm_percent[i, j]:.1f}%)'
            sns.heatmap(cm, annot=annotations, fmt='', cmap='Blues',
                        xticklabels=CLASSES, yticklabels=CLASSES,
                        cbar_kws={'label': 'Count'}, ax=ax)
            ax.set_title(f'Fold 1 - Detailed Confusion Matrix (ViT Only)\nVal Acc: {metrics["accuracy"]:.2f}%',
                         fontsize=14, fontweight='bold')
            ax.set_ylabel('True Label', fontsize=12)
            ax.set_xlabel('Predicted Label', fontsize=12)
            plt.tight_layout()
            plt.savefig('fold1_vitonly_confusion_detailed.png', dpi=150)
            plt.close()
            print("  ‚úì Detailed confusion matrix saved")

            # 3. Metrics summary bar chart
            fig, ax = plt.subplots(figsize=(12, 6))
            metric_names  = ['Accuracy', 'Precision', 'Recall', 'F1 Score', 'Sensitivity', 'Specificity']
            metric_values = [
                metrics['accuracy'],
                metrics['precision']        * 100,
                metrics['recall']           * 100,
                metrics['f1']               * 100,
                metrics['avg_sensitivity']  * 100,
                metrics['avg_specificity']  * 100
            ]
            colors = ['#3498db', '#2ecc71', '#e74c3c', '#f39c12', '#9b59b6', '#1abc9c']
            bars = ax.bar(metric_names, metric_values, color=colors, alpha=0.7, edgecolor='black', linewidth=1.5)
            for bar, value in zip(bars, metric_values):
                ax.text(bar.get_x() + bar.get_width() / 2., bar.get_height(),
                        f'{value:.2f}%', ha='center', va='bottom', fontsize=11, fontweight='bold')
            ax.set_ylabel('Score (%)', fontsize=12)
            ax.set_title('Fold 1 - Validation Metrics Summary (ViT Only)', fontsize=14, fontweight='bold')
            ax.set_ylim([0, 105])
            ax.grid(True, alpha=0.3, axis='y')
            plt.xticks(rotation=15)
            plt.tight_layout()
            plt.savefig('fold1_vitonly_metrics_summary.png', dpi=150)
            plt.close()
            print("  ‚úì Metrics summary saved")

    # ============================================================================
    # SUMMARY ACROSS FOLDS
    # ============================================================================

    print("\n" + "="*70)
    print("üìä K-FOLD CROSS-VALIDATION SUMMARY")
    print("="*70)

    accuracies    = [r['val_acc']         for r in fold_results]
    precisions    = [r['val_precision']   for r in fold_results]
    recalls       = [r['val_recall']      for r in fold_results]
    f1_scores     = [r['val_f1']          for r in fold_results]
    sensitivities = [r['val_sensitivity'] for r in fold_results]
    specificities = [r['val_specificity'] for r in fold_results]

    print(f"\nüìà Cross-Validation Results ({N_FOLDS} folds):")
    print(f"  Accuracy:     {np.mean(accuracies):.2f}%  ¬± {np.std(accuracies):.2f}%")
    print(f"  Precision:    {np.mean(precisions):.4f} ¬± {np.std(precisions):.4f}")
    print(f"  Recall:       {np.mean(recalls):.4f} ¬± {np.std(recalls):.4f}")
    print(f"  F1 Score:     {np.mean(f1_scores):.4f} ¬± {np.std(f1_scores):.4f}")
    print(f"  Sensitivity:  {np.mean(sensitivities):.4f} ¬± {np.std(sensitivities):.4f}")
    print(f"  Specificity:  {np.mean(specificities):.4f} ¬± {np.std(specificities):.4f}")

    best_fold_idx = np.argmax(accuracies)
    best_fold     = fold_results[best_fold_idx]
    print(f"\nüèÜ Best Fold: {best_fold['fold']} (Val Acc: {best_fold['val_acc']:.2f}%)")

    # ============================================================================
    # EVALUATE BEST MODEL ON HOLDOUT TEST SET
    # ============================================================================

    print("\n" + "="*70)
    print("üß™ EVALUATING BEST MODEL ON HOLDOUT TEST SET")
    print("="*70)

    print(f"\nüìä Loading best model from Fold {best_fold['fold']}...")
    best_model = ViTBaseline(model_name=MODEL_NAME, num_classes=len(CLASSES)).to(device)
    best_model.load_state_dict(torch.load(best_fold['model_path']))

    test_dataset = ERBMAHEDataset(test_df, processor=processor, augment=False)
    test_loader  = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=0, pin_memory=True)

    criterion = FocalLoss(alpha=class_weights, gamma=2.0, reduction='mean')
    test_loss, test_acc, test_f1, test_pred, test_true = validate(best_model, test_loader, criterion, device)
    test_metrics = calculate_metrics(test_true, test_pred)

    print(f"\nüìà Test Set Results (Best Model ‚Äî Fold {best_fold['fold']}):")
    print(f"  Accuracy:               {test_metrics['accuracy']:.2f}%")
    print(f"  Precision:              {test_metrics['precision']:.4f}")
    print(f"  Recall:                 {test_metrics['recall']:.4f}")
    print(f"  F1 Score:               {test_metrics['f1']:.4f}")
    print(f"  Sensitivity (Normal):   {test_metrics['sensitivity_class0']:.4f}")
    print(f"  Sensitivity (Abnormal): {test_metrics['sensitivity_class1']:.4f}")
    print(f"  Specificity (Normal):   {test_metrics['specificity_class0']:.4f}")
    print(f"  Specificity (Abnormal): {test_metrics['specificity_class1']:.4f}")
    print(f"  Avg Sensitivity:        {test_metrics['avg_sensitivity']:.4f}")
    print(f"  Avg Specificity:        {test_metrics['avg_specificity']:.4f}")

    # Test confusion matrix
    test_cm = test_metrics['confusion_matrix']
    test_cm_percent = test_cm.astype('float') / test_cm.sum(axis=1)[:, np.newaxis] * 100
    test_annotations = np.empty_like(test_cm).astype(str)
    for i in range(test_cm.shape[0]):
        for j in range(test_cm.shape[1]):
            test_annotations[i, j] = f'{test_cm[i, j]}\n({test_cm_percent[i, j]:.1f}%)'

    plt.figure(figsize=(10, 8))
    sns.heatmap(test_cm, annot=test_annotations, fmt='', cmap='Greens',
                xticklabels=CLASSES, yticklabels=CLASSES,
                cbar_kws={'label': 'Count'})
    plt.title(f'Test Set Confusion Matrix (ViT Only)\n'
              f'Best Model Fold {best_fold["fold"]} ‚Äî Acc: {test_metrics["accuracy"]:.2f}%',
              fontsize=14, fontweight='bold')
    plt.ylabel('True Label', fontsize=12)
    plt.xlabel('Predicted Label', fontsize=12)
    plt.tight_layout()
    plt.savefig('test_vitonly_confusion_matrix.png', dpi=150)
    plt.close()
    print("\n‚úì Test confusion matrix saved")

    # Save results to CSV
    results_df = pd.DataFrame(fold_results)
    results_df.to_csv('kfold_vitonly_results.csv', index=False)
    print("‚úì Results saved to CSV")

    print("\n" + "="*70)
    print("‚úÖ TRAINING COMPLETE")
    print("="*70)
    print(f"\nüìä Final Summary:")
    print(f"  Model: ViT Baseline (NO CNN-CBAM, NO QUANTUM)")
    print(f"  Ablation: Pure ViT + Focal Loss")
    print(f"  Dataset: ERBMAHE (Abnormal vs Normal)")
    print(f"  Strategy: Method 2 K-Fold with 10% Test Set")
    print(f"  K-Fold CV Accuracy: {np.mean(accuracies):.2f}% ¬± {np.std(accuracies):.2f}%")
    print(f"  Test Set Accuracy:  {test_metrics['accuracy']:.2f}%")
    print(f"  Test Set F1 Score:  {test_metrics['f1']:.4f}")
    print("\n" + "="*70)

if __name__ == '__main__':
    main()

‚úì PyTorch & torchvision imported successfully
‚úì Transformers imported successfully
üîß Using device: cuda
üîß PyTorch version: 2.5.1+cu121
ViT Baseline (NO CNN-CBAM, NO QUANTUM)
ABLATION STUDY: Pure ViT + Focal Loss Baseline
ERBMAHE Dataset Classification
Binary Classification: Abnormal vs Normal
METHOD 2: K-FOLD (5 folds) WITH SEPARATE TEST SET (10%)

ü§ñ Base Model: google/vit-base-patch16-224
üö´ CNN-CBAM Branch: REMOVED
üö´ Quantum Layer: REMOVED
‚úÖ Pure ViT Baseline
üìä Classes: ['Abnormal', 'Normal']
üîÑ K-Fold: 5 folds on 90% data
üß™ Test Set: 10% holdout

üìÅ Loading dataset...


KeyboardInterrupt: 