# NeurLLM: Vision Transformer for Neurophysiological Data

This notebook demonstrates how to fine-tune a Vision Transformer (ViT) encoder-decoder architecture on neurophysiological data (EEG) converted to spectrograms.

## Overview
1. Convert EEG time series to spectrograms/PSD images
2. Train a ViT encoder-decoder model for classification and reconstruction
3. Analyze the model's performance and latent space representations
4. Explore applications in driving behavior classification

## 1. Environment Setup

In [None]:
# Install required packages
!pip install -q torch torchvision torchaudio timm einops scipy scikit-learn matplotlib seaborn ipywidgets mne

In [None]:
# Import necessary libraries
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import scipy.signal
from tqdm.notebook import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

import timm
from einops import rearrange
import mne
from mne.time_frequency import psd_welch
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, confusion_matrix, classification_report

# For visualizations
import warnings
warnings.filterwarnings('ignore')
plt.style.use('ggplot')
sns.set_style('whitegrid')

# Check if CUDA is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

## 2. Sample Data Loading

# For this demo, we'll create synthetic EEG data to mimic the MPDB dataset structure.

In [None]:
def generate_synthetic_eeg(n_samples=500, n_channels=59, n_timepoints=1000, n_classes=5, seed=42):
    """Generate synthetic EEG data for demonstration purposes"""
    np.random.seed(seed)

    # Create empty arrays for data and labels
    X = np.zeros((n_samples, n_channels, n_timepoints))
    y = np.zeros(n_samples, dtype=int)

    # Generate samples for each class
    samples_per_class = n_samples // n_classes

    for class_idx in range(n_classes):
        # Calculate indices for this class
        start_idx = class_idx * samples_per_class
        end_idx = (class_idx + 1) * samples_per_class

        # Set labels
        y[start_idx:end_idx] = class_idx

        # Generate base signal with class-specific characteristics
        for i in range(start_idx, end_idx):
            # Generate signal with class-specific frequency characteristics
            for ch in range(n_channels):
                # Base signal (common to all classes)
                t = np.arange(n_timepoints)
                base_signal = np.sin(2 * np.pi * 10 * t / n_timepoints)  # 10 Hz base oscillation

                # Add class-specific features
                if class_idx == 0:  # Smooth driving: stronger alpha (8-12 Hz)
                    alpha = np.sin(2 * np.pi * 10 * t / n_timepoints) * 2.0
                    signal = base_signal + alpha
                elif class_idx == 1:  # Acceleration: increased beta (13-30 Hz)
                    beta = np.sin(2 * np.pi * 20 * t / n_timepoints) * 2.5
                    signal = base_signal + beta
                elif class_idx == 2:  # Deceleration: increased theta (4-7 Hz)
                    theta = np.sin(2 * np.pi * 6 * t / n_timepoints) * 3.0
                    signal = base_signal + theta
                elif class_idx == 3:  # Lane change: mixed alpha-beta
                    mixed = (np.sin(2 * np.pi * 10 * t / n_timepoints) +
                             np.sin(2 * np.pi * 20 * t / n_timepoints)) * 1.5
                    signal = base_signal + mixed
                else:  # Turning: increased delta (1-3 Hz) and beta
                    delta_beta = (np.sin(2 * np.pi * 2 * t / n_timepoints) * 2.0 +
                                  np.sin(2 * np.pi * 20 * t / n_timepoints) * 1.0)
                    signal = base_signal + delta_beta

                # Add channel-specific noise and amplitude variation
                channel_noise = np.random.normal(0, 0.5, n_timepoints)
                channel_amplitude = 0.8 + 0.4 * np.random.random()

                # Add sample-specific variation
                sample_variation = np.random.normal(0, 0.2, n_timepoints)

                # Combine all components
                final_signal = channel_amplitude * signal + channel_noise + sample_variation

                # Store in data array
                X[i, ch, :] = final_signal

    return X, y

# Generate synthetic data
print("Generating synthetic EEG data...")
X, y = generate_synthetic_eeg(n_samples=500, n_channels=5, n_timepoints=1000, n_classes=5)
print(f"Data shape: {X.shape}, Labels shape: {y.shape}")
print(f"Class distribution: {np.bincount(y)}")

### Visualize Raw EEG Data

In [None]:
def plot_eeg_samples(eeg_data, labels, class_names=None, n_samples=5, n_channels=5):
    """Plot example EEG time series for each class"""
    if class_names is None:
        class_names = [f"Class {i}" for i in range(len(np.unique(labels)))]

    fig, axes = plt.subplots(len(class_names), n_channels, figsize=(15, 3*len(class_names)))

    # Time axis
    time = np.arange(eeg_data.shape[2]) / 1000  # seconds

    for i, class_idx in enumerate(range(len(class_names))):
        # Get indices of samples for this class
        class_indices = np.where(labels == class_idx)[0]

        # Randomly select one sample
        sample_idx = np.random.choice(class_indices)

        for j in range(n_channels):
            ax = axes[i, j]
            ax.plot(time, eeg_data[sample_idx, j, :])

            if j == 0:
                ax.set_ylabel(class_names[class_idx])

            if i == 0:
                ax.set_title(f"Channel {j+1}")

            if i == len(class_names) - 1:
                ax.set_xlabel("Time (s)")

    plt.tight_layout()
    plt.show()

