# fMRI Learning Stage Classification using Vision Transformers

This notebook implements a Vision Transformer (ViT) model to classify different stages of learning from fMRI data.

The dataset used is the "Classification learning" dataset from OpenfMRI.

## Setup and Dependencies

In [None]:
!pip install einops nibabel seaborn tqdm

In [None]:
import os
import urllib.request
import zipfile
import nibabel as nib
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
import torchvision.transforms as transforms
from einops import rearrange
from torch.optim.lr_scheduler import CosineAnnealingLR
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
import seaborn as sns
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix
from sklearn.utils.class_weight import compute_class_weight
from google.colab import drive

## Mount Google Drive and Download Dataset

 Mount Google Drive

In [None]:
drive.mount('/content/drive')

Set up paths

In [None]:
base_path = '/content/drive/MyDrive/learnedSpectrum'
zip_path = os.path.join(base_path, "ds000002_R2.0.5_raw.zip")
extract_path = os.path.join(base_path, "fmri_data")

Download dataset if not already present

In [None]:
if not os.path.exists(zip_path):
    print("Downloading dataset...")
    url = "https://s3.amazonaws.com/openneuro/ds000002/ds000002_R2.0.5/compressed/ds000002_R2.0.5_raw.zip"
    urllib.request.urlretrieve(url, zip_path)
    print("Download complete!")

Extract dataset if not already extracted

In [None]:
if not os.path.exists(extract_path):
    print("Extracting dataset...")
    with zipfile.ZipFile(zip_path, 'r') as zip_ref:
        zip_ref.extractall(extract_path)
    print("Extraction complete!")

## Model Components

### 3D Patch Embedding

In [None]:
class PatchEmbed3D(nn.Module):
    def __init__(self, img_size=64, patch_size=8, in_channels=1, embed_dim=256):
        super().__init__()
        self.img_size = (img_size, img_size, img_size)
        self.patch_size = (patch_size, patch_size, patch_size)
        self.n_patches = (img_size // patch_size) ** 3

        # Single projection layer
        self.proj = nn.Conv3d(in_channels, embed_dim,
                             kernel_size=patch_size, stride=patch_size)

        # Add LayerNorm
        self.norm = nn.LayerNorm(embed_dim)

    def forward(self, x):
        B, C, H, W, D = x.shape
        x = self.proj(x)
        x = rearrange(x, 'b e h w d -> b (h w d) e')
        x = self.norm(x)
        return x

### Attention Mechanism

In [None]:
class Attention(nn.Module):
    def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim ** -0.5

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x):
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]

        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x

### Transformer Block

In [None]:
class TransformerBlock(nn.Module):
    def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
        self.norm2 = nn.LayerNorm(dim)
        self.mlp = nn.Sequential(
            nn.Linear(dim, int(dim * mlp_ratio)),
            nn.GELU(),
            nn.Dropout(drop),
            nn.Linear(int(dim * mlp_ratio), dim),
            nn.Dropout(drop)
        )

    def forward(self, x):
        x = x + self.attn(self.norm1(x))
        x = x + self.mlp(self.norm2(x))
        return x

### Complete Vision Transformer Model

