# Online Inference Tutorial
In this tutorial, we will implement online inference with event-based state-space models.
Online inference is the process of classifying events as they arrive in real-time.
For many edge systems, the batch size is 1, and the model has to meet a specific throughput of events per second.
Here, you will test if your CPU is able to run real-time classification with EventSSM.

The tutorial requires basic familiarity with JAX.

In [None]:
from hydra import initialize, compose
from omegaconf import OmegaConf as om

import jax
import jax.numpy as jnp

from event_ssm.ssm import init_S5SSM
from event_ssm.seq_model import ClassificationModel

## Step 1: Load the model

In [None]:
# set config_path to the event ssm repository event_ssm/configs
config_path = "../../event_ssm/configs"

# Load configurations
with initialize(version_base=None, config_path=config_path):
    cfg = compose(config_name="base.yaml", overrides=["model=dvs/small"])

In [None]:
# Print the configuration
print(om.to_yaml(cfg.model))

In [None]:
# Set the random seed manually for reproducibility.
key = jax.random.PRNGKey(cfg.seed)
init_key, data_key = jax.random.split(key)

In [None]:
# Model initialisation in flax
ssm_init_fn = init_S5SSM(**cfg.model.ssm_init)

# number of classes (dummy)
classes = 10

# number of tokens for a DVS sensor of size 128x128
num_tokens = 128 * 128 * 2
model = ClassificationModel(
        ssm=ssm_init_fn,
        num_classes=10,
        num_embeddings=num_tokens,
        **cfg.model.ssm,
    )

EventSSM subsamples sequences in multiple stages to reduce the computational cost.
Let's investigate the total subsampling

In [None]:
total_subsampling = cfg.model.ssm.pooling_stride ** cfg.model.ssm.num_stages
print(f"Total subsampling: {total_subsampling}")

In [None]:
# initialize model parameters
x = jnp.zeros(total_subsampling, dtype=jnp.int32)
t = jnp.ones(total_subsampling)
variables = model.init(
        {"params": init_key},
        x, t, total_subsampling, False
    )

## Step 2: Run the model on random data
Generate a random list of integer tokens, jit compile the model and classify online.

In [None]:
# Generate random data
sequence_length = 2 ** 18
tokens = jax.random.randint(data_key, shape=(sequence_length,), minval=0, maxval=num_tokens)
timesteps = jnp.ones(sequence_length)
print("Sequence length:", sequence_length)

In [None]:
# jit compile the model
from functools import partial
model_function = jax.jit(partial(model.apply, length=total_subsampling, train=False))

# run model on the first total_subsampling tokens
model_function(variables, x[:total_subsampling], t[:total_subsampling])

In [None]:
from tqdm import tqdm
from time import time

start = time()
# loop through the model and measure the throughput in tokens per second
# JAX works on asynchronous dispatch, so we need to block until the computation is done to get a reasonable timing estimate
# Hence make sure to call jax.block_until_ready(output) on the final output of the loop
#
end = time()
print(f"Time taken: {end - start:.2f}s")
print(f"Events per second: {sequence_length / (end - start):.2f}")

## Step 3: Optimize the inference speed
We suggest to use [jax.lax.scan](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.scan.html) instead of a for loop to further speed up the inference.

In [None]:
def step(carry, inputs):
    x, t = inputs
    logits = model.apply(variables, x, t, total_subsampling, False)
    return None, logits
tokens = tokens.reshape(-1, total_subsampling)
timesteps = timesteps.reshape(-1, total_subsampling)

# run the scan: first jit-compiles and then iterates
logits = jax.lax.scan(step, init=None, xs=(tokens, timesteps))

In [None]:
# measure run-time
start = time()
_, logits = jax.block_until_ready(jax.lax.scan(step, init=None, xs=(tokens, timesteps)))
end = time()
print(f"Time taken: {end - start:.2f}s")
print(f"Events per second: {sequence_length / (end - start):.2f}")

How many events per second can you classify on your CPU?