# fMRI Learning Stage Classification with Vision Transformers

This notebook demonstrates the use of Vision Transformers for classifying different stages of learning from fMRI data.

## Setup and Imports

In [1]:
import sys
from pathlib import Path

# Add project root to path
project_root = Path().absolute().parent
sys.path.append(str(project_root))

# Install package in editable mode if not already installed
!pip install -e {project_root}

Obtaining file:///C:/Users/twarn/Repositories/learnedSpectrum
  Installing build dependencies: started
  Installing build dependencies: finished with status 'done'
  Checking if build backend supports build_editable: started
  Checking if build backend supports build_editable: finished with status 'done'
  Getting requirements to build editable: started
  Getting requirements to build editable: finished with status 'done'
  Preparing editable metadata (pyproject.toml): started
  Preparing editable metadata (pyproject.toml): finished with status 'done'
Building wheels for collected packages: learnedSpectrum
  Building editable for learnedSpectrum (pyproject.toml): started
  Building editable for learnedSpectrum (pyproject.toml): finished with status 'done'
  Created wheel for learnedSpectrum: filename=learnedSpectrum-0.1.0-0.editable-py3-none-any.whl size=7719 sha256=74979da240ef5436af38c61aaba046c3950fcfb2b351075cf949b4706c4448bb
  Stored in directory: C:\Users\twarn\AppData\Local\Temp

In [23]:
import os
import math
import logging
import torch
import wandb
import numpy as np
from pathlib import Path
from torch.cuda.amp import GradScaler

from learnedSpectrum.config import Config, DataConfig
from learnedSpectrum.data import BIDSManager, NiftiLoader, DatasetManager, create_dataloaders
from learnedSpectrum.train import VisionTransformerModel, train_one_epoch, evaluate, LabelSmoothingLoss, get_scheduler, load_best_model
from learnedSpectrum.visualization import VisualizationManager
from learnedSpectrum.utils import (
    seed_everything,
    get_optimizer,
    get_cosine_schedule_with_warmup,
    save_checkpoint,
    load_checkpoint,
    calculate_metrics,
    verify_model_devices
)

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Set random seed for reproducibility
seed_everything(42)

## Configuration

In [3]:
# Initialize configurations
config = Config()
data_config = DataConfig()

os.makedirs(config.CKPT_DIR, exist_ok=True)

# Set up visualization
viz = VisualizationManager(save_dir=Path(config.ROOT) / "visualizations")

# Initialize wandb
wandb.init(
    project='fmri-learning-stages',
    config=vars(config),
    dir=Path(config.ROOT) / "wandb"
)

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
logger.info(f"Using device: {device}")

wandb: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
wandb: Currently logged in as: tawarner (tawarner-usc). Use `wandb login --relogin` to force relogin


INFO:__main__:Using device: cuda


## Data Preparation

In [4]:
# Initialize managers
bids_manager = BIDSManager(data_config)
nifti_loader = NiftiLoader(data_config, bids_manager)

# Download and preprocess datasets
for dataset_id in data_config.DATASET_URLS.keys():
    try:
        # Find dataset root
        dataset_root = bids_manager._find_dataset_root(dataset_id)
        
        # Find all valid NIFTI files
        nifti_files = []
        for path in dataset_root.rglob("*bold.nii.gz"):
            if bids_manager.validate_nifti(path):
                nifti_files.append(path)
        
        if not nifti_files:
            logger.warning(f"No valid NIFTI files found for {dataset_id}")
            continue
            
        logger.info(f"Processing {len(nifti_files)} files from {dataset_id}")
        
        # Preprocess files
        nifti_loader._parallel_preprocess(
            paths=nifti_files,
            max_workers=min(4, os.cpu_count() or 1)
        )
        
    except FileNotFoundError:
        logger.warning(f"Dataset {dataset_id} not found, skipping")
        continue
    except Exception as e:
        logger.error(f"Failed to process {dataset_id}: {str(e)}")
        continue
    finally:
        if torch.cuda.is_available():
            torch.cuda.empty_cache()

logger.info("Preprocessing complete")

