# The Annotated S5

Welcome! In this tutorial we will be implementing Simplified Structured State Models for Sequence Modeling, or more simply, S5, in Deepmind's Haiku neural network library for JAX. I've chosen to highlight the S5 because, though building off the original S4 and its variations, as the extra 'S' suggests, it provides a nice simplification of this family of architectures. If you are unfamiliar with any of these previous models I highly recommend checking out Sasha Rush's excellent blog post on the subject: https://srush.github.io/annotated-s4/, and its corresponding repository: https://github.com/srush/annotated-s4, in which he implements the original S4 in Flax. These were both extremely helpful for me in figuring out how to implement these architectures and serves as the direct inspiration for this notebook, so please do check them out! I have, however, tried to make this notebook as self-contained and accessible as possible so if you feel like jumping straight in, feel free. I should also mention the original repository from the paper authors: https://github.com/lindermanlab/S5, which itself builds upon Sasha Rush's Flax implementation of the S4s and has been the guiding light for my implementation. I have chosen to write mine in Haiku rather than Flax simply because that is my personal framework of choice, so if you are unfamiliar with Haiku, hopefully this can serve as a little introduction to that as well!

I will begin by briefly addressing what has been seen by many to be the black box of this family of architectures, and that is the High-order Polynomial Projection Operators, or HiPPOs, this is a new recurrent memory mechanism first introduced in https://arxiv.org/abs/2008.07669 that replaces the standard memory unit found in GRUs and LSTMs with an online function approximation based on a series of orthogonal polynomials, thereby creating an internal representation/memory of the function being modeled. A short comparison to standard RNN memory units is given in the paper mentioned above as follows:

    "by stacking multiple units in parallel and choosing a specific update function, we obtain the GRU update cell as a special case. In contrast to HiPPO which uses one hidden feature and projects it onto high order polynomials, these models use many hidden features but only project them with degree 1. This view sheds light on these classic techniques by showing how they can be derived from first principles."

There's alot of fascinating components that go into these memory structures which I plan on writing another tutorial about, but, for now, I will leave it at the short description above and note that we will be initializing our HiPPOs with a little library called Hippox which I created for that very purpose, check out the source code if you feel like diving a bit deeper: https://github.com/JPGoodale/hippox (shameless self-promotion), and of course I highly recommend checking out the original papers and their code at https://github.com/HazyResearch/state-spaces.

In [None]:
# Let's go ahead and import hippox and the core JAX libraries we will be using:
import jax
import jax.numpy as jnp
import haiku as hk
from hippox.main import Hippo
from typing import Optional

First we'll define some helper functions for discretization and timescale initialization as the SSM equation is naturally continuous and must be made discrete to be unrolled as a linear recurrence like standard RNNs.

In [None]:
# Here we are just using the zero order hold method for its sheer simplicity, with A, B and delta_t denoting the
# state matrix, input matrix and change in timescale respectively.

def discretize(A, B, delta_t):
    Identity = jnp.ones(A.shape[0])
    _A = jnp.exp(A*delta_t)
    _B = (1/A * (_A-Identity)) * B
    return _A, _B

# This is a function used to initialize the trainable timescale parameter.
def log_step_initializer(dt_min=0.001, dt_max=0.1):
    def init(shape, dtype):
        uniform = hk.initializers.RandomUniform()
        return uniform(shape, dtype)*(jnp.log(dt_max) - jnp.log(dt_min)) + jnp.log(dt_min)
    return init

# Taken directly from https://github.com/deepmind/dm-haiku/blob/main/haiku/_src/recurrent.py
def add_batch(nest, batch_size: Optional[int]):
    broadcast = lambda x: jnp.broadcast_to(x, (batch_size,) + x.shape)
    return jax.tree_util.tree_map(broadcast, nest)

The linear SSM equation is as follows:
        $ x0(t) = Ax(t) + Bu(t) $
        $ y(t) = Cx(t) + Du(t) $

 We will now implement it as a recurrent Haiku module:

