<a href="https://colab.research.google.com/github/Sat-A/s5-jax/blob/main/s5-tokenised.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# S5-Tokenised

The goal of this exercise is to adapt an annotated implementation of S5 https://github.com/JPGoodale/annotated-s5 that was made for classification tasks to regression tasks.

# Runtime Setup
Ensure runtime is set to GPU to ensure gpu parallelisation speedup

In [46]:
!pip install dm-haiku
!pip install hippox



In [47]:
# Core JAX, Haiku, and Optax
import jax
import jax.numpy as jnp
import haiku as hk
import optax
from functools import partial
from typing import NamedTuple, Optional, Tuple, MutableMapping, Any

# Imports from your original notebook
import torch
from hippox.main import Hippo
from tqdm import tqdm
import dataclasses
import numpy asnp # Using numpy for data generation

First we'll define some helper functions for discretization and timescale initialization as the SSM equation is naturally continuous and must be made discrete to be unrolled as a linear recurrence like standard RNNs.

In [48]:
# Here we are just using the zero order hold method for its sheer simplicity, with A, B and delta_t denoting the
# state matrix, input matrix and change in timescale respectively.

def discretize(A, B, delta_t):
    Identity = jnp.ones(A.shape[0])
    _A = jnp.exp(A*delta_t)
    _B = (1/A * (_A-Identity)) * B
    return _A, _B

# This is a function used to initialize the trainable timescale parameter.
def log_step_initializer(dt_min=0.001, dt_max=0.1):
    def init(shape, dtype):
        uniform = hk.initializers.RandomUniform()
        return uniform(shape, dtype)*(jnp.log(dt_max) - jnp.log(dt_min)) + jnp.log(dt_min)
    return init

# Taken directly from https://github.com/deepmind/dm-haiku/blob/main/haiku/_src/recurrent.py
def add_batch(nest, batch_size: Optional[int]):
    broadcast = lambda x: jnp.broadcast_to(x, (batch_size,) + x.shape)
    return jax.tree_util.tree_map(broadcast, nest)

The linear SSM equation is as follows:
$$ x_0(t) = Ax(t) + Bu(t) $$
$$ y(t) = Cx(t) + Du(t) $$

 We will now implement it as a recurrent Haiku module:

In [49]:
class LinearSSM(hk.RNNCore):
    def __init__(self, state_size: int, name: Optional[str] = None):
        super(LinearSSM, self).__init__(name=name)
        # We won't get into the basis measure families here, just note that they are basically just the
        # type of orthogonal polynomial we initialize with, the scaled Legendre measure (LegS) introduced
        # in the original HiPPO paper is pretty much the standard initialization and is what is used in the
        # main experiments in the S5 paper. I will also note that the Hippo class uses the diagonal representation
        # of the state matrix by default, as this has become the standard in neural SSMs since shown to be
        # equally effective as the diagonal plus low rank representation in https://arxiv.org/abs/2203.14343
        # and then formalized in https://arxiv.org/abs/2206.11893.

        _hippo = Hippo(state_size=state_size, basis_measure='legs')
        # Must be called for parameters to be initialized
        _hippo()

        # We register the real and imaginary components of the state matrix A as separate parameters because
        # they will have separate gradients in training, they will be conjoined back together and then discretized
        # but this will simply be backpropagated through as a transformation of the lambda real and imaginary
        # parameters (lambda is just what we call the diagonalized state matrix).

        self._lambda_real = hk.get_parameter(
            'lambda_real',
            shape=[state_size,],
            init=_hippo.lambda_initializer('real')
        )
        self._lambda_imag = hk.get_parameter(
            'lambda_imaginary',
            shape=[state_size,],
            init=_hippo.lambda_initializer('imaginary')
        )
        self._A = self._lambda_real + 1j * self._lambda_imag

       # For now, these initializations of the input and output matrices B and C match the S4D
        # parameterization for demonstration purposes, we will implement the S5 versions later.

        self._B = hk.get_parameter(
            'B',
            shape=[state_size,],
            init=_hippo.b_initializer()
        )
        self._C = hk.get_parameter(
            'C',
            shape=[state_size, 2],
            init=hk.initializers.RandomNormal(stddev=0.5**0.5)
        )
        self._output_matrix = self._C[..., 0] + 1j * self._C[..., 1]

        # This feed-through matrix basically acts as a residual connection.
        self._D = hk.get_parameter(
            'D',
            [1,],
            init=jnp.ones,
        )

        self._delta_t = hk.get_parameter(
            'delta_t',
            shape=[1,],
            init=log_step_initializer()
        )
        timescale = jnp.exp(self._delta_t)

        self._state_matrix, self._input_matrix = discretize(self._A, self._B, timescale)

    def __call__(self, inputs, prev_state):
        u = inputs[:, jnp.newaxis]
        new_state = self._state_matrix @ prev_state + self._input_matrix @ u
        y_s = self._output_matrix @ new_state
        out = y_s.reshape(-1).real + self._D * u
        return out, new_state

    def initial_state(self, batch_size: Optional[int] = None):
        state = jnp.zeros([self._state_size])
        if batch_size is not None:
            state = add_batch(state, batch_size)
        return state

You may notice that this looks an awful lot like a vanilla RNN cell, just with our special parameterization and without any activations, hence being a linear recurrence. I have initialized it as an instance of Haiku's RNN.Core abstract base class so that it can be unrolled using either the hk.dynamic_unroll or hk.static_unroll functions like any other recurrent module, however, if you are familiar with any of the S4 models you may be noticing that there's something crucial missing here: the convolutional representation. One of the key contributions of the S4 paper was its demonstration that the SSM ODE can be represented as either a linear recurrence, as above, for efficient inference, or as a global convolution for much faster training. That paper and the following papers then go on to present various kernels for efficiently computing this convolution with Fast Fourier Transforms, highly improving the computational efficiency of the model. Then why have we omitted them? Because the S5 architecture which we are about to explore simplifies all this by providing a purely recurrent representation in both training and inference, it does this by using a parallel recurrence that actually looks alot like a convolution itself! From the paper:

    "We use parallel scans to efficiently compute the states of a discretized linear SSM. Given a binary associative operator • (i.e. (a • b) • c = a • (b • c)) and a sequence of L elements [a1, a2, ..., aL], the scan operation (sometimes referred to as all-prefix-sum) returns the sequence [a1, (a1 • a2), ..., (a1 • a2 • ... • aL)]."