# Define class names
class_names = ["Smooth Driving", "Acceleration", "Deceleration", "Lane Change", "Turning"]

# Plot EEG samples
plot_eeg_samples(X, y, class_names=class_names)

## 3. Transform EEG to Spectrograms

In [None]:
def eeg_to_spectrogram(eeg_data, fs=1000, nperseg=256, noverlap=128):
    """Convert EEG time series to spectrograms

    Args:
        eeg_data: EEG data shape (n_samples, n_channels, n_timepoints)
        fs: Sampling frequency
        nperseg: Length of each segment
        noverlap: Overlap between segments

    Returns:
        spectrogram: Spectrogram data (n_samples, n_channels, n_freqs, n_times)
    """
    from scipy import signal

    n_samples, n_channels, n_timepoints = eeg_data.shape
    spectrograms = []

    for i in tqdm(range(n_samples), desc="Converting to spectrograms"):
        sample_specs = []
        for j in range(n_channels):
            f, t, Sxx = signal.spectrogram(eeg_data[i, j, :], fs=fs, nperseg=nperseg, noverlap=noverlap)
            # Log scale for better visualization
            Sxx = np.log1p(Sxx)
            sample_specs.append(Sxx)

        # Stack channels
        spectrograms.append(np.stack(sample_specs))

    # Return shape: (n_samples, n_channels, n_freqs, n_times)
    return np.stack(spectrograms), f, t

# Convert to spectrograms
X_spec, freqs, times = eeg_to_spectrogram(X)
print(f"Spectrogram shape: {X_spec.shape}")
print(f"Frequency range: {freqs.min():.1f} - {freqs.max():.1f} Hz")

### Visualize Spectrograms

In [None]:
def plot_spectrograms(spec_data, freqs, times, labels, class_names=None):
    """Plot example spectrograms for each class"""
    if class_names is None:
        class_names = [f"Class {i}" for i in range(len(np.unique(labels)))]

    fig, axes = plt.subplots(len(class_names), 3, figsize=(15, 4*len(class_names)))

    for i, class_idx in enumerate(range(len(class_names))):
        # Get indices of samples for this class
        class_indices = np.where(labels == class_idx)[0]

        # Randomly select one sample
        sample_idx = np.random.choice(class_indices)

        for j in range(3):  # Plot first 3 channels
            ax = axes[i, j]

            # Plot spectrogram
            im = ax.pcolormesh(times, freqs, spec_data[sample_idx, j], shading='gouraud', cmap='viridis')

            if j == 0:
                ax.set_ylabel(f"{class_names[class_idx]}\nFrequency (Hz)")
            else:
                ax.set_ylabel("Frequency (Hz)")

            if i == 0:
                ax.set_title(f"Channel {j+1}")

            if i == len(class_names) - 1:
                ax.set_xlabel("Time (s)")

            plt.colorbar(im, ax=ax, label='Log Power')

    plt.tight_layout()
    plt.show()

# Plot spectrograms
plot_spectrograms(X_spec, freqs, times, y, class_names=class_names)

## 4. Prepare Images for ViT

In [None]:
def prepare_images_for_vit(X_spec, target_size=(224, 224)):
    """Prepare spectrograms as images for ViT"""
    try:
        import cv2
    except ImportError:
        !pip install -q opencv-python
        import cv2

    n_samples, n_channels, n_freqs, n_times = X_spec.shape
    images = []

    for i in tqdm(range(n_samples), desc="Preparing images"):
        # Take first 3 channels or duplicate if fewer
        if n_channels >= 3:
            channels_to_use = X_spec[i, :3]
        else:
            # Duplicate channels to make 3
            channels_to_use = np.tile(X_spec[i, :1], (3, 1, 1))[:3]

        # Normalize each channel
        normalized_channels = []
        for j in range(3):
            channel = channels_to_use[j]
            # Normalize to 0-1
            norm_channel = (channel - channel.min()) / (channel.max() - channel.min() + 1e-8)
            # Resize to target size
            norm_channel = cv2.resize(norm_channel, target_size)
            normalized_channels.append(norm_channel)

        # Stack as RGB image
        rgb_image = np.stack(normalized_channels, axis=2)
        images.append(rgb_image)

    return np.array(images)

# Convert spectrograms to images
X_images = prepare_images_for_vit(X_spec)
print(f"Image data shape: {X_images.shape}")

# Show example images
plt.figure(figsize=(15, 10))
for i in range(5):
    plt.subplot(1, 5, i+1)
    plt.imshow(X_images[i*100])
    plt.title(class_names[y[i*100]])
    plt.axis('off')
plt.tight_layout()
plt.show()

## 5. Define ViT Encoder-Decoder Architecture