In [None]:
class LinearSSM(hk.RNNCore):
    def __init__(self, state_size: int, name: Optional[str] = None):
        super(LinearSSM, self).__init__(name=name)
        # We won't get into the basis measure families here, just note that they are basically just the
        # type of orthogonal polynomial we initialize with, the scaled Legendre measure (LegS) introduced
        # in the original HiPPO paper is pretty much the standard initialization and is what is used in the
        # main experiments in the S5 paper. I will also note that the Hippo class uses the diagonal representation
        # of the state matrix by default, as this has become the standard in neural SSMs since shown to be
        # equally effective as the diagonal plus low rank representation in https://arxiv.org/abs/2203.14343
        # and then formalized in https://arxiv.org/abs/2206.11893.

        _hippo = Hippo(state_size=state_size, basis_measure='legs')
        # Must be called for parameters to be initialized
        _hippo()

        # We register the real and imaginary components of the state matrix A as separate parameters because
        # they will have separate gradients in training, they will be conjoined back together and then discretized
        # but this will simply be backpropagated through as a transformation of the lambda real and imaginary
        # parameters (lambda is just what we call the diagonalized state matrix).

        self._lambda_real = hk.get_parameter(
            'lambda_real',
            shape=[state_size,],
            init=_hippo.lambda_initializer('real')
        )
        self._lambda_imag = hk.get_parameter(
            'lambda_imaginary',
            shape=[state_size,],
            init=_hippo.lambda_initializer('imaginary')
        )
        self._A = self._lambda_real + 1j * self._lambda_imag

       # For now, these initializations of the input and output matrices B and C match the S4D
        # parameterization for demonstration purposes, we will implement the S5 versions later.

        self._B = hk.get_parameter(
            'B',
            shape=[state_size,],
            init=_hippo.b_initializer()
        )
        self._C = hk.get_parameter(
            'C',
            shape=[state_size, 2],
            init=hk.initializers.RandomNormal(stddev=0.5**0.5)
        )
        self._output_matrix = self._C[..., 0] + 1j * self._C[..., 1]

        # This feed-through matrix basically acts as a residual connection.
        self._D = hk.get_parameter(
            'D',
            [1,],
            init=jnp.ones,
        )

        self._delta_t = hk.get_parameter(
            'delta_t',
            shape=[1,],
            init=log_step_initializer()
        )
        timescale = jnp.exp(self._delta_t)

        self._state_matrix, self._input_matrix = discretize(self._A, self._B, timescale)

    def __call__(self, inputs, prev_state):
        u = inputs[:, jnp.newaxis]
        new_state = self._state_matrix @ prev_state + self._input_matrix @ u
        y_s = self._output_matrix @ new_state
        out = y_s.reshape(-1).real + self._D * u
        return out, new_state

    def initial_state(self, batch_size: Optional[int] = None):
        state = jnp.zeros([self._state_size])
        if batch_size is not None:
            state = add_batch(state, batch_size)
        return state

You may notice that this looks an awful lot like a vanilla RNN cell, just with our special parameterization and without any activations, hence being a linear recurrence. I have initialized it as an instance of Haiku's RNN.Core abstract base class so that it can be unrolled using either the hk.dynamic_unroll or hk.static_unroll functions like any other recurrent module, however, if you are familiar with any of the S4 models you may be noticing that there's something crucial missing here: the convolutional representation. One of the key contributions of the S4 paper was its demonstration that the SSM ODE can be represented as either a linear recurrence, as above, for efficient inference, or as a global convolution for much faster training. That paper and the following papers then go on to present various kernels for efficiently computing this convolution with Fast Fourier Transforms, highly improving the computational efficiency of the model. Then why have we omitted them? Because the S5 architecture which we are about to explore simplifies all this by providing a purely recurrent representation in both training and inference, it does this by using a parallel recurrence that actually looks alot like a convolution itself! From the paper:

    "We use parallel scans to efficiently compute the states of a discretized linear SSM. Given a binary associative operator • (i.e. (a • b) • c = a • (b • c)) and a sequence of L elements [a1, a2, ..., aL], the scan operation (sometimes referred to as all-prefix-sum) returns the sequence [a1, (a1 • a2), ..., (a1 • a2 • ... • aL)]."