Let's see what this looks like in code, taken straight from the original author's implementation:

In [50]:
@jax.vmap
def binary_operator(q_i, q_j):
    A_i, b_i = q_i
    A_j, b_j = q_j
    return A_j * A_i, A_j * b_i + b_j

def parallel_scan(A, B, C, inputs):
    A_elements = A * jnp.ones((inputs.shape[0], A.shape[0]))
    Bu_elements = jax.vmap(lambda u: B @ u)(inputs)
    # Jax's built-in associative scan really comes in handy here as it executes a similar scan
    # operation as used in a normal recurrent unroll but is specifically tailored to fit an associative
    # operation like the one described in the paper.
    _, xs = jax.lax.associative_scan(binary_operator, (A_elements, Bu_elements))
    return jax.vmap(lambda x: (C @ x).real)(xs)

It's that simple! In the original S4 we would have had to apply an independent singe-input, single-output (SISO) SSM for each feature of the input sequence such as in this excerpt from Sasha Rush's Flax implementation:



```python
def cloneLayer(layer):
    return flax.linen.vmap(
        layer,
        in_axes=1,
        out_axes=1,
        variable_axes={"params": 1, "cache": 1, "prime": 1},
        split_rngs={"params": True},
    )
SSMLayer = cloneLayer(SSMLayer)
```



Whereas in the S5 we process the entire sequence in one multi-input, multi-output (MIMO) layer.

Let's now rewrite our Module as a full S5 layer using this new method, we will be adding a few extra conditional arguments as well as changing some parameterization to match the original paper, but we'll walk through the reason for all these changes below.

In [51]:
# First we add a new helper function for the timescale initialization, this one just takes the previous
# log_step_initializer and stores a bunch of them in an array since our model is now multi-in, multi-out.

def init_log_steps(shape, dtype):
    H = shape[0]
    log_steps = []
    for i in range(H):
        log_step = log_step_initializer()(shape=(1,), dtype=dtype)
        log_steps.append(log_step)

    return jnp.array(log_steps)

In [52]:
# We will also rewrite our discretization for the MIMO context
def discretize(A, B, delta_t):
    Identity = jnp.ones(A.shape[0])
    _A = jnp.exp(A*delta_t)
    _B = (1/A * (_A-Identity))[..., None] * B
    return _A, _B

In [53]:
class S5(hk.Module):
    def __init__(self,
                 state_size: int,

                 # Now that we're MIMO we'll need to know the number of input features, commonly
                 # referred to as the dimension of the model.
                 d_model: int,

                 # We must also now specify the number of blocks that we will split our matrices
                 # into due to the MIMO context.
                 n_blocks: int,

                 # Short for conjugate symmetry, because our state matrix is complex we can half
                 # the size of it since complex numbers are a real and imaginary number joined together,
                 # this is not new to the S5, we just didn't mention it above.
                 conj_sym: bool = True,

                 # Another standard SSM argument that we omitted above for simplicity's sake,
                 # this forces the real part of the state matrix to be negative for better
                 # stability, especially in autoregressive tasks.
                 clip_eigns: bool = False,

                 # Like most RNNs, the S5 can be run in both directions if need be.
                 bidirectional: bool = False,

                 # Rescales delta_t for varying input resolutions, such as different audio
                 # sampling rates.
                 step_rescale: float = 1.0,
                 name: Optional[str] = None
    ):
        super(S5, self).__init__(name=name)
        self.conj_sym = conj_sym
        self.bidirectional = bidirectional

        # Note that the Hippo class takes conj_sym as an argument and will automatically half
        # the state size provided in its initialization, which is why we need to provide a local
        # state size that matches this for the shape argument in hk.get_parameter().

        if conj_sym:
            _state_size = state_size // 2
        else:
            _state_size = state_size

        # With block_diagonal set as True and the number of blocks provided, our Hippo class
        # will automatically handle this change of structure.

        _hippo = Hippo(
            state_size=state_size,
            basis_measure='legs',
            conj_sym=conj_sym,
            block_diagonal=True,
            n_blocks=n_blocks,
        )
        _hippo()

        self._lambda_real = hk.get_parameter(
            'lambda_real',
            [_state_size],
            init=_hippo.lambda_initializer('real')
        )
        self._lambda_imag = hk.get_parameter(
            'lambda_imaginary',
            [_state_size],
            init=_hippo.lambda_initializer('imaginary')
        )
        if clip_eigns:
            self._lambda = jnp.clip(self._lambda_real, None, -1e-4) + 1j * self._lambda_imag
        else:
            self._A = self._lambda_real + 1j * self._lambda_imag

        # If you recall, I mentioned above that we are automatically using a diagonalized version of
        # the HiPPO state matrix rather than the pure one, due to it being very hard to efficiently
        # compute. I will now go into a little more detail on how this diagonal representation is
        # derived, as it is important for how we initialize the input and output matrices. The diagonal
        # decomposition of our state matrix is based on equivalence relation on the SSM parameters:
        # (A, B, C) ∼ (V−1AV ,V−1B, CV) with V being the eigenvector of our original A matrix and V-1
        # being the inverse eigenvector. The Hippo class has already performed the decomposition of A
        # into (V-1AV) automatically, but we have not yet performed the decomposition of B and C, we will
        # use the eigenvector_transform class method for that below, but first we must initialize B and C
        # as normal distributions, lecun normal and truncated normal respectively. I will note that there
        # are a few other options provided for C in the original repository but, to keep it simple, we will
        # just use one here.

        b_init = hk.initializers.VarianceScaling()
        b_shape = [state_size, d_model]
        b_init = b_init(b_shape, dtype=jnp.complex64)
        self._B = hk.get_parameter(
            'B',
            [_state_size, d_model, 2],
            init=_hippo.eigenvector_transform(b_init,  concatenate=True),
        )
        B = self._B[..., 0] + 1j * self._B[..., 1]

        c_init = hk.initializers.TruncatedNormal()
        c_shape = [d_model, state_size, 2]
        c_init = c_init(c_shape, dtype=jnp.complex64)
        self._C = hk.get_parameter(
            'C',
            [d_model, _state_size, 2],
            init=_hippo.eigenvector_transform(c_init, inverse=False, concatenate=True),
        )
        # We need two output heads if bidirectional is True.
        if bidirectional:
            self._C2 = hk.get_parameter(
                'C2',
                [d_model, _state_size, 2],
                init=_hippo.eigenvector_transform(c_init, inverse=False, concatenate=True),
            )
            C1 = self._C[..., 0] + 1j * self._C[..., 1]
            C2 = self._C2[..., 0] + 1j * self._C2[..., 1]
            self._output_matrix = jnp.concatenate((C1, C2), axis=-1)
        else:
            self._output_matrix = self._C[..., 0] + 1j * self._C[..., 1]

        self._D = hk.get_parameter(
            'D',
            [d_model,],
            init=hk.initializers.RandomNormal(stddev=1.0)
        )

        self._delta_t = hk.get_parameter(
            'delta_T',
            [_state_size, 1],
            init=init_log_steps
        )
        timescale = step_rescale * jnp.exp(self._delta_t[:, 0])

        # We could also use the bilinear discretization method, but we'll just stick to zoh for now.
        self._state_matrix, self._input_matrix = discretize(self._A, B, timescale)


    def __call__(self, inputs):
        # Note that this is the exact same function as presented above just with alternate procedures
        # depending on the bidirectional and conjugate symmetry arguments

        A_elements = self._state_matrix * jnp.ones((inputs.shape[0], self._state_matrix.shape[0]))
        Bu_elements = jax.vmap(lambda u: self._input_matrix @ u)(inputs)

        _, xs = jax.lax.associative_scan(binary_operator, (A_elements, Bu_elements))

        if self.bidirectional:
            _, xs2 = jax.lax.associative_scan(binary_operator,
                                          (A_elements, Bu_elements),
                                          reverse=True)
            xs = jnp.concatenate((xs, xs2), axis=-1)

        if self.conj_sym:
            ys = jax.vmap(lambda x: 2*(self._output_matrix @ x).real)(xs)
        else:
            ys = jax.vmap(lambda x: (self._output_matrix @ x).real)(xs)

        Du = jax.vmap(lambda u: self._D * u)(inputs)

        return ys + Du

