# fMRI Learning Stage Classification using Vision Transformers

This notebook implements a Vision Transformer model to classify different stages of learning from fMRI data.

### Abstract

This research project presents an innovative approach to understanding the temporal dynamics of human learning through the analysis of functional Magnetic Resonance Imaging (fMRI) data using Vision Transformers (ViT). By leveraging advanced deep learning architectures, we aim to classify different stages of learning (early, middle, and late) based on neural activation patterns during a classification learning task.

The study utilizes the "[Classification Learning](https://openfmri.org/dataset/ds000002/)", "[Classification learning and tone-counting](https://openfmri.org/dataset/ds000011/)", "[Classification learning and stop-signal](https://openfmri.org/dataset/ds000017/)", and "[Classification learning and reversal](https://openfmri.org/dataset/ds000052/)" datasets from OpenfMRI, which captures brain activity during a weather prediction task under both probabilistic and deterministic conditions. This dataset provides a unique opportunity to examine how the brain's activity patterns evolve as subjects progress through different learning phases, potentially revealing distinct neural signatures associated with each stage of skill acquisition.

### Introduction

Understanding how the human brain adapts and reorganizes during learning remains a fundamental challenge in cognitive neuroscience. Traditional approaches to analyzing learning-related neural changes often rely on univariate analyses or conventional machine learning methods. However, these approaches may miss complex spatial and temporal patterns that characterize different learning stages.

Our methodology introduces several key innovations:

1. **Vision Transformer Architecture**: By adapting ViT models to process 3D fMRI data, we leverage the transformer's ability to capture long-range dependencies and complex spatial relationships within neural activation patterns. This approach treats brain volumes as sequences of patches, allowing the model to learn hierarchical representations of neural activity patterns.

2. **Temporal Learning Classification**: The project aims to automatically identify and classify distinct phases of learning (early, middle, and late) based on whole-brain activation patterns. This classification could reveal how neural representations evolve throughout the learning process.

3. **Multi-condition Analysis**: By incorporating both probabilistic and deterministic learning conditions, we can investigate how different types of learning rules affect neural activation patterns and their temporal evolution.

### Expected Outcomes and Impact

This research has the potential to:

- Identify neural markers of learning progression
- Reveal differences in brain activation patterns between probabilistic and deterministic learning
- Provide insights into individual differences in learning trajectories
- Demonstrate the effectiveness of transformer-based architectures in neuroimaging analysis

By successfully classifying learning stages from fMRI data, this work could contribute to our understanding of skill acquisition and learning optimization, with potential applications in educational neuroscience and cognitive rehabilitation.

## Setup and Dependencies

#### Install required packages

In [None]:
!pip install einops nibabel seaborn tqdm monai matplotlib nilearn plotly

#### Import libraries

In [None]:
import os
import re
import numpy as np
import nibabel as nib
from pathlib import Path
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report
from einops import rearrange, repeat
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
import urllib.request
import zipfile
import tarfile
from pathlib import Path
from google.colab import drive

#### Set random seeds

In [None]:
def set_seeds(seed=42):
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True

set_seeds()

#### Device configuration

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

#### General configuration

In [None]:
CONFIG = {
    'patch_size': 8,
    'hidden_dim': 512,      # Reduced from 768
    'num_heads': 8,         # Reduced from 12
    'num_layers': 6,        # Reduced from 12
    'mlp_dim': 1024,        # Reduced from 3072
    'dropout': 0.2,         # Increased from 0.1
    'learning_rate': 5e-4,  # Increased from 1e-4
    'weight_decay': 0.05,   # Increased from 0.01
    'batch_size': 16,       # Increased from 8
    'epochs': 50,
    'warmup_steps': 50
}

In [None]:
print(f"Using device: {device}")
print("\nConfiguration:")
for k, v in CONFIG.items():
    print(f"{k}: {v}")

## Data Loading and Preprocessing

#### Download and extract dataset

In [None]:
class DownloadProgressBar(tqdm):
    def update_to(self, b=1, bsize=1, tsize=None):
        if tsize is not None:
            self.total = tsize
        self.update(b * bsize - self.n)

In [None]:
def download_url(url, output_path):
    with DownloadProgressBar(unit='B', unit_scale=True,
                           miniters=1, desc=url.split('/')[-1]) as t:
        urllib.request.urlretrieve(url, filename=output_path, reporthook=t.update_to)