# Now initialize dataset manager
dataset_manager = DatasetManager(config, data_config)

INFO:__main__:Processing 102 files from ds000002
INFO:learnedSpectrum.data:All files already processed
INFO:__main__:Processing 70 files from ds000011
INFO:learnedSpectrum.data:All files already processed
INFO:__main__:Processing 74 files from ds000017
INFO:learnedSpectrum.data:All files already processed
INFO:__main__:Processing 52 files from ds000052
INFO:learnedSpectrum.data:All files already processed
INFO:__main__:Preprocessing complete
INFO:learnedSpectrum.data:Found 298 cached samples in C:\Users\twarn\Repositories\learnedSpectrum\data\processed


## Dataset Creation

In [5]:
# Prepare datasets
logger.info("Preparing datasets...")
dataset_manager = DatasetManager(config, data_config)
train_ds, val_ds, test_ds = dataset_manager.prepare_datasets()

# Create dataloaders with memory-efficient settings
train_loader, val_loader, test_loader = create_dataloaders(
    train_ds, val_ds, test_ds, config
)

# Verify data was loaded
logger.info(f"Dataset sizes: train={len(train_ds)}, val={len(val_ds)}, test={len(test_ds)}")

INFO:__main__:Preparing datasets...
INFO:learnedSpectrum.data:Found 298 cached samples in C:\Users\twarn\Repositories\learnedSpectrum\data\processed
INFO:learnedSpectrum.data:Label distribution:
INFO:learnedSpectrum.data:Label 0: 109 samples
INFO:learnedSpectrum.data:Label 1: 92 samples
INFO:learnedSpectrum.data:Label 2: 71 samples
INFO:learnedSpectrum.data:Label 3: 26 samples
INFO:learnedSpectrum.data:Creating dataset with 208 paths and 208 labels
INFO:learnedSpectrum.data:Maximum timepoints found: 237
INFO:learnedSpectrum.data:Dataset initialized with 208 samples
INFO:learnedSpectrum.data:Max timepoints: 237
INFO:learnedSpectrum.data:Creating dataset with 45 paths and 45 labels
INFO:learnedSpectrum.data:Maximum timepoints found: 237
INFO:learnedSpectrum.data:Dataset initialized with 45 samples
INFO:learnedSpectrum.data:Max timepoints: 237
INFO:learnedSpectrum.data:Creating dataset with 45 paths and 45 labels
INFO:learnedSpectrum.data:Maximum timepoints found: 237
INFO:learnedSpectrum

## Visualize Sample Data

In [32]:
# Get and visualize a sample
sample_volume, sample_label = train_ds[90]
viz.plot_brain_slice(
    volume=sample_volume.numpy(),
    time_idx=0,  # View first timepoint
    title=f'Sample Brain Slice (Learning Stage: {sample_label})',
    save_name='sample_slice'
)

## Model Setup

In [7]:
# Initialize model
model = VisionTransformerModel(config)
verify_model_devices(model)

# Setup training components
optimizer = get_optimizer(model, config)

# Ensure optimizer params are in FP32
for param_group in optimizer.param_groups:
    for param in param_group['params']:
        if param.requires_grad:
            param.data = param.data.float()
            
scaler = torch.amp.GradScaler('cuda', enabled=config.USE_AMP)

criterion = LabelSmoothingLoss(classes=9, smoothing=0.1)  # 9 classes from your dataset

# Update the scheduler initialization
num_training_steps = config.NUM_EPOCHS * len(train_loader)
num_warmup_steps = num_training_steps // 10  # 10% warmup

scheduler = get_scheduler(
    optimizer,
    num_training_steps=num_training_steps,
    num_warmup_steps=num_warmup_steps
)

  pe[:, 0::2] = torch.sin(pos * omega.T)
INFO:learnedSpectrum.utils:model on: cuda:0


## Training Loop

In [8]:
# Training history
history = {'train_loss': [], 'val_loss': [], 'train_acc': [], 'val_acc': []}
best_val_loss = float('inf')

print(f"loader lens: train={len(train_loader)}, val={len(val_loader)}")