There we have it, a complete S5 layer! Now let's form a block around it using a structure very similar to a transformer block with a Gated Linear Unit (GLU).

In [54]:
import dataclasses

@dataclasses.dataclass
class S5Block(hk.Module):
    ssm: S5
    d_model: int
    dropout_rate: float
    prenorm: bool
    istraining: bool = True
    name: Optional[str] = None

    def __post_init__(self):
        super(S5Block, self).__post_init__()
        # We could use either layer norm or batch norm.
        self._norm = hk.LayerNorm(axis=-1, create_scale=True, create_offset=True)
        self._linear = hk.Linear(self.d_model)

    def __call__(self, x):
        skip = x
        if self.prenorm:
            x = self._norm(x)

        x = self.ssm(x)
        # There are a couple of other GLU patterns we could use here, but once again I have chosen
        # one semi-arbitrarily to avoid cluttering our module with if statements.
        x1 = hk.dropout(hk.next_rng_key(), self.dropout_rate, jax.nn.gelu(x))
        x = x * jax.nn.sigmoid(self._linear(x1))
        x = hk.dropout(hk.next_rng_key(), self.dropout_rate, x)

        x = skip + x
        if not self.prenorm:
            x = self._norm(x)

        return x

Now let's make a stack of these blocks:

In [55]:
@dataclasses.dataclass
class S5Stack(hk.Module):
    ssm: S5
    d_model: int
    n_layers: int
    dropout_rate: float
    prenorm: bool
    istraining: bool = True
    name: Optional[str] = None

    def __post_init__(self):
        super(S5Stack, self).__post_init__(name=self.name)
        self._encoder = hk.Linear(self.d_model)
        self._layers = [
            S5Block(
                ssm=self.ssm,
                d_model=self.d_model,
                dropout_rate=self.dropout_rate,
                istraining=self.istraining,
                prenorm=self.prenorm,
            )
            for _ in range(self.n_layers)
        ]

    def __call__(self, x):
        x = self._encoder(x)
        for layer in self._layers:
            x = layer(x)
        return x

# Implementing Tokenised prediction in S5

In [56]:
class StockPriceTokeniser:
    """
    A simple quantisation tokeniser.
    Converts continuous stock prices into discrete integer tokens.
    """
    def __init__(self, vocab_size: int):
        self.vocab_size = vocab_size
        self.bins = None
        self.min_price = 0.0
        self.max_price = 1.0

    def fit(self, data: jnp.ndarray):
        """Fits the tokeniser to the data to find min/max prices."""
        self.min_price = data.min()
        self.max_price = data.max()
        # Create vocab_size-1 thresholds for binning
        self.bins = jnp.linspace(self.min_price, self.max_price, self.vocab_size - 1)
        print(f"[*] Tokeniser fitted: min={self.min_price:.2f}, max={self.max_price:.2f}")

    def encode(self, prices: jnp.ndarray) -> jnp.ndarray:
        """Converts a sequence of prices into token IDs."""
        if self.bins is None:
            raise ValueError("Tokeniser must be fitted before encoding.")
        # jnp.digitize finds which bin each price belongs to
        # The result is an integer ID from 0 to vocab_size-1
        token_ids = jnp.digitize(prices, self.bins)
        return token_ids.astype(jnp.int32)

    def decode(self, token_ids: jnp.ndarray) -> jnp.ndarray:
        """Converts a sequence of token IDs back into approximate prices (bin centres)."""
        if self.bins is None:
            raise ValueError("Tokeniser must be fitted before decoding.")

        # Create bin centres
        bin_edges = jnp.concatenate([
            jnp.array([self.min_price]),
            self.bins,
            jnp.array([self.max_price])
        ])
        bin_centres = (bin_edges[:-1] + bin_edges[1:]) / 2.0

        # Map IDs to bin centres
        prices = bin_centres[token_ids]
        return prices