In [None]:
class ImprovedVisionTransformer3D(nn.Module):
    def __init__(self, img_size=64, patch_size=8, in_channels=1, num_classes=2,
                 embed_dim=256, depth=6, num_heads=8, mlp_ratio=4., qkv_bias=True,
                 drop_rate=0.1, attn_drop_rate=0.1):
        super().__init__()

        # Single patch embedding with fixed dimension
        self.patch_embed = PatchEmbed3D(
            img_size=img_size,
            patch_size=patch_size,
            in_channels=in_channels,
            embed_dim=embed_dim
        )

        num_patches = self.patch_embed.n_patches

        # Classification token and position embedding
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))

        # Dropout and layer norm
        self.pos_drop = nn.Dropout(p=drop_rate)
        self.norm_pre = nn.LayerNorm(embed_dim)

        # Transformer blocks
        self.blocks = nn.ModuleList([
            TransformerBlock(
                dim=embed_dim,
                num_heads=num_heads,
                mlp_ratio=mlp_ratio,
                qkv_bias=qkv_bias,
                drop=drop_rate,
                attn_drop=attn_drop_rate
            ) for _ in range(depth)
        ])

        # Output head
        self.norm = nn.LayerNorm(embed_dim)
        self.pre_head = nn.Sequential(
            nn.Linear(embed_dim, embed_dim),
            nn.GELU(),
            nn.Dropout(p=drop_rate),
            nn.Linear(embed_dim, embed_dim // 2),
            nn.GELU(),
            nn.Dropout(p=drop_rate)
        )
        self.head = nn.Linear(embed_dim // 2, num_classes)

        # Initialize weights
        self._init_weights()

        print(f"Model initialized with:")
        print(f"- Patch size: {patch_size}x{patch_size}x{patch_size}")
        print(f"- Number of patches: {num_patches}")
        print(f"- Embedding dimension: {embed_dim}")
        print(f"- Number of transformer blocks: {depth}")
        print(f"- Number of attention heads: {num_heads}")

    def _init_weights(self):
        nn.init.normal_(self.cls_token, std=0.02)
        nn.init.normal_(self.pos_embed, std=0.02)
        self.apply(self._init_fn)

    def _init_fn(self, m):
        if isinstance(m, nn.Linear):
            torch.nn.init.xavier_uniform_(m.weight)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    def forward(self, x):
        # Patch embedding
        x = self.patch_embed(x)

        # Add classification token
        B = x.shape[0]
        cls_tokens = self.cls_token.expand(B, -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)

        # Add position embedding and dropout
        x = x + self.pos_embed
        x = self.pos_drop(x)

        # Apply transformer blocks
        x = self.norm_pre(x)
        for block in self.blocks:
            x = block(x)

        # Classification head
        x = self.norm(x)
        x = x[:, 0]  # Take cls token only
        x = self.pre_head(x)
        x = self.head(x)

        return x

## Dataset Implementation

In [None]:
class fMRIDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = os.path.join(root_dir, 'ds002_R2.0.5')
        self.transform = transform
        self.samples = []
        self._load_samples()

    def _load_samples(self):
        print("Loading samples...")
        for subject in os.listdir(self.root_dir):
            if not subject.startswith('sub-'):
                continue

            func_path = os.path.join(self.root_dir, subject, 'func')
            if not os.path.exists(func_path):
                continue

            for file in os.listdir(func_path):
                if not file.endswith('_bold.nii.gz'):
                    continue

                file_path = os.path.join(func_path, file)
                task = file.split('task-')[1].split('_')[0]
                run = int(file.split('run-')[1].split('_')[0])
                stage = 0 if run == 1 else 1

                self.samples.append((file_path, stage, task))
        print(f"Loaded {len(self.samples)} samples")

    def _pad_or_crop(self, volume, target_shape):
        """
        Pad or crop a 3D volume to target shape.

        Args:
            volume (np.ndarray): Input volume of shape (H, W, D)
            target_shape (tuple): Desired output shape

        Returns:
            np.ndarray: Processed volume of shape target_shape
        """
        result = np.zeros(target_shape, dtype=np.float32)

        # Calculate padding/cropping for each dimension
        pad_width = []
        slices = []
        for i in range(3):
            if volume.shape[i] > target_shape[i]:
                # Need to crop
                start = (volume.shape[i] - target_shape[i]) // 2
                end = start + target_shape[i]
                slices.append(slice(start, end))
            else:
                # Need to pad
                pad_before = (target_shape[i] - volume.shape[i]) // 2
                pad_after = target_shape[i] - volume.shape[i] - pad_before
                pad_width.append((pad_before, pad_after))
                slices.append(slice(None))

        # Handle cropping first
        if volume.shape[0] > target_shape[0]:
            volume = volume[slices[0], :, :]
        if volume.shape[1] > target_shape[1]:
            volume = volume[:, slices[1], :]
        if volume.shape[2] > target_shape[2]:
            volume = volume[:, :, slices[2]]

        # Handle padding
        if volume.shape[0] < target_shape[0]:
            pad_before = (target_shape[0] - volume.shape[0]) // 2
            result[pad_before:pad_before+volume.shape[0], :volume.shape[1], :volume.shape[2]] = volume
        else:
            result[:, :volume.shape[1], :volume.shape[2]] = volume

        if volume.shape[1] < target_shape[1]:
            pad_before = (target_shape[1] - volume.shape[1]) // 2
            temp = np.zeros_like(result)
            temp[:, pad_before:pad_before+volume.shape[1], :] = result[:, :volume.shape[1], :]
            result = temp

        if volume.shape[2] < target_shape[2]:
            pad_before = (target_shape[2] - volume.shape[2]) // 2
            temp = np.zeros_like(result)
            temp[:, :, pad_before:pad_before+volume.shape[2]] = result[:, :, :volume.shape[2]]
            result = temp

        return result

    def _preprocess_volume(self, img_data):
        """
        Preprocess fMRI volume.

        Args:
            img_data (np.ndarray): Input fMRI data

        Returns:
            np.ndarray: Preprocessed volume
        """
        # Handle 4D data
        if len(img_data.shape) == 4:
            img_data = np.mean(img_data, axis=-1)

        # Convert to float32
        img_data = img_data.astype(np.float32)

        # Z-score normalization
        mean = np.mean(img_data)
        std = np.std(img_data)
        img_data = (img_data - mean) / (std + 1e-8)

        # Pad or crop to target size
        img_data = self._pad_or_crop(img_data, (64, 64, 64))

        return img_data

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.item()

        file_path, stage, task = self.samples[idx]

        try:
            # Load the NIfTI file
            img = nib.load(file_path)
            img_data = img.get_fdata()

            print(f"Original shape for {file_path}: {img_data.shape}")

            # Preprocess
            img_data = self._preprocess_volume(img_data)

            # Add channel dimension
            img_data = img_data[np.newaxis, ...].astype(np.float32)

            # Convert to tensor
            img_tensor = torch.from_numpy(img_data).float()

            if self.transform:
                img_tensor = self.transform(img_tensor)

            return img_tensor, torch.tensor(stage, dtype=torch.long), task

        except Exception as e:
            print(f"Error loading {file_path}: {str(e)}")
            # Return valid dummy data in case of error
            return (
                torch.zeros((1, 64, 64, 64), dtype=torch.float32),
                torch.tensor(-1, dtype=torch.long),
                "error"
            )

def visualize_sample(self, idx):
    """
    Visualize a sample from all three anatomical perspectives.

    Args:
        idx (int): Index of the sample to visualize
    """
    print(f"\nVisualizing sample {idx} from all perspectives:")
    for axis in range(3):
        self.visualize_slice(idx, axis=axis)

    def __getitem__(self, idx):
        file_path, stage, task = self.samples[idx]
        try:
            img = nib.load(file_path).get_fdata()

            img = (img - img.mean()) / (img.std() + 1e-8)

            current_shape = img.shape[:3]
            target_shape = (64, 64, 64)

            padded_img = np.zeros(target_shape)
            for dim in range(3):
                if current_shape[dim] > target_shape[dim]:
                    start = (current_shape[dim] - target_shape[dim]) // 2
                    if dim == 0:
                        img = img[start:start+target_shape[dim], :, :]
                    elif dim == 1:
                        img = img[:, start:start+target_shape[dim], :]
                    else:
                        img = img[:, :, start:start+target_shape[dim]]
                else:
                    start = (target_shape[dim] - current_shape[dim]) // 2
                    if dim == 0:
                        padded_img[start:start+current_shape[dim], :, :] = img
                    elif dim == 1:
                        padded_img[:, start:start+current_shape[dim], :] = img
                    else:
                        padded_img[:, :, start:start+current_shape[dim]] = img
                    img = padded_img

            img = img[np.newaxis, ...]

            img_tensor = torch.from_numpy(img).float()

            if self.transform:
                img_tensor = self.transform(img_tensor)

            return img_tensor, stage, task

        except Exception as e:
            print(f"Error loading {file_path}: {str(e)}")
            return torch.zeros((1, 64, 64, 64)), -1, "error"

In [None]:
def test_dataset(dataset, num_samples=3):
    print("\nTesting dataset access...")
    for i in range(min(num_samples, len(dataset))):
        img, stage, task = dataset[i]
        print(f"\nSample {i}:")
        print(f"Image shape: {img.shape}")
        print(f"Image stats - Min: {img.min():.4f}, Max: {img.max():.4f}, Mean: {img.mean():.4f}, Std: {img.std():.4f}")
        print(f"Stage: {stage}")
        print(f"Task: {task}")

        # Visualize middle slice
        plt.figure(figsize=(8, 8))
        plt.imshow(img[0, :, :, 32], cmap='gray')
        plt.colorbar()
        plt.title(f'Sample {i} - Middle Slice\nTask: {task}, Stage: {"Early" if stage == 0 else "Late"}')
        plt.show()

## Training Functions

In [None]:
def train_model_improved(model, train_loader, val_loader, criterion, optimizer, scheduler, num_epochs, device):
    scaler = torch.cuda.amp.GradScaler()
    best_val_acc = 0.0
    history = {'train_loss': [], 'train_acc': [], 'val_loss': [], 'val_acc': []}

    # Warmup parameters
    warmup_epochs = 5
    warmup_lr_init = 1e-6
    base_lr = optimizer.param_groups[0]['lr']

    for epoch in range(num_epochs):
        # Warmup learning rate
        if epoch < warmup_epochs:
            lr = warmup_lr_init + (base_lr - warmup_lr_init) * epoch / warmup_epochs
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr

        # Training phase
        model.train()
        train_loss = 0.0
        train_correct = 0
        train_total = 0

        train_bar = tqdm(train_loader, desc=f'Training Epoch {epoch+1}/{num_epochs}')
        for inputs, labels, _ in train_bar:
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()

            # Mixed precision training
            with torch.cuda.amp.autocast():
                outputs = model(inputs)
                loss = criterion(outputs, labels)

            scaler.scale(loss).backward()

            # Gradient clipping
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

            scaler.step(optimizer)
            scaler.update()

            train_loss += loss.item()
            _, predicted = outputs.max(1)
            train_total += labels.size(0)
            train_correct += predicted.eq(labels).sum().item()

            train_bar.set_postfix({
                'loss': f'{loss.item():.4f}',
                'acc': f'{100.*train_correct/train_total:.2f}%',
                'lr': f'{optimizer.param_groups[0]["lr"]:.6f}'
            })

        avg_train_loss = train_loss / len(train_loader)
        train_acc = 100. * train_correct / train_total

        # Validation phase
        model.eval()
        val_loss = 0.0
        val_correct = 0
        val_total = 0

        with torch.no_grad():
            val_bar = tqdm(val_loader, desc='Validation')
            for inputs, labels, _ in val_bar:
                inputs, labels = inputs.to(device), labels.to(device)

                with torch.cuda.amp.autocast():
                    outputs = model(inputs)
                    loss = criterion(outputs, labels)

                val_loss += loss.item()
                _, predicted = outputs.max(1)
                val_total += labels.size(0)
                val_correct += predicted.eq(labels).sum().item()

                val_bar.set_postfix({
                    'loss': f'{loss.item():.4f}',
                    'acc': f'{100.*val_correct/val_total:.2f}%'
                })

        avg_val_loss = val_loss / len(val_loader)
        val_acc = 100. * val_correct / val_total

        # Update learning rate
        scheduler.step(avg_val_loss)

        # Save best model
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict(),
                'best_val_acc': best_val_acc,
            }, 'best_model.pth')

        # Update history
        history['train_loss'].append(avg_train_loss)
        history['train_acc'].append(train_acc)
        history['val_loss'].append(avg_val_loss)
        history['val_acc'].append(val_acc)

        print(f'\nEpoch {epoch+1}/{num_epochs}:')
        print(f'Train Loss: {avg_train_loss:.4f}, Train Acc: {train_acc:.2f}%')
        print(f'Val Loss: {avg_val_loss:.4f}, Val Acc: {val_acc:.2f}%')
        print(f'Learning Rate: {optimizer.param_groups[0]["lr"]:.6f}\n')

    return history


