In [None]:
# SETUP & MOUNT DRIVE

from google.colab import files, drive
import zipfile
import os

# Create directories
!mkdir -p /content/datasets

print(" Directories created!")
print("Now upload your 2 ZIP files manually...")

 Directories created!
Now upload your 2 ZIP files manually...


In [1]:
import os
import random
import numpy as np
import torch

def set_seed(seed=42):
    """Sets the seed for reproducibility."""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  #
    os.environ['PYTHONHASHSEED'] = str(seed)

    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(42)

In [None]:
import torch
torch.multiprocessing.set_start_method("spawn", force=True)

In [None]:
from tqdm import tqdm

def tqdm_colab(iterable, **kwargs):
    return tqdm(iterable, leave=False, dynamic_ncols=True)

In [None]:
# EXTRACT DATASETS

import zipfile
from pathlib import Path

# Find uploaded files
uploaded_files = list(Path('/content').glob('*.zip'))
print(f"Found {len(uploaded_files)} ZIP files:")
for f in uploaded_files:
    print(f"  - {f.name}")

# Extract patient-level split
patient_zip = [f for f in uploaded_files if 'Patient_Level' in f.name][0]
print(f"\n Extracting {patient_zip.name}")
with zipfile.ZipFile(patient_zip, 'r') as zip_ref:
    zip_ref.extractall('/content/datasets/patient_split')

# Extract image-level split
image_zip = [f for f in uploaded_files if 'ImageLevel' in f.name][0]
print(f" Extracting {image_zip.name}")
with zipfile.ZipFile(image_zip, 'r') as zip_ref:
    zip_ref.extractall('/content/datasets/normal_split')

print("\n Extraction complete!")

# Verify structure
print("\n Dataset structure:")
!ls -lh /content/datasets/patient_split
!ls -lh /content/datasets/normal_split

Found 2 ZIP files:
  - Figshare_Dataset_Patient_Level_Split.zip
  - Figshare_ImageLevel_Split_PNGs.zip

 Extracting Figshare_Dataset_Patient_Level_Split.zip
 Extracting Figshare_ImageLevel_Split_PNGs.zip

 Extraction complete!

 Dataset structure:
total 4.0K
drwxr-xr-x 5 root root 4.0K Jan 26 12:03 Figshare_Dataset
total 4.0K
drwxr-xr-x 5 root root 4.0K Jan 26 12:03 Figshare_ImageLevel


In [None]:
# INSTALL LIBRARIES

!pip install -q timm torchmetrics pillow matplotlib seaborn scikit-learn pandas tqdm