In [57]:
# %%
@dataclasses.dataclass
class S5Forecaster(hk.Module):
    ssm: S5
    d_model: int
    n_layers: int
    vocab_size: int  # New: Number of tokens in our price vocabulary
    dropout_rate: float
    prenorm: bool = True
    istraining: bool = True
    name: Optional[str] = None

    def __post_init__(self):
        super(S5Forecaster, self).__post_init__(name=self.name)

        # 1. Token Embedding Layer
        self._embedding = hk.Embed(
            vocab_size=self.vocab_size,
            embed_dim=self.d_model
        )

        # 2. S5 Stack
        self._s5_stack = S5Stack(
            ssm=self.ssm,
            d_model=self.d_model,
            n_layers=self.n_layers,
            dropout_rate=self.dropout_rate,
            istraining=self.istraining,
            prenorm=self.prenorm,
        )

        # 3. Decoder Head (projects back to vocabulary)
        self._decoder = hk.Linear(self.vocab_size)

    def __call__(self, token_ids: jnp.ndarray) -> jnp.ndarray:
        """
        Input: sequence of token IDs, shape [SeqLen]
        Output: sequence of logits, shape [SeqLen, VocabSize]
        """
        # 1. Embed tokens
        x = self._embedding(token_ids)  # [SeqLen] -> [SeqLen, d_model]

        # 2. Process with S5
        x = self._s5_stack(x)           # [SeqLen, d_model] -> [SeqLen, d_model]

        # 3. Decode to logits
        logits = self._decoder(x)       # [SeqLen, d_model] -> [SeqLen, VocabSize]

        return logits

In [58]:
# %%
def generate_synthetic_stock_data(num_series, seq_length):
    """Generates a batch of noisy sine waves as dummy stock data."""
    t = jnp.linspace(0, 4 * jnp.pi, seq_length)
    # Generate a few different frequencies
    base_freqs = jnp.array([1.0, 2.0, 4.0])
    series = []
    key = jax.random.PRNGKey(42)

    for i in range(num_series):
        key, subkey1, subkey2, subkey3, subkey4 = jax.random.split(key, 5)

        # Random phase, amplitude, and noise
        phase = jax.random.uniform(subkey1) * 2 * jnp.pi
        amplitude = jax.random.uniform(subkey2) * 0.5 + 0.5
        noise = jax.random.normal(subkey3, shape=(seq_length,)) * 0.1

        # Pick a base frequency
        freq = jax.random.choice(subkey4, base_freqs)

        # Combine into a series and add a trend
        trend = jnp.linspace(0, 1.0, seq_length) * (jax.random.uniform(key)-0.5) * 2
        s = 100 + amplitude * jnp.sin(freq * t + phase) + noise + trend
        series.append(s)

    return jnp.stack(series)

def create_dataset(
    price_data: jnp.ndarray,
    tokeniser: StockPriceTokeniser,
    seq_len: int
):
    """
    Encodes price data and creates autoregressive (x, y) pairs.
    x = [t_0, t_1, ..., t_n-1]
    y = [t_1, t_2, ..., t_n]
    """
    # 1. Tokenise the entire dataset
    token_data = tokeniser.encode(price_data) # [NumSeries, TotalLength]

    # 2. Create overlapping sequences
    # We'll use a simple approach: just take the first seq_len+1 tokens
    # A real implementation would use sliding windows.

    input_seq_len = seq_len
    target_seq_len = seq_len

    # Ensure data is long enough
    assert token_data.shape[1] > seq_len, "Data is shorter than sequence length"

    # Input tokens: [t_0, ..., t_seq_len-1]
    inputs = token_data[:, :input_seq_len]

    # Target tokens (shifted by 1): [t_1, ..., t_seq_len]
    targets = token_data[:, 1:input_seq_len+1]

    return inputs, targets

In [59]:
# %%
# --- Hyperparameters ---
STATE_SIZE: int = 128
D_MODEL: int = 64
N_LAYERS: int = 3
N_BLOCKS: int = 4
DROPOUT_RATE: float = 0.1
LEARNING_RATE: float = 1e-3
SEED = 0

# New Hyperparameters
VOCAB_SIZE: int = 1024       # Number of "price bins"
SEQ_LEN: int = 784           # Sequence length (kept from MNIST for comparison)
TOTAL_DATA_LEN: int = SEQ_LEN + 50 # Length of underlying data
BATCH_SIZE: int = 64
EPOCHS: int = 50

# %%
# --- Optimizer and Training State ---
class TrainingState(NamedTuple):
    params: hk.Params
    opt_state: optax.OptState
    rng_key: jnp.ndarray

optim = optax.adam(LEARNING_RATE)

# %%
# --- Loss Function ---
def loss_fn(
    params: hk.Params,
    rng_key: jnp.ndarray,
    model: hk.Transformed,
    inputs: jnp.ndarray,  # [Batch, SeqLen]
    targets: jnp.ndarray  # [Batch, SeqLen]
) -> jnp.ndarray:

    # Forward pass: Get logits [Batch, SeqLen, VocabSize]
    logits = model.apply(params, rng_key, inputs)

    # Calculate cross-entropy loss for sequences
    # This compares logits[t] with targets[t] for all t
    loss_per_token = optax.softmax_cross_entropy_with_integer_labels(
        logits=logits,
        labels=targets
    )

    # Return the mean loss over the batch and sequence
    return jnp.mean(loss_per_token)

