# fMRI Learning Stage Classification with Vision Transformers

This notebook demonstrates the use of Vision Transformers for classifying different stages of learning from fMRI data.

## Setup and Imports

In [None]:
import sys
from pathlib import Path

# Add project root to path
project_root = Path().absolute().parent
sys.path.append(str(project_root))

# Install package in editable mode if not already installed
!pip install -e {project_root}

In [None]:
import logging
import torch
import wandb
import numpy as np
from torch.cuda.amp import GradScaler

from learnedSpectrum.config import Config, DataConfig
from learnedSpectrum.data import DatasetManager, create_dataloaders
from learnedSpectrum.scripts.train import VisionTransformerModel, train_one_epoch, evaluate
from learnedSpectrum.scripts.visualization import VisualizationManager
from learnedSpectrum.utils import (
    seed_everything,
    get_optimizer,
    get_cosine_schedule_with_warmup,
    verify_model_devices
)

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Set random seed for reproducibility
seed_everything(42)

## Configuration

In [None]:
# Initialize configurations
config = Config()
data_config = DataConfig()

# Set up visualization
viz = VisualizationManager(save_dir=Path(config.ROOT) / "visualizations")

# Initialize wandb
wandb.init(
    project='fmri-learning-stages',
    config=vars(config),
    dir=Path(config.ROOT) / "wandb"
)

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
logger.info(f"Using device: {device}")

## Data Preparation

In [None]:
# Initialize dataset manager
dataset_manager = DatasetManager(config, data_config)

# Prepare datasets
logger.info("Preparing datasets...")
train_dataset, val_dataset, test_dataset = dataset_manager.prepare_datasets()

# Create dataloaders
train_loader, val_loader, test_loader = create_dataloaders(
    train_dataset, val_dataset, test_dataset, config
)

logger.info(f"Dataset sizes - Train: {len(train_dataset)}, Val: {len(val_dataset)}, Test: {len(test_dataset)}")

## Visualize Sample Data

In [None]:
# Get and visualize a sample
sample_volume, sample_label = train_dataset[0]
viz.plot_brain_slice(
    volume=sample_volume.numpy(),
    title=f'Sample Brain Slice (Learning Stage: {sample_label})',
    save_name='sample_slice'
)

## Model Setup

In [None]:
# Initialize model
model = VisionTransformerModel(config).to(device)
verify_model_devices(model)

# Setup training components
optimizer = get_optimizer(model, config)
scaler = GradScaler(enabled=config.USE_AMP)
scheduler = get_cosine_schedule_with_warmup(
    optimizer,
    num_warmup_steps=config.WARMUP_EPOCHS * len(train_loader),
    num_training_steps=config.NUM_EPOCHS * len(train_loader)
)

## Training Loop

In [None]:
# Training history
history = {'train_loss': [], 'val_loss': [], 'train_acc': [], 'val_acc': []}
best_val_loss = float('inf')

# Training loop
for epoch in range(config.NUM_EPOCHS):
    logger.info(f"\nEpoch {epoch + 1}/{config.NUM_EPOCHS}")
    
    # Training phase
    train_loss = train_one_epoch(model, train_loader, optimizer, scheduler, scaler, config)
    train_loss, train_metrics = evaluate(model, train_loader, config)
    
    # Validation phase
    val_loss, val_metrics = evaluate(model, val_loader, config)
    
    # Update history
    history['train_loss'].append(train_loss)
    history['val_loss'].append(val_loss)
    history['train_acc'].append(train_metrics['accuracy'])
    history['val_acc'].append(val_metrics['accuracy'])
    
    # Plot training progress
    viz.plot_training_history(history, save_name=f'training_history_epoch_{epoch}')
    
    # Log to wandb
    viz.log_to_wandb({
        'train_loss': train_loss,
        'val_loss': val_loss,
        'train_metrics': train_metrics,
        'val_metrics': val_metrics,
        'learning_rate': optimizer.param_groups[0]['lr']
    }, epoch)
    
    # Save best model
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        save_checkpoint(
            model, optimizer, epoch, val_loss, config,
            filename=f"best_model_epoch_{epoch}.pth"
        )
        
    logger.info(
        f"Epoch {epoch + 1} - "
        f"Train Loss: {train_loss:.4f}, "
        f"Train Acc: {train_metrics['accuracy']:.4f}, "
        f"Val Loss: {val_loss:.4f}, "
        f"Val Acc: {val_metrics['accuracy']:.4f}"
    )

## Final Evaluation

In [None]:
# Load best model
best_model_path = Path(config.CKPT_DIR) / "best_model.pth"
model, _, _ = load_checkpoint(model, None, best_model_path)

# Evaluate on test set
test_loss, test_metrics = evaluate(model, test_loader, config)
logger.info(f"\nTest Results - Loss: {test_loss:.4f}, Accuracy: {test_metrics['accuracy']:.4f}, AUC: {test_metrics['auc']:.4f}")

# Get predictions for visualization
all_preds = []
all_labels = []
all_probs = []

model.eval()
with torch.no_grad():
    for inputs, labels in test_loader:
        inputs = inputs.to(device)
        outputs = model(inputs)
        probs = torch.softmax(outputs, dim=1)
        preds = torch.argmax(outputs, dim=1)
        
        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.numpy())
        all_probs.extend(probs.cpu().numpy())

all_preds = np.array(all_preds)
all_labels = np.array(all_labels)
all_probs = np.array(all_probs)

## Results Visualization

In [None]:
# Plot confusion matrix
viz.plot_confusion_matrix(
    y_true=all_labels,
    y_pred=all_preds,
    classes=['Early', 'Middle', 'Late', 'Mastery'],
    save_name='confusion_matrix'
)

# Plot ROC curves
viz.plot_roc_curves(
    y_true=all_labels,
    y_scores=all_probs,
    classes=['Early', 'Middle', 'Late', 'Mastery'],
    save_name='roc_curves'
)

# Visualize attention maps for a sample
sample_input = next(iter(test_loader))[0][:1].to(device)
with torch.no_grad():
    attention_weights = model.vit.get_attention_weights(sample_input)

viz.plot_attention_map(
    attention_weights=attention_weights[0].cpu(),  # First head's attention
    volume_shape=config.VOLUME_SIZE,
    save_name='attention_map'
)

# Save final results to wandb
wandb.log({
    'final_test_loss': test_loss,
    'final_test_accuracy': test_metrics['accuracy'],
    'final_test_auc': test_metrics['auc'],
    'confusion_matrix': wandb.Image(str(viz.save_dir / 'confusion_matrix.png')),
    'roc_curves': wandb.Image(str(viz.save_dir / 'roc_curves.png')),
    'attention_map': wandb.Image(str(viz.save_dir / 'attention_map.png'))
})

wandb.finish()