In [None]:
def plot_training_history(history):
    plt.figure(figsize=(12, 4))

    plt.subplot(1, 2, 1)
    plt.plot(history['train_loss'], label='Train Loss')
    plt.plot(history['val_loss'], label='Val Loss')
    plt.title('Loss History')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()

    plt.subplot(1, 2, 2)
    plt.plot(history['train_acc'], label='Train Acc')
    plt.plot(history['val_acc'], label='Val Acc')
    plt.title('Accuracy History')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy (%)')
    plt.legend()

    plt.tight_layout()
    plt.show()

In [None]:
def plot_confusion_matrix(true_labels, pred_labels):
    cm = confusion_matrix(true_labels, pred_labels)
    plt.figure(figsize=(8, 6))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
    plt.title('Confusion Matrix')
    plt.ylabel('True Label')
    plt.xlabel('Predicted Label')
    plt.show()

## Data Preparation and Model Training

 Data augmentation for fMRI

In [None]:
class fMRIAugmentation:
    def __init__(self, p=0.5):
        self.p = p

    def gaussian_noise(self, x, std=0.01):
        if torch.rand(1) < self.p:
            return x + torch.randn_like(x) * std
        return x

    def random_flip(self, x):
        if torch.rand(1) < self.p:
            dims = torch.randint(1, 4, (1,)).item()
            return torch.flip(x, [dims])
        return x

    def random_rotate(self, x):
        if torch.rand(1) < self.p:
            k = torch.randint(1, 4, (1,)).item()
            dim1, dim2 = torch.randperm(3)[:2] + 1
            return torch.rot90(x, k, [dim1, dim2])
        return x

    def __call__(self, x):
        x = self.gaussian_noise(x)
        x = self.random_flip(x)
        x = self.random_rotate(x)
        return x