In [None]:
def check_dataset_extracted(extract_path, dataset_id):
    """Check if dataset is properly extracted by looking for expected structure"""
    extract_path = Path(extract_path)

    # Check common indicators of proper extraction
    indicators = {
        'ds000002': ['sub-01', 'sub-02', 'dataset_description.json'],
        'ds000011': ['sub-01', 'sub-02', 'dataset_description.json'],
        'ds000017': ['sub-01', 'sub-02', 'dataset_description.json'],
        'ds000052': ['sub-01', 'sub-02', 'dataset_description.json']
    }

    # First check if the path exists
    if not extract_path.exists():
        return False

    # Look for dataset in subfolders if not found in main directory
    possible_roots = [extract_path] + list(extract_path.glob('*'))

    for root in possible_roots:
        if not root.is_dir():
            continue

        # Check for expected files/folders
        found_indicators = []
        for indicator in indicators[dataset_id]:
            if any(Path(p).name == indicator for p in root.glob('*')):
                found_indicators.append(indicator)

        # If we found at least 2 indicators, consider it properly extracted
        if len(found_indicators) >= 2:
            print(f"Found valid dataset structure in: {root}")
            return True

    return False

In [None]:
def setup_datasets(base_path):
    """Download and extract multiple fMRI datasets"""
    # Create base directory if it doesn't exist
    base_path = Path(base_path)
    base_path.mkdir(parents=True, exist_ok=True)

    # Dataset information
    datasets = {
        'ds000002': {
            'url': 'https://s3.amazonaws.com/openneuro/ds000002/ds000002_R2.0.5/compressed/ds000002_R2.0.5_raw.zip',
            'filename': 'ds000002_R2.0.5_raw.zip',
            'extract_dir': 'ds000002'
        },
        'ds000011': {
            'url': 'https://s3.amazonaws.com/openneuro/ds000011/ds000011_R2.0.1/compressed/ds000011_R2.0.1_raw.zip',
            'filename': 'ds000011_R2.0.1_raw.zip',
            'extract_dir': 'ds000011'
        },
        'ds000017': {
            'url': 'https://s3.amazonaws.com/openneuro/ds000017/ds000017_R2.0.1/compressed/ds000017_R2.0.1.zip',
            'filename': 'ds000017_R2.0.1.zip',
            'extract_dir': 'ds000017'
        },
        'ds000052': {
            'url': 'https://s3.amazonaws.com/openneuro/ds000052/ds000052_R2.0.0/compressed/ds052_R2.0.0_01-14.tgz',
            'filename': 'ds052_R2.0.0_01-14.tgz',
            'extract_dir': 'ds000052'
        }
    }

    # Process each dataset
    dataset_paths = {}
    for dataset_id, info in datasets.items():
        print(f"\nProcessing {dataset_id}...")

        # Setup paths
        zip_path = base_path / info['filename']
        extract_path = base_path / 'fmri_data' / info['extract_dir']
        dataset_paths[dataset_id] = extract_path

        # Download if needed
        if not zip_path.exists():
            print(f"Downloading {dataset_id}...")
            try:
                download_url(info['url'], zip_path)
                print("Download complete!")
            except Exception as e:
                print(f"Error downloading {dataset_id}: {str(e)}")
                continue
        else:
            print(f"Found existing download for {dataset_id}")

        # Extract if needed
        if not extract_path.exists():
            print(f"Extracting {dataset_id}...")
            try:
                extract_path.parent.mkdir(parents=True, exist_ok=True)

                if zip_path.suffix == '.tgz':
                    with tarfile.open(zip_path, 'r:gz') as tar_ref:
                        tar_ref.extractall(extract_path)
                else:
                    with zipfile.ZipFile(zip_path, 'r') as zip_ref:
                        zip_ref.extractall(extract_path)

                print("Extraction complete!")
            except Exception as e:
                print(f"Error extracting {dataset_id}: {str(e)}")
                continue
        else:
            print(f"Found existing extracted data for {dataset_id}")

    return dataset_paths

#### Mount Google Drive

In [None]:
drive.mount('/content/drive')

#### Setup base path

In [None]:
base_path = Path('/content/drive/MyDrive/learnedSpectrum')

####  Download and extract all datasets

In [None]:
dataset_paths = setup_datasets(base_path)