print(" All libraries installed!")

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/983.2 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m983.2/983.2 kB[0m [31m58.2 MB/s[0m eta [36m0:00:00[0m
[?25h All libraries installed!


In [None]:
# IMPORTS & CONFIG

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms, models
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score
from pathlib import Path
from tqdm import tqdm
import pandas as pd
import timm
import warnings
warnings.filterwarnings('ignore')

# Set style
sns.set_style('whitegrid')
plt.rcParams['figure.figsize'] = (12, 6)

# CONFIGURATION
class Config:
    # Dataset paths (will auto-detect structure)
    PATIENT_SPLIT_ROOT = '/content/datasets/patient_split'
    NORMAL_SPLIT_ROOT = '/content/datasets/normal_split'

    # Classes
    CLASSES = ['glioma', 'meningioma', 'pituitary']
    NUM_CLASSES = 3

    # Training hyperparameters
    IMG_SIZE = 224
    BATCH_SIZE = 32
    EPOCHS = 10
    LR = 1e-4
    WEIGHT_DECAY = 1e-4

    # Device
    DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

    # Experiment tracking
    RESULTS_DIR = '/content/results'

config = Config()

# Create results directory
Path(config.RESULTS_DIR).mkdir(exist_ok=True)

print(f" Configuration:")
print(f"  Device: {config.DEVICE}")
print(f"  Batch Size: {config.BATCH_SIZE}")
print(f"  Epochs: {config.EPOCHS}")
print(f"  Image Size: {config.IMG_SIZE}x{config.IMG_SIZE}")
print(f"  Classes: {config.CLASSES}")

 Configuration:
  Device: cuda
  Batch Size: 32
  Epochs: 10
  Image Size: 224x224
  Classes: ['glioma', 'meningioma', 'pituitary']


In [None]:
# DATASET CLASS (FIXED FOR 3-WAY SPLIT)

class BrainTumorDataset(Dataset):
    """
    Flexible dataset loader that auto-detects folder structure
    Supports Training/Validation/Testing splits
    """
    def __init__(self, root_dir, split='Training', transform=None):
        self.transform = transform
        self.images = []
        self.labels = []

        # Auto-detect structure
        root_path = Path(root_dir)

        # Possible split names
        split_aliases = {
            'Training': ['Training', 'train', 'Train'],
            'Validation': ['Validation', 'val', 'Val', 'valid', 'Valid'],
            'Testing': ['Testing', 'test', 'Test']
        }

        # Find the correct split folder
        base_path = None

        # Check direct path
        for alias in split_aliases.get(split, [split]):
            if (root_path / alias).exists():
                base_path = root_path / alias
                break

        # Check inside Figshare_Dataset folder
        if base_path is None:
            for alias in split_aliases.get(split, [split]):
                if (root_path / 'Figshare_Dataset' / alias).exists():
                    base_path = root_path / 'Figshare_Dataset' / alias
                    break

        # Search recursively
        if base_path is None:
            for alias in split_aliases.get(split, [split]):
                possible_paths = list(root_path.rglob(alias))
                if possible_paths:
                    base_path = possible_paths[0]
                    break

        if base_path is None:
            raise FileNotFoundError(
                f"Cannot find {split} folder in {root_dir}\n"
                f"Available folders: {[str(p) for p in root_path.rglob('*') if p.is_dir()]}"
            )

        print(f"Loading {split} from: {base_path}")

        # Load images
        for idx, class_name in enumerate(config.CLASSES):
            class_dir = base_path / class_name
            if not class_dir.exists():
                print(f"  Warning: {class_dir} not found, skipping...")
                continue

            # Support both .jpg and .png
            img_files = list(class_dir.glob('*.jpg')) + list(class_dir.glob('*.png'))

            for img_path in img_files:
                self.images.append(str(img_path))
                self.labels.append(idx)

            print(f"  {class_name}: {len(img_files)} images")

        print(f"Total: {len(self.images)} images\n")

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img = Image.open(self.images[idx]).convert('RGB')
        label = self.labels[idx]

        if self.transform:
            img = self.transform(img)

        return img, label

# Test dataset loading
print(" Testing dataset loading...")
try:
    train_ds = BrainTumorDataset(config.NORMAL_SPLIT_ROOT, 'Training')
    val_ds = BrainTumorDataset(config.NORMAL_SPLIT_ROOT, 'Validation')
    test_ds = BrainTumorDataset(config.NORMAL_SPLIT_ROOT, 'Testing')
    print(f" Normal split loaded successfully!")
    print(f"   Train: {len(train_ds)}, Val: {len(val_ds)}, Test: {len(test_ds)}")
except Exception as e:
    print(f" Error: {e}")

 Testing dataset loading...
Loading Training from: /content/datasets/normal_split/Figshare_ImageLevel/Training
  glioma: 467 images
  meningioma: 1021 images
  pituitary: 656 images
Total: 2144 images

Loading Validation from: /content/datasets/normal_split/Figshare_ImageLevel/Validation
  glioma: 131 images
  meningioma: 207 images
  pituitary: 121 images
Total: 459 images

Loading Testing from: /content/datasets/normal_split/Figshare_ImageLevel/Testing
  glioma: 110 images
  meningioma: 198 images
  pituitary: 153 images
Total: 461 images

 Normal split loaded successfully!
   Train: 2144, Val: 459, Test: 461


In [None]:
# CREATE DATALOADERS

from torch.utils.data import DataLoader

train_loader = DataLoader(
    train_ds,
    batch_size=config.BATCH_SIZE,
    shuffle=True,
    num_workers=0,
    pin_memory=True
)

val_loader = DataLoader(
    val_ds,
    batch_size=config.BATCH_SIZE,
    shuffle=False,
    num_workers=0
)

test_loader = DataLoader(
    test_ds,
    batch_size=config.BATCH_SIZE,
    shuffle=False,
    num_workers=0
)

print(" DataLoaders created successfully")

 DataLoaders created successfully


In [None]:
# DATA AUGMENTATION

train_transform = transforms.Compose([
    transforms.Resize((config.IMG_SIZE, config.IMG_SIZE)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomVerticalFlip(p=0.3),
    transforms.RandomRotation(15),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.1),
    transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

test_transform = transforms.Compose([
    transforms.Resize((config.IMG_SIZE, config.IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

print(" Transforms defined!")

 Transforms defined!


In [None]:
# ATTENTION MECHANISMS
class ChannelAttention(nn.Module):
    """Channel Attention Module (from CBAM)"""
    def __init__(self, channels, reduction=16):
        super().__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)

        self.fc = nn.Sequential(
            nn.Linear(channels, channels // reduction, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(channels // reduction, channels, bias=False)
        )
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        b, c, _, _ = x.size()

        # Channel attention
        avg_out = self.fc(self.avg_pool(x).view(b, c))
        max_out = self.fc(self.max_pool(x).view(b, c))

        channel_att = self.sigmoid(avg_out + max_out).view(b, c, 1, 1)
        return x * channel_att

class SpatialAttention(nn.Module):
    """Spatial Attention Module (from CBAM)"""
    def __init__(self, kernel_size=7):
        super().__init__()
        self.conv = nn.Conv2d(2, 1, kernel_size, padding=kernel_size//2, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        # Spatial attention
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)

        spatial_att = torch.cat([avg_out, max_out], dim=1)
        spatial_att = self.sigmoid(self.conv(spatial_att))

        return x * spatial_att

class CBAM(nn.Module):
    """Convolutional Block Attention Module"""
    def __init__(self, channels, reduction=16):
        super().__init__()
        self.channel_att = ChannelAttention(channels, reduction)
        self.spatial_att = SpatialAttention()

    def forward(self, x):
        x = self.channel_att(x)
        x = self.spatial_att(x)
        return x

print(" Attention modules defined!")

 Attention modules defined!


In [6]:
# MODEL ARCHITECTURES

import torch
import torch.nn as nn
import timm
from torchvision import models

# -------------------------------------------------------------------------
# ResNet-50 Baseline
# -------------------------------------------------------------------------
class ResNet50Classifier(nn.Module):
    def __init__(self, num_classes=3, pretrained=True):
        super().__init__()
        self.model = models.resnet50(pretrained=pretrained)
        self.model.fc = nn.Linear(self.model.fc.in_features, num_classes)

    def forward(self, x):
        return self.model(x)


# -------------------------------------------------------------------------
# DenseNet-121
# -------------------------------------------------------------------------
class DenseNet121Classifier(nn.Module):
    def __init__(self, num_classes=3, pretrained=True):
        super().__init__()
        self.model = models.densenet121(pretrained=pretrained)
        self.model.classifier = nn.Linear(
            self.model.classifier.in_features, num_classes
        )

    def forward(self, x):
        return self.model(x)


# -------------------------------------------------------------------------
# EfficientNet-B0
# -------------------------------------------------------------------------
class EfficientNetB0Classifier(nn.Module):
    def __init__(self, num_classes=3, pretrained=True):
        super().__init__()
        self.model = timm.create_model(
            'efficientnet_b0',
            pretrained=pretrained,
            num_classes=num_classes
        )

    def forward(self, x):
        return self.model(x)


# -------------------------------------------------------------------------
# Attention-Enhanced Swin (CBAM + SAFE FEATURE HANDLING)
# -------------------------------------------------------------------------
class AttentionEnhancedSwin(nn.Module):
    """
    Swin Transformer + CBAM + custom classifier
    SAFE across timm versions
    """
    def __init__(self, num_classes=3, pretrained=True):
        super().__init__()

        # Swin backbone WITHOUT classifier or pooling
        self.swin_backbone = timm.create_model(
            'swin_tiny_patch4_window7_224',
            pretrained=pretrained,
            num_classes=0,
            global_pool=''
        )

        self.feature_dim = self.swin_backbone.num_features  # 768

        # CBAM attention
        self.cbam = CBAM(self.feature_dim, reduction=16)

        # Global pooling
        self.global_pool = nn.AdaptiveAvgPool2d(1)

        # Classifier
        self.classifier = nn.Sequential(
            nn.Linear(self.feature_dim, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(inplace=True),
            nn.Dropout(0.4),

            nn.Linear(512, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(inplace=True),
            nn.Dropout(0.3),

            nn.Linear(256, num_classes)
        )

    def forward(self, x):
        # -------- SAFE feature extraction --------
        features = self.swin_backbone.forward_features(x)

        # Case 1: [B, N, C] → reshape to [B, C, H, W]
        if features.dim() == 3:
            B, N, C = features.shape
            H = W = int(N ** 0.5)
            features = features.transpose(1, 2).contiguous().view(B, C, H, W)

        # Case 2: [B, H, W, C] → NCHW
        elif features.dim() == 4 and features.shape[-1] == self.feature_dim:
            features = features.permute(0, 3, 1, 2).contiguous()

        # -------- Attention + pooling --------
        features = self.cbam(features)
        features = self.global_pool(features)
        features = features.flatten(1)

        # -------- Classification --------
        return self.classifier(features)


print(" All 4 models defined successfully!")
print(" Models:")
print("  1. ResNet-50")
print("  2. DenseNet-121")
print("  3. EfficientNet-B0")
print("  5. Attention-Enhanced Swin (CBAM)")


 All 4 models defined successfully!
 Models:
  1. ResNet-50
  2. DenseNet-121
  3. EfficientNet-B0
  5. Attention-Enhanced Swin (CBAM)


In [None]:
from tqdm import tqdm as tqdm_bar
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix

# TRAIN ONE EPOCH
def train_one_epoch(model, loader, criterion, optimizer, device, epoch):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    pbar = tqdm_bar(
        loader,
        desc=f"Epoch {epoch+1} [Train]",
        ascii=True,
        leave=True,
        mininterval=1.0
    )

    for images, labels in pbar:
        images = images.to(device)
        labels = labels.to(device)

        optimizer.zero_grad(set_to_none=True)
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * images.size(0)
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()

        pbar.set_postfix_str(
            f"loss={running_loss/total:.4f}, acc={100.*correct/total:.2f}%"
        )

    return running_loss / total, 100. * correct / total

# EVALUATION (WITH TQDM)
def evaluate(model, loader, criterion, device, desc="Eval"):
    model.eval()
    running_loss = 0.0
    all_preds, all_labels = [], []

    pbar = tqdm_bar(
        loader,
        desc=desc,
        ascii=True,
        leave=False,
        mininterval=1.0
    )

    with torch.no_grad():
        for images, labels in pbar:
            images = images.to(device)
            labels = labels.to(device)

            outputs = model(images)
            loss = criterion(outputs, labels)

            running_loss += loss.item() * images.size(0)
            _, predicted = outputs.max(1)

            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

            pbar.set_postfix_str(
                f"loss={running_loss/len(all_labels):.4f}"
            )

    acc = 100. * accuracy_score(all_labels, all_preds)
    return running_loss / len(all_labels), acc, np.array(all_preds), np.array(all_labels)

In [None]:
def plot_confusion_matrix(cm, class_names, model_name, save_path):
    """Plot and save confusion matrix"""
    plt.figure(figsize=(8, 6))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                xticklabels=class_names, yticklabels=class_names,
                cbar_kws={'label': 'Count'})
    plt.title(f'{model_name}\nConfusion Matrix', fontsize=14, fontweight='bold')
    plt.ylabel('True Label', fontsize=12)
    plt.xlabel('Predicted Label', fontsize=12)
    plt.tight_layout()
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.close()

def plot_training_history(history, model_name, save_path):
    """Plot training history"""
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

    # Loss plot
    ax1.plot(history['train_loss'], label='Train Loss', marker='o')
    ax1.plot(history['test_loss'], label='Test Loss', marker='s')
    ax1.set_xlabel('Epoch', fontsize=12)
    ax1.set_ylabel('Loss', fontsize=12)
    ax1.set_title(f'{model_name} - Loss Curves', fontsize=13, fontweight='bold')
    ax1.legend()
    ax1.grid(True, alpha=0.3)

    # Accuracy plot
    ax2.plot(history['train_acc'], label='Train Accuracy', marker='o')
    ax2.plot(history['test_acc'], label='Test Accuracy', marker='s')
    ax2.set_xlabel('Epoch', fontsize=12)
    ax2.set_ylabel('Accuracy (%)', fontsize=12)
    ax2.set_title(f'{model_name} - Accuracy Curves', fontsize=13, fontweight='bold')
    ax2.legend()
    ax2.grid(True, alpha=0.3)

    plt.tight_layout()
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.close()

print(" Training functions ready!")

 Training functions ready!


In [None]:
def train_model(model_name, model, train_loader, val_loader, test_loader,
                epochs, device, save_dir):

    print(f"\n{'='*70}")
    print(f"TRAINING: {model_name}")
    print(f"{'='*70}")

    model = model.to(device)

    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=config.LR,
        weight_decay=config.WEIGHT_DECAY
    )
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs)

    history = {
        "train_loss": [], "train_acc": [],
        "val_loss": [], "val_acc": [],
        "test_loss": [], "test_acc": []
    }

    best_val_acc = 0.0
    patience = 5
    wait = 0

    for epoch in range(epochs):

        train_loss, train_acc = train_one_epoch(
            model, train_loader, criterion, optimizer, device, epoch
        )

        val_loss, val_acc, _, _ = evaluate(
            model, val_loader, criterion, device, f"Epoch {epoch+1} [Val]"
        )

        test_loss, test_acc, _, _ = evaluate(
            model, test_loader, criterion, device, f"Epoch {epoch+1} [Test]"
        )

        scheduler.step()

        history["train_loss"].append(train_loss)
        history["train_acc"].append(train_acc)
        history["val_loss"].append(val_loss)
        history["val_acc"].append(val_acc)
        history["test_loss"].append(test_loss)
        history["test_acc"].append(test_acc)

        print(
            f"\nEpoch {epoch+1}/{epochs} | "
            f"Train: {train_acc:.2f}% | "
            f"Val: {val_acc:.2f}% | "
            f"Test: {test_acc:.2f}%"
        )

        if val_acc > best_val_acc:
            best_val_acc = val_acc
            wait = 0
            torch.save(
                model.state_dict(),
                save_dir / f"{model_name}_best.pth"
            )
            print( "Best model saved")
        else:
            wait += 1
            if wait >= patience:
                print(" Early stopping triggered")
                break

    model.load_state_dict(torch.load(save_dir / f"{model_name}_best.pth"))

    _, final_acc, preds, labels = evaluate(
        model, test_loader, criterion, device, "Final Test"
    )

    print("\nFinal Test Accuracy:", final_acc)

    return {
        "history": history,
        "best_val_acc": best_val_acc,
        "final_test_acc": final_acc
    }

In [4]:
# RUN ALL MODELS

def run_experiment(split_name, dataset_root):
    """Run complete experiment with Train/Val/Test split"""

    print(f"\n{'#'*80}")
    print(f"{'#'*80}")
    print(f"###  EXPERIMENT: {split_name.upper()}")
    print(f"{'#'*80}")
    print(f"{'#'*80}\n")

    # Create save directory
    save_dir = Path(config.RESULTS_DIR) / split_name
    save_dir.mkdir(exist_ok=True, parents=True)

    # Load datasets (3-way split)
    print(" Loading datasets...")
    train_dataset = BrainTumorDataset(dataset_root, 'Training', train_transform)
    val_dataset = BrainTumorDataset(dataset_root, 'Validation', test_transform)  # No augmentation for val
    test_dataset = BrainTumorDataset(dataset_root, 'Testing', test_transform)

    train_loader = DataLoader(
        train_dataset,
        batch_size=config.BATCH_SIZE,
        shuffle=True,
        num_workers=0,
        pin_memory=True
    )

    val_loader = DataLoader(
        val_dataset,
        batch_size=config.BATCH_SIZE,
        shuffle=False,
        num_workers=0,
        pin_memory=True
    )

    test_loader = DataLoader(
        test_dataset,
        batch_size=config.BATCH_SIZE,
        shuffle=False,
        num_workers=0,
        pin_memory=True
    )

    print(f" Train: {len(train_dataset)} samples")
    print(f" Val:   {len(val_dataset)} samples")
    print(f" Test:  {len(test_dataset)} samples\n")

    # Define models
    models_dict = {
        'ResNet-50': ResNet50Classifier(config.NUM_CLASSES),
        'DenseNet-121': DenseNet121Classifier(config.NUM_CLASSES),
        'EfficientNet-B0': EfficientNetB0Classifier(config.NUM_CLASSES),
        'Attention-Enhanced-Swin': AttentionEnhancedSwin(config.NUM_CLASSES)
    }

    # Train all models
    all_results = {}

    for model_name, model in models_dict.items():
        results = train_model(
            model_name, model,
            train_loader, val_loader, test_loader,  # ← Now 3 loaders
            config.EPOCHS, config.DEVICE,
            save_dir
        )
        all_results[model_name] = results

        # Save intermediate results
        torch.save(all_results, save_dir / 'all_results.pth')

    # Create results summary
    summary_df = pd.DataFrame({
        'Model': list(all_results.keys()),
        'Best Val Acc (%)': [all_results[m]['best_val_acc'] for m in all_results.keys()],
        'Final Test Acc (%)': [all_results[m]['final_test_acc'] for m in all_results.keys()]
    })

    summary_df = summary_df.sort_values('Final Test Acc (%)', ascending=False)

    print(f"\n{'='*80}")
    print(f" FINAL RESULTS - {split_name.upper()}")
    print(f"{'='*80}\n")
    print(summary_df.to_string(index=False))
    print(f"\n{'='*80}\n")

    # Save results
    summary_df.to_csv(save_dir / 'results_summary.csv', index=False)

    # Val accuracy comparison
    ax1.bar(summary_df['Model'], summary_df['Best Val Acc (%)'],
            color=['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd'])
    ax1.set_xlabel('Model', fontweight='bold', fontsize=11)
    ax1.set_ylabel('Validation Accuracy (%)', fontweight='bold', fontsize=11)
    ax1.set_title(f'Best Validation Accuracy - {split_name.upper()}',
                  fontsize=13, fontweight='bold')
    ax1.tick_params(axis='x', rotation=45)
    ax1.set_ylim([85, 100])
    ax1.grid(axis='y', alpha=0.3)

    for i, (model, acc) in enumerate(zip(summary_df['Model'], summary_df['Best Val Acc (%)'])):
        ax1.text(i, acc + 0.5, f'{acc:.2f}%', ha='center', va='bottom', fontweight='bold', fontsize=9)

    # Test accuracy comparison
    bars = ax2.bar(summary_df['Model'], summary_df['Final Test Acc (%)'],
                   color=['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd'])
    ax2.set_xlabel('Model', fontweight='bold', fontsize=11)
    ax2.set_ylabel('Test Accuracy (%)', fontweight='bold', fontsize=11)
    ax2.set_title(f'Final Test Accuracy - {split_name.upper()}',
                  fontsize=13, fontweight='bold')
    ax2.tick_params(axis='x', rotation=45)
    ax2.set_ylim([85, 100])
    ax2.grid(axis='y', alpha=0.3)

    for i, (model, acc) in enumerate(zip(summary_df['Model'], summary_df['Final Test Acc (%)'])):
        ax2.text(i, acc + 0.5, f'{acc:.2f}%', ha='center', va='bottom', fontweight='bold', fontsize=9)

    plt.tight_layout()
    plt.savefig(save_dir / 'model_comparison.png', dpi=300, bbox_inches='tight')
    plt.show()

    return all_results, summary_df

# RUN NORMAL SPLIT EXPERIMENT

print(" Starting NORMAL SPLIT experiment...")
normal_results, normal_summary = run_experiment('normal_split', config.NORMAL_SPLIT_ROOT)

Starting NORMAL SPLIT experiment...

################################################################################
################################################################################
###  EXPERIMENT: NORMAL_SPLIT
################################################################################
################################################################################

Loading datasets...
Loading Training from: /content/datasets/normal_split/Figshare_ImageLevel/Training
 glioma: 467 images
 meningioma: 1021 images
 pituitary: 656 images
Total: 2144 images

Loading Validation from: /content/datasets/normal_split/Figshare_ImageLevel/Validation
 glioma: 131 images
 meningioma: 207 images
 pituitary: 121 images
Total: 459 images

Loading Testing from: /content/datasets/normal_split/Figshare_ImageLevel/Testing
 glioma: 110 images
 meningioma: 198 images
 pituitary: 153 images
Total: 461 images

Train: 2144 samples
Val:   459 samples
Test:  461 samples


TRAINING: ResNet-

In [8]:
# RUN PATIENT-LEVEL SPLIT EXPERIMENT

print("\n\n Starting PATIENT-LEVEL SPLIT experiment...")
patient_results, patient_summary = run_experiment('patient_split', config.PATIENT_SPLIT_ROOT)

 Starting PATIENT-LEVEL SPLIT experiment...

################################################################################
################################################################################
###  EXPERIMENT: PATIENT_SPLIT
################################################################################
################################################################################

 Loading datasets...
Loading Training from: /content/datasets/patient_split/Figshare_Dataset/Training
  glioma: 1002 images
  meningioma: 509 images
  pituitary: 648 images
Total: 2159 images

Loading Validation from: /content/datasets/patient_split/Figshare_Dataset/Validation
  glioma: 223 images
  meningioma: 80 images
  pituitary: 130 images
Total: 433 images

Loading Testing from: /content/datasets/patient_split/Figshare_Dataset/Testing
  glioma: 201 images
  meningioma: 119 images
  pituitary: 152 images
Total: 472 images

 Train: 2159 samples
 Val:   433 samples
 Test:  472 samples


TR

In [None]:
import os
import glob
import torch
import torch.nn as nn
import timm
import cv2
import numpy as np
import matplotlib.pyplot as plt
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.image import show_cam_on_image, preprocess_image
from google.colab import drive
from datetime import datetime

# 1. Mount Drive
if not os.path.exists('/content/drive'):
    drive.mount('/content/drive')

# =============================================================================
# REDEFINE THE MODEL CLASS (To fix the Attribute Error)
# =============================================================================
class ChannelAttention(nn.Module):
    def __init__(self, channels, reduction=16):
        super().__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(channels, channels // reduction, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(channels // reduction, channels, bias=False)
        )
        self.sigmoid = nn.Sigmoid()
    def forward(self, x):
        b, c, _, _ = x.size()
        avg_out = self.fc(self.avg_pool(x).view(b, c))
        max_out = self.fc(self.max_pool(x).view(b, c))
        out = self.sigmoid(avg_out + max_out).view(b, c, 1, 1)
        return x * out

class SpatialAttention(nn.Module):
    def __init__(self, kernel_size=7):
        super().__init__()
        self.conv = nn.Conv2d(2, 1, kernel_size, padding=kernel_size//2, bias=False)
        self.sigmoid = nn.Sigmoid()
    def forward(self, x):
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        out = torch.cat([avg_out, max_out], dim=1)
        out = self.sigmoid(self.conv(out))
        return x * out

class CBAM(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.ca = ChannelAttention(channels)
        self.sa = SpatialAttention()
    def forward(self, x):
        x = self.ca(x)
        x = self.sa(x)
        return x

class AttentionEnhancedSwin(nn.Module):
    def __init__(self, num_classes=3):
        super().__init__()
        # FORCE name to 'swin' to match the GradCAM code
        self.swin = timm.create_model('swin_tiny_patch4_window7_224', pretrained=True, num_classes=0)
        self.feature_dim = self.swin.num_features
        self.cbam = CBAM(self.feature_dim)
        self.classifier = nn.Sequential(
            nn.Linear(self.feature_dim, 512),
            nn.ReLU(inplace=True),
            nn.Dropout(0.3),
            nn.Linear(512, num_classes)
        )
    def forward(self, x):
        features = self.swin(x)
        features_4d = features.unsqueeze(-1).unsqueeze(-1)
        features_attn = self.cbam(features_4d)
        features_attn = features_attn.squeeze(-1).squeeze(-1)
        return self.classifier(features_attn)

# =============================================================================
# SMART LOAD & GRAD-CAM
# =============================================================================
print(" Locating model file...")
# Find ANY file ending in .pth that looks like our model
possible_files = glob.glob('**/*Attention*best.pth', recursive=True)

if not possible_files:
    print(" Could not find 'Attention' model. Trying ANY .pth file...")
    possible_files = glob.glob('**/*.pth', recursive=True)

if possible_files:
    weights_path = possible_files[0] # Pick the first one found
    print(f" Found weights: {weights_path}")

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = AttentionEnhancedSwin(num_classes=3).to(device)

    try:
        model.load_state_dict(torch.load(weights_path, map_location=device))
        print(" Weights loaded successfully.")

        # --- Run GradCAM ---
        # Target the last LayerNorm of the Swin backbone
        target_layers = [model.swin.norm]
        cam = GradCAM(model=model, target_layers=target_layers)

        # Find images
        test_images = glob.glob('**/*.jpg', recursive=True)
        if len(test_images) > 0:
            # Pick a random image
            img_path = np.random.choice(test_images)

            rgb_img = cv2.imread(img_path, 1)[:, :, ::-1]
            rgb_img = cv2.resize(rgb_img, (224, 224))
            rgb_img_float = np.float32(rgb_img) / 255
            input_tensor = preprocess_image(rgb_img_float, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

            grayscale_cam = cam(input_tensor=input_tensor.to(device))[0, :]
            visualization = show_cam_on_image(rgb_img_float, grayscale_cam, use_rgb=True)

            # Save
            cv2.imwrite("Final_GradCAM.png", cv2.cvtColor(visualization, cv2.COLOR_RGB2BGR))

            plt.figure(figsize=(10, 5))
            plt.subplot(1,2,1); plt.imshow(rgb_img); plt.title("Original")
            plt.subplot(1,2,2); plt.imshow(visualization); plt.title("Attention Map")
            plt.show()
            print(" Grad-CAM generated and saved as 'Final_GradCAM.png'")
        else:
            print(" No jpg images found to visualize.")

    except Exception as e:
        print(f" Error during visualization (skipping): {e}")
else:
    print(" Critical: No .pth model file found. Cannot generate GradCAM.")

# =============================================================================
# BACKUP EVERYTHING
# =============================================================================
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
dest_dir = f"/content/drive/MyDrive/GCON_FINAL_SUBMISSION_{timestamp}"
os.makedirs(dest_dir, exist_ok=True)

print(f"\n Backing up to: {dest_dir}")
!cp *.pth "{dest_dir}/" 2>/dev/null
!cp *.png "{dest_dir}/" 2>/dev/null
!cp *.csv "{dest_dir}/" 2>/dev/null
!zip -r "{dest_dir}/PROJECT_CODE_DATA.zip" . -i *.py *.ipynb *.pth *.csv *.png

print(f" DONE. You can close the tab.")

 Locating model file...
 Found weights: results/patient_split/Attention-Enhanced-Swin_best.pth
 Error during visualization (skipping): Error(s) in loading state_dict for AttentionEnhancedSwin:
	Missing key(s) in state_dict: "swin.patch_embed.proj.weight", "swin.patch_embed.proj.bias", "swin.patch_embed.norm.weight", "swin.patch_embed.norm.bias", "swin.layers.0.blocks.0.norm1.weight", "swin.layers.0.blocks.0.norm1.bias", "swin.layers.0.blocks.0.attn.relative_position_bias_table", "swin.layers.0.blocks.0.attn.qkv.weight", "swin.layers.0.blocks.0.attn.qkv.bias", "swin.layers.0.blocks.0.attn.proj.weight", "swin.layers.0.blocks.0.attn.proj.bias", "swin.layers.0.blocks.0.norm2.weight", "swin.layers.0.blocks.0.norm2.bias", "swin.layers.0.blocks.0.mlp.fc1.weight", "swin.layers.0.blocks.0.mlp.fc1.bias", "swin.layers.0.blocks.0.mlp.fc2.weight", "swin.layers.0.blocks.0.mlp.fc2.bias", "swin.layers.0.blocks.1.norm1.weight", "swin.layers.0.blocks.1.norm1.bias", "swin.layers.0.blocks.1.attn.relative_


KeyboardInterrupt



In [6]:
import os
import torch
import torch.nn as nn
import timm
from torchvision import datasets, transforms
from torch.utils.data import DataLoader


BATCH_SIZE = 32
NUM_EPOCHS = 10
NUM_CLASSES = 3
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
save_dir = "/content/drive/MyDrive/GCON_FINAL_COMPARISON_RUN_V2"
os.makedirs(save_dir, exist_ok=True)

# Standard ImageNet Normalization
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])


def train_model_3split(model_name, split_name, dataset_path):
    print(f"\n{'='*60}")
    print(f" PROCESSING: {model_name} on {split_name} SPLIT")
    print(f"{'='*60}")

    # LOCATE FOLDERS (Handles your 3-folder structure)
    # We look for the folder that contains 'Training', 'Validation', 'Testing'
    target_root = ""
    for root, dirs, files in os.walk(dataset_path):
        if "Training" in dirs and "Validation" in dirs:
            target_root = root
            break

    if not target_root:
        # Fallback: Assume the path provided is the root
        target_root = dataset_path

    train_dir = os.path.join(target_root, "Training")
    val_dir   = os.path.join(target_root, "Validation")
    test_dir  = os.path.join(target_root, "Testing")

    # VERIFY PATHS
    if not os.path.exists(train_dir):
        print(f" ERROR: Could not find Training folder at {train_dir}")
        return 0.0

    print(f" Loading Data from: {target_root}")

    # LOAD DATA
    train_data = datasets.ImageFolder(train_dir, transform=transform)
    val_data   = datasets.ImageFolder(val_dir, transform=transform)
    test_data  = datasets.ImageFolder(test_dir, transform=transform)

    train_loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True)
    val_loader   = DataLoader(val_data, batch_size=BATCH_SIZE, shuffle=False)
    test_loader  = DataLoader(test_data, batch_size=BATCH_SIZE, shuffle=False)

    print(f" Stats: {len(train_data)} Train | {len(val_data)} Val | {len(test_data)} Test")

    # LOAD MODEL

    model = timm.create_model('swin_tiny_patch4_window7_224', pretrained=True, num_classes=NUM_CLASSES)

    model = model.to(DEVICE)
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=0.01)

    best_val_acc = 0.0

    # TRAINING LOOP
    for epoch in range(NUM_EPOCHS):
        # 1. TRAIN
        model.train()
        for images, labels in train_loader:
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

        # 2. VALIDATION
        model.eval()
        correct = 0; total = 0
        with torch.no_grad():
            for images, labels in val_loader:
                images, labels = images.to(DEVICE), labels.to(DEVICE)
                outputs = model(images)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

        val_acc = 100 * correct / total
        print(f"   Epoch {epoch+1}: Val Acc: {val_acc:.2f}%")

        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save(model.state_dict(), f"{save_dir}/{model_name}_{split_name}_best.pth")

    # 3. FINAL TEST
    print(f" Loading Best Model for FINAL TESTING...")
    model.load_state_dict(torch.load(f"{save_dir}/{model_name}_{split_name}_best.pth"))
    model.eval()

    correct = 0; total = 0
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    final_test_acc = 100 * correct / total
    print(f" FINAL TEST ACCURACY ({model_name}): {final_test_acc:.2f}%")
    return final_test_acc

# --- RUN EXPERIMENT A: NORMAL SPLIT ---
swin_normal = train_model_3split("Swin-Tiny", "NORMAL", "/content/dataset_normal")

# --- RUN EXPERIMENT B: PATIENT SPLIT ---
swin_patient = train_model_3split("Swin-Tiny", "PATIENT", "/content/dataset_patient")


 PROCESSING: Swin-Tiny on NORMAL SPLIT
 Loading Data from: /content/dataset_normal/Figshare_ImageLevel
 Stats: 2144 Train | 459 Val | 461 Test
   Epoch 1: Val Acc: 91.94%
   Epoch 2: Val Acc: 97.39%
   Epoch 3: Val Acc: 97.60%
   Epoch 4: Val Acc: 96.51%
   Epoch 5: Val Acc: 97.39%
   Epoch 6: Val Acc: 98.04%
   Epoch 7: Val Acc: 96.95%
   Epoch 8: Val Acc: 98.04%
   Epoch 9: Val Acc: 95.86%
   Epoch 10: Val Acc: 96.51%
🏁 Loading Best Model for FINAL TESTING...
 FINAL TEST ACCURACY (Swin-Tiny): 97.83%

 PROCESSING: Swin-Tiny on PATIENT SPLIT
 Loading Data from: /content/dataset_patient/Figshare_Dataset
 Stats: 2159 Train | 433 Val | 472 Test
   Epoch 1: Val Acc: 92.15%
   Epoch 2: Val Acc: 94.46%
   Epoch 3: Val Acc: 94.69%
   Epoch 4: Val Acc: 95.84%
   Epoch 5: Val Acc: 95.61%
   Epoch 6: Val Acc: 90.99%
   Epoch 7: Val Acc: 94.00%
   Epoch 8: Val Acc: 95.38%
   Epoch 9: Val Acc: 94.23%
   Epoch 10: Val Acc: 95.38%
🏁 Loading Best Model for FINAL TESTING...
 FINAL TEST ACCURACY (Swin

In [10]:
import os
import torch
import torch.nn as nn
import timm
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

BATCH_SIZE = 32
NUM_EPOCHS = 10
NUM_CLASSES = 3
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
save_dir = "/content/drive/MyDrive/GCON_FINAL_COMPARISON_RUN_V2"

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

def train_vit_patient(dataset_path):
    print(f"\n{'='*60}")
    print(f"🚀 TRAINING: ViT-Small (Standard) on PATIENT SPLIT")
    print(f"{'='*60}")

    # Locate Folders
    target_root = ""
    for root, dirs, files in os.walk(dataset_path):
        if "Training" in dirs and "Validation" in dirs:
            target_root = root
            break
    if not target_root: target_root = dataset_path

    train_dir = os.path.join(target_root, "Training")
    val_dir   = os.path.join(target_root, "Validation")
    test_dir  = os.path.join(target_root, "Testing")

    # Load Data
    train_data = datasets.ImageFolder(train_dir, transform=transform)
    val_data   = datasets.ImageFolder(val_dir, transform=transform)
    test_data  = datasets.ImageFolder(test_dir, transform=transform)

    train_loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True)
    val_loader   = DataLoader(val_data, batch_size=BATCH_SIZE, shuffle=False)
    test_loader  = DataLoader(test_data, batch_size=BATCH_SIZE, shuffle=False)

    print(f" Stats: {len(train_data)} Train | {len(val_data)} Val | {len(test_data)} Test")

    # Load ViT Model
    # We use the standard 'vit_small_patch16_224'
    model = timm.create_model('vit_small_patch16_224', pretrained=True, num_classes=NUM_CLASSES)
    model = model.to(DEVICE)

    criterion = nn.CrossEntropyLoss()
    # Standard AdamW
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=0.01)

    best_val_acc = 0.0

    # Train Loop
    for epoch in range(NUM_EPOCHS):
        model.train()
        for images, labels in train_loader:
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

        # Validation
        model.eval()
        correct = 0; total = 0
        with torch.no_grad():
            for images, labels in val_loader:
                images, labels = images.to(DEVICE), labels.to(DEVICE)
                outputs = model(images)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

        val_acc = 100 * correct / total
        print(f"   Epoch {epoch+1}: Val Acc: {val_acc:.2f}%")

        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save(model.state_dict(), f"{save_dir}/ViT_Small_PATIENT_best.pth")

    # Final Test
    print(f"🏁 Final Testing (Best ViT Model)...")
    model.load_state_dict(torch.load(f"{save_dir}/ViT_Small_PATIENT_best.pth"))
    model.eval()

    correct = 0; total = 0
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    final_acc = 100 * correct / total
    print(f"FINAL ViT ACCURACY (Patient Split): {final_acc:.2f}%")
    return final_acc

# RUN IT
vit_score = train_vit_patient("/content/dataset_patient")


🚀 TRAINING: ViT-Small (Standard) on PATIENT SPLIT
 Stats: 2159 Train | 433 Val | 472 Test
   Epoch 1: Val Acc: 89.38%
   Epoch 2: Val Acc: 92.61%
   Epoch 3: Val Acc: 95.84%
   Epoch 4: Val Acc: 93.07%
   Epoch 5: Val Acc: 94.69%
   Epoch 6: Val Acc: 92.38%
   Epoch 7: Val Acc: 90.76%
   Epoch 8: Val Acc: 94.46%
   Epoch 9: Val Acc: 89.84%
   Epoch 10: Val Acc: 95.15%
🏁 Final Testing (Best ViT Model)...
FINAL ViT ACCURACY (Patient Split): 94.28%


In [11]:
# --- RUN ViT (Standard) on NORMAL SPLIT ---
# This gives you the "Baseline" score to compare against the "Hard" score.

# Re-using the same function, just changing the path and save name.
def train_vit_normal(dataset_path):
    print(f"\n{'='*60}")
    print(f" TRAINING: ViT-Small (Standard) on NORMAL SPLIT")
    print(f"{'='*60}")

    # Locate Folders
    target_root = ""
    for root, dirs, files in os.walk(dataset_path):
        if "Training" in dirs and "Validation" in dirs:
            target_root = root
            break
    if not target_root: target_root = dataset_path

    train_dir = os.path.join(target_root, "Training")
    val_dir   = os.path.join(target_root, "Validation")
    test_dir  = os.path.join(target_root, "Testing")

    # Load Data (Standard Transform)
    train_data = datasets.ImageFolder(train_dir, transform=transform)
    val_data   = datasets.ImageFolder(val_dir, transform=transform)
    test_data  = datasets.ImageFolder(test_dir, transform=transform)

    train_loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True)
    val_loader   = DataLoader(val_data, batch_size=BATCH_SIZE, shuffle=False)
    test_loader  = DataLoader(test_data, batch_size=BATCH_SIZE, shuffle=False)

    print(f" Stats: {len(train_data)} Train | {len(val_data)} Val | {len(test_data)} Test")

    # Load Model
    model = timm.create_model('vit_small_patch16_224', pretrained=True, num_classes=NUM_CLASSES)
    model = model.to(DEVICE)

    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=0.01)

    best_val_acc = 0.0

    for epoch in range(NUM_EPOCHS):
        model.train()
        for images, labels in train_loader:
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

        # Validation
        model.eval()
        correct = 0; total = 0
        with torch.no_grad():
            for images, labels in val_loader:
                images, labels = images.to(DEVICE), labels.to(DEVICE)
                outputs = model(images)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

        val_acc = 100 * correct / total
        print(f"   Epoch {epoch+1}: Val Acc: {val_acc:.2f}%")

        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save(model.state_dict(), f"{save_dir}/ViT_Small_NORMAL_best.pth")

    # Final Test
    print(f"🏁 Final Testing (Normal Split)...")
    model.load_state_dict(torch.load(f"{save_dir}/ViT_Small_NORMAL_best.pth"))
    model.eval()

    correct = 0; total = 0
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    final_acc = 100 * correct / total
    print(f" FINAL ViT ACCURACY (Normal Split): {final_acc:.2f}%")
    return final_acc

# RUN IT
vit_normal_score = train_vit_normal("/content/dataset_normal")


 TRAINING: ViT-Small (Standard) on NORMAL SPLIT
 Stats: 2144 Train | 459 Val | 461 Test
   Epoch 1: Val Acc: 84.75%
   Epoch 2: Val Acc: 96.08%
   Epoch 3: Val Acc: 96.73%
   Epoch 4: Val Acc: 96.95%
   Epoch 5: Val Acc: 96.95%
   Epoch 6: Val Acc: 95.86%
   Epoch 7: Val Acc: 96.73%
   Epoch 8: Val Acc: 95.42%
   Epoch 9: Val Acc: 93.03%
   Epoch 10: Val Acc: 91.94%
🏁 Final Testing (Normal Split)...
 FINAL ViT ACCURACY (Normal Split): 97.40%