In [None]:
class NeurLLM_EncoderDecoder(nn.Module):
    def __init__(self, num_classes=5, img_size=224, patch_size=16, embed_dim=768,
                 depth_encoder=12, depth_decoder=8, num_heads=12):
        super().__init__()

        # Vision Transformer Encoder
        self.encoder = timm.create_model(
            'vit_base_patch16_224',
            pretrained=True,
            img_size=img_size,
            patch_size=patch_size,
            embed_dim=embed_dim,
            depth=depth_encoder,
            num_heads=num_heads
        )

        # Remove the classification head from the encoder
        self.encoder.head = nn.Identity()

        # Classification head (separate from the main encoder-decoder path)
        self.classifier = nn.Linear(embed_dim, num_classes)

        # Vision Transformer Decoder
        self.decoder = nn.ModuleList([
            # Transformer blocks for decoding
            nn.ModuleList([
                nn.LayerNorm(embed_dim),
                nn.MultiheadAttention(embed_dim, num_heads),
                nn.LayerNorm(embed_dim),
                nn.Linear(embed_dim, embed_dim * 4),
                nn.GELU(),
                nn.Linear(embed_dim * 4, embed_dim)
            ]) for _ in range(depth_decoder)
        ])

        # Final layer to reconstruct image from patches
        self.patch_dim = 3 * patch_size * patch_size  # 3 channels * patch dimensions
        self.decoder_pred = nn.Linear(embed_dim, self.patch_dim)

        # Patch unembedding (to reconstruct the image)
        self.patch_height = img_size // patch_size
        self.patch_width = img_size // patch_size

        # Class mappings
        self.idx_to_class = {
            0: "smooth driving",
            1: "acceleration",
            2: "deceleration",
            3: "lane change",
            4: "turning"
        }

    def encode(self, x):
        """Encode image into latent representation"""
        # Forward through the encoder
        x = self.encoder.forward_features(x)
        return x

    def decode(self, x):
        """Decode latent representation back to image space"""
        # Apply decoder transformer blocks
        for norm1, attn, norm2, fc1, gelu, fc2 in self.decoder:
            # Self-attention block
            x_norm = norm1(x)
            x_attn = attn(x_norm, x_norm, x_norm)[0]
            x = x + x_attn

            # MLP block
            x_norm = norm2(x)
            x_mlp = fc1(x_norm)
            x_mlp = gelu(x_mlp)
            x_mlp = fc2(x_mlp)
            x = x + x_mlp

        # Predict patches
        x = self.decoder_pred(x)

        # Reshape to image: [B, N, patch_dim] -> [B, h, w, c]
        B, N, D = x.shape

        # Remove CLS token
        x = x[:, 1:, :]

        # Reshape to patches
        x = x.reshape(B, self.patch_height, self.patch_width, 3, self.patch_dim // 3)

        # Reshape to original image dimensions
        patch_size = int(np.sqrt(self.patch_dim // 3))
        x = x.reshape(B, self.patch_height * patch_size, self.patch_width * patch_size, 3)

        # Change to channels-first format
        x = x.permute(0, 3, 1, 2)

        return x

    def forward(self, x, task='both'):
        """
        Forward pass through the encoder-decoder

        Args:
            x: Input image
            task: 'encode', 'decode', 'classify', or 'both'

        Returns:
            Dictionary containing requested outputs
        """
        result = {}

        if task in ['encode', 'both', 'classify']:
            # Encode the input
            latent = self.encode(x)
            result['latent'] = latent

            # Classification from CLS token
            if task in ['classify', 'both']:
                cls_token = latent[:, 0]
                logits = self.classifier(cls_token)
                result['logits'] = logits

        if task in ['decode', 'both']:
            # If we have latent from encoding step, use it
            if 'latent' in result:
                latent = result['latent']
            else:
                # Otherwise, encode first
                latent = self.encode(x)

            # Decode the latent representation
            reconstruction = self.decode(latent)
            result['reconstruction'] = reconstruction

        return result

# Initialize the model
print("Initializing ViT Encoder-Decoder model...")
model = NeurLLM_EncoderDecoder(num_classes=len(class_names)).to(device)
print("Model initialized!")

## 6. Dataset and DataLoader Setup

In [None]:
class EEGSpectrogramDataset(Dataset):
    def __init__(self, images, labels=None, transform=None, target_transform=None):
        self.images = images
        self.labels = labels
        self.transform = transform
        self.target_transform = target_transform

        # Setup class names mapping
        self.classes = ["smooth driving", "acceleration", "deceleration", "lane change", "turning"]

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = self.images[idx]

        # Create both input and target images (same for pure reconstruction)
        input_image = image.copy()
        target_image = image.copy()

        if self.transform:
            input_image = self.transform(input_image)

        if self.target_transform:
            target_image = self.target_transform(target_image)

        if self.labels is not None:
            label = self.labels[idx]
            return input_image, target_image, label
        else:
            return input_image, target_image

# Split data into train and test sets
X_train, X_test, y_train, y_test = train_test_split(
    X_images, y, test_size=0.2, random_state=42, stratify=y
)

# Further split training data into train and validation
X_train, X_val, y_train, y_val = train_test_split(
    X_train, y_train, test_size=0.15, random_state=42, stratify=y_train
)

print(f"Train set: {X_train.shape}, Validation set: {X_val.shape}, Test set: {X_test.shape}")

# Transforms
input_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

target_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Create datasets
train_dataset = EEGSpectrogramDataset(
    X_train, y_train,
    transform=input_transform,
    target_transform=target_transform
)

val_dataset = EEGSpectrogramDataset(
    X_val, y_val,
    transform=input_transform,
    target_transform=target_transform
)

test_dataset = EEGSpectrogramDataset(
    X_test, y_test,
    transform=input_transform,
    target_transform=target_transform
)

# Create dataloaders
batch_size = 16
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2)

print(f"Number of batches - Train: {len(train_loader)}, Val: {len(val_loader)}, Test: {len(test_loader)}")

## 7. Training the Model

In [None]:
def train_encoder_decoder(model, train_loader, val_loader, num_epochs=20, device=None):
    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model.to(device)

    # Define loss functions
    classification_criterion = nn.CrossEntropyLoss()
    reconstruction_criterion = nn.MSELoss()

    # Set up optimizer
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=0.01)

    # Learning rate scheduler
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)

    # Training metrics
    best_val_loss = float('inf')
    history = {
        'train_class_loss': [], 'train_recon_loss': [], 'train_total_loss': [],
        'val_class_loss': [], 'val_recon_loss': [], 'val_total_loss': [],
        'val_accuracy': []
    }

    for epoch in range(num_epochs):
        # Training phase
        model.train()
        train_class_loss = 0.0
        train_recon_loss = 0.0
        train_total_loss = 0.0

        for input_imgs, target_imgs, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs} - Training"):
            input_imgs = input_imgs.to(device)
            target_imgs = target_imgs.to(device)
            labels = labels.to(device)

            optimizer.zero_grad()

            # Forward pass
            outputs = model(input_imgs, task='both')

            # Calculate losses
            class_loss = classification_criterion(outputs['logits'], labels)
            recon_loss = reconstruction_criterion(outputs['reconstruction'], target_imgs)

            # Weighted sum of losses (can be adjusted)
            total_loss = class_loss + 0.5 * recon_loss

            # Backward pass
            total_loss.backward()
            optimizer.step()

            # Update metrics
            train_class_loss += class_loss.item() * input_imgs.size(0)
            train_recon_loss += recon_loss.item() * input_imgs.size(0)
            train_total_loss += total_loss.item() * input_imgs.size(0)

        # Normalize losses
        train_class_loss /= len(train_loader.dataset)
        train_recon_loss /= len(train_loader.dataset)
        train_total_loss /= len(train_loader.dataset)

        # Validation phase
        model.eval()
        val_class_loss = 0.0
        val_recon_loss = 0.0
        val_total_loss = 0.0
        val_correct = 0

        with torch.no_grad():
            for input_imgs, target_imgs, labels in tqdm(val_loader, desc=f"Epoch {epoch+1}/{num_epochs} - Validation"):
                input_imgs = input_imgs.to(device)
                target_imgs = target_imgs.to(device)
                labels = labels.to(device)

                # Forward pass
                outputs = model(input_imgs, task='both')

                # Calculate losses
                class_loss = classification_criterion(outputs['logits'], labels)
                recon_loss = reconstruction_criterion(outputs['reconstruction'], target_imgs)
                total_loss = class_loss + 0.5 * recon_loss

                # Update metrics
                val_class_loss += class_loss.item() * input_imgs.size(0)
                val_recon_loss += recon_loss.item() * input_imgs.size(0)
                val_total_loss += total_loss.item() * input_imgs.size(0)

                # Calculate accuracy
                _, predicted = torch.max(outputs['logits'], 1)
                val_correct += (predicted == labels).sum().item()

        # Normalize validation metrics
        val_class_loss /= len(val_loader.dataset)
        val_recon_loss /= len(val_loader.dataset)
        val_total_loss /= len(val_loader.dataset)
        val_accuracy = val_correct / len(val_loader.dataset)

        # Update learning rate
        scheduler.step()

        # Save history
        history['train_class_loss'].append(train_class_loss)
        history['train_recon_loss'].append(train_recon_loss)
        history['train_total_loss'].append(train_total_loss)
        history['val_class_loss'].append(val_class_loss)
        history['val_recon_loss'].append(val_recon_loss)
        history['val_total_loss'].append(val_total_loss)
        history['val_accuracy'].append(val_accuracy)

        # Print progress
        print(f'Epoch {epoch+1}/{num_epochs}:')
        print(f'  Train - Class Loss: {train_class_loss:.4f}, Recon Loss: {train_recon_loss:.4f}')
        print(f'  Val   - Class Loss: {val_class_loss:.4f}, Recon Loss: {val_recon_loss:.4f}, Acc: {val_accuracy:.4f}')

        # Save best model
        if val_total_loss < best_val_loss:
            best_val_loss = val_total_loss
            torch.save(model.state_dict(), 'best_neurllm_model.pth')
            print('  Saved new best model!')

    return model, history

