# Training ELM on Long Range Arena (LRA)

This notebook demonstrates how to train an ELM model on the Long Range Arena benchmark using PyTorch Lightning.

LRA is a benchmark for evaluating long-context sequence models across 5 diverse tasks:
- **ListOps** (2K): Hierarchical mathematical reasoning
- **Text** (4K): IMDb sentiment analysis (byte-level)
- **Retrieval** (8K): Document matching
- **Image** (1K): CIFAR-10 as sequences
- **Pathfinder** (1K): Visual path connectivity

## Imports

In [None]:
import os
import random
from pathlib import Path

import numpy as np
import torch
from lightning.pytorch import Trainer
from lightning.pytorch.callbacks import ModelCheckpoint, EarlyStopping

In [None]:
# Import ELM components
from elmneuron.expressive_leaky_memory_neuron_v2 import ELM
from elmneuron.tasks.classification_task import ClassificationTask
from elmneuron.lra.lra_datamodule import (
    ListOpsDataModule,
    LRATextDataModule,
    LRARetrievalDataModule,
    LRAImageDataModule,
    LRAPathfinderDataModule,
)
from elmneuron.callbacks import (
    SequenceVisualizationCallback,
    MemoryDynamicsCallback,
)

## Configuration

In [None]:
# Seeding & Config
general_seed = 42
os.environ['PYTHONHASHSEED'] = str(general_seed)
random.seed(general_seed)
np.random.seed(general_seed)
torch.manual_seed(general_seed)
torch.cuda.manual_seed(general_seed)
torch.backends.cudnn.deterministic = True

# Dataset config
lra_task = "listops"  # Choose: "listops", "text", "retrieval", "image", "pathfinder"
data_dir = "./data/lra"
batch_size = 32  # Smaller batch size for long sequences
num_workers = 4

# Model config (adjust based on task)
num_memory = 200  # Higher for longer sequences
lambda_value = 5.0
tau_b_value = 1.0
memory_tau_min = 1.0
memory_tau_max = 500.0  # Higher for long-range dependencies
learn_memory_tau = False

# Training config
learning_rate = 1e-3
num_epochs = 50  # LRA tasks may require more epochs

print(f"Training ELM on LRA task: {lra_task}")

## Setup DataModule

In [None]:
# Create DataModule based on selected task
if lra_task == "listops":
    datamodule = ListOpsDataModule(
        data_dir=data_dir,
        batch_size=batch_size,
        num_workers=num_workers,
    )
    seq_info = "2048 tokens, 10 classes"
    
elif lra_task == "text":
    datamodule = LRATextDataModule(
        data_dir=data_dir,
        batch_size=batch_size,
        num_workers=num_workers,
    )
    seq_info = "4096 bytes, 2 classes (sentiment)"
    
elif lra_task == "retrieval":
    datamodule = LRARetrievalDataModule(
        data_dir=data_dir,
        batch_size=batch_size,
        num_workers=num_workers,
    )
    seq_info = "8192 tokens (2 docs), 2 classes (match/no-match)"
    
elif lra_task == "image":
    datamodule = LRAImageDataModule(
        data_dir=data_dir,
        batch_size=batch_size,
        num_workers=num_workers,
    )
    seq_info = "1024 pixels (32x32), 10 classes (CIFAR-10)"
    
elif lra_task == "pathfinder":
    difficulty = "easy"  # Choose: "easy", "medium", "hard"
    datamodule = LRAPathfinderDataModule(
        data_dir=data_dir,
        batch_size=batch_size,
        num_workers=num_workers,
        difficulty=difficulty,
    )
    seq_info = f"1024 pixels (32x32), 2 classes (connected/not), difficulty: {difficulty}"

# Prepare and setup data
datamodule.prepare_data()
datamodule.setup("fit")

print(f"Task: {lra_task}")
print(f"Sequence info: {seq_info}")
print(f"Input dimension: {datamodule.input_dim}")
print(f"Number of classes: {datamodule.num_classes}")
print(f"Training samples: {len(datamodule.train_dataset) if hasattr(datamodule, 'train_dataset') else 'Loading...'}")

## Create Model

In [None]:
# Create base ELM model
elm_model = ELM(
    num_input=datamodule.input_dim,
    num_output=datamodule.num_classes,
    num_memory=num_memory,
    lambda_value=lambda_value,
    tau_b_value=tau_b_value,
    memory_tau_min=memory_tau_min,
    memory_tau_max=memory_tau_max,
    learn_memory_tau=learn_memory_tau,
)