In [None]:
print("\nDataset locations:")
for dataset_id, path in dataset_paths.items():
    print(f"{dataset_id}: {path}")

#### FMRI Volume Loader

In [None]:
def load_fmri_volume(file_path):
    """Load and preprocess a single fMRI volume"""
    # Load nifti file
    img = nib.load(file_path)
    data = img.get_fdata()

    # Handle 4D data (take middle 20% of timepoints)
    if len(data.shape) == 4:
        mid = data.shape[-1] // 2
        window = data.shape[-1] // 10
        data = data[..., mid-window:mid+window]
        data = np.mean(data, axis=-1)

    # Normalize
    data = (data - np.percentile(data, 5)) / (np.percentile(data, 95) - np.percentile(data, 5) + 1e-8)

    # Basic brain extraction
    mask = data > np.percentile(data, 20)
    data = data * mask

    return data

In [None]:
print("FMRI loader test:")
test_path = list(Path(dataset_paths['ds000002']).rglob('*bold.nii.gz'))[0]
test_data = load_fmri_volume(test_path)
print(f"Loaded volume shape: {test_data.shape}")

#### Volume Patchification

In [None]:
def create_patches(volume, patch_size):
    """Convert 3D volume into patches with proper batch dimension"""
    # Ensure volume dimensions are divisible by patch_size
    pad_h = (patch_size - volume.shape[0] % patch_size) % patch_size
    pad_w = (patch_size - volume.shape[1] % patch_size) % patch_size
    pad_d = (patch_size - volume.shape[2] % patch_size) % patch_size

    # Pad volume
    volume = np.pad(volume,
                   ((0, pad_h), (0, pad_w), (0, pad_d)),
                   mode='constant')

    # Create patches using einops
    patches = rearrange(volume,
                       '(h p1) (w p2) (d p3) -> (h w d) (p1 p2 p3)',
                       p1=patch_size, p2=patch_size, p3=patch_size)

    return patches

In [None]:
test_patches = create_patches(test_data, CONFIG['patch_size'])
print("\nPatch creation test:")
print(f"Input shape: {test_data.shape}")
print(f"Output patches shape: {test_patches.shape}")

#### Dataset Creation

In [None]:
def get_task_files(dataset_path):
    """Get all task-related fMRI files with their stages"""
    files = []
    for bold_file in Path(dataset_path).rglob('*bold.nii.gz'):
        # Skip if not a task file
        if 'task-' not in str(bold_file):
            continue

        # Get run number
        run_match = re.search(r'run-(\d+)', str(bold_file))
        if not run_match:
            continue
        run_num = int(run_match.group(1))

        # Assign early/late stage
        if run_num == 1:
            stage = 0  # early
        elif run_num > 1:
            stage = 1  # late
        else:
            continue

        files.append((bold_file, stage))

    return files

In [None]:
test_files = get_task_files(dataset_paths['ds000002'])
print("\nFile collection test:")
print(f"Found {len(test_files)} task files")
print("Sample entries:")
for f, s in test_files[:3]:
    print(f"File: {f.name}, Stage: {s}")

## Vision Transformer Components

#### Position Embedding

In [None]:
def create_position_embeddings(n_patches, hidden_dim):
    """Create learnable position embeddings"""
    pos_embeddings = nn.Parameter(torch.randn(1, n_patches + 1, hidden_dim))
    return pos_embeddings

#### CLS Token

In [None]:
def create_cls_token(hidden_dim):
    """Create learnable classification token"""
    cls_token = nn.Parameter(torch.randn(1, 1, hidden_dim))
    return cls_token

In [None]:
test_n_patches = test_patches.shape[0]  # from previous section
pos_embed = create_position_embeddings(test_n_patches, CONFIG['hidden_dim'])
cls_token = create_cls_token(CONFIG['hidden_dim'])
print(f"Position embedding shape: {pos_embed.shape}")
print(f"CLS token shape: {cls_token.shape}")

#### Multi-Head Attention Block

In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self, hidden_dim, num_heads, dropout=0.1):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = hidden_dim // num_heads
        self.scale = self.head_dim ** -0.5

        self.qkv = nn.Linear(hidden_dim, hidden_dim * 3)
        self.attn_drop = nn.Dropout(dropout)
        self.proj = nn.Linear(hidden_dim, hidden_dim)
        self.proj_drop = nn.Dropout(dropout)

    def forward(self, x):
        B, N, C = x.shape

        # Generate Q, K, V
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]

        # Compute attention scores
        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        # Apply attention to V
        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)

        return x