# Train the model with fewer epochs for demonstration
num_epochs = 5  # Use more epochs (20+) for better results
print("Training model...")
model, history = train_encoder_decoder(model, train_loader, val_loader, num_epochs=num_epochs, device=device)

# Plot training history
plt.figure(figsize=(15, 5))

plt.subplot(1, 3, 1)
plt.plot(history['train_class_loss'], label='Train')
plt.plot(history['val_class_loss'], label='Validation')
plt.title('Classification Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()

plt.subplot(1, 3, 2)
plt.plot(history['train_recon_loss'], label='Train')
plt.plot(history['val_recon_loss'], label='Validation')
plt.title('Reconstruction Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()

plt.subplot(1, 3, 3)
plt.plot(history['val_accuracy'])
plt.title('Validation Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')

plt.tight_layout()
plt.show()

## 8. Model Evaluation

In [None]:
def evaluate_model(model, data_loader, device=None):
    """Evaluate model performance on test set"""
    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model.eval()

    all_preds = []
    all_labels = []
    total_recon_loss = 0.0
    reconstruction_criterion = nn.MSELoss()

    with torch.no_grad():
        for input_imgs, target_imgs, labels in tqdm(data_loader, desc="Evaluating"):
            input_imgs = input_imgs.to(device)
            target_imgs = target_imgs.to(device)
            labels = labels.to(device)

            # Forward pass
            outputs = model(input_imgs, task='both')

            # Get predictions
            _, preds = torch.max(outputs['logits'], 1)

            # Calculate reconstruction loss
            recon_loss = reconstruction_criterion(outputs['reconstruction'], target_imgs)
            total_recon_loss += recon_loss.item() * input_imgs.size(0)

            # Store predictions and labels
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    # Calculate metrics
    avg_recon_loss = total_recon_loss / len(data_loader.dataset)
    accuracy = accuracy_score(all_labels, all_preds)
    conf_mat = confusion_matrix(all_labels, all_preds)
    class_report = classification_report(all_labels, all_preds, target_names=class_names)

    return {
        'accuracy': accuracy,
        'reconstruction_loss': avg_recon_loss,
        'confusion_matrix': conf_mat,
        'classification_report': class_report,
        'predictions': all_preds,
        'true_labels': all_labels
    }

# Evaluate model on test set
print("Evaluating model on test set...")
results = evaluate_model(model, test_loader, device)

# Print results
print(f"\nTest Accuracy: {results['accuracy']:.4f}")
print(f"Test Reconstruction Loss: {results['reconstruction_loss']:.4f}")
print("\nClassification Report:")
print(results['classification_report'])

# Plot confusion matrix
plt.figure(figsize=(10, 8))
conf_mat = results['confusion_matrix']
sns.heatmap(conf_mat, annot=True, fmt='d', cmap='Blues',
            xticklabels=class_names,
            yticklabels=class_names)
plt.xlabel('Predicted')
plt.ylabel('True')
plt.title('Confusion Matrix')
plt.show()

## 9. Visualize Reconstructions

In [None]:
def visualize_reconstructions(model, data_loader, num_samples=5, device=None):
    """Visualize example reconstructions from the model"""
    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model.eval()

    # Get samples
    all_inputs = []
    all_targets = []
    all_outputs = []
    all_labels = []

    with torch.no_grad():
        for input_imgs, target_imgs, labels in data_loader:
            if len(all_inputs) >= num_samples:
                break

            input_imgs = input_imgs.to(device)
            target_imgs = target_imgs.to(device)

            # Forward pass
            outputs = model(input_imgs, task='both')
            reconstructions = outputs['reconstruction']

            # Get predictions
            _, preds = torch.max(outputs['logits'], 1)

            # Store data
            all_inputs.extend(input_imgs.cpu())
            all_targets.extend(target_imgs.cpu())
            all_outputs.extend(reconstructions.cpu())
            all_labels.extend(labels.cpu())

    # Select samples
    indices = np.arange(len(all_inputs))
    np.random.shuffle(indices)
    indices = indices[:num_samples]

    # Create visualization
    fig, axes = plt.subplots(num_samples, 3, figsize=(15, 4*num_samples))

    for i, idx in enumerate(indices):
        # Get data
        input_img = all_inputs[idx]
        target_img = all_targets[idx]
        output_img = all_outputs[idx]
        label = all_labels[idx]

        # Denormalize
        def denormalize(img):
            img = img.clone()
            img = img.permute(1, 2, 0).numpy()
            img = img * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406])
            img = np.clip(img, 0, 1)
            return img

        input_img = denormalize(input_img)
        target_img = denormalize(target_img)
        output_img = denormalize(output_img)

        # Display images
        axes[i, 0].imshow(input_img)
        axes[i, 0].set_title(f"Input: {class_names[label]}")
        axes[i, 0].axis('off')

        axes[i, 1].imshow(output_img)
        axes[i, 1].set_title("Reconstruction")
        axes[i, 1].axis('off')

        # Display reconstruction error
        error = np.abs(target_img - output_img)
        error_img = np.mean(error, axis=2)  # Average across channels
        axes[i, 2].imshow(error_img, cmap='hot')
        axes[i, 2].set_title(f"Error (MSE: {np.mean(error**2):.4f})")
        axes[i, 2].axis('off')

    plt.tight_layout()
    plt.show()