all_labels = []
for _, labels in train_loader:
    all_labels.extend(labels.tolist())
unique_labels = sorted(set(all_labels))
print(f"Unique labels in dataset: {unique_labels}")
print(f"Number of classes: {len(unique_labels)}")

# Safe batch peek without timeout
try:
    batch = next(iter(train_loader))
    print(f"batch peek: {batch[0].shape}")
except Exception as e:
    print(f"Batch peek failed (this is ok): {str(e)}")

# Training loop
for epoch in range(config.NUM_EPOCHS):
    logger.info(f"\nEpoch {epoch + 1}/{config.NUM_EPOCHS}")
    
    # Training phase
    train_loss = train_one_epoch(model, train_loader, optimizer, scheduler, scaler, config)
    train_loss, train_metrics = evaluate(model, train_loader, config)
    
    # Validation phase
    val_loss, val_metrics = evaluate(model, val_loader, config)
    
    # Update history
    history['train_loss'].append(train_loss)
    history['val_loss'].append(val_loss)
    history['train_acc'].append(train_metrics['accuracy'])
    history['val_acc'].append(val_metrics['accuracy'])
    
    # Plot training progress
    viz.plot_training_history(history, save_name=f'training_history_epoch_{epoch}')
    
    # Log to wandb
    viz.log_to_wandb({
        'train_loss': train_loss,
        'val_loss': val_loss,
        'train_metrics': train_metrics,
        'val_metrics': val_metrics,
        'learning_rate': optimizer.param_groups[0]['lr']
    }, epoch)
    
    # Save best model
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        save_checkpoint(
            model, optimizer, epoch, val_loss, config,
            filename=f"best_model_epoch_{epoch}.pth"
        )
        
    logger.info(
        f"Epoch {epoch + 1} - "
        f"Train Loss: {train_loss:.4f}, "
        f"Train Acc: {train_metrics['accuracy']:.4f}, "
        f"Val Loss: {val_loss:.4f}, "
        f"Val Acc: {val_metrics['accuracy']:.4f}"
    )

loader lens: train=13, val=3
Unique labels in dataset: [0, 1, 2, 3]
Number of classes: 4
batch peek: torch.Size([16, 64, 64, 30, 237])


INFO:__main__:
Epoch 1/50
INFO:learnedSpectrum.utils:checkpoint: C:\Users\twarn\Repositories\learnedSpectrum\notebooks\models\best_model_epoch_0.pth
INFO:__main__:Epoch 1 - Train Loss: 1.6950, Train Acc: 0.2356, Val Loss: 1.6341, Val Acc: 0.3556
INFO:__main__:
Epoch 2/50
INFO:learnedSpectrum.utils:checkpoint: C:\Users\twarn\Repositories\learnedSpectrum\notebooks\models\best_model_epoch_1.pth
INFO:__main__:Epoch 2 - Train Loss: 1.5688, Train Acc: 0.2019, Val Loss: 1.5254, Val Acc: 0.2667
INFO:__main__:
Epoch 3/50
INFO:learnedSpectrum.utils:checkpoint: C:\Users\twarn\Repositories\learnedSpectrum\notebooks\models\best_model_epoch_2.pth
INFO:__main__:Epoch 3 - Train Loss: 1.4699, Train Acc: 0.2788, Val Loss: 1.4692, Val Acc: 0.3111
INFO:__main__:
Epoch 4/50
INFO:learnedSpectrum.utils:checkpoint: C:\Users\twarn\Repositories\learnedSpectrum\notebooks\models\best_model_epoch_3.pth
INFO:__main__:Epoch 4 - Train Loss: 1.4321, Train Acc: 0.2885, Val Loss: 1.4283, Val Acc: 0.3333
INFO:__main__:
E

## Final Evaluation

In [15]:
import torch
import torch.nn.functional as F
import numpy as np
torch.serialization.add_safe_globals(['Config'])
from sklearn.metrics import precision_score, recall_score, f1_score, confusion_matrix
from sklearn.metrics import roc_auc_score, balanced_accuracy_score, cohen_kappa_score
import logging