# Wrap in Lightning classification task
lightning_module = ClassificationTask(
    model=elm_model,
    learning_rate=learning_rate,
    optimizer="adam",
    scheduler="cosine",
    scheduler_kwargs={"T_max": num_epochs * 1000},  # Approximate
    output_selection="last",  # Use last timestep for classification
)

num_params = sum(p.numel() for p in elm_model.parameters())
print(f"Model initialized with {num_params:,} parameters")

## Training

In [None]:
# Setup callbacks
callbacks = [
    # Model checkpointing
    ModelCheckpoint(
        dirpath=f"./checkpoints_lra_{lra_task}",
        filename=f"elm-lra-{lra_task}-{{epoch:02d}}-{{val/accuracy:.4f}}",
        monitor="val/accuracy",
        mode="max",
        save_top_k=3,
        save_last=True,
    ),
    # Early stopping
    EarlyStopping(
        monitor="val/accuracy",
        patience=10,  # Higher patience for LRA
        mode="max",
        verbose=True,
    ),
    # Visualization callbacks (careful with long sequences)
    SequenceVisualizationCallback(
        log_every_n_epochs=10,
        num_samples=2,  # Fewer samples for long sequences
        task_type="classification",
        save_dir=f"./visualizations_lra_{lra_task}",
        log_to_wandb=False,
    ),
    MemoryDynamicsCallback(
        log_every_n_epochs=10,
        num_samples=2,
        save_dir=f"./memory_lra_{lra_task}",
        log_to_wandb=False,
    ),
]

# Create trainer
trainer = Trainer(
    max_epochs=num_epochs,
    accelerator="auto",
    devices=1,
    callbacks=callbacks,
    deterministic=True,
    log_every_n_steps=50,
    enable_progress_bar=True,
    gradient_clip_val=1.0,
    # Consider using gradient accumulation for long sequences
    accumulate_grad_batches=1,
)

# Train the model
print("Starting training...")
print("Note: LRA tasks may take significant time due to long sequences")
trainer.fit(lightning_module, datamodule=datamodule)

## Testing

In [None]:
# Test the model
print("Testing model...")
test_results = trainer.test(lightning_module, datamodule=datamodule, ckpt_path="best")

test_accuracy = test_results[0]['test/accuracy']
print(f"\nTest Accuracy: {test_accuracy:.4f}")

## Save Model

In [None]:
# Save the best model
model_path = f"./lra_{lra_task}_best_model.pt"
torch.save(lightning_module.model.state_dict(), model_path)
print(f"Model saved to {model_path}")

## Notes

### LRA Benchmark Tasks

| Task | Sequence Length | Classes | Description | Expected Baseline |
|------|----------------|---------|-------------|-------------------|
| ListOps | 2,048 | 10 | Hierarchical mathematical operations | ~35-40% |
| Text | 4,096 | 2 | IMDb sentiment (byte-level) | ~60-65% |
| Retrieval | 8,192 | 2 | Document matching | ~55-60% |
| Image | 1,024 | 10 | CIFAR-10 as sequences | ~40-45% |
| Pathfinder | 1,024 | 2 | Visual path connectivity | ~60-70% |

### Training Tips for LRA

1. **Memory Efficiency**:
   - Use smaller batch sizes (16-32)
   - Consider gradient accumulation
   - Enable mixed precision training

2. **Hyperparameters**:
   - Increase `num_memory` for longer sequences (200-500)
   - Increase `memory_tau_max` for long-range dependencies (500-1000)
   - Use gradient clipping (1.0)

3. **Training Time**:
   - LRA tasks are computationally expensive
   - Consider using GPU acceleration
   - Each epoch may take 10-30 minutes

### Pathfinder Difficulty Levels

The Pathfinder task has three difficulty levels:
- **Easy**: Shorter paths, clearer connections
- **Medium**: Moderate complexity
- **Hard**: Longer paths, more distractors

### Performance Comparison

Compare your results with published baselines:
- Transformers: ~35-60% depending on task
- S4 (State Space Models): ~60-80%
- ELM: Results will vary based on hyperparameters

### Running All Tasks

To benchmark ELM across all LRA tasks:

```python
tasks = ["listops", "text", "retrieval", "image", "pathfinder"]
for task in tasks:
    # Update lra_task variable and run training
    print(f"Training on {task}...")
```