# fMRI Learning Stage Classification with Vision Transformers

This notebook demonstrates the use of Vision Transformers for classifying different stages of learning from fMRI data.

## Setup and Imports

In [1]:
import sys
from pathlib import Path

# Add project root to path
project_root = Path().absolute().parent
sys.path.append(str(project_root))

# Install package in editable mode if not already installed
!pip install -e {project_root}

Obtaining file:///C:/Users/twarn/Repositories/learnedSpectrum
  Installing build dependencies: started
  Installing build dependencies: finished with status 'done'
  Checking if build backend supports build_editable: started
  Checking if build backend supports build_editable: finished with status 'done'
  Getting requirements to build editable: started
  Getting requirements to build editable: finished with status 'done'
  Preparing editable metadata (pyproject.toml): started
  Preparing editable metadata (pyproject.toml): finished with status 'done'
Building wheels for collected packages: learnedSpectrum
  Building editable for learnedSpectrum (pyproject.toml): started
  Building editable for learnedSpectrum (pyproject.toml): finished with status 'done'
  Created wheel for learnedSpectrum: filename=learnedSpectrum-0.1.0-0.editable-py3-none-any.whl size=7651 sha256=7578107ea23b08901d895164f60f2bb73a23fa6d94259fc94821a48ddd1145ce
  Stored in directory: C:\Users\twarn\AppData\Local\Temp

In [2]:
import os
import logging
import torch
import wandb
import numpy as np
from torch.cuda.amp import GradScaler

from learnedSpectrum.config import Config, DataConfig
from learnedSpectrum.data import DatasetManager, create_dataloaders
from learnedSpectrum.train import VisionTransformerModel, train_one_epoch, evaluate
from learnedSpectrum.visualization import VisualizationManager
from learnedSpectrum.utils import (
    seed_everything,
    get_optimizer,
    get_cosine_schedule_with_warmup,
    save_checkpoint,
    load_checkpoint,
    calculate_metrics,
    verify_model_devices
)

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Set random seed for reproducibility
seed_everything(42)

  @torch.cuda.amp.autocast()


## Configuration

In [3]:
# Initialize configurations
config = Config()
data_config = DataConfig()

os.makedirs(config.CKPT_DIR, exist_ok=True)

# Set up visualization
viz = VisualizationManager(save_dir=Path(config.ROOT) / "visualizations")

# Initialize wandb
wandb.init(
    project='fmri-learning-stages',
    config=vars(config),
    dir=Path(config.ROOT) / "wandb"
)

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
logger.info(f"Using device: {device}")

wandb: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
wandb: Currently logged in as: tawarner (tawarner-usc). Use `wandb login --relogin` to force relogin


INFO:__main__:Using device: cuda


## Data Preparation

In [4]:
# Initialize dataset manager
dataset_manager = DatasetManager(config, data_config)

# Prepare datasets
logger.info("Preparing datasets...")
train_ds, val_ds, test_ds = dataset_manager.prepare_datasets()

# Create dataloaders
train_loader, val_loader, test_loader = create_dataloaders(
    train_ds, val_ds, test_ds, config  # Use correct variable names
)

logger.info(f"Dataset sizes - Train: {len(train_ds)}, Val: {len(val_ds)}, Test: {len(test_ds)}")

INFO:__main__:Preparing datasets...


loading datasets:   0%|          | 0/4 [00:00<?, ?it/s]

validating:   0%|          | 0/298 [00:00<?, ?it/s]

preprocessing:   0%|          | 0/264 [00:00<?, ?it/s]

validating pairs:   0%|          | 0/264 [00:00<?, ?it/s]

INFO:learnedSpectrum.data:total valid: 264
INFO:__main__:Dataset sizes - Train: 184, Val: 39, Test: 41


## Visualize Sample Data

In [5]:
# Get and visualize a sample
sample_volume, sample_label = train_ds[0]
viz.plot_brain_slice(
    volume=sample_volume.numpy(),
    time_idx=0,  # View first timepoint
    title=f'Sample Brain Slice (Learning Stage: {sample_label})',
    save_name='sample_slice'
)

## Model Setup

In [6]:
# Initialize model
model = VisionTransformerModel(config)
verify_model_devices(model)

# Setup training components
optimizer = get_optimizer(model, config)
scaler = torch.amp.GradScaler('cuda', enabled=config.USE_AMP)
scheduler = get_cosine_schedule_with_warmup(
    optimizer,
    num_warmup_steps=config.WARMUP_EPOCHS * len(train_loader),
    num_training_steps=config.NUM_EPOCHS * len(train_loader)
)

  pe[:, 0::2] = torch.sin(pos * omega.T)
INFO:learnedSpectrum.utils:model on: cuda:0


## Training Loop

In [7]:
# Training history
history = {'train_loss': [], 'val_loss': [], 'train_acc': [], 'val_acc': []}
best_val_loss = float('inf')

print(f"loader lens: train={len(train_loader)}, val={len(val_loader)}")

# Safe batch peek without timeout
try:
    batch = next(iter(train_loader))
    print(f"batch peek: {batch[0].shape}")
except Exception as e:
    print(f"Batch peek failed (this is ok): {str(e)}")

