In [1]:
# NOTE: reduced implementation might require slightly different parameters for same results

In [None]:
# Imports
import os
import random
from pathlib import Path

import numpy as np
import torch
from lightning.pytorch import Trainer
from lightning.pytorch.callbacks import ModelCheckpoint, EarlyStopping

In [None]:
# Import ELM components
from elmneuron.expressive_leaky_memory_neuron_v2 import ELM
from elmneuron.tasks.classification_task import ClassificationTask
from elmneuron.shd.shd_datamodule import SHDDataModule
from elmneuron.shd.shd_download_utils import get_shd_dataset
from elmneuron.callbacks import (
    SequenceVisualizationCallback,
    MemoryDynamicsCallback,
)

In [None]:
# Seeding & Config
general_seed = 0
os.environ['PYTHONHASHSEED'] = str(general_seed)
random.seed(general_seed)
np.random.seed(general_seed)
torch.manual_seed(general_seed)
torch.cuda.manual_seed(general_seed)
torch.backends.cudnn.deterministic = True

# Dataset config
dataset_type = "shd"  # "shd" or "shdadding"
data_path = "./shd_data"
bin_size = 10
num_classes = 20 if dataset_type == "shd" else 19
num_input_channel = 700

# Model config
lambda_value = 5.0
num_memory = 100
memory_tau_min = 1.0
memory_tau_max = 500.0
learn_memory_tau = False
tau_b_value = float(bin_size)

# Training config
learning_rate = 5e-3
num_epochs = 10
batch_size = 8
dropout = 0.5

In [5]:
# Data download config
data_path = "shd_data"
get_shd_dataset("./", data_path)

Available at: ./shd_data/shd_train.h5
Available at: ./shd_data/shd_test.h5


In [None]:
# Setup DataModule
datamodule = SHDDataModule(
    data_dir=data_path,
    dataset_type=dataset_type,
    batch_size=batch_size,
    bin_size=bin_size,
    valid_fraction=0.2,
    num_workers=4,
    seed=general_seed,
    dropout=dropout,
)

# Prepare data (download if needed)
datamodule.prepare_data()
datamodule.setup("fit")

print(f"Training samples: {len(datamodule.train_dataset)}")
print(f"Validation samples: {len(datamodule.val_dataset)}")
print(f"Input dimension: {datamodule.input_dim}")
print(f"Number of classes: {datamodule.num_classes}")

In [None]:
# Create ELM model and Lightning task wrapper

# Create base ELM model
elm_model = ELM(
    num_input=datamodule.input_dim,
    num_output=datamodule.num_classes,
    num_memory=num_memory,
    lambda_value=lambda_value,
    tau_b_value=tau_b_value,
    memory_tau_min=memory_tau_min,
    memory_tau_max=memory_tau_max,
    learn_memory_tau=learn_memory_tau,
    delta_t=bin_size,
)

# Wrap in Lightning classification task
lightning_module = ClassificationTask(
    model=elm_model,
    learning_rate=learning_rate,
    optimizer="adamax",
    scheduler="cosine",
    scheduler_kwargs={"T_max": num_epochs * len(datamodule.train_dataset) // batch_size},
    output_selection="last",  # Use last timestep for SHD classification
)

print(f"Model initialized with {sum(p.numel() for p in elm_model.parameters())} parameters")

In [None]:
# Setup Lightning Trainer

callbacks = [
    # Model checkpointing
    ModelCheckpoint(
        dirpath="./checkpoints_shd",
        filename="elm-shd-{epoch:02d}-{val/accuracy:.4f}",
        monitor="val/accuracy",
        mode="max",
        save_top_k=3,
        save_last=True,
    ),
    # Early stopping
    EarlyStopping(
        monitor="val/accuracy",
        patience=5,
        mode="max",
        verbose=True,
    ),
    # Visualization callbacks
    SequenceVisualizationCallback(
        log_every_n_epochs=5,
        num_samples=4,
        task_type="classification",
        save_dir="./visualizations_shd",
        log_to_wandb=False,
    ),
    MemoryDynamicsCallback(
        log_every_n_epochs=5,
        num_samples=2,
        save_dir="./memory_shd",
        log_to_wandb=False,
    ),
]

# Create trainer
trainer = Trainer(
    max_epochs=num_epochs,
    accelerator="auto",
    devices=1,
    callbacks=callbacks,
    deterministic=True,
    log_every_n_steps=50,
    enable_progress_bar=True,
)

# Train the model
trainer.fit(lightning_module, datamodule=datamodule)

# Test the model
trainer.test(lightning_module, datamodule=datamodule, ckpt_path="best")

# Save the best model
torch.save(lightning_module.model.state_dict(), "./shd_best_model.pt")
print("Model saved to ./shd_best_model.pt")