# Tutorial: Training a model
In this tutorial, we will train an event-based state-space model on a reduced version of the [Spiking Heidelberg Digits](https://zenkelab.org/resources/spiking-heidelberg-datasets-shd/) dataset.
For training on larger datasets or multiple GPUs, we recommend using the training script `run_training.py` instead.

## Setup

Install and load the important modules and configuration. To install required packages, please do 
```
pip3 install requirements.txt
```

Directories for loading datasets, model checkpoints and saving results are defined in the configuration file `system/local.yaml`.
Please set your directories accordingly.

## Data loading
The SHD dataset contains 20 classes, digits from 0 to 9 in both German and English. 
We will use a reduced version of the dataset containing only two digits to train the model to non-trivial performance in reasonable time even on CPUs.

[Download the training and test dataset](https://zenkelab.org/datasets/) and unpack the archives to `./data/`.

In [None]:
from torch.utils.data import Dataset, DataLoader, random_split
import h5py
import numpy as np

class SpikingHeidelbergDigits(Dataset):
    def __init__(self, path_to_file):
        self.num_classes = 2
        self.num_channels = 700
        self.path_to_file = path_to_file
        
        # load the dataset
        with h5py.File(path_to_file, 'r') as f:
            self.channels = f['spikes']['units'][:]
            self.timesteps = f['spikes']['times'][:]
            self.labels = f['labels'][:]
        
        # filter the dataset to contain only two classes
        mask = (self.labels == 0) | (self.labels == 1)
        self.channels = self.channels[mask]
        self.timesteps = self.timesteps[mask]
        self.labels = self.labels[mask]
        
    def __len__(self):
        return len(self.labels)
    
    def __getitem__(self, idx):
        # create tonic-like structured arrays
        dtype = np.dtype([("t", int), ("x", int), ("p", int)])
        struct_arr = np.empty_like(self.channels[idx], dtype=dtype)
        
        # yield timesteps in milliseconds
        timesteps = self.timesteps[idx] * 1e6
        
        struct_arr['t'] = timesteps
        struct_arr['x'] = self.channels[idx]
        struct_arr['p'] = 1
        
        # one-hot encoding of labels (required for CutMix augmentation)
        label = np.eye(self.num_classes)[self.labels[idx]].astype(np.int32)
            
        return struct_arr, label

In [None]:
# Load the training and test dataset
train_dataset = SpikingHeidelbergDigits('data/shd_train.h5')
test_dataset = SpikingHeidelbergDigits('data/shd_test.h5')

Check the length of the datasets to check if the data loading was successful.

In [None]:
print(f"Number of training samples: {len(train_dataset)}")
print(f"Number of test samples: {len(test_dataset)}")

Now, create a validation set by randomly splitting the training dataset, and create data loaders for training, validation, and test datasets.

In [None]:
# Split the training dataset into training and validation
train_dataset, val_dataset = random_split(train_dataset, [int(0.8*len(train_dataset)), len(train_dataset) - int(0.8*len(train_dataset))])

# Create data loaders
from event_ssm.dataloading import event_stream_collate_fn
from functools import partial

collate_fn = partial(event_stream_collate_fn, resolution=(700,), pad_unit=8192)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, drop_last=True, collate_fn=collate_fn)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, drop_last=True, collate_fn=collate_fn)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, collate_fn=collate_fn)