# Training loop
for epoch in range(config.NUM_EPOCHS):
    logger.info(f"\nEpoch {epoch + 1}/{config.NUM_EPOCHS}")
    
    # Training phase
    train_loss = train_one_epoch(model, train_loader, optimizer, scheduler, scaler, config)
    train_loss, train_metrics = evaluate(model, train_loader, config)
    
    # Validation phase
    val_loss, val_metrics = evaluate(model, val_loader, config)
    
    # Update history
    history['train_loss'].append(train_loss)
    history['val_loss'].append(val_loss)
    history['train_acc'].append(train_metrics['accuracy'])
    history['val_acc'].append(val_metrics['accuracy'])
    
    # Plot training progress
    viz.plot_training_history(history, save_name=f'training_history_epoch_{epoch}')
    
    # Log to wandb
    viz.log_to_wandb({
        'train_loss': train_loss,
        'val_loss': val_loss,
        'train_metrics': train_metrics,
        'val_metrics': val_metrics,
        'learning_rate': optimizer.param_groups[0]['lr']
    }, epoch)
    
    # Save best model
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        save_checkpoint(
            model, optimizer, epoch, val_loss, config,
            filename=f"best_model_epoch_{epoch}.pth"
        )
        
    logger.info(
        f"Epoch {epoch + 1} - "
        f"Train Loss: {train_loss:.4f}, "
        f"Train Acc: {train_metrics['accuracy']:.4f}, "
        f"Val Loss: {val_loss:.4f}, "
        f"Val Acc: {val_metrics['accuracy']:.4f}"
    )

loader lens: train=46, val=39
batch peek: torch.Size([4, 64, 64, 8, 4860])


INFO:__main__:
Epoch 1/20
INFO:learnedSpectrum.utils:checkpoint: C:\Users\twarn\Repositories\learnedSpectrum\notebooks\models\best_model_epoch_0.pth
INFO:__main__:Epoch 1 - Train Loss: 1.3291, Train Acc: 0.0217, Val Loss: 1.3320, Val Acc: 0.0513
INFO:__main__:
Epoch 2/20
INFO:__main__:Epoch 2 - Train Loss: 1.3383, Train Acc: 0.0272, Val Loss: 1.3393, Val Acc: 0.0256
INFO:__main__:
Epoch 3/20
INFO:learnedSpectrum.utils:checkpoint: C:\Users\twarn\Repositories\learnedSpectrum\notebooks\models\best_model_epoch_2.pth
INFO:__main__:Epoch 3 - Train Loss: 1.2473, Train Acc: 0.3152, Val Loss: 1.2466, Val Acc: 0.2564
INFO:__main__:
Epoch 4/20
INFO:learnedSpectrum.utils:checkpoint: C:\Users\twarn\Repositories\learnedSpectrum\notebooks\models\best_model_epoch_3.pth
INFO:__main__:Epoch 4 - Train Loss: 1.2434, Train Acc: 0.2826, Val Loss: 1.2435, Val Acc: 0.2564
INFO:__main__:
Epoch 5/20
INFO:learnedSpectrum.utils:checkpoint: C:\Users\twarn\Repositories\learnedSpectrum\notebooks\models\best_model_ep

## Final Evaluation

In [8]:
# Load best model
best_model_path = Path(config.CKPT_DIR) / "best_model.pth"
model, _, _ = load_checkpoint(model, None, best_model_path)

# Evaluate on test set
test_loss, test_metrics = evaluate(model, test_loader, config)
logger.info(f"\nTest Results - Loss: {test_loss:.4f}, Accuracy: {test_metrics['accuracy']:.4f}, AUC: {test_metrics['auc']:.4f}")

# Get predictions for visualization
all_preds = []
all_labels = []
all_probs = []

model.eval()
with torch.no_grad():
    for inputs, labels in test_loader:
        inputs = inputs.to(device)
        outputs = model(inputs)
        probs = torch.softmax(outputs, dim=1)
        preds = torch.argmax(outputs, dim=1)
        
        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.numpy())
        all_probs.extend(probs.cpu().numpy())

all_preds = np.array(all_preds)
all_labels = np.array(all_labels)
all_probs = np.array(all_probs)

INFO:__main__:
Test Results - Loss: 0.0457, Accuracy: 1.0000, AUC: nan


## Results Visualization

In [9]:
# conf matrix + roc
viz.plot_confusion_matrix(
    y_true=all_labels,
    y_pred=all_preds,
    classes=['Early', 'Middle', 'Late', 'Mastery'],
    save_name='confusion_matrix'
)

viz.plot_roc_curves(
    y_true=all_labels,
    y_scores=all_probs,
    classes=['Early', 'Middle', 'Late', 'Mastery'],
    save_name='roc_curves'
)

# attn viz w/ proper extraction
x = next(iter(test_loader))[0][:1].to(device)
with torch.no_grad():
    weights = model.get_attention_weights(x)
    # [layers, heads, seq, seq] -> [seq, seq] for viz
    avg_weights = weights.mean(dim=[0,1])  

viz.plot_attention_map(
    attention_weights=avg_weights.cpu(),
    volume_shape=config.VOLUME_SIZE,
    save_name='attention_map'
)

# wandb artifacts
wandb.log({
    'final_test_loss': test_loss,
    'final_test_accuracy': test_metrics['accuracy'], 
    'final_test_auc': test_metrics['auc'],
    'visualizations': {
        'confusion_matrix': wandb.Image(str(viz.save_dir / 'confusion_matrix.png')),
        'roc_curves': wandb.Image(str(viz.save_dir / 'roc_curves.png')),
        'attention_map': wandb.Image(str(viz.save_dir / 'attention_map.png'))
    }
})
wandb.finish()



AttributeError: 'VisionTransformerModel' object has no attribute 'get_attention_weights'

<Figure size 1000x800 with 0 Axes>