In [None]:
attention = MultiHeadAttention(CONFIG['hidden_dim'], CONFIG['num_heads']).to(device)
test_input = torch.randn(2, test_n_patches + 1, CONFIG['hidden_dim']).to(device)
test_output = attention(test_input)
print(f"\nAttention test:")
print(f"Input shape: {test_input.shape}")
print(f"Output shape: {test_output.shape}")

#### MLP Block

In [None]:
class MLPBlock(nn.Module):
    def __init__(self, hidden_dim, mlp_dim, dropout=0.1):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(hidden_dim, mlp_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(mlp_dim, hidden_dim),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        return self.net(x)

In [None]:
mlp = MLPBlock(CONFIG['hidden_dim'], CONFIG['mlp_dim']).to(device)
test_output = mlp(test_input)
print(f"\nMLP test:")
print(f"Output shape: {test_output.shape}")

#### Transformer Encoder Block

In [None]:
class TransformerBlock(nn.Module):
    def __init__(self, hidden_dim, num_heads, mlp_dim, dropout=0.1):
        super().__init__()
        self.norm1 = nn.LayerNorm(hidden_dim)
        self.attn = MultiHeadAttention(hidden_dim, num_heads, dropout)
        self.norm2 = nn.LayerNorm(hidden_dim)
        self.mlp = MLPBlock(hidden_dim, mlp_dim, dropout)

    def forward(self, x):
        x = x + self.attn(self.norm1(x))
        x = x + self.mlp(self.norm2(x))
        return x

In [None]:
transformer = TransformerBlock(
    CONFIG['hidden_dim'],
    CONFIG['num_heads'],
    CONFIG['mlp_dim']
).to(device)
test_output = transformer(test_input)
print(f"\nTransformer block test:")
print(f"Output shape: {test_output.shape}")

#### Patch Embedding

In [None]:
class PatchEmbedding(nn.Module):
    def __init__(self, patch_dim, hidden_dim):
        super().__init__()
        self.projection = nn.Linear(patch_dim, hidden_dim)

    def forward(self, patches):
        # Ensure input is 3D [B, N, D]
        if len(patches.shape) == 2:
            patches = patches.unsqueeze(0)

        return self.projection(patches)

In [None]:
patch_dim = CONFIG['patch_size'] ** 3  # cubic patches
patch_embed = PatchEmbedding(patch_dim, CONFIG['hidden_dim']).to(device)
test_patches_tensor = torch.FloatTensor(test_patches).to(device)
embedded_patches = patch_embed(test_patches_tensor)
print(f"\nPatch embedding test:")
print(f"Input shape: {test_patches_tensor.shape}")
print(f"Embedded shape: {embedded_patches.shape}")

In [None]:
def calculate_n_patches(volume_shape, patch_size):
    """Calculate number of patches based on input volume shape"""
    h, w, d = volume_shape

    # Add padding if needed
    h = h + (patch_size - h % patch_size) % patch_size
    w = w + (patch_size - w % patch_size) % patch_size
    d = d + (patch_size - d % patch_size) % patch_size

    # Calculate number of patches
    n_patches = (h // patch_size) * (w // patch_size) * (d // patch_size)
    return n_patches

#### Complete Vision Transformer

In [None]:
class VisionTransformer(nn.Module):
    def __init__(self, config, volume_shape=(64, 64, 32)):
        super().__init__()
        self.config = config
        patch_dim = config['patch_size'] ** 3

        # Calculate number of patches
        self.n_patches = calculate_n_patches(volume_shape, config['patch_size'])
        print(f"Number of patches: {self.n_patches}")

        # Layers
        self.patch_embed = PatchEmbedding(patch_dim, config['hidden_dim'])
        self.cls_token = create_cls_token(config['hidden_dim'])
        self.pos_embed = create_position_embeddings(
            self.n_patches,
            config['hidden_dim']
        )
        self.dropout = nn.Dropout(config['dropout'])

        self.transformer_blocks = nn.ModuleList([
            TransformerBlock(
                config['hidden_dim'],
                config['num_heads'],
                config['mlp_dim'],
                config['dropout']
            )
            for _ in range(config['num_layers'])
        ])

        self.norm = nn.LayerNorm(config['hidden_dim'])
        self.head = nn.Linear(config['hidden_dim'], 2)

    def forward(self, patches):
        # Ensure correct input shape
        if len(patches.shape) == 2:  # [N, D]
            patches = patches.unsqueeze(0)  # Add batch dimension [1, N, D]
        elif len(patches.shape) == 3 and patches.shape[1] == 1:  # [B, 1, N*D]
            patches = patches.squeeze(1)  # Remove singleton dimension

        B = patches.shape[0]  # Batch size

        # Embed patches
        x = self.patch_embed(patches)  # [B, N, hidden_dim]

        # Add CLS token
        cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b=B)  # [B, 1, hidden_dim]
        x = torch.cat([cls_tokens, x], dim=1)  # [B, N+1, hidden_dim]

        # Add position embeddings
        x = x + self.pos_embed
        x = self.dropout(x)

        # Apply transformer blocks
        for block in self.transformer_blocks:
            x = block(x)

        # Classification
        x = self.norm(x)
        x = x[:, 0]  # Take CLS token
        x = self.head(x)

        return x

In [None]:
test_patches = create_patches(test_data, CONFIG['patch_size'])
test_patches_tensor = torch.FloatTensor(test_patches).to(device)
print(f"Input shape (single): {test_patches_tensor.shape}")

In [None]:
model = VisionTransformer(CONFIG, volume_shape=(64, 64, 32)).to(device)
test_output = model(test_patches_tensor)
print(f"Output shape (single): {test_output.shape}")

In [None]:
batch_size = 4
test_batch = test_patches_tensor.unsqueeze(0).repeat(batch_size, 1, 1)
print(f"Input shape (batch): {test_batch.shape}")
test_output = model(test_batch)
print(f"Output shape (batch): {test_output.shape}")

## Dataset and Training Pipeline

#### FMRI Dataset Class

In [None]:
def pad_volume_to_size(volume, target_size=(64, 64, 32)):
    """Pad volume to target size"""
    pad_h = max(0, target_size[0] - volume.shape[0])
    pad_w = max(0, target_size[1] - volume.shape[1])
    pad_d = max(0, target_size[2] - volume.shape[2])

    padded = np.pad(volume,
                    ((0, pad_h), (0, pad_w), (0, pad_d)),
                    mode='constant')

    # If larger than target size, crop
    padded = padded[:target_size[0], :target_size[1], :target_size[2]]

    return padded

In [None]:
class FMRIDataset(Dataset):
    def __init__(self, file_paths, labels, patch_size=8, augment=False, target_size=(64, 64, 32)):
        self.file_paths = file_paths
        self.labels = labels
        self.patch_size = patch_size
        self.augment = augment
        self.target_size = target_size

    def __len__(self):
        return len(self.file_paths)

    def get_class_weights(labels):
        counts = np.bincount(labels)
        total = len(labels)
        weights = total / (len(counts) * counts)
        weights = torch.FloatTensor(weights)
        print(f"Class weights: {weights}")
        return weights

    # Update data augmentation
    def apply_augmentation(volume):
        """More aggressive augmentation"""
        # Random flip
        if np.random.random() > 0.5:
            volume = np.flip(volume, axis=0)
        if np.random.random() > 0.5:
            volume = np.flip(volume, axis=1)
        if np.random.random() > 0.5:
            volume = np.flip(volume, axis=2)

        # Random rotation with interpolation
        angle = np.random.uniform(-15, 15)
        volume = scipy.ndimage.rotate(volume, angle, axes=(0,1), reshape=False)

        # Random scaling
        scale = np.random.uniform(0.8, 1.2)
        volume = volume * scale

        # Add random noise
        noise = np.random.normal(0, 0.05, volume.shape)
        volume = volume + noise

        # Random intensity shift
        shift = np.random.uniform(-0.1, 0.1)
        volume = volume + shift

        # Random contrast
        contrast = np.random.uniform(0.8, 1.2)
        mean = volume.mean()
        volume = (volume - mean) * contrast + mean

        return volume

    def __getitem__(self, idx):
        # Load and preprocess volume
        volume = load_fmri_volume(self.file_paths[idx])
        volume = pad_volume_to_size(volume, self.target_size)

        if self.augment:
            volume = self.apply_augmentation(volume)

        # Create patches
        patches = create_patches(volume, self.patch_size)
        patches = torch.FloatTensor(patches)

        return patches, torch.tensor(self.labels[idx])

In [None]:
def collate_fn(batch):
    """Custom collate function to handle batching properly"""
    patches = [item[0] for item in batch]
    labels = [item[1] for item in batch]

    # Stack patches and labels
    patches = torch.stack(patches)
    labels = torch.stack(labels)

    return patches, labels

#### Data Collection

In [None]:
def collect_dataset_files():
    all_files = []
    all_labels = []

    for dataset_id, path in dataset_paths.items():
        print(f"\nProcessing {dataset_id}...")
        files = get_task_files(path)
        print(f"Found {len(files)} files")

        for file_path, label in files:
            all_files.append(file_path)
            all_labels.append(label)

    return all_files, all_labels

In [None]:
files, labels = collect_dataset_files()
print(f"\nTotal samples: {len(files)}")
print(f"Class distribution: {np.bincount(labels)}")

#### Train-Val Split

In [None]:
def create_data_splits(files, labels, config):
    # Create train/val split
    train_files, val_files, train_labels, val_labels = train_test_split(
        files, labels,
        test_size=0.2,
        stratify=labels,
        random_state=42
    )

    # Target size for all volumes
    target_size = (64, 64, 32)  # This ensures consistent size

    # Create datasets
    train_dataset = FMRIDataset(
        train_files,
        train_labels,
        patch_size=config['patch_size'],
        augment=True,
        target_size=target_size
    )

    val_dataset = FMRIDataset(
        val_files,
        val_labels,
        patch_size=config['patch_size'],
        augment=False,
        target_size=target_size
    )

    return train_dataset, val_dataset

In [None]:
train_dataset, val_dataset = create_data_splits(files, labels, CONFIG)
print(f"\nTrain samples: {len(train_dataset)}")
print(f"Validation samples: {len(val_dataset)}")

#### Training Components

In [None]:
def create_dataloaders(train_dataset, val_dataset, batch_size):
    # Use fewer workers and persistent workers
    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=0,  # Run in main process
        pin_memory=True,
        collate_fn=collate_fn,
        persistent_workers=False
    )

    val_loader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=0,  # Run in main process
        pin_memory=True,
        collate_fn=collate_fn,
        persistent_workers=False
    )

    return train_loader, val_loader