Let's see what this looks like in code, taken straight from the original author's implementation:

In [None]:
@jax.vmap
def binary_operator(q_i, q_j):
    A_i, b_i = q_i
    A_j, b_j = q_j
    return A_j * A_i, A_j * b_i + b_j

def parallel_scan(A, B, C, inputs):
    A_elements = A * jnp.ones((inputs.shape[0], A.shape[0]))
    Bu_elements = jax.vmap(lambda u: B @ u)(inputs)
    # Jax's built-in associative scan really comes in handy here as it executes a similar scan
    # operation as used in a normal recurrent unroll but is specifically tailored to fit an associative
    # operation like the one described in the paper.
    _, xs = jax.lax.associative_scan(binary_operator, (A_elements, Bu_elements))
    return jax.vmap(lambda x: (C @ x).real)(xs)

It's that simple! In the original S4 we would have had to apply an independent singe-input, single-output (SISO) SSM for each feature of the input sequence such as in this excerpt from Sasha Rush's Flax implementation:

In [None]:
def cloneLayer(layer):
    return flax.linen.vmap(
        layer,
        in_axes=1,
        out_axes=1,
        variable_axes={"params": 1, "cache": 1, "prime": 1},
        split_rngs={"params": True},
    )
SSMLayer = cloneLayer(SSMLayer)

Whereas in the S5 we process the entire sequence in one multi-input, multi-output (MIMO) layer.

Let's now rewrite our Module as a full S5 layer using this new method, we will be adding a few extra conditional arguments as well as changing some parameterization to match the original paper, but we'll walk through the reason for all these changes below.

In [None]:
# First we add a new helper function for the timescale initialization, this one just takes the previous
# log_step_initializer and stores a bunch of them in an array since our model is now multi-in, multi-out.

def init_log_steps(shape, dtype):
    H = shape[0]
    log_steps = []
    for i in range(H):
        log_step = log_step_initializer()(shape=(1,), dtype=dtype)
        log_steps.append(log_step)

    return jnp.array(log_steps)

In [None]:
# We will also rewrite our discretization for the MIMO context
def discretize(A, B, delta_t):
    Identity = jnp.ones(A.shape[0])
    _A = jnp.exp(A*delta_t)
    _B = (1/A * (_A-Identity))[..., None] * B
    return _A, _B

