# Training ELM on WikiText (Text Dataset)

This notebook demonstrates how to train an ELM model on the WikiText dataset using PyTorch Lightning.

WikiText is a language modeling dataset. Text is tokenized into sequences for character-level or word-level modeling.

## Imports

In [None]:
import os
import random
from pathlib import Path

import numpy as np
import torch
from lightning.pytorch import Trainer
from lightning.pytorch.callbacks import ModelCheckpoint, EarlyStopping

In [None]:
# Import ELM components
from elmneuron.expressive_leaky_memory_neuron_v2 import ELM
from elmneuron.tasks.classification_task import ClassificationTask
from elmneuron.text.text_datamodule import WikiText2DataModule, WikiText103DataModule
from elmneuron.transforms import CharTokenization, WordTokenization
from elmneuron.callbacks import (
    SequenceVisualizationCallback,
    MemoryDynamicsCallback,
)

## Configuration

In [None]:
# Seeding & Config
general_seed = 42
os.environ['PYTHONHASHSEED'] = str(general_seed)
random.seed(general_seed)
np.random.seed(general_seed)
torch.manual_seed(general_seed)
torch.cuda.manual_seed(general_seed)
torch.backends.cudnn.deterministic = True

# Dataset config
data_dir = "./data/wikitext"
dataset_version = "wikitext2"  # "wikitext2" or "wikitext103"
batch_size = 64
num_workers = 4
sequence_length = 128  # Context length for training

# Tokenization strategy
tokenization = "char"  # "char" or "word"
vocab_size = 256 if tokenization == "char" else 10000

# Model config
embedding_dim = 64
num_memory = 200
lambda_value = 5.0
tau_b_value = 1.0
memory_tau_min = 1.0
memory_tau_max = 200.0
learn_memory_tau = False

# Training config
learning_rate = 1e-3
num_epochs = 20

print(f"Training ELM on {dataset_version} with {tokenization}-level tokenization")
print(f"Vocabulary size: {vocab_size}")
print(f"Sequence length: {sequence_length}")

## Setup DataModule

In [None]:
# Create tokenization transform
if tokenization == "char":
    tokenization_fn = CharTokenization(vocab_size=vocab_size)
else:
    tokenization_fn = WordTokenization(vocab_size=vocab_size)

# Create DataModule
if dataset_version == "wikitext2":
    datamodule = WikiText2DataModule(
        data_dir=data_dir,
        batch_size=batch_size,
        num_workers=num_workers,
        tokenization=tokenization_fn,
        sequence_length=sequence_length,
        embedding_dim=embedding_dim,
    )
else:
    datamodule = WikiText103DataModule(
        data_dir=data_dir,
        batch_size=batch_size,
        num_workers=num_workers,
        tokenization=tokenization_fn,
        sequence_length=sequence_length,
        embedding_dim=embedding_dim,
    )

# Prepare and setup data
datamodule.prepare_data()
datamodule.setup("fit")

print(f"Input dimension: {datamodule.input_dim}")
print(f"Number of classes (vocab): {datamodule.num_classes}")
print(f"Embedding dimension: {embedding_dim}")

## Create Model

In [None]:
# Create base ELM model
elm_model = ELM(
    num_input=datamodule.input_dim,  # Embedding dimension
    num_output=datamodule.num_classes,  # Vocabulary size
    num_memory=num_memory,
    lambda_value=lambda_value,
    tau_b_value=tau_b_value,
    memory_tau_min=memory_tau_min,
    memory_tau_max=memory_tau_max,
    learn_memory_tau=learn_memory_tau,
)

# Wrap in Lightning classification task (next-token prediction)
lightning_module = ClassificationTask(
    model=elm_model,
    learning_rate=learning_rate,
    optimizer="adam",
    scheduler="cosine",
    scheduler_kwargs={"T_max": num_epochs * 1000},  # Approximate
    output_selection="all",  # Predict at every timestep
)

num_params = sum(p.numel() for p in elm_model.parameters())
print(f"Model initialized with {num_params:,} parameters")

## Training

In [None]:
# Setup callbacks
callbacks = [
    # Model checkpointing
    ModelCheckpoint(
        dirpath="./checkpoints_wikitext",
        filename="elm-wikitext-{epoch:02d}-{val/loss:.4f}",
        monitor="val/loss",
        mode="min",
        save_top_k=3,
        save_last=True,
    ),
    # Early stopping
    EarlyStopping(
        monitor="val/loss",
        patience=5,
        mode="min",
        verbose=True,
    ),
    # Memory dynamics visualization
    MemoryDynamicsCallback(
        log_every_n_epochs=5,
        num_samples=2,
        save_dir="./memory_wikitext",
        log_to_wandb=False,
    ),
]

# Create trainer
trainer = Trainer(
    max_epochs=num_epochs,
    accelerator="auto",
    devices=1,
    callbacks=callbacks,
    deterministic=True,
    log_every_n_steps=50,
    enable_progress_bar=True,
    gradient_clip_val=1.0,  # Important for language modeling
)

# Train the model
print("Starting training...")
trainer.fit(lightning_module, datamodule=datamodule)

## Testing

In [None]:
# Test the model
print("Testing model...")
test_results = trainer.test(lightning_module, datamodule=datamodule, ckpt_path="best")

# Calculate perplexity from test loss
test_loss = test_results[0]['test/loss']
perplexity = torch.exp(torch.tensor(test_loss))
print(f"Test Perplexity: {perplexity:.2f}")

## Save Model

In [None]:
# Save the best model
torch.save(lightning_module.model.state_dict(), "./wikitext_best_model.pt")
print("Model saved to ./wikitext_best_model.pt")

## Notes

### Language Modeling Task

This notebook demonstrates **next-token prediction** (language modeling):
- Input: Sequence of tokens (characters or words)
- Output: Prediction for next token at each timestep
- Metric: Cross-entropy loss, perplexity

### Tokenization Strategies

1. **Character-level** (vocab_size ~256):
   - Pros: Small vocabulary, no OOV issues
   - Cons: Longer sequences, less semantic units

2. **Word-level** (vocab_size ~10K):
   - Pros: Shorter sequences, more semantic units
   - Cons: Larger vocabulary, OOV issues

### Dataset Variants

- **WikiText-2**: ~2M tokens (smaller, faster)
- **WikiText-103**: ~103M tokens (larger, more data)

### Custom Text

To use your own text data:

```python
from elmneuron.text.text_datamodule import CustomTextDataModule

datamodule = CustomTextDataModule(
    train_file="path/to/train.txt",
    val_file="path/to/val.txt",
    test_file="path/to/test.txt",
    batch_size=64,
    tokenization=CharTokenization(vocab_size=256),
    sequence_length=128,
)
```