In [None]:
def create_optimizer(model, config):
    # Split parameters into two groups for different learning rates
    no_decay = ['bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [
        {
            'params': [p for n, p in model.named_parameters()
                      if not any(nd in n for nd in no_decay)],
            'weight_decay': config['weight_decay']
        },
        {
            'params': [p for n, p in model.named_parameters()
                      if any(nd in n for nd in no_decay)],
            'weight_decay': 0.0
        }
    ]

    optimizer = torch.optim.AdamW(
        optimizer_grouped_parameters,
        lr=config['learning_rate'],
        betas=(0.9, 0.999),
        eps=1e-8
    )

    return optimizer

In [None]:
def create_scheduler(optimizer, config, num_training_steps):
    return get_cosine_schedule_with_warmup(
        optimizer,
        num_warmup_steps=config['warmup_steps'],
        num_training_steps=num_training_steps,
        num_cycles=0.5
    )

In [None]:
def get_class_weights(labels):
    counts = np.bincount(labels)
    total = len(labels)
    weights = total / (len(counts) * counts)
    weights = torch.FloatTensor(weights)
    print(f"Class weights: {weights}")
    return weights

In [None]:
class_weights = get_class_weights(labels)
criterion = nn.CrossEntropyLoss(
    weight=class_weights.to(device),
    label_smoothing=0.1
)