Initialize dataset

In [None]:
print("Initializing dataset...")
root_dir = '/content/drive/MyDrive/learnedSpectrum/fmri_data'
augmentation = fMRIAugmentation(p=0.5)
dataset = fMRIDataset(root_dir, transform=augmentation)
test_dataset(dataset)

Create weighted sampler for balanced batches

In [None]:
labels = [sample[1] for sample in dataset.samples]
labels = np.array(labels)
unique_labels = np.unique(labels)
class_weights = compute_class_weight(
    class_weight='balanced',
    classes=unique_labels,
    y=labels
)

In [None]:
print("\nClass distribution:")
for label, weight in zip(unique_labels, class_weights):
    count = np.sum(labels == label)
    print(f"Class {label}: {count} samples, weight: {weight:.4f}")

In [None]:
sample_weights = [class_weights[label] for label in labels]
sampler = WeightedRandomSampler(
    weights=sample_weights,
    num_samples=len(sample_weights),
    replacement=True
)

Create data loaders

In [None]:
train_size = int(0.7 * len(dataset))
val_size = int(0.15 * len(dataset))
test_size = len(dataset) - train_size - val_size

train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(
    dataset, [train_size, val_size, test_size],
    generator=torch.Generator().manual_seed(42)
)

batch_size = 16
train_loader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=2,
    pin_memory=True
)