# Visualize reconstructions
visualize_reconstructions(model, test_loader, num_samples=5, device=device)

## 10. Latent Space Visualization

In [None]:
def visualize_latent_space(model, data_loader, num_samples=200):
    """Visualize the latent space of the model using PCA or t-SNE"""
    from sklearn.decomposition import PCA
    from sklearn.manifold import TSNE

    device = next(model.parameters()).device
    model.eval()

    # Collect latent representations and labels
    latents = []
    labels = []

    with torch.no_grad():
        for inputs, _, label in data_loader:
            if len(latents) * inputs.size(0) >= num_samples:
                break

            inputs = inputs.to(device)

            # Get latent representation
            output = model(inputs, task='encode')

            # Use CLS token as latent vector
            cls_token = output['latent'][:, 0].cpu().numpy()
            latents.append(cls_token)
            labels.extend(label.numpy())

    # Concatenate all latent vectors
    latents = np.vstack(latents)[:num_samples]
    labels = np.array(labels)[:num_samples]

    # Use PCA to reduce dimensions
    pca = PCA(n_components=2)
    latents_2d_pca = pca.fit_transform(latents)

    # Use t-SNE for a non-linear projection
    tsne = TSNE(n_components=2, random_state=42)
    latents_2d_tsne = tsne.fit_transform(latents)

    # Create scatter plots
    fig, axes = plt.subplots(1, 2, figsize=(18, 7))

    # PCA plot
    scatter1 = axes[0].scatter(latents_2d_pca[:, 0], latents_2d_pca[:, 1],
                               c=labels, cmap='viridis', alpha=0.7, s=70)
    axes[0].set_title('Latent Space PCA Projection')
    axes[0].set_xlabel(f'PC1 (Explained Variance: {pca.explained_variance_ratio_[0]:.2f})')
    axes[0].set_ylabel(f'PC2 (Explained Variance: {pca.explained_variance_ratio_[1]:.2f})')
    axes[0].grid(alpha=0.3)

    # t-SNE plot
    scatter2 = axes[1].scatter(latents_2d_tsne[:, 0], latents_2d_tsne[:, 1],
                               c=labels, cmap='viridis', alpha=0.7, s=70)
    axes[1].set_title('Latent Space t-SNE Projection')
    axes[1].set_xlabel('t-SNE Dimension 1')
    axes[1].set_ylabel('t-SNE Dimension 2')
    axes[1].grid(alpha=0.3)

    # Add colorbar
    legend1 = plt.legend(scatter1.legend_elements()[0],
                         [class_names[i] for i in range(len(class_names))],
                         title="Driving Behaviors", loc="upper right")
    axes[0].add_artist(legend1)

    plt.tight_layout()
    plt.show()

    return latents, labels, pca