# --- Accuracy Function ---
def accuracy_fn(
    params: hk.Params,
    rng_key: jnp.ndarray,
    model: hk.Transformed,
    inputs: jnp.ndarray,
    targets: jnp.ndarray
) -> jnp.ndarray:

    logits = model.apply(params, rng_key, inputs)

    # Get the token ID with the highest probability
    predictions = jnp.argmax(logits, axis=-1)

    # Compare predictions to the actual next token
    accuracy = jnp.mean(predictions == targets)
    return accuracy

In [60]:
# %%
# --- Update and Evaluate Functions ---
_Metrics = MutableMapping[str, Any]

@partial(jax.jit, static_argnums=(2,))
def update(
    state: TrainingState,
    batch: Tuple[jnp.ndarray, jnp.ndarray],
    model: hk.Transformed,
) -> Tuple[TrainingState, _Metrics]:

    inputs, targets = batch
    rng_key, next_rng_key = jax.random.split(state.rng_key)

    # Calculate loss and gradients
    (loss, gradients) = jax.value_and_grad(loss_fn)(
        state.params,
        rng_key,
        model,
        inputs,
        targets
    )

    updates, new_opt_state = optim.update(gradients, state.opt_state, state.params)
    new_params = optax.apply_updates(state.params, updates)

    new_state = TrainingState(
        params=new_params,
        opt_state=new_opt_state,
        rng_key=next_rng_key,
    )

    metrics = {'loss': loss}
    return new_state, metrics


@partial(jax.jit, static_argnums=(2,))
def evaluate(
    state: TrainingState,
    batch: Tuple[jnp.ndarray, jnp.ndarray],
    model: hk.Transformed,
) -> _Metrics:

    inputs, targets = batch

    # We can use the same state.rng_key for eval; no dropout updates
    loss = loss_fn(state.params, state.rng_key, model, inputs, targets)
    accuracy = accuracy_fn(state.params, state.rng_key, model, inputs, targets)

    metrics = {
        'loss': loss,
        'accuracy': accuracy
    }
    return metrics

In [61]:
# %%
# --- Set random seeds ---
torch.random.manual_seed(SEED)
key = jax.random.PRNGKey(SEED)
rng, init_rng = jax.random.split(key)

# %%
# --- Create Dataset ---
print("[*] Generating synthetic stock data...")
raw_price_data = generate_synthetic_stock_data(
    num_series=BATCH_SIZE*10, # 10 batches for training
    seq_length=TOTAL_DATA_LEN
)

print("[*] Fitting tokeniser...")
tokeniser = StockPriceTokeniser(vocab_size=VOCAB_SIZE)
tokeniser.fit(raw_price_data)

print("[*] Creating tokenised (x, y) pairs...")
# For this example, we'll just use one large batch
# A real pipeline would use a proper data loader and batching

# For validation, we generate new data
val_price_data = generate_synthetic_stock_data(BATCH_SIZE, TOTAL_DATA_LEN)
val_inputs, val_targets = create_dataset(val_price_data, tokeniser, SEQ_LEN)
val_batch = (val_inputs, val_targets)

# For training data, we'll create multiple batches
train_price_data = generate_synthetic_stock_data(BATCH_SIZE * 10, TOTAL_DATA_LEN)
train_inputs, train_targets = create_dataset(train_price_data, tokeniser, SEQ_LEN)

num_batches = train_inputs.shape[0] // BATCH_SIZE
train_batches = []
for i in range(num_batches):
    start = i * BATCH_SIZE
    end = (i + 1) * BATCH_SIZE
    train_batches.append((train_inputs[start:end], train_targets[start:end]))


# %%
# --- Initialise Model ---

# Use a dummy batch for initialisation
dummy_batch = train_batches[0][0]

@hk.transform
def forward(x) -> hk.transform:
    # Build the S5 component
    s5_ssm = S5(
        state_size=STATE_SIZE,
        d_model=D_MODEL,
        n_blocks=N_BLOCKS,
    )

    # Build the full forecaster model
    neural_net = S5Forecaster(
        ssm=s5_ssm,
        d_model=D_MODEL,
        n_layers=N_LAYERS,
        vocab_size=VOCAB_SIZE,
        dropout_rate=DROPOUT_RATE,
        istraining=True # We can toggle this for eval if needed
    )

    # vmap over the batch dimension
    return hk.vmap(neural_net, in_axes=0, split_rng=False)(x)

# %%
# Set state
initial_params = forward.init(init_rng, dummy_batch)
initial_opt_state = optim.init(initial_params)

state = TrainingState(
    params=initial_params,
    opt_state=initial_opt_state,
    rng_key=rng
)

[*] Generating synthetic stock data...
[*] Fitting tokeniser...
[*] Tokeniser fitted: min=97.93, max=102.02
[*] Creating tokenised (x, y) pairs...


In [62]:
# %%
# --- Training Loop ---
print("\n[*] Starting training...")
for epoch in range(EPOCHS):

    # --- Training ---
    epoch_losses = []
    for batch in tqdm(train_batches, desc=f"Epoch {epoch + 1}/{EPOCHS} [Train]"):
        state, metrics = update(state, batch, forward)
        epoch_losses.append(metrics['loss'])

    train_loss = jnp.mean(jnp.array(epoch_losses))

    # --- Validation ---
    val_metrics = evaluate(state, val_batch, forward)

    print(
        f"\n=>> Epoch {epoch + 1} Metrics ===\n"
        f"\tTrain Loss: {train_loss:.5f}\n"
        f"\t Val. Loss: {val_metrics['loss']:.5f} -- Val. Accuracy: {val_metrics['accuracy']:.4f}"
    )

print("[*] Training complete.")


[*] Starting training...


  x_bar = _convert_element_type(x_bar, x.aval.dtype, x.aval.weak_type)
Epoch 1/50 [Train]: 100%|██████████| 10/10 [00:24<00:00,  2.48s/it]



=>> Epoch 1 Metrics ===
	Train Loss: 12.46808
	 Val. Loss: 10.19843 -- Val. Accuracy: 0.0009