In [None]:
class S5(hk.Module):
    def __init__(self,
                 state_size: int,

                 # Now that we're MIMO we'll need to know the number of input features, commonly
                 # referred to as the dimension of the model.
                 d_model: int,

                 # We must also now specify the number of blocks that we will split our matrices
                 # into due to the MIMO context.
                 n_blocks: int,

                 # Short for conjugate symmetry, because our state matrix is complex we can half
                 # the size of it since complex numbers are a real and imaginary number joined together,
                 # this is not new to the S5, we just didn't mention it above.
                 conj_sym: bool = True,

                 # Another standard SSM argument that we omitted above for simplicity's sake,
                 # this forces the real part of the state matrix to be negative for better
                 # stability, especially in autoregressive tasks.
                 clip_eigns: bool = False,

                 # Like most RNNs, the S5 can be run in both directions if need be.
                 bidirectional: bool = False,

                 # Rescales delta_t for varying input resolutions, such as different audio
                 # sampling rates.
                 step_rescale: float = 1.0,
                 name: Optional[str] = None
    ):
        super(S5, self).__init__(name=name)
        self.conj_sym = conj_sym
        self.bidirectional = bidirectional

        # Note that the Hippo class takes conj_sym as an argument and will automatically half
        # the state size provided in its initialization, which is why we need to provide a local
        # state size that matches this for the shape argument in hk.get_parameter().

        if conj_sym:
            _state_size = state_size // 2
        else:
            _state_size = state_size

        # With block_diagonal set as True and the number of blocks provided, our Hippo class
        # will automatically handle this change of structure.

        _hippo = Hippo(
            state_size=state_size,
            basis_measure='legs',
            conj_sym=conj_sym,
            block_diagonal=True,
            n_blocks=n_blocks,
        )
        _hippo()

        self._lambda_real = hk.get_parameter(
            'lambda_real',
            [_state_size],
            init=_hippo.lambda_initializer('real')
        )
        self._lambda_imag = hk.get_parameter(
            'lambda_imaginary',
            [_state_size],
            init=_hippo.lambda_initializer('imaginary')
        )
        if clip_eigns:
            self._lambda = jnp.clip(self._lambda_real, None, -1e-4) + 1j * self._lambda_imag
        else:
            self._A = self._lambda_real + 1j * self._lambda_imag

        # If you recall, I mentioned above that we are automatically using a diagonalized version of
        # the HiPPO state matrix rather than the pure one, due to it being very hard to efficiently
        # compute. I will now go into a little more detail on how this diagonal representation is
        # derived, as it is important for how we initialize the input and output matrices. The diagonal
        # decomposition of our state matrix is based on equivalence relation on the SSM parameters:
        # (A, B, C) ∼ (V−1AV ,V−1B, CV) with V being the eigenvector of our original A matrix and V-1
        # being the inverse eigenvector. The Hippo class has already performed the decomposition of A
        # into (V-1AV) automatically, but we have not yet performed the decomposition of B and C, we will
        # use the eigenvector_transform class method for that below, but first we must initialize B and C
        # as normal distributions, lecun normal and truncated normal respectively. I will note that there
        # are a few other options provided for C in the original repository but, to keep it simple, we will
        # just use one here.

        b_init = hk.initializers.VarianceScaling()
        b_shape = [state_size, d_model]
        b_init = b_init(b_shape, dtype=jnp.complex64)
        self._B = hk.get_parameter(
            'B',
            [_state_size, d_model, 2],
            init=_hippo.eigenvector_transform(b_init,  concatenate=True),
        )
        B = self._B[..., 0] + 1j * self._B[..., 1]

        c_init = hk.initializers.TruncatedNormal()
        c_shape = [d_model, state_size, 2]
        c_init = c_init(c_shape, dtype=jnp.complex64)
        self._C = hk.get_parameter(
            'C',
            [d_model, _state_size, 2],
            init=_hippo.eigenvector_transform(c_init, inverse=False, concatenate=True),
        )
        # We need two output heads if bidirectional is True.
        if bidirectional:
            self._C2 = hk.get_parameter(
                'C2',
                [d_model, _state_size, 2],
                init=_hippo.eigenvector_transform(c_init, inverse=False, concatenate=True),
            )
            C1 = self._C[..., 0] + 1j * self._C[..., 1]
            C2 = self._C2[..., 0] + 1j * self._C2[..., 1]
            self._output_matrix = jnp.concatenate((C1, C2), axis=-1)
        else:
            self._output_matrix = self._C[..., 0] + 1j * self._C[..., 1]

        self._D = hk.get_parameter(
            'D',
            [d_model,],
            init=hk.initializers.RandomNormal(stddev=1.0)
        )

        self._delta_t = hk.get_parameter(
            'delta_T',
            [_state_size, 1],
            init=init_log_steps
        )
        timescale = step_rescale * jnp.exp(self._delta_t[:, 0])

        # We could also use the bilinear discretization method, but we'll just stick to zoh for now.
        self._state_matrix, self._input_matrix = discretize(self._A, B, timescale)


    def __call__(self, inputs):
        # Note that this is the exact same function as presented above just with alternate procedures
        # depending on the bidirectional and conjugate symmetry arguments

        A_elements = self._state_matrix * jnp.ones((inputs.shape[0], self._state_matrix.shape[0]))
        Bu_elements = jax.vmap(lambda u: self._input_matrix @ u)(inputs)

        _, xs = jax.lax.associative_scan(binary_operator, (A_elements, Bu_elements))

        if self.bidirectional:
            _, xs2 = jax.lax.associative_scan(binary_operator,
                                          (A_elements, Bu_elements),
                                          reverse=True)
            xs = jnp.concatenate((xs, xs2), axis=-1)

        if self.conj_sym:
            ys = jax.vmap(lambda x: 2*(self._output_matrix @ x).real)(xs)
        else:
            ys = jax.vmap(lambda x: (self._output_matrix @ x).real)(xs)

        Du = jax.vmap(lambda u: self._D * u)(inputs)

        return ys + Du