# Visualize latent space
latents, labels, pca = visualize_latent_space(model, test_loader, num_samples=200)

## 11. Interactive Demo

In [None]:
def get_attention_maps(model, input_tensor):
    """Extract attention maps from ViT for visualization"""
    device = next(model.parameters()).device
    attention_maps = []

    def hook_fn(module, input, output):
        attention_maps.append(output.detach().cpu())

    # Get the last attention layer
    try:
        attn_layer = model.encoder.blocks[-1].attn.attn_drop
        hook = attn_layer.register_forward_hook(hook_fn)

        # Forward pass
        with torch.no_grad():
            _ = model(input_tensor.unsqueeze(0).to(device))

        # Remove hook
        hook.remove()

        # Process attention
        if attention_maps:
            # Average attention across heads
            attention = attention_maps[0].mean(dim=1)[0]  # [batch, heads, seq, seq] -> [seq, seq]

            # Extract attention from CLS token to patches
            cls_attention = attention[0, 1:]  # Skip the CLS token

            # Reshape to spatial grid
            grid_size = int(np.sqrt(cls_attention.shape[0]))
            attention_grid = cls_attention.reshape(grid_size, grid_size).numpy()

            # Upscale to image size
            from scipy.ndimage import zoom
            attention_grid = zoom(attention_grid, 224 // grid_size)

            return attention_grid
        else:
            # Fallback if hook failed
            return np.ones((224, 224)) * 0.5
    except:
        # Fallback for any errors
        return np.ones((224, 224)) * 0.5

from ipywidgets import widgets
from IPython.display import display

def create_interactive_demo(model, test_dataset):
    """Create an interactive demo for exploring the model"""
    model.eval()
    device = next(model.parameters()).device

    # Create widgets
    sample_slider = widgets.IntSlider(
        value=0, min=0, max=len(test_dataset)-1,
        description='Sample:', continuous_update=False
    )

    mode_select = widgets.RadioButtons(
        options=['Classification', 'Reconstruction', 'Attention Map'],
        description='Mode:',
        value='Classification'
    )

    output_area = widgets.Output()

    def on_change(change):
        with output_area:
            output_area.clear_output()

            # Get sample
            sample_idx = sample_slider.value
            input_img, target_img, label = test_dataset[sample_idx]

            # Add batch dimension and move to device
            input_tensor = input_img.unsqueeze(0).to(device)

            # Process based on selected mode
            mode = mode_select.value

            if mode == 'Classification':
                # Run classification
                with torch.no_grad():
                    outputs = model(input_tensor, task='classify')
                    logits = outputs['logits']
                    probs = F.softmax(logits, dim=1)[0]
                    pred_idx = probs.argmax().item()
                    pred_class = model.idx_to_class[pred_idx]
                    true_class = test_dataset.classes[label]

                # Visualize input and class probabilities
                plt.figure(figsize=(12, 5))

                # Show input image
                plt.subplot(1, 2, 1)
                img_np = input_img.permute(1, 2, 0).cpu().numpy()
                img_np = (img_np * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406])).clip(0, 1)
                plt.imshow(img_np)
                plt.title(f'EEG Spectrogram\nTrue Class: {true_class}')
                plt.axis('off')

                # Show class probabilities
                plt.subplot(1, 2, 2)
                classes = list(model.idx_to_class.values())
                probs_np = probs.cpu().numpy()
                colors = ['green' if i == label else 'red' if i == pred_idx else 'blue'
                          for i in range(len(classes))]
                plt.bar(classes, probs_np, color=colors)
                plt.ylabel('Probability')
                plt.title(f'Predicted: {pred_class} ({probs_np[pred_idx]:.4f})')
                plt.xticks(rotation=30, ha='right')
                plt.tight_layout()
                plt.show()

            elif mode == 'Reconstruction':
                # Run reconstruction
                with torch.no_grad():
                    outputs = model(input_tensor, task='both')
                    reconstruction = outputs['reconstruction']

                # Visualize input and reconstruction
                plt.figure(figsize=(15, 5))

                # Show input image
                plt.subplot(1, 3, 1)
                img_np = input_img.permute(1, 2, 0).cpu().numpy()
                img_np = (img_np * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406])).clip(0, 1)
                plt.imshow(img_np)
                plt.title('Input EEG Spectrogram')
                plt.axis('off')

                # Show reconstruction
                plt.subplot(1, 3, 2)
                recon_np = reconstruction.squeeze(0).permute(1, 2, 0).cpu().numpy()
                recon_np = (recon_np * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406])).clip(0, 1)
                plt.imshow(recon_np)
                plt.title('Reconstructed Spectrogram')
                plt.axis('off')

                # Show error
                plt.subplot(1, 3, 3)
                error = np.abs(img_np - recon_np)
                error_img = np.mean(error, axis=2)  # Average across channels
                plt.imshow(error_img, cmap='hot')
                plt.title(f'Error Map (MSE: {np.mean(error**2):.4f})')
                plt.colorbar(label='Error Magnitude')
                plt.axis('off')

                plt.tight_layout()
                plt.show()

            elif mode == 'Attention Map':
                # Get attention map
                attention_map = get_attention_maps(model, input_img)

                # Visualize attention
                plt.figure(figsize=(15, 5))

                # Show input image
                plt.subplot(1, 3, 1)
                img_np = input_img.permute(1, 2, 0).cpu().numpy()
                img_np = (img_np * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406])).clip(0, 1)
                plt.imshow(img_np)
                plt.title(f'Input (Class: {test_dataset.classes[label]})')
                plt.axis('off')

                # Show attention map
                plt.subplot(1, 3, 2)
                plt.imshow(attention_map, cmap='hot')
                plt.title('Attention Map')
                plt.colorbar(label='Attention Weight')
                plt.axis('off')

                # Show overlay
                plt.subplot(1, 3, 3)
                plt.imshow(img_np)
                plt.imshow(attention_map, alpha=0.6, cmap='hot')
                plt.title('Attention Overlay')
                plt.axis('off')

                plt.tight_layout()
                plt.show()

                # Run classification
                with torch.no_grad():
                    outputs = model(input_tensor, task='classify')
                    logits = outputs['logits']
                    probs = F.softmax(logits, dim=1)[0]
                    pred_idx = probs.argmax().item()
                    pred_class = model.idx_to_class[pred_idx]

                print(f"Prediction: {pred_class} (Confidence: {probs[pred_idx]:.4f})")
                print(f"The attention map shows which parts of the spectrogram the model focuses on when making predictions.")

    # Register event handlers
    sample_slider.observe(on_change, names='value')
    mode_select.observe(on_change, names='value')

    # Create UI layout
    demo_ui = widgets.VBox([
        widgets.HBox([sample_slider, mode_select]),
        output_area
    ])

    # Initial display
    display(demo_ui)
    on_change(None)

    return demo_ui