Epoch 2/50 [Train]: 100%|██████████| 10/10 [00:00<00:00, 565.03it/s]



=>> Epoch 2 Metrics ===
	Train Loss: 9.30281
	 Val. Loss: 8.49791 -- Val. Accuracy: 0.0010


Epoch 3/50 [Train]: 100%|██████████| 10/10 [00:00<00:00, 586.57it/s]



=>> Epoch 3 Metrics ===
	Train Loss: 8.11612
	 Val. Loss: 7.74346 -- Val. Accuracy: 0.0014


Epoch 4/50 [Train]: 100%|██████████| 10/10 [00:00<00:00, 647.21it/s]



=>> Epoch 4 Metrics ===
	Train Loss: 7.54373
	 Val. Loss: 7.33622 -- Val. Accuracy: 0.0017


Epoch 5/50 [Train]: 100%|██████████| 10/10 [00:00<00:00, 490.85it/s]



=>> Epoch 5 Metrics ===
	Train Loss: 7.21346
	 Val. Loss: 7.08273 -- Val. Accuracy: 0.0021


Epoch 6/50 [Train]: 100%|██████████| 10/10 [00:00<00:00, 561.72it/s]



=>> Epoch 6 Metrics ===
	Train Loss: 7.00340
	 Val. Loss: 6.91889 -- Val. Accuracy: 0.0021


Epoch 7/50 [Train]: 100%|██████████| 10/10 [00:00<00:00, 610.38it/s]



=>> Epoch 7 Metrics ===
	Train Loss: 6.86186
	 Val. Loss: 6.80637 -- Val. Accuracy: 0.0024


Epoch 8/50 [Train]: 100%|██████████| 10/10 [00:00<00:00, 444.44it/s]



=>> Epoch 8 Metrics ===
	Train Loss: 6.75810
	 Val. Loss: 6.71610 -- Val. Accuracy: 0.0025


Epoch 9/50 [Train]: 100%|██████████| 10/10 [00:00<00:00, 655.49it/s]



=>> Epoch 9 Metrics ===
	Train Loss: 6.67960
	 Val. Loss: 6.64613 -- Val. Accuracy: 0.0027


Epoch 10/50 [Train]: 100%|██████████| 10/10 [00:00<00:00, 547.83it/s]



=>> Epoch 10 Metrics ===
	Train Loss: 6.61104
	 Val. Loss: 6.57914 -- Val. Accuracy: 0.0029


Epoch 11/50 [Train]: 100%|██████████| 10/10 [00:00<00:00, 624.51it/s]



=>> Epoch 11 Metrics ===
	Train Loss: 6.54758
	 Val. Loss: 6.51559 -- Val. Accuracy: 0.0029


Epoch 12/50 [Train]: 100%|██████████| 10/10 [00:00<00:00, 590.25it/s]



=>> Epoch 12 Metrics ===
	Train Loss: 6.47967
	 Val. Loss: 6.44455 -- Val. Accuracy: 0.0038


Epoch 13/50 [Train]: 100%|██████████| 10/10 [00:00<00:00, 449.72it/s]



=>> Epoch 13 Metrics ===
	Train Loss: 6.39771
	 Val. Loss: 6.34773 -- Val. Accuracy: 0.0046


Epoch 14/50 [Train]: 100%|██████████| 10/10 [00:00<00:00, 430.87it/s]



=>> Epoch 14 Metrics ===
	Train Loss: 6.29170
	 Val. Loss: 6.23048 -- Val. Accuracy: 0.0048


Epoch 15/50 [Train]: 100%|██████████| 10/10 [00:00<00:00, 422.85it/s]



=>> Epoch 15 Metrics ===
	Train Loss: 6.16541
	 Val. Loss: 6.09156 -- Val. Accuracy: 0.0056


Epoch 16/50 [Train]: 100%|██████████| 10/10 [00:00<00:00, 298.10it/s]



=>> Epoch 16 Metrics ===
	Train Loss: 6.02273
	 Val. Loss: 5.94997 -- Val. Accuracy: 0.0059


Epoch 17/50 [Train]: 100%|██████████| 10/10 [00:00<00:00, 601.22it/s]



=>> Epoch 17 Metrics ===
	Train Loss: 5.87932
	 Val. Loss: 5.79894 -- Val. Accuracy: 0.0072


Epoch 18/50 [Train]: 100%|██████████| 10/10 [00:00<00:00, 573.24it/s]



=>> Epoch 18 Metrics ===
	Train Loss: 5.73446
	 Val. Loss: 5.65598 -- Val. Accuracy: 0.0078


Epoch 19/50 [Train]: 100%|██████████| 10/10 [00:00<00:00, 620.96it/s]



=>> Epoch 19 Metrics ===
	Train Loss: 5.60712
	 Val. Loss: 5.53742 -- Val. Accuracy: 0.0082


Epoch 20/50 [Train]: 100%|██████████| 10/10 [00:00<00:00, 662.98it/s]



=>> Epoch 20 Metrics ===
	Train Loss: 5.49701
	 Val. Loss: 5.44088 -- Val. Accuracy: 0.0085


Epoch 21/50 [Train]: 100%|██████████| 10/10 [00:00<00:00, 590.86it/s]



=>> Epoch 21 Metrics ===
	Train Loss: 5.40769
	 Val. Loss: 5.36038 -- Val. Accuracy: 0.0087


Epoch 22/50 [Train]: 100%|██████████| 10/10 [00:00<00:00, 614.86it/s]



=>> Epoch 22 Metrics ===
	Train Loss: 5.33156
	 Val. Loss: 5.28914 -- Val. Accuracy: 0.0096


Epoch 23/50 [Train]: 100%|██████████| 10/10 [00:00<00:00, 577.59it/s]



=>> Epoch 23 Metrics ===
	Train Loss: 5.26743
	 Val. Loss: 5.23122 -- Val. Accuracy: 0.0099


Epoch 24/50 [Train]: 100%|██████████| 10/10 [00:00<00:00, 400.08it/s]