logger = logging.getLogger(__name__)

# Load best model checkpoint
checkpoints = list(Path(config.CKPT_DIR).glob("best_model_epoch_*.pth"))
latest = max(checkpoints, key=lambda x: int(str(x).split("_")[-1].split(".")[0]))
model, _, _ = load_checkpoint(model, None, latest, weights_only=False)
logger.info(f"Loaded checkpoint: {latest}")

model.eval()
all_preds = []
all_labels = []
all_probs = []
all_losses = []

# Evaluation with mixed precision
with torch.no_grad(), torch.amp.autocast('cuda'):
    for inputs, labels in test_loader:
        inputs = inputs.to(device)
        labels = labels.to(device)
        
        outputs = model(inputs)
        loss = F.cross_entropy(outputs, labels, reduction='none')
        probs = torch.softmax(outputs, dim=1)
        preds = torch.argmax(outputs, dim=1)
        
        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())
        all_probs.extend(probs.cpu().numpy())
        all_losses.extend(loss.cpu().numpy())

# Convert to numpy arrays
all_preds = np.array(all_preds)
all_labels = np.array(all_labels)
all_probs = np.array(all_probs)
all_losses = np.array(all_losses)

# Calculate comprehensive metrics
class_names = ['Early', 'Middle', 'Late', 'Mastery']
metrics = {
    'overall': {
        'accuracy': (all_preds == all_labels).mean(),
        'balanced_accuracy': balanced_accuracy_score(all_labels, all_preds),
        'macro_f1': f1_score(all_labels, all_preds, average='macro'),
        'weighted_f1': f1_score(all_labels, all_preds, average='weighted'),
        'cohen_kappa': cohen_kappa_score(all_labels, all_preds),
        'mean_loss': np.mean(all_losses),
        'std_loss': np.std(all_losses)
    }
}

# Per-class metrics
for i, class_name in enumerate(class_names):
    metrics[class_name] = {
        'precision': precision_score(all_labels == i, all_preds == i),
        'recall': recall_score(all_labels == i, all_preds == i),
        'f1': f1_score(all_labels == i, all_preds == i),
        'support': np.sum(all_labels == i),
        'roc_auc': roc_auc_score(all_labels == i, all_probs[:, i])
    }

# Confusion matrix
conf_matrix = confusion_matrix(all_labels, all_preds)
norm_conf_matrix = confusion_matrix(all_labels, all_preds, normalize='true')

# Print results in paper-ready format
print("\nModel Performance Metrics")
print("========================")
print(f"Overall Accuracy: {metrics['overall']['accuracy']:.3f}")
print(f"Balanced Accuracy: {metrics['overall']['balanced_accuracy']:.3f}")
print(f"Macro F1: {metrics['overall']['macro_f1']:.3f}")
print(f"Cohen's Kappa: {metrics['overall']['cohen_kappa']:.3f}")
print(f"Mean Loss: {metrics['overall']['mean_loss']:.3f} (±{metrics['overall']['std_loss']:.3f})")

print("\nPer-class Performance")
print("====================")
for class_name in class_names:
    m = metrics[class_name]
    print(f"\n{class_name}:")
    print(f"Precision: {m['precision']:.3f}")
    print(f"Recall: {m['recall']:.3f}")
    print(f"F1: {m['f1']:.3f}")
    print(f"ROC AUC: {m['roc_auc']:.3f}")
    print(f"Support: {m['support']}")

print("\nConfusion Matrix (Normalized)")
print("============================")
print(norm_conf_matrix)

# Calculate reliability metrics
correct_probs = np.array([all_probs[i, all_labels[i]] for i in range(len(all_labels))])
confidence = np.max(all_probs, axis=1)
metrics['reliability'] = {
    'mean_confidence': np.mean(confidence),
    'overconfidence': np.mean(confidence - correct_probs),
    'expected_calibration_error': np.mean(np.abs(confidence - (all_preds == all_labels)))
}

