# fMRI Learning Stage Classification using Vision Transformers

This notebook implements a Vision Transformer (ViT) model to classify different stages of learning from fMRI data.

The dataset used is the "Classification learning" dataset from OpenfMRI.

## Setup and Dependencies

In [None]:
!pip install einops nibabel seaborn tqdm

In [None]:
import os
import random
import urllib.request
import zipfile
import nibabel as nib
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
import torchvision.transforms as transforms
from einops import rearrange, repeat
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
from torch.amp import GradScaler, autocast
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix
from sklearn.utils.class_weight import compute_class_weight
from google.colab import drive

## Mount Google Drive and Download Dataset

 Mount Google Drive

In [None]:
drive.mount('/content/drive')

Set up paths

In [None]:
base_path = '/content/drive/MyDrive/learnedSpectrum'
zip_path = os.path.join(base_path, "ds000002_R2.0.5_raw.zip")
extract_path = os.path.join(base_path, "fmri_data")

Download dataset if not already present

In [None]:
if not os.path.exists(zip_path):
    print("Downloading dataset...")
    url = "https://s3.amazonaws.com/openneuro/ds000002/ds000002_R2.0.5/compressed/ds000002_R2.0.5_raw.zip"
    urllib.request.urlretrieve(url, zip_path)
    print("Download complete!")

Extract dataset if not already extracted

In [None]:
if not os.path.exists(extract_path):
    print("Extracting dataset...")
    with zipfile.ZipFile(zip_path, 'r') as zip_ref:
        zip_ref.extractall(extract_path)
    print("Extraction complete!")

## Data Augmentation

In [None]:
class AdvancedFMRIAugmentation:
    def __init__(self, p=0.5):
        self.p = p

    def __call__(self, x):
        # Ensure input is tensor
        x = torch.as_tensor(x)

        # More aggressive augmentation for small dataset
        if torch.rand(1) < self.p:
            # Gaussian noise with random intensity
            noise_level = torch.rand(1) * 0.1
            x = x + torch.randn_like(x) * noise_level

            # Random intensity scaling
            scale = 0.8 + torch.rand(1) * 0.4  # Scale between 0.8 and 1.2
            x = x * scale

            # Random rotations (90 degree increments)
            if torch.rand(1) < 0.5:
                x = torch.rot90(x, k=torch.randint(1, 4, (1,)).item(), dims=(1, 2))

            # Random flips
            if torch.rand(1) < 0.5:
                x = torch.flip(x, dims=[torch.randint(1, 4, (1,)).item()])

            # Elastic deformation (subtle)
            if torch.rand(1) < 0.3:
                sigma = torch.rand(1) * 3
                x = self._elastic_transform(x, sigma=sigma.item())

        return x

    def _elastic_transform(self, x, sigma=3):
        shape = x.shape
        dx = torch.randn(*shape) * sigma
        dy = torch.randn(*shape) * sigma
        dz = torch.randn(*shape) * sigma

        x_new = x + dx
        y_new = x + dy
        z_new = x + dz

        return (x_new + y_new + z_new) / 3.0