=>> Epoch 24 Metrics ===
	Train Loss: 5.21280
	 Val. Loss: 5.18003 -- Val. Accuracy: 0.0108


Epoch 25/50 [Train]: 100%|██████████| 10/10 [00:00<00:00, 578.57it/s]



=>> Epoch 25 Metrics ===
	Train Loss: 5.16546
	 Val. Loss: 5.13222 -- Val. Accuracy: 0.0122


Epoch 26/50 [Train]: 100%|██████████| 10/10 [00:00<00:00, 654.32it/s]



=>> Epoch 26 Metrics ===
	Train Loss: 5.12275
	 Val. Loss: 5.09496 -- Val. Accuracy: 0.0111


Epoch 27/50 [Train]: 100%|██████████| 10/10 [00:00<00:00, 590.89it/s]



=>> Epoch 27 Metrics ===
	Train Loss: 5.08584
	 Val. Loss: 5.06363 -- Val. Accuracy: 0.0128


Epoch 28/50 [Train]: 100%|██████████| 10/10 [00:00<00:00, 633.93it/s]



=>> Epoch 28 Metrics ===
	Train Loss: 5.05469
	 Val. Loss: 5.03290 -- Val. Accuracy: 0.0119


Epoch 29/50 [Train]: 100%|██████████| 10/10 [00:00<00:00, 638.12it/s]



=>> Epoch 29 Metrics ===
	Train Loss: 5.02434
	 Val. Loss: 5.00242 -- Val. Accuracy: 0.0129


Epoch 30/50 [Train]: 100%|██████████| 10/10 [00:00<00:00, 612.64it/s]



=>> Epoch 30 Metrics ===
	Train Loss: 4.99829
	 Val. Loss: 4.98053 -- Val. Accuracy: 0.0130


Epoch 31/50 [Train]: 100%|██████████| 10/10 [00:00<00:00, 628.67it/s]



=>> Epoch 31 Metrics ===
	Train Loss: 4.97569
	 Val. Loss: 4.96022 -- Val. Accuracy: 0.0137


Epoch 32/50 [Train]: 100%|██████████| 10/10 [00:00<00:00, 593.54it/s]



=>> Epoch 32 Metrics ===
	Train Loss: 4.95492
	 Val. Loss: 4.94165 -- Val. Accuracy: 0.0138


Epoch 33/50 [Train]: 100%|██████████| 10/10 [00:00<00:00, 639.08it/s]



=>> Epoch 33 Metrics ===
	Train Loss: 4.93465
	 Val. Loss: 4.92558 -- Val. Accuracy: 0.0129


Epoch 34/50 [Train]: 100%|██████████| 10/10 [00:00<00:00, 568.00it/s]



=>> Epoch 34 Metrics ===
	Train Loss: 4.91839
	 Val. Loss: 4.90571 -- Val. Accuracy: 0.0139


Epoch 35/50 [Train]: 100%|██████████| 10/10 [00:00<00:00, 609.19it/s]



=>> Epoch 35 Metrics ===
	Train Loss: 4.90259
	 Val. Loss: 4.88869 -- Val. Accuracy: 0.0141


Epoch 36/50 [Train]: 100%|██████████| 10/10 [00:00<00:00, 632.72it/s]



=>> Epoch 36 Metrics ===
	Train Loss: 4.88869
	 Val. Loss: 4.87734 -- Val. Accuracy: 0.0145


Epoch 37/50 [Train]: 100%|██████████| 10/10 [00:00<00:00, 475.37it/s]



=>> Epoch 37 Metrics ===
	Train Loss: 4.87473
	 Val. Loss: 4.86327 -- Val. Accuracy: 0.0140


Epoch 38/50 [Train]: 100%|██████████| 10/10 [00:00<00:00, 689.35it/s]



=>> Epoch 38 Metrics ===
	Train Loss: 4.86278
	 Val. Loss: 4.85276 -- Val. Accuracy: 0.0149


Epoch 39/50 [Train]: 100%|██████████| 10/10 [00:00<00:00, 633.62it/s]



=>> Epoch 39 Metrics ===
	Train Loss: 4.85199
	 Val. Loss: 4.84259 -- Val. Accuracy: 0.0153


Epoch 40/50 [Train]: 100%|██████████| 10/10 [00:00<00:00, 651.23it/s]



=>> Epoch 40 Metrics ===
	Train Loss: 4.84168
	 Val. Loss: 4.83211 -- Val. Accuracy: 0.0144


Epoch 41/50 [Train]: 100%|██████████| 10/10 [00:00<00:00, 590.91it/s]



=>> Epoch 41 Metrics ===
	Train Loss: 4.83290
	 Val. Loss: 4.82312 -- Val. Accuracy: 0.0161


Epoch 42/50 [Train]: 100%|██████████| 10/10 [00:00<00:00, 511.23it/s]



=>> Epoch 42 Metrics ===
	Train Loss: 4.82426
	 Val. Loss: 4.81643 -- Val. Accuracy: 0.0157


Epoch 43/50 [Train]: 100%|██████████| 10/10 [00:00<00:00, 372.04it/s]



=>> Epoch 43 Metrics ===
	Train Loss: 4.81727
	 Val. Loss: 4.80711 -- Val. Accuracy: 0.0168


Epoch 44/50 [Train]: 100%|██████████| 10/10 [00:00<00:00, 415.82it/s]



=>> Epoch 44 Metrics ===
	Train Loss: 4.80905
	 Val. Loss: 4.80202 -- Val. Accuracy: 0.0155


Epoch 45/50 [Train]: 100%|██████████| 10/10 [00:00<00:00, 393.35it/s]



=>> Epoch 45 Metrics ===
	Train Loss: 4.80297
	 Val. Loss: 4.79927 -- Val. Accuracy: 0.0158


Epoch 46/50 [Train]: 100%|██████████| 10/10 [00:00<00:00, 547.61it/s]



=>> Epoch 46 Metrics ===
	Train Loss: 4.79553
	 Val. Loss: 4.78919 -- Val. Accuracy: 0.0163


