In [1]:
# NOTE: reduced implementation might require slightly different parameters for same results

## Notebook Config

In [None]:
# Imports
import os
import random
from pathlib import Path

import numpy as np
import torch
from lightning.pytorch import Trainer
from lightning.pytorch.callbacks import ModelCheckpoint, EarlyStopping

In [None]:
# Import ELM components
from elmneuron.expressive_leaky_memory_neuron_v2 import ELM
from elmneuron.tasks.neuronio_task import NeuronIOTask
from elmneuron.neuronio.neuronio_datamodule import NeuronIODataModule
from elmneuron.neuronio.neuronio_data_utils import (
    NEURONIO_DATA_DIM, 
    NEURONIO_LABEL_DIM,
    get_data_files_from_folder,
)
from elmneuron.transforms import NeuronIORouting
from elmneuron.callbacks import (
    StateRecorderCallback,
    SequenceVisualizationCallback,
    MemoryDynamicsCallback,
)

In [4]:
# General Config
general_config = dict()
general_config["seed"] = 0
general_config["device"] = 'cuda' if torch.cuda.is_available() else 'cpu'
general_config["short_training_run"] = True # TODO: change to full length run
general_config["verbose"] = general_config["short_training_run"]
torch_device = torch.device(general_config["device"])
print("Torch Device: ", torch_device)

Torch Device:  cpu


In [5]:
# Seeding & Determinism
os.environ['PYTHONHASHSEED'] = str(general_config["seed"])
random.seed(general_config["seed"])
np.random.seed(general_config["seed"])
torch.manual_seed(general_config["seed"])
torch.cuda.manual_seed(general_config["seed"])
torch.backends.cudnn.deterministic = True

## Data, Model, Training Config

In [6]:
# NOTE: this step requires you having downloaded the dataset

# Download Train Data: 
# https://www.kaggle.com/datasets/selfishgene/single-neurons-as-deep-nets-nmda-train-data
# Download Test Data: 
# https://www.kaggle.com/datasets/selfishgene/single-neurons-as-deep-nets-nmda-test-data # Data_test

# Location of downloaded folders
data_dir_path = Path("~/Data").expanduser().resolve() # TODO: change to neuronio data path
train_data_dir_path = data_dir_path / "neuronio_train_data"  # TODO: change to train subfolder
test_data_dir_path = data_dir_path / "neuronio_test_data"  # TODO: change to test subfolder

In [7]:
# Data Config

data_config = dict()
train_data_dirs = [
    str(train_data_dir_path / "full_ergodic_train_batch_2"),
    str(train_data_dir_path / "full_ergodic_train_batch_3"),
    str(train_data_dir_path / "full_ergodic_train_batch_4"),
    str(train_data_dir_path / "full_ergodic_train_batch_5"),
    str(train_data_dir_path / "full_ergodic_train_batch_6"),
    str(train_data_dir_path / "full_ergodic_train_batch_7"),
    str(train_data_dir_path / "full_ergodic_train_batch_8"),
    str(train_data_dir_path / "full_ergodic_train_batch_9"),
    str(train_data_dir_path / "full_ergodic_train_batch_10"),
]
valid_data_dirs = [str(train_data_dir_path / "full_ergodic_train_batch_1")]
test_data_dirs = [str(test_data_dir_path)]

data_config["train_data_dirs"] = train_data_dirs
data_config["valid_data_dirs"] = valid_data_dirs
data_config["test_data_dirs"] = test_data_dirs

data_config["data_dim"] = NEURONIO_DATA_DIM 
data_config["label_dim"] = NEURONIO_LABEL_DIM

In [None]:
# Model Config
model_config = dict()
model_config["num_branch"] = 45
model_config["num_synapse_per_branch"] = 100
model_config["num_memory"] = 20
model_config["memory_tau_min"] = 1.0
model_config["memory_tau_max"] = 150.0
model_config["learn_memory_tau"] = False
model_config["lambda_value"] = 5.0
model_config["tau_b_value"] = 5.0

In [9]:
# Training Config

train_config = dict()
train_config["num_epochs"] = 5 if general_config["short_training_run"] else 30
train_config["learning_rate"] = 5e-4
train_config["batch_size"] = 32 if general_config["short_training_run"] else 8
train_config["batches_per_epoch"] = 1000 if general_config["short_training_run"] else 10000
train_config["batches_per_epoch"] = int(8/train_config["batch_size"] * train_config["batches_per_epoch"])
train_config["file_load_fraction"] = 0.5 if general_config["short_training_run"] else 0.3
train_config["num_prefetch_batch"] = 100
train_config["num_workers"] = 10 # will make run nondeterministic
train_config["burn_in_time"] = 150
train_config["input_window_size"] = 500