There we have it, a complete S5 layer! Now let's form a block around it using a structure very similar to a transformer block with a Gated Linear Unit (GLU).

In [None]:
import dataclasses

@dataclasses.dataclass
class S5Block(hk.Module):
    ssm: S5
    d_model: int
    dropout_rate: float
    prenorm: bool
    istraining: bool = True
    name: Optional[str] = None

    def __post_init__(self):
        super(S5Block, self).__post_init__()
        # We could use either layer norm or batch norm.
        self._norm = hk.LayerNorm(axis=-1, create_scale=True, create_offset=True)
        self._linear = hk.Linear(self.d_model)

    def __call__(self, x):
        skip = x
        if self.prenorm:
            x = self._norm(x)

        x = self.ssm(x)
        # There are a couple of other GLU patterns we could use here, but once again I have chosen
        # one semi-arbitrarily to avoid cluttering our module with if statements.
        x1 = hk.dropout(hk.next_rng_key(), self.dropout_rate, jax.nn.gelu(x))
        x = x * jax.nn.sigmoid(self._linear(x1))
        x = hk.dropout(hk.next_rng_key(), self.dropout_rate, x)

        x = skip + x
        if not self.prenorm:
            x = self._norm(x)

        return x

Now let's make a stack of these blocks:

In [None]:
@dataclasses.dataclass
class S5Stack(hk.Module):
    ssm: S5
    d_model: int
    n_layers: int
    dropout_rate: float
    prenorm: bool
    istraining: bool = True
    name: Optional[str] = None

    def __post_init__(self):
        super(S5Stack, self).__post_init__(name=self.name)
        self._encoder = hk.Linear(self.d_model)
        self._layers = [
            S5Block(
                ssm=self.ssm,
                d_model=self.d_model,
                dropout_rate=self.dropout_rate,
                istraining=self.istraining,
                prenorm=self.prenorm,
            )
            for _ in range(self.n_layers)
        ]

    def __call__(self, x):
        x = self._encoder(x)
        for layer in self._layers:
            x = layer(x)
        return x

And finally a classifier on top, as we will be doing a simple classification task for this tutorial:

In [None]:
@dataclasses.dataclass
class S5Classifier(hk.Module):
    ssm: S5
    d_model: int
    d_output: int
    n_layers: int
    dropout_rate: float
    mode: str = 'pool'
    prenorm: bool = True
    istraining: bool = True
    name: Optional[str] = None

    def __post_init__(self):
        super(S5Classifier, self).__post_init__(name=self.name)
        self._encoder = S5Stack(
            ssm=self.ssm,
            d_model=self.d_model,
            n_layers=self.n_layers,
            dropout_rate=self.dropout_rate,
            istraining=self.istraining,
            prenorm=self.prenorm,
        )
        self._decoder = hk.Linear(self.d_output)

    def __call__(self, x):
        x = self._encoder(x)
        if self.mode == 'pool':
            x = jnp.mean(x, axis=0)
        elif self.mode == 'last':
            x = x[-1]
        else:
            raise NotImplementedError("Mode must be in ['pool', 'last]")
        x = self._decoder(x)
        return jax.nn.log_softmax(x, axis=-1)

Our model is now ready for training! Let's load some data. We will use the classic MNIST benchmark, but unlike the standard CNNs usually used for the task, we will be processing the images as a 1-dimensional sequence.

In [None]:
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from typing import NamedTuple