Create data loaders

In [None]:
train_loader, val_loader = create_dataloaders(
    train_dataset,
    val_dataset,
    CONFIG['batch_size']
)

#### Training Functions

In [None]:
def train_epoch(model, train_loader, criterion, optimizer, scheduler):
    model.train()
    total_loss = 0
    correct = 0
    total = 0

    pbar = tqdm(train_loader, desc='Training')
    for patches, labels in pbar:
        patches = patches.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()

        # Use mixed precision training
        with torch.cuda.amp.autocast():
            outputs = model(patches)
            loss = criterion(outputs, labels)

        # Gradient clipping
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

        loss.backward()
        optimizer.step()
        scheduler.step()

        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()
        total_loss += loss.item()

        pbar.set_postfix({
            'loss': f'{total_loss/(pbar.n+1):.4f}',
            'acc': f'{100.*correct/total:.2f}%',
            'lr': f'{scheduler.get_last_lr()[0]:.6f}'
        })

    return total_loss / len(train_loader), correct / total

In [None]:
@torch.no_grad()
def validate(model, val_loader, criterion):
    model.eval()
    total_loss = 0
    correct = 0
    total = 0

    for patches, labels in tqdm(val_loader, desc='Validation'):
        patches = patches.to(device)
        labels = labels.to(device)

        outputs = model(patches)
        loss = criterion(outputs, labels)

        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()
        total_loss += loss.item()

    return total_loss / len(val_loader), correct / total