# Create interactive demo
print("Initializing interactive demo...")
demo = create_interactive_demo(model, test_dataset)

## 12. Latent Space Manipulation

In [None]:
def interpolate_latent_space(model, img1, img2, num_steps=8):
    """Interpolate between two samples in latent space and visualize results"""
    device = next(model.parameters()).device
    model.eval()

    # Get indices of samples
    idx1, idx2 = img1, img2

    # Get images
    input1, _, label1 = test_dataset[idx1]
    input2, _, label2 = test_dataset[idx2]

    # Move to device and add batch dimension
    input1 = input1.unsqueeze(0).to(device)
    input2 = input2.unsqueeze(0).to(device)

    # Get latent representations
    with torch.no_grad():
        latent1 = model(input1, task='encode')['latent']
        latent2 = model(input2, task='encode')['latent']

    # Create interpolation steps
    alphas = np.linspace(0, 1, num_steps)
    interpolated_imgs = []
    interpolated_latents = []

    # Generate interpolated images
    with torch.no_grad():
        for alpha in alphas:
            # Linear interpolation in latent space
            interpolated_latent = (1 - alpha) * latent1 + alpha * latent2
            interpolated_latents.append(interpolated_latent)

            # Decode interpolated latent
            outputs = model.decode(interpolated_latent)
            interpolated_imgs.append(outputs.cpu())

    # Visualize interpolation
    plt.figure(figsize=(15, 8))

    # Show original images
    plt.subplot(3, num_steps, 1)
    img_np = input1.squeeze(0).permute(1, 2, 0).cpu().numpy()
    img_np = (img_np * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406])).clip(0, 1)
    plt.imshow(img_np)
    plt.title(f"Class: {test_dataset.classes[label1]}")
    plt.axis('off')

    plt.subplot(3, num_steps, num_steps)
    img_np = input2.squeeze(0).permute(1, 2, 0).cpu().numpy()
    img_np = (img_np * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406])).clip(0, 1)
    plt.imshow(img_np)
    plt.title(f"Class: {test_dataset.classes[label2]}")
    plt.axis('off')

    # Show interpolated latent space decodings
    for i, img in enumerate(interpolated_imgs):
        plt.subplot(3, num_steps, num_steps + i + 1)
        img_np = img.squeeze(0).permute(1, 2, 0).numpy()
        img_np = (img_np * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406])).clip(0, 1)
        plt.imshow(img_np)
        plt.title(f"α={alphas[i]:.2f}")
        plt.axis('off')

    # Run classification on interpolated latents
    with torch.no_grad():
        probs_list = []
        for latent in interpolated_latents:
            # Get CLS token
            cls_token = latent[:, 0]
            # Get logits
            logits = model.classifier(cls_token)
            # Get probabilities
            probs = F.softmax(logits, dim=1)[0]
            probs_list.append(probs.cpu().numpy())

    # Plot probabilities for each class
    probs_array = np.array(probs_list)
    for i, class_name in enumerate(test_dataset.classes):
        plt.subplot(3, 1, 3)
        plt.plot(alphas, probs_array[:, i], 'o-', label=class_name)

    plt.xlabel('Interpolation Factor (α)')
    plt.ylabel('Class Probability')
    plt.title('Class Probabilities Across Latent Space Interpolation')
    plt.legend()
    plt.grid(True, alpha=0.3)

    plt.tight_layout()
    plt.show()