# Code taken directly from https://github.com/srush/annotated-s4
def create_mnist_classification_dataset(batch_size=128):
    print("[*] Generating MNIST Classification Dataset...")

    # The usual 28*28 format of the images is now being flattened into a sequence of 784 pixels.
    SEQ_LENGTH, N_CLASSES, IN_DIM = 784, 10, 1
    tf = transforms.Compose(
        [
            transforms.ToTensor(),
            transforms.Normalize(mean=0.5, std=0.5),
            transforms.Lambda(lambda x: x.view(IN_DIM, SEQ_LENGTH).t()),
        ]
    )

    train = torchvision.datasets.MNIST(
        "./data", train=True, download=True, transform=tf
    )
    test = torchvision.datasets.MNIST(
        "./data", train=False, download=True, transform=tf
    )

    trainloader = DataLoader(
        train, batch_size, shuffle=True
    )
    testloader = DataLoader(
        test, batch_size, shuffle=False
    )

    return trainloader, testloader, N_CLASSES, SEQ_LENGTH, IN_DIM

Datasets = {
    "mnist-classification": create_mnist_classification_dataset,
}

# Simple class for storing our dataset parameters.
class Dataset(NamedTuple):
    trainloader: DataLoader
    testloader: DataLoader
    n_classes: int
    seq_length: int
    d_input: int
    classification: bool

def create_dataset(dataset: str, batch_size: int) -> Dataset:
    classification = 'classification' in dataset
    dataset_init = Datasets[dataset]
    trainloader, testloader, n_classes, seq_length, d_input = dataset_init(batch_size)
    return Dataset(trainloader, testloader, n_classes, seq_length, d_input, classification)

Next we will set some hyperparameters:

In [None]:
# I have halved the size of most of these parameters from what is used in the
# paper for easier compute.
STATE_SIZE: int = 128
D_MODEL: int = 64
N_LAYERS: int = 3
N_BLOCKS: int = 4
EPOCHS: int = 50
BATCH_SIZE: int = 64
DROPOUT_RATE: float = 0.1
LEARNING_RATE: float = 1e-3
SEED = 0

To keep things simple for this example, we will just be using a plain adam optimizer, but the results can be highly improved with extra techniques such as cosine annealing, learning rate schedules and weight decay.

In [None]:
import optax

class TrainingState(NamedTuple):
    params: hk.Params
    opt_state: optax.OptState
    rng_key: jnp.ndarray

optim = optax.adam(LEARNING_RATE)


Now for the loss and update functions:

In [None]:
from functools import partial
from typing import Tuple, MutableMapping, Any
_Metrics = MutableMapping[str, Any]

@partial(jnp.vectorize, signature="(c),()->()")
def cross_entropy_loss(logits, label) -> jnp.ndarray:
    one_hot_label = jax.nn.one_hot(label, num_classes=logits.shape[0])
    return -jnp.sum(one_hot_label * logits)


@partial(jnp.vectorize, signature="(c),()->()")
def compute_accuracy(logits, label):
    return jnp.argmax(logits) == label


@partial(jax.jit, static_argnums=(3, 4))
def update(
        state: TrainingState,
        inputs: jnp.ndarray,
        targets: jnp.ndarray,
        model: hk.transform,
        classification: bool,
) -> Tuple[TrainingState, _Metrics]:

    rng_key, next_rng_key = jax.random.split(state.rng_key)

    def loss_fn(params):
        logits = model.apply(params, rng_key, inputs)
        _loss = jnp.mean(cross_entropy_loss(logits, targets))
        _accuracy = jnp.mean(compute_accuracy(logits, targets))
        return _loss, _accuracy

    if not classification:
        targets = inputs[:, :, 0]

    grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
    (loss, accuracy), gradients = grad_fn(state.params)
    updates, new_opt_state = optim.update(gradients, state.opt_state, state.params)
    new_params = optax.apply_updates(state.params, updates)

    new_state = TrainingState(
        params=new_params,
        opt_state=new_opt_state,
        rng_key=next_rng_key,
    )
    metrics = {
        'loss': loss,
        'accuracy': accuracy
    }

    return new_state, metrics