print("\nReliability Metrics")
print("==================")
print(f"Mean Confidence: {metrics['reliability']['mean_confidence']:.3f}")
print(f"Overconfidence: {metrics['reliability']['overconfidence']:.3f}")
print(f"Expected Calibration Error: {metrics['reliability']['expected_calibration_error']:.3f}")

INFO:__main__:Loaded checkpoint: C:\Users\twarn\Repositories\learnedSpectrum\notebooks\models\best_model_epoch_43.pth



Model Performance Metrics
Overall Accuracy: 0.356
Balanced Accuracy: 0.428
Macro F1: 0.407
Cohen's Kappa: 0.093
Mean Loss: 1.082 (±0.257)

Per-class Performance

Early:
Precision: 0.286
Recall: 0.235
F1: 0.258
ROC AUC: 0.368
Support: 17

Middle:
Precision: 0.353
Recall: 0.429
F1: 0.387
ROC AUC: 0.555
Support: 14

Late:
Precision: 0.333
Recall: 0.300
F1: 0.316
ROC AUC: 0.740
Support: 10

Mastery:
Precision: 0.600
Recall: 0.750
F1: 0.667
ROC AUC: 0.945
Support: 4

Confusion Matrix (Normalized)
[[0.23529412 0.47058824 0.29411765 0.        ]
 [0.35714286 0.42857143 0.07142857 0.14285714]
 [0.5        0.2        0.3        0.        ]
 [0.         0.25       0.         0.75      ]]

Reliability Metrics
Mean Confidence: 0.437
Overconfidence: 0.088
Expected Calibration Error: 0.491


## Results Visualization

In [17]:
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from sklearn.metrics import confusion_matrix, roc_curve, auc
from pathlib import Path