## Training Loop and Visualization

#### Training History Tracker

In [None]:
class TrainingHistory:
    def __init__(self):
        self.train_loss = []
        self.train_acc = []
        self.val_loss = []
        self.val_acc = []
        self.best_acc = 0
        self.best_epoch = 0

    def update(self, train_metrics, val_metrics, epoch):
        train_loss, train_acc = train_metrics
        val_loss, val_acc = val_metrics

        self.train_loss.append(train_loss)
        self.train_acc.append(train_acc)
        self.val_loss.append(val_loss)
        self.val_acc.append(val_acc)

        if val_acc > self.best_acc:
            self.best_acc = val_acc
            self.best_epoch = epoch
            return True
        return False

history = TrainingHistory()

#### Visualization Functions

In [None]:
def plot_training_history(history):
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))

    # Plot loss
    ax1.plot(history.train_loss, label='Train')
    ax1.plot(history.val_loss, label='Validation')
    ax1.set_title('Loss')
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Loss')
    ax1.legend()
    ax1.grid(True)

    # Plot accuracy
    ax2.plot(history.train_acc, label='Train')
    ax2.plot(history.val_acc, label='Validation')
    ax2.set_title('Accuracy')
    ax2.set_xlabel('Epoch')
    ax2.set_ylabel('Accuracy')
    ax2.legend()
    ax2.grid(True)

    plt.tight_layout()
    plt.show()

In [None]:
def plot_confusion_matrix(true_labels, pred_labels):
    from sklearn.metrics import confusion_matrix
    import seaborn as sns

    cm = confusion_matrix(true_labels, pred_labels)
    plt.figure(figsize=(8, 6))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                xticklabels=['Early', 'Late'],
                yticklabels=['Early', 'Late'])
    plt.title('Confusion Matrix')
    plt.xlabel('Predicted')
    plt.ylabel('True')
    plt.show()

#### Training Loop

In [None]:
def train_model(model, train_loader, val_loader, criterion, optimizer, scheduler, epochs):
    history = TrainingHistory()
    best_val_acc = 0
    patience = 10
    patience_counter = 0

    try:
        for epoch in range(epochs):
            print(f"\nEpoch {epoch+1}/{epochs}")

            # Training phase
            model.train()
            train_loss = 0
            train_correct = 0
            train_total = 0

            # Use tqdm with leave=True to keep progress bar
            train_pbar = tqdm(train_loader, desc='Training', leave=True)
            for patches, labels in train_pbar:
                # Move to device
                patches = patches.to(device)
                labels = labels.to(device)

                # Forward pass
                optimizer.zero_grad()
                outputs = model(patches)
                loss = criterion(outputs, labels)

                # Backward pass
                loss.backward()
                optimizer.step()

                # Update metrics
                _, predicted = outputs.max(1)
                train_total += labels.size(0)
                train_correct += predicted.eq(labels).sum().item()
                train_loss += loss.item()

                # Update progress bar
                train_pbar.set_postfix({
                    'loss': f'{train_loss/train_total:.4f}',
                    'acc': f'{100.*train_correct/train_total:.2f}%'
                })

            # Calculate epoch metrics
            train_loss = train_loss / len(train_loader)
            train_acc = train_correct / train_total

            # Validation phase
            model.eval()
            val_loss = 0
            val_correct = 0
            val_total = 0

            with torch.no_grad():
                val_pbar = tqdm(val_loader, desc='Validation', leave=True)
                for patches, labels in val_pbar:
                    patches = patches.to(device)
                    labels = labels.to(device)

                    outputs = model(patches)
                    loss = criterion(outputs, labels)

                    _, predicted = outputs.max(1)
                    val_total += labels.size(0)
                    val_correct += predicted.eq(labels).sum().item()
                    val_loss += loss.item()

                    val_pbar.set_postfix({
                        'loss': f'{val_loss/val_total:.4f}',
                        'acc': f'{100.*val_correct/val_total:.2f}%'
                    })

            # Calculate validation metrics
            val_loss = val_loss / len(val_loader)
            val_acc = val_correct / val_total

            # Update learning rate
            if scheduler is not None:
                scheduler.step()

            # Update history
            is_best = history.update((train_loss, train_acc), (val_loss, val_acc), epoch)

            # Print epoch summary
            print(f"\nEpoch Summary:")
            print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}")
            print(f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}")
            print(f"Learning Rate: {optimizer.param_groups[0]['lr']:.6f}")

            # Save best model
            if is_best:
                print("New best model saved!")
                torch.save({
                    'epoch': epoch,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'val_acc': val_acc,
                }, 'best_model.pth')
                patience_counter = 0
            else:
                patience_counter += 1

            # Early stopping check
            if patience_counter >= patience:
                print("\nEarly stopping triggered!")
                break

    except KeyboardInterrupt:
        print("\nTraining interrupted by user!")

    return history