@partial(jax.jit, static_argnums=(3, 4))
def evaluate(
        state: TrainingState,
        inputs: jnp.ndarray,
        targets: jnp.ndarray,
        model: hk.transform,
        classification,
) -> _Metrics:

    rng_key, _ = jax.random.split(state.rng_key, 2)

    if not classification:
        targets = inputs[:, :, 0]

    logits = model.apply(state.params, rng_key, inputs)
    loss = jnp.mean(cross_entropy_loss(logits, targets))
    accuracy = jnp.mean(compute_accuracy(logits, targets))

    metrics = {
        'loss': loss,
        'accuracy': accuracy
    }

    return metrics


Now we call these update and evaluate functions in their respective epochs:

In [None]:
from tqdm import tqdm

def training_epoch(
        state: TrainingState,
        trainloader: DataLoader,
        model: hk.transform,
        classification: bool = False,
) -> Tuple[TrainingState, jnp.ndarray, jnp.ndarray]:

    batch_losses, batch_accuracies = [], []
    for batch_idx, (inputs, targets) in enumerate(tqdm(trainloader)):
        inputs = jnp.array(inputs.numpy())
        targets = jnp.array(targets.numpy())
        state, metrics = update(
            state, inputs, targets,
            model, classification
        )
        batch_losses.append(metrics['loss'])
        batch_accuracies.append(metrics['accuracy'])

    return (
        state,
        jnp.mean(jnp.array(batch_losses)),
        jnp.mean(jnp.array(batch_accuracies))
    )


def validation_epoch(
        state: TrainingState,
        testloader: DataLoader,
        model: hk.transform,
        classification: bool = True,
) -> Tuple[jnp.ndarray, jnp.ndarray]:

    losses, accuracies = [], []
    for batch_idx, (inputs, targets) in enumerate(tqdm(testloader)):
        inputs = jnp.array(inputs.numpy())
        targets = jnp.array(targets.numpy())
        metrics = evaluate(
            state, inputs, targets,
            model, classification
        )
        losses.append(metrics['loss'])
        accuracies.append(metrics['accuracy'])

    return jnp.mean(jnp.array(losses)), jnp.mean(jnp.array(accuracies))

In [None]:
import torch
# Set random number generators
torch.random.manual_seed(SEED)
key = jax.random.PRNGKey(SEED)
rng, init_rng = jax.random.split(key)

In [None]:
# Create our dataset and dummy data for initialization of the model's params.
ds = create_dataset('mnist-classification', BATCH_SIZE)
init_data = jnp.array(next(iter(ds.trainloader))[0].numpy())

In [None]:
# In Haiku, we have to call our model inside a transformed function using hk.transform for it to become
# functionally pure and compatible with essential JAX functions like jax.grad(). Here we are using hk.vmap
# instead of jax.vmap because we are calling it from within a hk.transform.
@hk.transform
def forward(x) -> hk.transform:
    neural_net = S5Classifier(
        S5(
            STATE_SIZE,
            D_MODEL,
            N_BLOCKS,
        ),
        D_MODEL,
        ds.n_classes,
        N_LAYERS,
        DROPOUT_RATE,
    )
    return hk.vmap(neural_net, split_rng=False)(x)

In [None]:
# Set state
initial_params = forward.init(init_rng, init_data)
initial_opt_state = optim.init(initial_params)

state = TrainingState(
    params=initial_params,
    opt_state=initial_opt_state,
    rng_key=rng
)

And finally our training loop!

In [None]:
for epoch in range(EPOCHS):
    print(f"[*] Training Epoch {epoch + 1}...")
    state, training_loss, training_accuracy = training_epoch(
        state,
        ds.trainloader,
        forward,
        ds.classification
    )
    print(f"[*] Running Epoch {epoch + 1} Validation...")
    test_loss, test_accuracy = validation_epoch(
        state,
        ds.testloader,
        forward,
        ds.classification
    )
    print(f"\n=>> Epoch {epoch + 1} Metrics ===")
    print(
        f"\tTrain Loss: {training_loss:.5f} -- Train Accuracy:"
        f" {training_accuracy:.4f}\n\t Test Loss: {test_loss:.5f} --  Test"
        f" Accuracy: {test_accuracy:.4f}"
    )