val_loader = DataLoader(
    val_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=2,
    pin_memory=True
)

test_loader = DataLoader(
    test_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=2,
    pin_memory=True
)

In [None]:
print(f"Number of training batches: {len(train_loader)}")
print(f"Number of validation batches: {len(val_loader)}")
print(f"Number of test batches: {len(test_loader)}")

## Model Initialization and Training

Set device

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

Initialize model, loss function and optimizer

In [None]:
def setup_training():
    # Initialize dataset
    dataset = fMRIDataset(root_dir)

    # Calculate class weights
    labels = torch.tensor([sample[1] for sample in dataset.samples])
    label_counts = torch.bincount(labels)
    weights = len(labels) / (2 * label_counts.float())
    class_weights = weights.float()  # Ensure float32

    # Split dataset
    train_size = int(0.7 * len(dataset))
    val_size = int(0.15 * len(dataset))
    test_size = len(dataset) - train_size - val_size

    train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(
        dataset, [train_size, val_size, test_size],
        generator=torch.Generator().manual_seed(42)
    )

    # Create data loaders
    batch_size = 16
    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=2,
        pin_memory=True
    )

    val_loader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=2,
        pin_memory=True
    )

    # Initialize model
    model = ImprovedVisionTransformer3D(
        img_size=64,
        patch_size=8,
        in_channels=1,
        num_classes=2,
        embed_dim=256,
        depth=6,
        num_heads=8,
        mlp_ratio=4.,
        qkv_bias=True,
        drop_rate=0.2,
        attn_drop_rate=0.1
    ).to(device)

    # Initialize loss and optimizer
    criterion = nn.CrossEntropyLoss(
        weight=class_weights.to(device)
    )

    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=2e-4,
        weight_decay=0.01,
        betas=(0.9, 0.999)
    )

    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer,
        mode='min',
        factor=0.5,
        patience=5,
        verbose=True,
        min_lr=1e-6
    )

    return model, criterion, optimizer, scheduler, train_loader, val_loader