#### Model Evaluation

In [None]:
@torch.no_grad()
def evaluate_model(model, dataloader):
    model.eval()
    all_preds = []
    all_labels = []

    for patches, labels in tqdm(dataloader, desc='Evaluating'):
        patches = patches.to(device)
        outputs = model(patches)
        _, preds = outputs.max(1)

        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.numpy())

    return np.array(all_preds), np.array(all_labels)

#### Final Debug

In [None]:
print("\nTesting data pipeline:")
test_batch = next(iter(train_loader))
patches, labels = test_batch
print(f"Batch patches shape: {patches.shape}")
print(f"Batch labels shape: {labels.shape}")

In [None]:
print(f"\nPatches statistics:")
print(f"Min value: {patches.min().item():.4f}")
print(f"Max value: {patches.max().item():.4f}")
print(f"Mean value: {patches.mean().item():.4f}")
print(f"Std value: {patches.std().item():.4f}")

In [None]:
unique_labels, counts = labels.unique(return_counts=True)
print("\nLabel distribution in batch:")
for label, count in zip(unique_labels.tolist(), counts.tolist()):
    print(f"Label {label}: {count}")

In [None]:
print("\nTesting forward pass:")
model = VisionTransformer(CONFIG, volume_shape=(64, 64, 32)).to(device)
optimizer = create_optimizer(model, CONFIG)
num_training_steps = len(train_loader) * CONFIG['epochs']
scheduler = create_scheduler(optimizer, CONFIG, num_training_steps)
test_output = model(patches.to(device))
print(f"Model output shape: {test_output.shape}")

#### Run Training

In [None]:
print("\nStarting training...")
history = train_model(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    criterion=criterion,
    optimizer=optimizer,
    scheduler=scheduler,
    epochs=CONFIG['epochs']
)

Plot training history

In [None]:
plot_training_history(history)

Load best model and evaluate

In [None]:
print("\nEvaluating best model...")
checkpoint = torch.load('best_model.pth')
model.load_state_dict(checkpoint['model_state_dict'])
pred_labels, true_labels = evaluate_model(model, val_loader)

Plot confusion matrix

In [None]:
plot_confusion_matrix(true_labels, pred_labels)

Print final metrics

In [None]:
print("\nClassification Report:")
print(classification_report(true_labels, pred_labels,
                          target_names=['Early Stage', 'Late Stage']))

## Save Results

In [None]:
def save_results(history, model_path='best_model.pth'):
    results = {
        'config': CONFIG,
        'history': {
            'train_loss': history.train_loss,
            'train_acc': history.train_acc,
            'val_loss': history.val_loss,
            'val_acc': history.val_acc,
            'best_acc': history.best_acc,
            'best_epoch': history.best_epoch
        },
        'model_path': model_path
    }

    np.save('training_results.npy', results)
    print("\nResults saved!")

In [None]:
save_results(history)