# Online Inference Tutorial
In this tutorial, we will implement online inference with event-based state-space models.
Online inference is the process of classifying events as they arrive in real-time.
For many edge systems, the batch size is 1, and the model has to meet a specific throughput of events per second.
Here, you will test if your CPU is able to run real-time classification with EventSSM.

The tutorial requires basic familiarity with JAX.

In [24]:
from hydra import initialize, compose
from omegaconf import OmegaConf as om

import jax
import jax.numpy as jnp

from event_ssm.ssm import init_S5SSM
from event_ssm.seq_model import ClassificationModel

## Step 1: Load the model

In [25]:
# Load configurations
with initialize(version_base=None, config_path="configs"):
    cfg = compose(config_name="base.yaml", overrides=["model=dvs/small"])

In [26]:
# Print the configuration
print(om.to_yaml(cfg.model))

ssm_init:
  C_init: lecun_normal
  dt_min: 0.001
  dt_max: 0.1
  conj_sym: false
  clip_eigs: true
ssm:
  discretization: async
  d_model: 128
  d_ssm: 128
  ssm_block_size: 16
  num_stages: 2
  num_layers_per_stage: 3
  dropout: 0.25
  classification_mode: timepool
  prenorm: true
  batchnorm: false
  bn_momentum: 0.95
  pooling_stride: 16
  pooling_mode: timepool
  state_expansion_factor: 2



In [27]:
# Set the random seed manually for reproducibility.
key = jax.random.PRNGKey(cfg.seed)
init_key, data_key = jax.random.split(key)

In [28]:
# Model initialisation in flax
ssm_init_fn = init_S5SSM(**cfg.model.ssm_init)

# number of classes (dummy)
classes = 10

# number of tokens for a DVS sensor of size 128x128
num_tokens = 128 * 128 * 2
model = ClassificationModel(
        ssm=ssm_init_fn,
        num_classes=10,
        num_embeddings=num_tokens,
        **cfg.model.ssm,
    )

EventSSM subsamples sequences in multiple stages to reduce the computational cost.
Let's investigate the total subsampling

In [35]:
total_subsampling = cfg.model.ssm.pooling_stride ** cfg.model.ssm.num_stages
print(f"Total subsampling: {total_subsampling}")

Total subsampling: 256


In [36]:
# initialize model parameters
x = jnp.zeros(total_subsampling, dtype=jnp.int32)
t = jnp.ones(total_subsampling)
variables = model.init(
        {"params": init_key},
        x, t, total_subsampling, False
    )

SSM: 128 -> 128 -> 128 (stride 16 with pooling mode timepool)
SSM: 128 -> 128 -> 128
SSM: 128 -> 128 -> 128
SSM: 128 -> 256 -> 256 (stride 16 with pooling mode timepool)
SSM: 256 -> 256 -> 256
SSM: 256 -> 256 -> 256


## Step 2: Run the model on random data
Generate a random list of integer tokens, jit compile the model and classify online.

In [63]:
# Generate random data
sequence_length = 2 ** 18
tokens = jax.random.randint(data_key, shape=(sequence_length,), minval=0, maxval=num_tokens)
timesteps = jnp.ones(sequence_length)
print("Sequence length:", sequence_length)

Sequence length: 262144


In [64]:
# jit compile the model
from functools import partial
jit_apply = jax.jit(partial(model.apply, length=total_subsampling, train=False))
jit_apply(variables, x[:total_subsampling], t[:total_subsampling])

SSM: 128 -> 128 -> 128 (stride 16 with pooling mode timepool)
SSM: 128 -> 128 -> 128
SSM: 128 -> 128 -> 128
SSM: 128 -> 256 -> 256 (stride 16 with pooling mode timepool)
SSM: 256 -> 256 -> 256
SSM: 256 -> 256 -> 256


Array([-0.12317943, -0.17902763, -0.26315966,  0.5992651 ,  0.7048361 ,
        1.2036127 ,  0.00121723,  0.41398254,  0.26262668,  0.18357195],      dtype=float32)

In [65]:
# loop through the model
from tqdm import tqdm
from time import time
print(f"Looping through {sequence_length} events with total_subsampling={total_subsampling} --> {sequence_length // total_subsampling} iterations")
start = time()
for i in tqdm(range(0, sequence_length, total_subsampling)):
    x = tokens[i:i + total_subsampling]
    t = timesteps[i:i + total_subsampling]
    logits = jit_apply(variables, x, t).block_until_ready()
end = time()
print(f"Time taken: {end - start:.2f}s")
print(f"Events per second: {sequence_length / (end - start):.2f}")

Looping through 262144 events with total_subsampling=256 --> 1024 iterations


100%|██████████| 1024/1024 [00:03<00:00, 285.19it/s]

Time taken: 3.59s
Events per second: 72962.94





## Step 3: Optimize the inference speed
We suggest to use [jax.lax.scan](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.scan.html) instead of the for loop to further speed up the inference.

In [66]:
def step(carry, inputs):
    x, t = inputs
    logits = model.apply(variables, x, t, total_subsampling, False)
    return None, logits
tokens = tokens.reshape(-1, total_subsampling)
timesteps = timesteps.reshape(-1, total_subsampling)

# run the scan: first jit-compiles and then iterates
logits = jax.lax.scan(step, init=None, xs=(tokens, timesteps))

SSM: 128 -> 128 -> 128 (stride 16 with pooling mode timepool)
SSM: 128 -> 128 -> 128
SSM: 128 -> 128 -> 128
SSM: 128 -> 256 -> 256 (stride 16 with pooling mode timepool)
SSM: 256 -> 256 -> 256
SSM: 256 -> 256 -> 256


In [68]:
# measure run-time
start = time()
_, logits = jax.block_until_ready(jax.lax.scan(step, init=None, xs=(tokens, timesteps)))
end = time()
print(f"Time taken: {end - start:.2f}s")
print(f"Events per second: {sequence_length / (end - start):.2f}")

Time taken: 2.65s
Events per second: 99018.86


In [69]:
logits.shape

(1024, 10)

## Step 4: Run inference on the DVS128 Gestures dataset
Follow the steps in the `tutorial_inference.ipynb` to run inference on the DVS128 Gestures dataset with a pretrained model.
Plot the confidence of the model in the correct class over time