Epoch 47/50 [Train]: 100%|██████████| 10/10 [00:00<00:00, 630.99it/s]



=>> Epoch 47 Metrics ===
	Train Loss: 4.78990
	 Val. Loss: 4.78319 -- Val. Accuracy: 0.0170


Epoch 48/50 [Train]: 100%|██████████| 10/10 [00:00<00:00, 606.38it/s]



=>> Epoch 48 Metrics ===
	Train Loss: 4.78469
	 Val. Loss: 4.78068 -- Val. Accuracy: 0.0170


Epoch 49/50 [Train]: 100%|██████████| 10/10 [00:00<00:00, 607.40it/s]



=>> Epoch 49 Metrics ===
	Train Loss: 4.77954
	 Val. Loss: 4.77314 -- Val. Accuracy: 0.0162


Epoch 50/50 [Train]: 100%|██████████| 10/10 [00:00<00:00, 568.97it/s]



=>> Epoch 50 Metrics ===
	Train Loss: 4.77450
	 Val. Loss: 4.76984 -- Val. Accuracy: 0.0173
[*] Training complete.


In [63]:
# %%
@partial(jax.jit, static_argnums=(0, 2))
def generate_forecast(
    model: hk.Transformed,
    params: hk.Params,
    steps_to_forecast: int,
    prompt_tokens: jnp.ndarray, # Shape [SeqLen]
    rng_key: jnp.ndarray
):
    """
    Autoregressively generates token predictions.
    """

    @hk.transform
    def forward_generate(x) -> hk.transform:
        # Build the S5 component
        s5_ssm = S5(
            state_size=STATE_SIZE,
            d_model=D_MODEL,
            n_blocks=N_BLOCKS,
        )
        # Build the full forecaster model, ensuring istraining=False
        neural_net = S5Forecaster(
            ssm=s5_ssm,
            d_model=D_MODEL,
            n_layers=N_LAYERS,
            vocab_size=VOCAB_SIZE,
            dropout_rate=0.0, # No dropout for inference
            istraining=False,
        )
        # NOTE: No vmap, we process a single sequence
        return neural_net(x)


    def scan_body(carry, _):
        """
        The body of the scan function.
        'carry' contains (current_tokens, rng_key)
        """
        current_tokens, rng_key = carry

        # Apply the model
        # We only need logits for the *very last* token to predict the next one
        logits = forward_generate.apply(params, rng_key, current_tokens)
        last_token_logits = logits[-1] # Shape [VocabSize]

        # Get next token prediction (greedy)
        # For better results, you could use jax.random.categorical to sample
        next_token = jnp.argmax(last_token_logits, axis=-1)
        next_token = next_token.astype(jnp.int32)

        # Create the new token sequence
        # We roll the window: drop the first token, append the new one
        new_tokens = jnp.roll(current_tokens, -1)
        new_tokens = new_tokens.at[-1].set(next_token)

        # Split RNG for next step (though not strictly needed for argmax)
        rng_key, _ = jax.random.split(rng_key)

        # Return new state and the token we just predicted
        return (new_tokens, rng_key), next_token

    # Initial state for the scan
    initial_carry = (prompt_tokens, rng_key)

    # Run the scan to generate 'steps_to_forecast' new tokens
    _, predicted_token_sequence = jax.lax.scan(
        scan_body, initial_carry, None, length=steps_to_forecast
    )

    return predicted_token_sequence

# %%
# --- Run a Forecast ---

# 1. Get a prompt from the validation data
# We'll use the first sequence from our validation batch
prompt_price_data = val_price_data[0, :SEQ_LEN]
prompt_token_data = tokeniser.encode(prompt_price_data)

# 2. Define how many steps to forecast
FORECAST_HORIZON = 50

print(f"[*] Generating {FORECAST_HORIZON}-step forecast...")

# 3. Generate the sequence of *future token IDs*
# We need to use the *training* state's params, but with a new model
# transform set to istraining=False.
predicted_tokens = generate_forecast(
    forward, # Use the hk.transform, not the instance
    state.params,
    FORECAST_HORIZON,
    prompt_token_data,
    state.rng_key
)

# 4. Decode the tokens back into prices
predicted_prices = tokeniser.decode(predicted_tokens)

# 5. Get the "ground truth" prices for comparison
ground_truth_prices = val_price_data[0, SEQ_LEN : SEQ_LEN + FORECAST_HORIZON]

print("[*] Forecast complete.")
print("\n--- Forecast vs. Ground Truth ---")
for i in range(FORECAST_HORIZON):
    print(f"Step {i+1}: Predicted={predicted_prices[i]:.2f}, Actual={ground_truth_prices[i]:.2f}")

[*] Generating 50-step forecast...
[*] Forecast complete.

--- Forecast vs. Ground Truth ---
Step 1: Predicted=100.13, Actual=100.16
Step 2: Predicted=100.10, Actual=99.96
Step 3: Predicted=100.13, Actual=99.98
Step 4: Predicted=100.10, Actual=99.96
Step 5: Predicted=100.13, Actual=99.96
Step 6: Predicted=100.03, Actual=99.84
Step 7: Predicted=100.00, Actual=99.84
Step 8: Predicted=99.95, Actual=99.95
Step 9: Predicted=99.95, Actual=99.95
Step 10: Predicted=99.95, Actual=99.96
Step 11: Predicted=99.85, Actual=99.87
Step 12: Predicted=99.85, Actual=99.94
Step 13: Predicted=99.85, Actual=99.69
Step 14: Predicted=99.85, Actual=99.73
Step 15: Predicted=99.85, Actual=99.47
Step 16: Predicted=99.78, Actual=99.75
Step 17: Predicted=99.77, Actual=99.89
Step 18: Predicted=99.73, Actual=99.67
Step 19: Predicted=99.69, Actual=99.67
Step 20: Predicted=99.64, Actual=99.39
Step 21: Predicted=99.62, Actual=99.57
Step 22: Predicted=99.62, Actual=99.58
Step 23: Predicted=99.62, Actual=99.63
Step 24: Pr