# Choose two samples to interpolate between
sample1_idx = 10  # Choose an index from class 0
sample2_idx = 110  # Choose an index from class 1

print(f"Interpolating between sample {sample1_idx} ({test_dataset.classes[test_dataset[sample1_idx][2]]}) and sample {sample2_idx} ({test_dataset.classes[test_dataset[sample2_idx][2]]})")
interpolate_latent_space(model, sample1_idx, sample2_idx, num_steps=8)

## 13. Conclusion

In [None]:
"""
# NeurLLM: Key Findings and Future Work

## What We've Demonstrated

1. **Effective Representation Learning**: The ViT encoder-decoder architecture successfully learns to extract meaningful representations from EEG spectrograms, achieving good classification performance across different driving behaviors.

2. **Dual Capability**: Our model combines both classification and reconstruction capabilities, providing a more comprehensive understanding of the neurophysiological data.

3. **Interpretable Attention**: The attention maps reveal which parts of the EEG spectrograms are most relevant for identifying different driving behaviors.

4. **Latent Space Structure**: The latent space visualization shows clear clustering of different driving behaviors, confirming that the model has learned meaningful representations.

## Future Directions

1. **Multimodal Integration**: Extend the model to incorporate other physiological signals (EMG, ECG, etc.) alongside EEG.

2. **VLM Integration**: Connect the encoder with a language decoder to generate textual descriptions of EEG patterns.

3. **Real-time Processing**: Optimize the model for real-time inference to enable applications in driving monitoring systems.

4. **Transfer Learning**: Explore how pre-training on large datasets can improve performance on smaller, specialized datasets.

5. **Temporal Dynamics**: Enhance the model to better capture the temporal dynamics in continuous EEG recordings.
"""

# Save the model
torch.save({
    'model_state_dict': model.state_dict(),
    'class_names': class_names,
    'history': history
}, 'neurllm_model.pth')

print("Model saved as 'neurllm_model.pth'")
print("\nThank you for exploring NeurLLM!")