## Model definition
We use the [hydra](https://hydra.cc/docs/intro/) package for efficient configuration management. Define the model configuration in a config file in the `configs` directory.

In [None]:
from hydra import compose, initialize
from omegaconf import OmegaConf, open_dict

with initialize(version_base=None, config_path="configs", job_name="training tutorial"):
    cfg = compose(config_name="base", overrides=["task=tutorial"])

with open_dict(cfg):    
    # optax updates the schedule every iteration and not every epoch
    cfg.optimizer.total_steps = cfg.training.num_epochs * len(train_loader) // cfg.optimizer.accumulation_steps
    cfg.optimizer.warmup_steps = cfg.optimizer.warmup_epochs * len(train_loader) // cfg.optimizer.accumulation_steps
    
    # scale learning rate by batch size
    cfg.optimizer.ssm_lr = cfg.optimizer.ssm_base_lr * cfg.training.per_device_batch_size * cfg.optimizer.accumulation_steps

print(OmegaConf.to_yaml(cfg))

Now, create the model using the configuration defined above.

In [None]:
from event_ssm.ssm import init_S5SSM
from event_ssm.seq_model import BatchClassificationModel

ssm_init_fn = init_S5SSM(**cfg.model.ssm_init)
model = BatchClassificationModel(
    ssm=ssm_init_fn,
    num_classes=test_dataset.num_classes,
    num_embeddings=test_dataset.num_channels,
    **cfg.model.ssm,
)


Initialize the training state by feeding a dummy input

In [None]:
import jax
from event_ssm.train_utils import init_model_state

# pick the first batch from the training loader
batch = next(iter(train_loader))
inputs, targets, timesteps, lengths = batch

# initialize the training state
key = jax.random.PRNGKey(cfg.seed)
state = init_model_state(key, model, inputs, timesteps, lengths, cfg.optimizer)

## Inspect the model
The model parameters are accessible as part of the training state. 
We will look into the spectrum of the recurrent operator here.
The model was initialized with a single stage of blocks.

In [None]:
def get_spectrum(state):
    params = state.params['encoder']['stages_0']
    lambda_bar = []
    time_scales = []
    for name, sequence_layer in params.items():
        # read lambda parameters
        Lambda_im = sequence_layer['S5SSM_0']['Lambda_im']
        Lambda_re = sequence_layer['S5SSM_0']['Lambda_re']
        
        # read and compute delta and Lambda
        delta = np.exp(sequence_layer['S5SSM_0']['log_step'][:, 0])
        Lambda = Lambda_re + 1j * Lambda_im
        
        # compute lambda_bar and time scales
        lambda_bar.append(np.exp(Lambda * delta))
        time_scales.append(1 / np.abs(Lambda) / delta)
    return lambda_bar, time_scales
spectrum, time_scales = get_spectrum(state)

Plot the spectrum of the recurrent operator and the corresponding time scales upon initialization.

In [None]:
import matplotlib.pyplot as plt

def plot_spectrum(spectrum):
    fig, axes = plt.subplots(1, 6, figsize=(len(spectrum) * 4, 4))
    # draw the unit circle
    theta = np.linspace(0, 2 * np.pi, 100)  # 100 points from 0 to 2*pi
    x = np.cos(theta)
    y = np.sin(theta)
    
    # plot the spectrum
    for i, (ax, layer) in enumerate(zip(axes, spectrum)):
        ax.plot(x, y, 'r', linewidth=1)
        ax.scatter(np.real(layer), np.imag(layer), marker='o', alpha=0.8)
    
        # format axis
        ax.set_title(f'Layer {i}')
        ax.set_aspect('equal', adjustable='box')
        ax.set_xlim(-1.1, 1.1)
        ax.set_ylim(-1.1, 1.1)
    
    plt.tight_layout()
    plt.show()
    
plot_spectrum(spectrum)

In [None]:
def plot_time_scales(time_scales):
    log_scales = np.log2(np.stack(time_scales).flatten())
    fig, ax = plt.subplots(1, 1, figsize=(6, 6))
    ax.hist(log_scales)
    
    # format axis
    max_scale = np.max(np.ceil(log_scales))
    min_scale = np.min(np.floor(log_scales))
    ax.set_xlim((min_scale, max_scale))
    xticks = np.arange(1 + max_scale - min_scale) + min_scale
    ax.set_xticks(xticks, (2 ** xticks).astype(np.int32))
    ax.set_title('Distribution of time scales')
    ax.set_xlabel('Time scale')
    ax.set_ylabel('Count')
    plt.show()
    
plot_time_scales(time_scales)

## Train the model
For training, we implemented a trainer module that makes training as easy as possible. The trainer module hides some boilerplate code for training from the user and provides a simple interface to train the model. It loops through the data loader, computes the loss, and updates the model parameters. Therefore, we need to define training_step and validation_step functions that the loop calls upon the model. These are implemented already, and can be used here.

In [None]:
from event_ssm.train_utils import training_step, evaluation_step
from event_ssm.trainer import TrainerModule

# just-in-time compile the training and evaluation functions
train_step = jax.jit(training_step)
eval_step = jax.jit(evaluation_step)

# initialize the trainer module
num_devices = 1
trainer = TrainerModule(
    train_state=state,
    training_step_fn=train_step,
    evaluation_step_fn=eval_step,
    world_size=num_devices,
    config=cfg,
)

We are now ready to start the training loop. 

**Note:** JAX compiles your program just-in-time (JIT) to optimize performance. This means that the first iteration of the training loop will be slower than the following ones.  

In [None]:
# generate random key for dropout
key, dropout_key = jax.random.split(key)

# train the model
trainer.train_model(
    train_loader=train_loader,
    val_loader=val_loader,
    test_loader=test_loader,
    dropout_key=dropout_key
)

## Inspect the trained model
We now have a trained toy model on the SHD dataset.
Let's look into the spectrum of the recurrent operator after training.

In [None]:
spectrum, time_scales = get_spectrum(trainer.train_state)
plot_spectrum(spectrum)
plot_time_scales(time_scales)

## Assignment
The function `apply_ssm` in `event_ssm/ssm.py` implements the recurrent operator with an associative scan. On highly parallel GPUs, this can speed up training on very long sequences. 
On CPUs however, the overhead of the scan operation can slow down training. 
Your task is to implement a CPU-friendly version of the recurrent operator in `event_ssm/ssm.py` and compare the training time with the original implementation.
We suggest to implement a step-by-step recurrence with [`jax.lax.scan`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.scan.html) instead of the currenlty used [`jax.lax.associative_scan`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.associative_scan.html) for this purpose.