## Data, Model, Training Setup

In [None]:
# Setup Lightning DataModule

# Get file lists
train_files = get_data_files_from_folder(data_config["train_data_dirs"])
valid_files = get_data_files_from_folder(data_config["valid_data_dirs"])
test_files = get_data_files_from_folder(data_config["test_data_dirs"])

# Create routing transform
routing = NeuronIORouting(
    num_input=NEURONIO_DATA_DIM,
    num_branch=model_config["num_branch"],
    num_synapse_per_branch=model_config["num_synapse_per_branch"],
)

# Calculate total number of synapses after routing
num_synapse = routing.num_synapse
print(f"Total synapses after routing: {num_synapse}")

# Create DataModule
datamodule = NeuronIODataModule(
    train_files=train_files,
    val_files=valid_files,
    test_files=test_files,
    routing=routing,
    batch_size=train_config["batch_size"],
    input_window_size=train_config["input_window_size"],
    file_load_fraction=train_config["file_load_fraction"],
    num_workers=train_config["num_workers"],
    num_prefetch_batch=train_config["num_prefetch_batch"],
    train_batches_per_epoch=train_config["batches_per_epoch"],
    val_batches_per_epoch=train_config["batches_per_epoch"] // 10,
    test_batches_per_epoch=train_config["batches_per_epoch"] // 5,
    seed=general_config["seed"],
    verbose=general_config["verbose"],
)

In [None]:
# Create ELM model and Lightning task wrapper

# Create base ELM model
elm_model = ELM(
    num_input=num_synapse,  # After routing
    num_output=NEURONIO_LABEL_DIM,
    num_memory=model_config["num_memory"],
    lambda_value=model_config["lambda_value"],
    tau_b_value=model_config["tau_b_value"],
    memory_tau_min=model_config["memory_tau_min"],
    memory_tau_max=model_config["memory_tau_max"],
    learn_memory_tau=model_config["learn_memory_tau"],
)

# Wrap in Lightning task module
lightning_module = NeuronIOTask(
    model=elm_model,
    learning_rate=train_config["learning_rate"],
    optimizer="adam",
    scheduler="cosine",
    scheduler_kwargs={
        "T_max": train_config["num_epochs"] * train_config["batches_per_epoch"]
    },
)

print(f"Model initialized with {sum(p.numel() for p in elm_model.parameters())} parameters")

In [None]:
# Setup Lightning Trainer with callbacks

callbacks = [
    # Model checkpointing
    ModelCheckpoint(
        dirpath="./checkpoints",
        filename="elm-{epoch:02d}-{val/spike_auc:.4f}",
        monitor="val/spike_auc",
        mode="max",
        save_top_k=3,
        save_last=True,
    ),
    # Early stopping
    EarlyStopping(
        monitor="val/spike_auc",
        patience=10,
        mode="max",
        verbose=True,
    ),
    # Visualization callbacks
    StateRecorderCallback(
        record_every_n_epochs=5,
        num_samples=4,
        save_dir="./states",
    ),
    SequenceVisualizationCallback(
        log_every_n_epochs=5,
        num_samples=4,
        task_type="regression",
        save_dir="./visualizations",
        log_to_wandb=False,  # Set to True if using wandb
    ),
    MemoryDynamicsCallback(
        log_every_n_epochs=5,
        num_samples=2,
        save_dir="./memory",
        log_to_wandb=False,
    ),
]

# Create trainer
trainer = Trainer(
    max_epochs=train_config["num_epochs"],
    accelerator="auto",
    devices=1,
    callbacks=callbacks,
    gradient_clip_val=1.0,
    deterministic=True,
    log_every_n_steps=50,
    enable_progress_bar=True,
)

## Training

In [None]:
# Train the model
trainer.fit(lightning_module, datamodule=datamodule)

## Evaluation

In [None]:
# Test the model
trainer.test(lightning_module, datamodule=datamodule, ckpt_path="best")

In [None]:
# Save the best model
torch.save(lightning_module.model.state_dict(), "./neuronio_best_model.pt")
print("Model saved to ./neuronio_best_model.pt")