def create_publication_plots(metrics, all_preds, all_labels, all_probs, class_names, save_dir='figures/'):
    # Set publication style
    plt.style.use('default')
    plt.rcParams.update({
        'figure.figsize': (10, 8),
        'figure.dpi': 300,
        'font.family': 'Arial',
        'font.size': 10,
        'axes.linewidth': 0.5,
        'axes.spines.top': False,
        'axes.spines.right': False,
        'axes.grid': True,
        'grid.alpha': 0.3,
        'lines.linewidth': 1.5
    })

    save_dir = Path(save_dir)
    save_dir.mkdir(parents=True, exist_ok=True)

    # Figure 1: Overall Performance Summary
    fig = plt.figure(figsize=(12, 8))
    gs = plt.GridSpec(2, 2, figure=fig)

    # 1a: Confusion Matrix
    ax1 = fig.add_subplot(gs[0, 0])
    cm = confusion_matrix(all_labels, all_preds, normalize='true')
    sns.heatmap(cm, annot=True, fmt='.2%', cmap='YlOrRd', 
                xticklabels=class_names, yticklabels=class_names,
                ax=ax1)
    ax1.set_title('A. Normalized Confusion Matrix')
    ax1.set_xlabel('Predicted')
    ax1.set_ylabel('True')

    # 1b: ROC Curves
    ax2 = fig.add_subplot(gs[0, 1])
    for i, cls in enumerate(class_names):
        y_true = (all_labels == i).astype(int)
        y_score = all_probs[:, i]
        fpr, tpr, _ = roc_curve(y_true, y_score)
        roc_auc = auc(fpr, tpr)
        ax2.plot(fpr, tpr, label=f'{cls} (AUC={roc_auc:.2f})')
    
    ax2.plot([0, 1], [0, 1], 'k--', alpha=0.5)
    ax2.set_title('B. ROC Curves')
    ax2.set_xlabel('False Positive Rate')
    ax2.set_ylabel('True Positive Rate')
    ax2.legend(bbox_to_anchor=(1.05, 1), loc='upper left')

    # 1c: Performance Metrics
    ax3 = fig.add_subplot(gs[1, 0])
    metric_names = ['Precision', 'Recall', 'F1']
    class_metrics = np.array([[metrics[cls][m.lower()] for m in metric_names] 
                            for cls in class_names])
    
    x = np.arange(len(class_names))
    width = 0.25
    multiplier = 0
    
    for i, metric in enumerate(metric_names):
        offset = width * multiplier
        ax3.bar(x + offset, class_metrics[:, i], width, label=metric)
        multiplier += 1

    ax3.set_xticks(x + width, class_names)
    ax3.set_title('C. Per-class Performance')
    ax3.set_ylim(0, 1)
    ax3.legend(loc='upper right')
    ax3.grid(axis='y')

    # 1d: Confidence Distribution
    ax4 = fig.add_subplot(gs[1, 1])
    for i, cls in enumerate(class_names):
        mask = all_labels == i
        if mask.any():
            sns.kdeplot(all_probs[mask, i], label=f'{cls}',
                       ax=ax4)
    
    ax4.axvline(0.5, color='k', linestyle='--', alpha=0.3)
    ax4.set_title('D. Prediction Confidence Distribution')
    ax4.set_xlabel('Model Confidence')
    ax4.set_ylabel('Density')
    ax4.legend()

    plt.suptitle('Model Performance Analysis', y=1.02)
    plt.tight_layout()
    plt.savefig(save_dir / 'performance_summary.png', 
                dpi=300, bbox_inches='tight')
    plt.close()

    # Figure 2: Calibration Analysis
    plt.figure(figsize=(8, 6))
    confidence = np.max(all_probs, axis=1)
    correct = (all_preds == all_labels)
    
    # Create reliability diagram
    n_bins = 10
    bins = np.linspace(0, 1, n_bins + 1)
    bin_accuracies = []
    bin_confidences = []
    bin_counts = []
    
    for i in range(n_bins):
        mask = (confidence >= bins[i]) & (confidence < bins[i + 1])
        if np.sum(mask) > 0:
            bin_accuracies.append(np.mean(correct[mask]))
            bin_confidences.append(np.mean(confidence[mask]))
            bin_counts.append(np.sum(mask))
    
    plt.plot([0, 1], [0, 1], 'k--', alpha=0.5, label='Perfect Calibration')
    plt.plot(bin_confidences, bin_accuracies, 'o-', label='Model Calibration')
    
    plt.title('Reliability Diagram')
    plt.xlabel('Confidence')
    plt.ylabel('Accuracy')
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.savefig(save_dir / 'calibration_analysis.png', 
                dpi=300, bbox_inches='tight')
    plt.close()

    # Print LaTeX table for paper
    print("\nLaTeX Table for Paper:")
    print("\\begin{table}[h]")
    print("\\centering")
    print("\\begin{tabular}{lcccc}")
    print("\\toprule")
    print("Class & Precision & Recall & F1 & Support \\\\")
    print("\\midrule")
    
    for cls in class_names:
        m = metrics[cls]
        print(f"{cls} & {m['precision']:.3f} & {m['recall']:.3f} & "
              f"{m['f1']:.3f} & {int(m['support'])} \\\\")
    
    print("\\midrule")
    print(f"Overall & {metrics['overall']['macro_f1']:.3f} & "
          f"{metrics['overall']['balanced_accuracy']:.3f} & "
          f"{metrics['overall']['weighted_f1']:.3f} & "
          f"{sum([metrics[cls]['support'] for cls in class_names])} \\\\")
    print("\\bottomrule")
    print("\\end{tabular}")
    print("\\caption{Model Performance Metrics}")
    print("\\label{tab:model_performance}")
    print("\\end{table}")

create_publication_plots(metrics, all_preds, all_labels, all_probs, class_names)


LaTeX Table for Paper:
\begin{table}[h]
\centering
\begin{tabular}{lcccc}
\toprule
Class & Precision & Recall & F1 & Support \\
\midrule
Early & 0.286 & 0.235 & 0.258 & 17 \\
Middle & 0.353 & 0.429 & 0.387 & 14 \\
Late & 0.333 & 0.300 & 0.316 & 10 \\
Mastery & 0.600 & 0.750 & 0.667 & 4 \\
\midrule
Overall & 0.407 & 0.428 & 0.347 & 45 \\
\bottomrule
\end{tabular}
\caption{Model Performance Metrics}
\label{tab:model_performance}
\end{table}