Train the model

In [None]:
print("Setting up training components...")
model, criterion, optimizer, scheduler, train_loader, val_loader = setup_training()

In [None]:
print("\nStarting training...")
history = train_model_improved(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    criterion=criterion,
    optimizer=optimizer,
    scheduler=scheduler,
    num_epochs=50,
    device=device
)

Plot training history

In [None]:
plot_training_history(history)

## Model Evaluation

Load best model

In [None]:
model.load_state_dict(torch.load('best_model.pth'))
model.eval()

Evaluate on test set

In [None]:
test_loss = 0.0
test_correct = 0
test_total = 0
all_preds = []
all_labels = []

In [None]:
print("\nEvaluating model on test set...")
with torch.no_grad():
    for inputs, labels, _ in tqdm(test_loader):
        inputs, labels = inputs.to(device), labels.to(device)
        outputs = model(inputs)
        loss = criterion(outputs, labels)

        test_loss += loss.item()
        _, predicted = outputs.max(1)
        test_total += labels.size(0)
        test_correct += predicted.eq(labels).sum().item()

        all_preds.extend(predicted.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

Calculate and print test metrics

In [None]:
test_acc = 100. * test_correct / test_total
print(f"\nTest Accuracy: {test_acc:.2f}%")
print("\nClassification Report:")
print(classification_report(all_labels, all_preds, target_names=['Early', 'Late']))

Plot confusion matrix

In [None]:
print("\nConfusion Matrix:")
plot_confusion_matrix(all_labels, all_preds)

## Save Model and Results

Save final model and results

In [None]:
results = {
    'history': history,
    'test_accuracy': test_acc,
    'test_predictions': all_preds,
    'test_labels': all_labels
}

torch.save({
    'model_state_dict': model.state_dict(),
    'results': results
}, 'fmri_vit_model_results.pth')

print("\nModel and results saved to 'fmri_vit_model_results.pth'")