# Training ELM on MNIST (Vision Dataset)

This notebook demonstrates how to train an ELM model on the MNIST dataset using PyTorch Lightning.

MNIST is treated as a sequence by flattening or patchifying the 28x28 images.

## Imports

In [None]:
import os
import random
from pathlib import Path

import numpy as np
import torch
from lightning.pytorch import Trainer
from lightning.pytorch.callbacks import ModelCheckpoint, EarlyStopping

In [None]:
# Import ELM components
from elmneuron.expressive_leaky_memory_neuron_v2 import ELM
from elmneuron.tasks.classification_task import ClassificationTask
from elmneuron.vision.vision_datamodule import MNISTDataModule
from elmneuron.transforms import FlattenSequentialization, PatchSequentialization
from elmneuron.callbacks import (
    SequenceVisualizationCallback,
    MemoryDynamicsCallback,
)

## Configuration

In [None]:
# Seeding & Config
general_seed = 42
os.environ['PYTHONHASHSEED'] = str(general_seed)
random.seed(general_seed)
np.random.seed(general_seed)
torch.manual_seed(general_seed)
torch.cuda.manual_seed(general_seed)
torch.backends.cudnn.deterministic = True

# Dataset config
data_dir = "./data/mnist"
batch_size = 128
num_workers = 4

# Sequentialization strategy
# Options: "flatten" (784 timesteps) or "patch" (e.g., 49 patches of 4x4)
seq_strategy = "flatten"  # or "patch"

# Model config
num_memory = 100
lambda_value = 5.0
tau_b_value = 1.0
memory_tau_min = 1.0
memory_tau_max = 100.0
learn_memory_tau = False

# Training config
learning_rate = 1e-3
num_epochs = 20

print(f"Training ELM on MNIST with {seq_strategy} sequentialization")

## Setup DataModule

In [None]:
# Create sequentialization transform
if seq_strategy == "flatten":
    sequentialization = FlattenSequentialization(flatten_order="row_major")
    seq_length = 28 * 28  # 784
    input_dim = 1  # Grayscale
else:  # patch
    sequentialization = PatchSequentialization(patch_size=4, flatten_patches=True)
    num_patches = (28 // 4) * (28 // 4)  # 7x7 = 49
    seq_length = num_patches
    input_dim = 4 * 4 * 1  # 16

# Create DataModule
datamodule = MNISTDataModule(
    data_dir=data_dir,
    batch_size=batch_size,
    num_workers=num_workers,
    sequentialization=sequentialization,
    normalize=True,
)

# Prepare and setup data
datamodule.prepare_data()
datamodule.setup("fit")

print(f"Training samples: {len(datamodule.mnist_train)}")
print(f"Validation samples: {len(datamodule.mnist_val)}")
print(f"Sequence length: {seq_length}")
print(f"Input dimension: {input_dim}")
print(f"Number of classes: {datamodule.num_classes}")

## Create Model

In [None]:
# Create base ELM model
elm_model = ELM(
    num_input=input_dim,
    num_output=datamodule.num_classes,
    num_memory=num_memory,
    lambda_value=lambda_value,
    tau_b_value=tau_b_value,
    memory_tau_min=memory_tau_min,
    memory_tau_max=memory_tau_max,
    learn_memory_tau=learn_memory_tau,
)

# Wrap in Lightning classification task
lightning_module = ClassificationTask(
    model=elm_model,
    learning_rate=learning_rate,
    optimizer="adam",
    scheduler="cosine",
    scheduler_kwargs={"T_max": num_epochs * (len(datamodule.mnist_train) // batch_size)},
    output_selection="mean",  # Average over sequence for classification
)

num_params = sum(p.numel() for p in elm_model.parameters())
print(f"Model initialized with {num_params:,} parameters")

## Training

In [None]:
# Setup callbacks
callbacks = [
    # Model checkpointing
    ModelCheckpoint(
        dirpath="./checkpoints_mnist",
        filename="elm-mnist-{epoch:02d}-{val/accuracy:.4f}",
        monitor="val/accuracy",
        mode="max",
        save_top_k=3,
        save_last=True,
    ),
    # Early stopping
    EarlyStopping(
        monitor="val/accuracy",
        patience=5,
        mode="max",
        verbose=True,
    ),
    # Visualization callbacks
    SequenceVisualizationCallback(
        log_every_n_epochs=5,
        num_samples=4,
        task_type="classification",
        save_dir="./visualizations_mnist",
        log_to_wandb=False,
    ),
    MemoryDynamicsCallback(
        log_every_n_epochs=5,
        num_samples=2,
        save_dir="./memory_mnist",
        log_to_wandb=False,
    ),
]

# Create trainer
trainer = Trainer(
    max_epochs=num_epochs,
    accelerator="auto",
    devices=1,
    callbacks=callbacks,
    deterministic=True,
    log_every_n_steps=50,
    enable_progress_bar=True,
)

# Train the model
print("Starting training...")
trainer.fit(lightning_module, datamodule=datamodule)

## Testing

In [None]:
# Test the model
print("Testing model...")
trainer.test(lightning_module, datamodule=datamodule, ckpt_path="best")

## Save Model

In [None]:
# Save the best model
torch.save(lightning_module.model.state_dict(), "./mnist_best_model.pt")
print("Model saved to ./mnist_best_model.pt")

## Notes

### Sequentialization Strategies

1. **Flatten** (row-major): Treats 28x28 image as 784-length sequence
   - Pros: Simple, preserves spatial locality within rows
   - Cons: Long sequence, no explicit 2D structure

2. **Patches** (4x4): Treats image as 49 patches of 16 pixels each
   - Pros: Shorter sequence, more semantic units
   - Cons: Requires larger input dimension

### Other Vision Datasets

To use Fashion-MNIST, CIFAR-10, or CIFAR-100, simply replace the DataModule:

```python
from elmneuron.vision.vision_datamodule import (
    FashionMNISTDataModule,
    CIFAR10DataModule,
    CIFAR100DataModule,
)

# Example: CIFAR-10 (32x32x3)
datamodule = CIFAR10DataModule(
    data_dir="./data/cifar10",
    batch_size=128,
    sequentialization=FlattenSequentialization(),
)
```