In [None]:
class SelfAttentionPooling(nn.Module):
    def __init__(self, in_dim):
        super().__init__()
        self.attention = nn.Sequential(
            nn.Linear(in_dim, in_dim // 2),
            nn.Tanh(),
            nn.Linear(in_dim // 2, 1),
            nn.Softmax(dim=1)
        )

    def forward(self, x):
        weights = self.attention(x)
        weighted = torch.sum(weights * x, dim=1)
        return weighted

## Dataset

In [None]:
class ImprovedfMRIDataset(Dataset):
    def __init__(self, root_dir, transform=None, phase='train'):
        self.root_dir = root_dir
        self.transform = transform
        self.phase = phase
        self.samples = []
        self.temporal_window = 8
        self._load_samples()

    def _find_dataset_root(self):
        """Recursively find the directory containing subject folders"""
        print(f"\nSearching for dataset root starting from: {self.root_dir}")

        for root, dirs, _ in os.walk(self.root_dir):
            if 'ds002_R2.0.5' in dirs:
                ds_path = os.path.join(root, 'ds002_R2.0.5')
                sub_dirs = [d for d in os.listdir(ds_path) if d.startswith('sub-')]
                if sub_dirs:
                    print(f"Found subject directories in: {ds_path}")
                    return ds_path, sorted(sub_dirs)

        return None, []

    def _normalize_volume(self, volume):
        """Robust normalization for fMRI data"""
        if volume.ndim == 4:
            # Normalize each timepoint independently
            for t in range(volume.shape[-1]):
                vol_t = volume[..., t]
                mask = vol_t != 0
                if mask.any():
                    p1, p99 = np.percentile(vol_t[mask], (1, 99))
                    volume[..., t] = np.clip((vol_t - p1) / (p99 - p1 + 1e-8), 0, 1)

                    # Z-score normalization within mask
                    mean = np.mean(volume[..., t][mask])
                    std = np.std(volume[..., t][mask])
                    volume[..., t][mask] = (volume[..., t][mask] - mean) / (std + 1e-8)
        else:
            # Single volume normalization
            mask = volume != 0
            if mask.any():
                p1, p99 = np.percentile(volume[mask], (1, 99))
                volume = np.clip((volume - p1) / (p99 - p1 + 1e-8), 0, 1)

                # Z-score normalization within mask
                mean = np.mean(volume[mask])
                std = np.std(volume[mask])
                volume[mask] = (volume[mask] - mean) / (std + 1e-8)

        return volume

    def _pad_or_crop(self, volume, target_shape):
        """Center-aligned padding or cropping"""
        if volume.ndim == 4:
            result = np.zeros((*target_shape, volume.shape[-1]), dtype=np.float32)
        else:
            result = np.zeros(target_shape, dtype=np.float32)

        for i in range(3):
            if volume.shape[i] > target_shape[i]:
                # Center crop
                start = (volume.shape[i] - target_shape[i]) // 2
                end = start + target_shape[i]
                slices = [slice(None)] * volume.ndim
                slices[i] = slice(start, end)
                volume = volume[tuple(slices)]
            else:
                # Center pad
                pad_before = (target_shape[i] - volume.shape[i]) // 2
                pad_after = target_shape[i] - volume.shape[i] - pad_before
                pad_width = [(0, 0)] * volume.ndim
                pad_width[i] = (pad_before, pad_after)
                volume = np.pad(volume, pad_width, mode='constant')

        return volume

    def _preprocess_volume(self, img_data):
        """Enhanced preprocessing pipeline with proper temporal handling"""
        # Handle 4D data with temporal sampling
        if len(img_data.shape) == 4:
            # Select evenly spaced timepoints
            total_timepoints = img_data.shape[-1]
            indices = np.linspace(0, total_timepoints-1, self.temporal_window, dtype=int)
            img_data = img_data[..., indices]

            # Ensure consistent spatial dimensions
            target_size = (64, 64, 64)
            temp_data = np.zeros((*target_size, len(indices)), dtype=np.float32)

            # Process each timepoint
            for t in range(len(indices)):
                vol = self._pad_or_crop(img_data[..., t], target_size)
                temp_data[..., t] = vol

            img_data = temp_data
        else:
            # Single volume case
            target_size = (64, 64, 64)
            img_data = self._pad_or_crop(img_data, target_size)

        # Normalize
        img_data = self._normalize_volume(img_data)

        return img_data.astype(np.float32)

    def _load_samples(self):
        print(f"\nLoading {self.phase} samples...")

        # Find the actual dataset root directory
        dataset_root, subjects = self._find_dataset_root()
        if not dataset_root:
            raise ValueError(f"Could not find subject directories in {self.root_dir}")

        print(f"Dataset root: {dataset_root}")
        print(f"Found {len(subjects)} subjects: {subjects}")

        all_samples = []

        # First, collect all valid samples
        for subject in subjects:
            func_path = os.path.join(dataset_root, subject, 'func')
            print(f"\nChecking subject {subject} func path: {func_path}")

            if not os.path.exists(func_path):
                print(f"No func directory found for subject {subject}")
                continue

            # Collect files by task
            subject_files = {}
            for file in os.listdir(func_path):
                if file.endswith('_bold.nii.gz') and 'task-' in file and 'mixedeventrelatedprobe' not in file:
                    task = file.split('task-')[1].split('_')[0]
                    if task not in subject_files:
                        subject_files[task] = []
                    subject_files[task].append(file)

            print(f"Found files by task: {subject_files.keys()}")

            subject_samples = []
            for task, files in subject_files.items():
                sorted_files = sorted(files)
                if len(sorted_files) >= 2:
                    # First run is early stage
                    early_file = sorted_files[0]
                    file_path = os.path.join(func_path, early_file)
                    subject_samples.append((file_path, 0))
                    print(f"Added early stage sample: {early_file}")

                    # Last run is late stage
                    late_file = sorted_files[-1]
                    file_path = os.path.join(func_path, late_file)
                    subject_samples.append((file_path, 1))
                    print(f"Added late stage sample: {late_file}")

            if len(subject_samples) >= 2:
                all_samples.extend(subject_samples)
                print(f"Added {len(subject_samples)} samples from subject {subject}")
            else:
                print(f"Skipped subject {subject} - insufficient samples")

        print(f"\nTotal collected samples: {len(all_samples)}")

        if len(all_samples) == 0:
            raise ValueError("No valid samples found in the dataset!")

        # Sort for consistent splits
        all_samples.sort(key=lambda x: x[0])

        # Split samples while maintaining class balance
        early_samples = [s for s in all_samples if s[1] == 0]
        late_samples = [s for s in all_samples if s[1] == 1]

        print(f"\nTotal early samples: {len(early_samples)}")
        print(f"Total late samples: {len(late_samples)}")

        # Calculate split indices based on phase
        if self.phase == 'train':
            early_split = early_samples[:int(len(early_samples) * 0.7)]
            late_split = late_samples[:int(len(late_samples) * 0.7)]
        elif self.phase == 'val':
            early_split = early_samples[int(len(early_samples) * 0.7):int(len(early_samples) * 0.85)]
            late_split = late_samples[int(len(late_samples) * 0.7):int(len(late_samples) * 0.85)]
        else:  # test
            early_split = early_samples[int(len(early_samples) * 0.85):]
            late_split = late_samples[int(len(late_samples) * 0.85):]

        self.samples = early_split + late_split

        # Print final class distribution
        early_count = sum(1 for s in self.samples if s[1] == 0)
        late_count = sum(1 for s in self.samples if s[1] == 1)

        print(f"\nFinal class distribution in {self.phase} set:")
        print(f"Early stage (0): {early_count} samples")
        print(f"Late stage (1): {late_count} samples")
        print(f"Total: {len(self.samples)} samples")

        if len(self.samples) == 0:
            raise ValueError(f"No samples found for {self.phase} set after splitting!")

    def __getitem__(self, idx):
        try:
            file_path, stage = self.samples[idx]
            img = nib.load(file_path)
            img_data = img.get_fdata()

            # Preprocess
            img_data = self._preprocess_volume(img_data)

            # Convert to tensor
            img_tensor = torch.from_numpy(img_data).float()

            # Handle dimensionality
            if len(img_tensor.shape) == 4:  # If we have temporal dimension
                # [H, W, D, T] -> [T, H, W, D] -> [1, H, W, D]
                img_tensor = img_tensor.permute(3, 0, 1, 2)
                # Average across temporal dimension
                img_tensor = img_tensor.mean(dim=0, keepdim=True)
            else:
                # Add channel dimension for 3D volume
                img_tensor = img_tensor.unsqueeze(0)

            if self.transform and self.phase == 'train':
                img_tensor = self.transform(img_tensor)

            return img_tensor, torch.tensor(stage, dtype=torch.long)

        except Exception as e:
            print(f"Error loading {file_path}: {str(e)}")
            # Return a properly shaped tensor in case of error
            return torch.zeros((1, 64, 64, 64)), torch.tensor(-1)

    def __len__(self):
        return len(self.samples)

## Model Architecture

3D Patch Embedding

In [None]:
class PatchEmbed3D(nn.Module):
    def __init__(self, img_size=64, patch_size=8, in_channels=1, embed_dim=256):
        super().__init__()
        self.img_size = (img_size, img_size, img_size)
        self.patch_size = (patch_size, patch_size, patch_size)
        self.n_patches = (img_size // patch_size) ** 3

        # Single projection layer
        self.proj = nn.Conv3d(in_channels, embed_dim,
                             kernel_size=patch_size, stride=patch_size)

        # Add LayerNorm
        self.norm = nn.LayerNorm(embed_dim)

    def forward(self, x):
        B, C, H, W, D = x.shape
        x = self.proj(x)
        x = rearrange(x, 'b e h w d -> b (h w d) e')
        x = self.norm(x)
        return x

Attention Mechanism

In [None]:
class Attention(nn.Module):
    def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim ** -0.5

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x):
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]

        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x

Transformer Block

In [None]:
class TransformerBlockWithResidual(nn.Module):
    """Transformer block with residual attention and stochastic depth"""
    def __init__(self, dim, num_heads, mlp_ratio=2., dropout=0.1, drop_path=0.1):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.attn = Attention(dim, num_heads=num_heads, dropout=dropout)
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()

        self.norm2 = nn.LayerNorm(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = nn.Sequential(
            nn.Linear(dim, mlp_hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(mlp_hidden_dim, dim),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        x = x + self.drop_path(self.attn(self.norm1(x)))
        x = x + self.drop_path(self.mlp(self.norm2(x)))
        return x

Complete Vision Transformer Model

In [None]:
class ImprovedVisionTransformer3D(nn.Module):
    def __init__(self, img_size=64, patch_size=8, in_channels=1, num_classes=2,
                 embed_dim=128, depth=4, num_heads=4, mlp_ratio=2., dropout=0.2):
        super().__init__()

        # Reduced complexity for small dataset
        self.patch_embed = nn.Sequential(
            nn.Conv3d(in_channels, embed_dim//2, kernel_size=3, stride=2, padding=1),
            nn.LayerNorm([embed_dim//2, img_size//2, img_size//2, img_size//2]),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Conv3d(embed_dim//2, embed_dim, kernel_size=3, stride=2, padding=1),
            nn.LayerNorm([embed_dim, img_size//4, img_size//4, img_size//4]),
            nn.GELU(),
            nn.Dropout(dropout)
        )

        num_patches = (img_size // 4) ** 3
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))

        # Transformer blocks with residual attention
        self.blocks = nn.ModuleList([
            TransformerBlockWithResidual(
                dim=embed_dim,
                num_heads=num_heads,
                mlp_ratio=mlp_ratio,
                dropout=dropout
            ) for _ in range(depth)
        ])

        # Improved classification head with self-attention pooling
        self.norm = nn.LayerNorm(embed_dim)
        self.attention_pool = SelfAttentionPooling(embed_dim)
        self.fc = nn.Sequential(
            nn.Linear(embed_dim, embed_dim//2),
            nn.LayerNorm(embed_dim//2),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(embed_dim//2, num_classes)
        )

        # Weight initialization
        self.apply(self._init_weights)
        nn.init.trunc_normal_(self.pos_embed, std=0.02)
        nn.init.trunc_normal_(self.cls_token, std=0.02)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.trunc_normal_(m.weight, std=0.02)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    def forward(self, x, return_features=False):
        # Forward pass with feature extraction option
        x = self.patch_embed(x)
        x = rearrange(x, 'b c h w d -> b (h w d) c')

        cls_token = self.cls_token.expand(x.shape[0], -1, -1)
        x = torch.cat((cls_token, x), dim=1)
        x = x + self.pos_embed[:, :x.size(1)]

        features = []
        for block in self.blocks:
            x = block(x)
            features.append(x[:, 0])  # Store CLS token features

        x = self.norm(x)
        x = self.attention_pool(x)
        logits = self.fc(x)

        if return_features:
            return logits, features
        return logits

In [None]:
def train_with_curriculum(model, train_loader, val_loader, criterion, optimizer, scheduler, num_epochs, device):
    scaler = GradScaler()
    best_val_acc = 0.0
    curriculum_phases = [
        {'epoch': 0, 'dropout': 0.4, 'aug_prob': 0.7},
        {'epoch': 5, 'dropout': 0.3, 'aug_prob': 0.5},
        {'epoch': 10, 'dropout': 0.2, 'aug_prob': 0.3}
    ]

    for epoch in range(num_epochs):
        # Update curriculum phase
        current_phase = next(phase for phase in reversed(curriculum_phases)
                           if epoch >= phase['epoch'])

        # Update model dropout
        for module in model.modules():
            if isinstance(module, nn.Dropout):
                module.p = current_phase['dropout']

        # Training
        model.train()
        train_loss = 0.0
        train_correct = 0
        total = 0

        for batch_idx, (inputs, labels) in enumerate(train_loader):
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()
            with autocast():
                outputs, features = model(inputs, return_features=True)
                # Main classification loss
                cls_loss = criterion(outputs, labels)

                # Feature consistency loss
                consistency_loss = feature_consistency_loss(features)

                # Total loss
                loss = cls_loss + 0.1 * consistency_loss

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

            train_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            train_correct += predicted.eq(labels).sum().item()

        # Validation
        model.eval()
        val_loss = 0.0
        val_correct = 0
        val_total = 0

        with torch.no_grad():
            for inputs, labels in val_loader:
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs)
                loss = criterion(outputs, labels)

                val_loss += loss.item()
                _, predicted = outputs.max(1)
                val_total += labels.size(0)
                val_correct += predicted.eq(labels).sum().item()

        # Print metrics
        print(f'\nEpoch {epoch+1}/{num_epochs}:')
        print(f'Train Loss: {train_loss/len(train_loader):.4f}, '
              f'Train Acc: {100.*train_correct/total:.2f}%')
        print(f'Val Loss: {val_loss/len(val_loader):.4f}, '
              f'Val Acc: {100.*val_correct/val_total:.2f}%')

        # Save best model
        val_acc = 100. * val_correct / val_total
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save({
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'val_acc': val_acc,
                'epoch': epoch
            }, 'best_model.pth')

        scheduler.step()

    return model

In [None]:
def feature_consistency_loss(features):
    """Calculate consistency loss between features from different layers"""
    loss = 0
    for i in range(len(features)-1):
        loss += F.mse_loss(F.normalize(features[i]), F.normalize(features[i+1]))
    return loss / (len(features)-1)

## Training Configuration

In [None]:
model_config = {
    'img_size': 64,
    'patch_size': 8,
    'in_channels': 1,
    'embed_dim': 128,
    'depth': 4,
    'num_heads': 4,
    'mlp_ratio': 2.,
    'num_classes': 2,
    'dropout': 0.2
}

In [None]:
training_config = {
    'learning_rate': 1e-4,
    'weight_decay': 0.05,
    'batch_size': 4,
    'num_epochs': 30
}

In [None]:
def setup_training():
    # Initialize datasets with improved augmentation
    augmentation = AdvancedFMRIAugmentation(p=0.5)
    train_dataset = ImprovedfMRIDataset(root_dir=extract_path, transform=augmentation, phase='train')
    val_dataset = ImprovedfMRIDataset(root_dir=extract_path, transform=None, phase='val')

    # Get class weights for balanced training
    labels = torch.tensor([sample[1] for sample in train_dataset.samples])
    all_classes = np.array([0, 1])

    class_counts = np.bincount(labels.numpy(), minlength=2)
    class_weights = compute_class_weight(
        class_weight='balanced',
        classes=all_classes,
        y=labels.numpy()
    )
    class_weights = torch.FloatTensor(class_weights).to(device)

    # Create data loaders with smaller batch size
    train_loader = DataLoader(
        train_dataset,
        batch_size=2,  # Reduced batch size
        sampler=WeightedRandomSampler(
            weights=[class_weights[label] for label in labels],
            num_samples=len(labels),
            replacement=True
        ),
        num_workers=2,
        pin_memory=True
    )

    val_loader = DataLoader(
        val_dataset,
        batch_size=2,  # Reduced batch size
        shuffle=False,
        num_workers=2,
        pin_memory=True
    )

    # Initialize model with smaller architecture
    model = ImprovedVisionTransformer3D(
        img_size=64,
        patch_size=8,
        in_channels=1,
        num_classes=2,
        embed_dim=128,  # Reduced embedding dimension
        depth=6,        # Reduced number of transformer blocks
        num_heads=4,    # Reduced number of attention heads
        mlp_ratio=2.,   # Reduced MLP ratio
        qkv_bias=True,
        drop_rate=0.3,
        attn_drop_rate=0.2
    ).to(device)

    # Loss function with label smoothing
    criterion = nn.CrossEntropyLoss(
        weight=class_weights,
        label_smoothing=0.1
    )

    # Optimizer with weight decay
    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=1e-4,
        weight_decay=0.05,
        betas=(0.9, 0.999)
    )

    # Learning rate scheduler
    scheduler = torch.optim.lr_scheduler.OneCycleLR(
        optimizer,
        max_lr=1e-4,
        epochs=num_epochs,
        steps_per_epoch=len(train_loader),
        pct_start=0.1,
        anneal_strategy='cos'
    )

    return model, criterion, optimizer, scheduler, train_loader, val_loader

## Training Loop

In [None]:
def train_model(model, train_loader, val_loader, criterion, optimizer, scheduler, num_epochs, device):
    # Enable gradient scaler for mixed precision training
    scaler = GradScaler() if torch.cuda.is_available() else None
    best_val_acc = 0.0
    history = {'train_loss': [], 'train_acc': [], 'val_loss': [], 'val_acc': []}

    # Enable automatic memory management
    torch.cuda.empty_cache()
    if torch.cuda.is_available():
        torch.backends.cudnn.benchmark = True

    for epoch in range(num_epochs):
        # Training phase
        model.train()
        train_loss = 0.0
        train_correct = 0
        train_total = 0

        train_bar = tqdm(train_loader, desc=f'Training Epoch {epoch+1}/{num_epochs}')
        for batch_idx, (inputs, labels) in enumerate(train_bar):
            try:
                # Move to GPU and convert to half precision
                inputs = inputs.to(device, non_blocking=True)
                labels = labels.to(device, non_blocking=True)

                optimizer.zero_grad(set_to_none=True)  # More efficient than zero_grad()

                if scaler is not None:
                    with autocast(device_type='cuda', dtype=torch.float16):
                        outputs = model(inputs)
                        loss = criterion(outputs, labels)

                    scaler.scale(loss).backward()
                    scaler.unscale_(optimizer)
                    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                    scaler.step(optimizer)
                    scaler.update()
                else:
                    outputs = model(inputs)
                    loss = criterion(outputs, labels)
                    loss.backward()
                    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                    optimizer.step()

                # Update metrics
                train_loss += loss.item()
                _, predicted = outputs.max(1)
                train_total += labels.size(0)
                train_correct += predicted.eq(labels).sum().item()

                # Update progress bar
                train_bar.set_postfix({
                    'batch': f'{batch_idx+1}/{len(train_loader)}',
                    'loss': f'{loss.item():.4f}',
                    'acc': f'{100.*train_correct/train_total:.2f}%'
                })

                # Clear GPU cache periodically
                if batch_idx % 5 == 0:
                    torch.cuda.empty_cache()

            except Exception as e:
                print(f"Error in training batch {batch_idx}: {str(e)}")
                torch.cuda.empty_cache()  # Clear cache on error
                continue

        # Validation phase
        model.eval()
        val_loss = 0.0
        val_correct = 0
        val_total = 0

        with torch.no_grad():
            val_bar = tqdm(val_loader, desc='Validation')
            for inputs, labels in val_bar:
                try:
                    inputs = inputs.to(device, non_blocking=True)
                    labels = labels.to(device, non_blocking=True)

                    with autocast(device_type='cuda', dtype=torch.float16):
                        outputs = model(inputs)
                        loss = criterion(outputs, labels)

                    val_loss += loss.item()
                    _, predicted = outputs.max(1)
                    val_total += labels.size(0)
                    val_correct += predicted.eq(labels).sum().item()

                    val_bar.set_postfix({
                        'loss': f'{loss.item():.4f}',
                        'acc': f'{100.*val_correct/val_total:.2f}%'
                    })

                except Exception as e:
                    print(f"Error in validation batch: {str(e)}")
                    torch.cuda.empty_cache()
                    continue

        # Calculate epoch metrics
        avg_train_loss = train_loss / len(train_loader)
        train_acc = 100. * train_correct / train_total
        avg_val_loss = val_loss / len(val_loader)
        val_acc = 100. * val_correct / val_total

        # Update history
        history['train_loss'].append(avg_train_loss)
        history['train_acc'].append(train_acc)
        history['val_loss'].append(avg_val_loss)
        history['val_acc'].append(val_acc)

        # Save best model (use CPU for saving)
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            model_state = {k: v.cpu() for k, v in model.state_dict().items()}
            torch.save({
                'epoch': epoch,
                'model_state_dict': model_state,
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict(),
                'best_val_acc': best_val_acc,
            }, 'best_model.pth')

        print(f'\nEpoch {epoch+1}/{num_epochs}:')
        print(f'Train Loss: {avg_train_loss:.4f}, Train Acc: {train_acc:.2f}%')
        print(f'Val Loss: {avg_val_loss:.4f}, Val Acc: {val_acc:.2f}%')
        print(f'Learning Rate: {optimizer.param_groups[0]["lr"]:.6f}')

        # Clear cache after each epoch
        torch.cuda.empty_cache()

    return history

## Main Execution

Set device

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

Set random seeds for reproducibility

In [None]:
torch.manual_seed(42)
np.random.seed(42)
random.seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(42)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

Setup training components

In [None]:
model = ImprovedVisionTransformer3D(**model_config).to(device)

Print model summary

In [None]:
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"\nModel Parameters:")
print(f"Total: {total_params:,}")
print(f"Trainable: {trainable_params:,}")

Training configuration

In [None]:
num_epochs = 18
early_stopping_patience = 15
early_stopping_counter = 0
best_val_acc = 0.0

Train the model

In [None]:
model = train_with_curriculum(model, train_loader, val_loader,
                            criterion, optimizer, scheduler,
                            config['num_epochs'], device)

Plot training history

In [None]:
plt.figure(figsize=(15, 5))

plt.subplot(1, 2, 1)
plt.plot(history['train_loss'], label='Train Loss')
plt.plot(history['val_loss'], label='Val Loss')
plt.title('Loss History')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.grid(True)

plt.subplot(1, 2, 2)
plt.plot(history['train_acc'], label='Train Acc')
plt.plot(history['val_acc'], label='Val Acc')
plt.title('Accuracy History')
plt.xlabel('Epoch')
plt.ylabel('Accuracy (%)')
plt.legend()
plt.grid(True)

plt.tight_layout()
plt.savefig('training_history.png')
plt.show()

Load best model for evaluation

In [None]:
print("\nLoading best model for evaluation...")
checkpoint = torch.load('best_model.pth')
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

Prepare test dataset and loader

Evaluate on test set

In [None]:
print("\nEvaluating model on test set...")
test_loss = 0.0
test_correct = 0
test_total = 0
all_preds = []
all_labels = []

In [None]:
def evaluate_model(model, test_loader, criterion, device):
    model.eval()
    test_loss = 0.0
    test_correct = 0
    test_total = 0
    all_preds = []
    all_labels = []

    print("\nEvaluating model on test set...")
    with torch.no_grad():
        for inputs, labels, _ in tqdm(test_loader, desc="Testing"):
            inputs, labels = inputs.to(device), labels.to(device)

            with autocast(device_type='cuda' if torch.cuda.is_available() else 'cpu'):
                outputs = model(inputs)
                loss = criterion(outputs, labels)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            test_total += labels.size(0)
            test_correct += predicted.eq(labels).sum().item()

            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    # Calculate metrics
    test_acc = 100. * test_correct / test_total
    print(f"\nTest Results:")
    print(f"Test Accuracy: {test_acc:.2f}%")

    # Convert to numpy arrays
    all_preds = np.array(all_preds)
    all_labels = np.array(all_labels)

    # Get unique classes actually present in the data
    unique_classes = np.unique(np.concatenate([all_labels, all_preds]))
    class_names = ['Early', 'Late']

    # Print class distribution
    print("\nClass Distribution:")
    for label in range(2):  # We expect binary classification
        count = np.sum(all_labels == label)
        print(f"Class {class_names[label]}: {count} samples")

    # Compute confusion matrix
    cm = confusion_matrix(all_labels, all_preds)
    print("\nConfusion Matrix:")
    print(cm)

    # Plot confusion matrix
    plt.figure(figsize=(8, 6))
    sns.heatmap(
        cm,
        annot=True,
        fmt='d',
        cmap='Blues',
        xticklabels=class_names,
        yticklabels=class_names
    )
    plt.title('Confusion Matrix')
    plt.xlabel('Predicted Label')
    plt.ylabel('True Label')
    plt.tight_layout()
    plt.savefig('confusion_matrix.png')
    plt.show()

    # Compute and print detailed metrics for available classes
    print("\nDetailed Metrics:")
    try:
        # Try to compute classification report with present classes
        report = classification_report(
            all_labels,
            all_preds,
            labels=range(len(unique_classes)),
            target_names=[class_names[i] for i in unique_classes],
            digits=4,
            zero_division=0
        )
        print(report)

        # Additional metrics
        metrics = {
            'Accuracy': test_acc / 100,
            'Precision': precision_score(all_labels, all_preds, average='weighted', zero_division=0),
            'Recall': recall_score(all_labels, all_preds, average='weighted', zero_division=0),
            'F1-Score': f1_score(all_labels, all_preds, average='weighted', zero_division=0)
        }

        print("\nAggregated Metrics:")
        for metric_name, value in metrics.items():
            print(f"{metric_name}: {value:.4f}")

    except Exception as e:
        print(f"Warning: Could not compute some metrics: {str(e)}")
        metrics = {'Accuracy': test_acc / 100}

    # Return results
    results = {
        'test_accuracy': test_acc,
        'test_loss': test_loss / len(test_loader),
        'predictions': all_preds,
        'true_labels': all_labels,
        'confusion_matrix': cm,
        'metrics': metrics
    }

    # Save results to file
    np.save('test_results.npy', results)
    print("\nResults saved to 'test_results.npy'")

    return results

In [None]:
print("Loading best model for evaluation...")
checkpoint = torch.load('best_model.pth')
model.load_state_dict(checkpoint['model_state_dict'])

In [None]:
def prepare_test_dataset():
    # Create test dataset with balanced classes
    test_dataset = ImprovedfMRIDataset(root_dir=extract_path, transform=None, phase='test')

    # Print class distribution before testing
    labels = [sample[1] for sample in test_dataset.samples]
    print("\nTest Set Class Distribution:")
    unique_labels, counts = np.unique(labels, return_counts=True)
    for label, count in zip(unique_labels, counts):
        print(f"Class {'Early' if label == 0 else 'Late'}: {count} samples")

    return test_dataset

# Create and print test dataset distribution
test_dataset = prepare_test_dataset()
test_loader = DataLoader(
    test_dataset,
    batch_size=8,
    shuffle=False,
    num_workers=2,
    pin_memory=True
)

In [None]:
results = evaluate_model(model, test_loader, criterion, device)

Save final results

In [None]:
results = {
    'history': history,
    'test_accuracy': test_acc,
    'test_predictions': all_preds,
    'test_labels': all_labels,
    'best_val_accuracy': checkpoint['best_val_acc']
}

In [None]:
torch.save({
    'model_state_dict': model.state_dict(),
    'results': results,
    'training_args': {
        'num_epochs': num_epochs,
        'batch_size': 8,
        'learning_rate': optimizer.param_groups[0]['lr'],
        'weight_decay': 0.05,
        'architecture': str(model)
    }
}, 'final_model_and_results.pth')

In [None]:
print("\nTraining complete! Model and results saved.")
print(f"Best validation accuracy: {checkpoint['best_val_acc']:.2f}%")
print(f"Final test accuracy: {test_acc:.2f}%")