<a href="https://colab.research.google.com/github/Sat-A/s5-jax/blob/main/s5-Predictor.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# S5-JAX Predictor

The goal of this exercise is to adapt an annotated implementation of S5 https://github.com/JPGoodale/annotated-s5 that was made for classification tasks to regression tasks.

# Runtime Setup
Ensure runtime is set to GPU to ensure gpu parallelisation speedup

In [23]:
!pip install dm-haiku
!pip install hippox



In [54]:
# Let's go ahead and import hippox and the core JAX libraries we will be using:
import jax
import numpy as np
import jax.numpy as jnp
import haiku as hk
from hippox.main import Hippo
from typing import Optional
# New imports for stock data
import yfinance as yf
from sklearn.preprocessing import MinMaxScaler
from sklearn.model_selection import train_test_split
import torch  # Make sure torch is imported if not already
from torch.utils.data import DataLoader, Dataset as TorchDataset
from typing import NamedTuple
import matplotlib.pyplot as plt

First we'll define some helper functions for discretization and timescale initialization as the SSM equation is naturally continuous and must be made discrete to be unrolled as a linear recurrence like standard RNNs.

In [26]:
# Here we are just using the zero order hold method for its sheer simplicity, with A, B and delta_t denoting the
# state matrix, input matrix and change in timescale respectively.

def discretize(A, B, delta_t):
    Identity = jnp.ones(A.shape[0])
    _A = jnp.exp(A*delta_t)
    _B = (1/A * (_A-Identity)) * B
    return _A, _B

# This is a function used to initialize the trainable timescale parameter.
def log_step_initializer(dt_min=0.001, dt_max=0.1):
    def init(shape, dtype):
        uniform = hk.initializers.RandomUniform()
        return uniform(shape, dtype)*(jnp.log(dt_max) - jnp.log(dt_min)) + jnp.log(dt_min)
    return init

# Taken directly from https://github.com/deepmind/dm-haiku/blob/main/haiku/_src/recurrent.py
def add_batch(nest, batch_size: Optional[int]):
    broadcast = lambda x: jnp.broadcast_to(x, (batch_size,) + x.shape)
    return jax.tree_util.tree_map(broadcast, nest)

The linear SSM equation is as follows:
$$ x_0(t) = Ax(t) + Bu(t) $$
$$ y(t) = Cx(t) + Du(t) $$

 We will now implement it as a recurrent Haiku module:

In [27]:
class LinearSSM(hk.RNNCore):
    def __init__(self, state_size: int, name: Optional[str] = None):
        super(LinearSSM, self).__init__(name=name)
        # We won't get into the basis measure families here, just note that they are basically just the
        # type of orthogonal polynomial we initialize with, the scaled Legendre measure (LegS) introduced
        # in the original HiPPO paper is pretty much the standard initialization and is what is used in the
        # main experiments in the S5 paper. I will also note that the Hippo class uses the diagonal representation
        # of the state matrix by default, as this has become the standard in neural SSMs since shown to be
        # equally effective as the diagonal plus low rank representation in https://arxiv.org/abs/2203.14343
        # and then formalized in https://arxiv.org/abs/2206.11893.

        _hippo = Hippo(state_size=state_size, basis_measure='legs')
        # Must be called for parameters to be initialized
        _hippo()

        # We register the real and imaginary components of the state matrix A as separate parameters because
        # they will have separate gradients in training, they will be conjoined back together and then discretized
        # but this will simply be backpropagated through as a transformation of the lambda real and imaginary
        # parameters (lambda is just what we call the diagonalized state matrix).

        self._lambda_real = hk.get_parameter(
            'lambda_real',
            shape=[state_size,],
            init=_hippo.lambda_initializer('real')
        )
        self._lambda_imag = hk.get_parameter(
            'lambda_imaginary',
            shape=[state_size,],
            init=_hippo.lambda_initializer('imaginary')
        )
        self._A = self._lambda_real + 1j * self._lambda_imag

       # For now, these initializations of the input and output matrices B and C match the S4D
        # parameterization for demonstration purposes, we will implement the S5 versions later.

        self._B = hk.get_parameter(
            'B',
            shape=[state_size,],
            init=_hippo.b_initializer()
        )
        self._C = hk.get_parameter(
            'C',
            shape=[state_size, 2],
            init=hk.initializers.RandomNormal(stddev=0.5**0.5)
        )
        self._output_matrix = self._C[..., 0] + 1j * self._C[..., 1]

        # This feed-through matrix basically acts as a residual connection.
        self._D = hk.get_parameter(
            'D',
            [1,],
            init=jnp.ones,
        )

        self._delta_t = hk.get_parameter(
            'delta_t',
            shape=[1,],
            init=log_step_initializer()
        )
        timescale = jnp.exp(self._delta_t)

        self._state_matrix, self._input_matrix = discretize(self._A, self._B, timescale)

    def __call__(self, inputs, prev_state):
        u = inputs[:, jnp.newaxis]
        new_state = self._state_matrix @ prev_state + self._input_matrix @ u
        y_s = self._output_matrix @ new_state
        out = y_s.reshape(-1).real + self._D * u
        return out, new_state

    def initial_state(self, batch_size: Optional[int] = None):
        state = jnp.zeros([self._state_size])
        if batch_size is not None:
            state = add_batch(state, batch_size)
        return state

You may notice that this looks an awful lot like a vanilla RNN cell, just with our special parameterization and without any activations, hence being a linear recurrence. I have initialized it as an instance of Haiku's RNN.Core abstract base class so that it can be unrolled using either the hk.dynamic_unroll or hk.static_unroll functions like any other recurrent module, however, if you are familiar with any of the S4 models you may be noticing that there's something crucial missing here: the convolutional representation. One of the key contributions of the S4 paper was its demonstration that the SSM ODE can be represented as either a linear recurrence, as above, for efficient inference, or as a global convolution for much faster training. That paper and the following papers then go on to present various kernels for efficiently computing this convolution with Fast Fourier Transforms, highly improving the computational efficiency of the model. Then why have we omitted them? Because the S5 architecture which we are about to explore simplifies all this by providing a purely recurrent representation in both training and inference, it does this by using a parallel recurrence that actually looks alot like a convolution itself! From the paper:

    "We use parallel scans to efficiently compute the states of a discretized linear SSM. Given a binary associative operator • (i.e. (a • b) • c = a • (b • c)) and a sequence of L elements [a1, a2, ..., aL], the scan operation (sometimes referred to as all-prefix-sum) returns the sequence [a1, (a1 • a2), ..., (a1 • a2 • ... • aL)]."

Let's see what this looks like in code, taken straight from the original author's implementation:

In [28]:
@jax.vmap
def binary_operator(q_i, q_j):
    A_i, b_i = q_i
    A_j, b_j = q_j
    return A_j * A_i, A_j * b_i + b_j

def parallel_scan(A, B, C, inputs):
    A_elements = A * jnp.ones((inputs.shape[0], A.shape[0]))
    Bu_elements = jax.vmap(lambda u: B @ u)(inputs)
    # Jax's built-in associative scan really comes in handy here as it executes a similar scan
    # operation as used in a normal recurrent unroll but is specifically tailored to fit an associative
    # operation like the one described in the paper.
    _, xs = jax.lax.associative_scan(binary_operator, (A_elements, Bu_elements))
    return jax.vmap(lambda x: (C @ x).real)(xs)

It's that simple! In the original S4 we would have had to apply an independent singe-input, single-output (SISO) SSM for each feature of the input sequence such as in this excerpt from Sasha Rush's Flax implementation:



```python
def cloneLayer(layer):
    return flax.linen.vmap(
        layer,
        in_axes=1,
        out_axes=1,
        variable_axes={"params": 1, "cache": 1, "prime": 1},
        split_rngs={"params": True},
    )
SSMLayer = cloneLayer(SSMLayer)```



Whereas in the S5 we process the entire sequence in one multi-input, multi-output (MIMO) layer.

Let's now rewrite our Module as a full S5 layer using this new method, we will be adding a few extra conditional arguments as well as changing some parameterization to match the original paper, but we'll walk through the reason for all these changes below.

In [29]:
# First we add a new helper function for the timescale initialization, this one just takes the previous
# log_step_initializer and stores a bunch of them in an array since our model is now multi-in, multi-out.

def init_log_steps(shape, dtype):
    H = shape[0]
    log_steps = []
    for i in range(H):
        log_step = log_step_initializer()(shape=(1,), dtype=dtype)
        log_steps.append(log_step)

    return jnp.array(log_steps)

In [30]:
# We will also rewrite our discretization for the MIMO context
def discretize(A, B, delta_t):
    Identity = jnp.ones(A.shape[0])
    _A = jnp.exp(A*delta_t)
    _B = (1/A * (_A-Identity))[..., None] * B
    return _A, _B

In [31]:
class S5(hk.Module):
    def __init__(self,
                 state_size: int,

                 # Now that we're MIMO we'll need to know the number of input features, commonly
                 # referred to as the dimension of the model.
                 d_model: int,

                 # We must also now specify the number of blocks that we will split our matrices
                 # into due to the MIMO context.
                 n_blocks: int,

                 # Short for conjugate symmetry, because our state matrix is complex we can half
                 # the size of it since complex numbers are a real and imaginary number joined together,
                 # this is not new to the S5, we just didn't mention it above.
                 conj_sym: bool = True,

                 # Another standard SSM argument that we omitted above for simplicity's sake,
                 # this forces the real part of the state matrix to be negative for better
                 # stability, especially in autoregressive tasks.
                 clip_eigns: bool = False,

                 # Like most RNNs, the S5 can be run in both directions if need be.
                 bidirectional: bool = False,

                 # Rescales delta_t for varying input resolutions, such as different audio
                 # sampling rates.
                 step_rescale: float = 1.0,
                 name: Optional[str] = None
    ):
        super(S5, self).__init__(name=name)
        self.conj_sym = conj_sym
        self.bidirectional = bidirectional

        # Note that the Hippo class takes conj_sym as an argument and will automatically half
        # the state size provided in its initialization, which is why we need to provide a local
        # state size that matches this for the shape argument in hk.get_parameter().

        if conj_sym:
            _state_size = state_size // 2
        else:
            _state_size = state_size

        # With block_diagonal set as True and the number of blocks provided, our Hippo class
        # will automatically handle this change of structure.

        _hippo = Hippo(
            state_size=state_size,
            basis_measure='legs',
            conj_sym=conj_sym,
            block_diagonal=True,
            n_blocks=n_blocks,
        )
        _hippo()

        self._lambda_real = hk.get_parameter(
            'lambda_real',
            [_state_size],
            init=_hippo.lambda_initializer('real')
        )
        self._lambda_imag = hk.get_parameter(
            'lambda_imaginary',
            [_state_size],
            init=_hippo.lambda_initializer('imaginary')
        )
        if clip_eigns:
            self._lambda = jnp.clip(self._lambda_real, None, -1e-4) + 1j * self._lambda_imag
        else:
            self._A = self._lambda_real + 1j * self._lambda_imag

        # If you recall, I mentioned above that we are automatically using a diagonalized version of
        # the HiPPO state matrix rather than the pure one, due to it being very hard to efficiently
        # compute. I will now go into a little more detail on how this diagonal representation is
        # derived, as it is important for how we initialize the input and output matrices. The diagonal
        # decomposition of our state matrix is based on equivalence relation on the SSM parameters:
        # (A, B, C) ∼ (V−1AV ,V−1B, CV) with V being the eigenvector of our original A matrix and V-1
        # being the inverse eigenvector. The Hippo class has already performed the decomposition of A
        # into (V-1AV) automatically, but we have not yet performed the decomposition of B and C, we will
        # use the eigenvector_transform class method for that below, but first we must initialize B and C
        # as normal distributions, lecun normal and truncated normal respectively. I will note that there
        # are a few other options provided for C in the original repository but, to keep it simple, we will
        # just use one here.

        b_init = hk.initializers.VarianceScaling()
        b_shape = [state_size, d_model]
        b_init = b_init(b_shape, dtype=jnp.complex64)
        self._B = hk.get_parameter(
            'B',
            [_state_size, d_model, 2],
            init=_hippo.eigenvector_transform(b_init,  concatenate=True),
        )
        B = self._B[..., 0] + 1j * self._B[..., 1]

        c_init = hk.initializers.TruncatedNormal()
        c_shape = [d_model, state_size, 2]
        c_init = c_init(c_shape, dtype=jnp.complex64)
        self._C = hk.get_parameter(
            'C',
            [d_model, _state_size, 2],
            init=_hippo.eigenvector_transform(c_init, inverse=False, concatenate=True),
        )
        # We need two output heads if bidirectional is True.
        if bidirectional:
            self._C2 = hk.get_parameter(
                'C2',
                [d_model, _state_size, 2],
                init=_hippo.eigenvector_transform(c_init, inverse=False, concatenate=True),
            )
            C1 = self._C[..., 0] + 1j * self._C[..., 1]
            C2 = self._C2[..., 0] + 1j * self._C2[..., 1]
            self._output_matrix = jnp.concatenate((C1, C2), axis=-1)
        else:
            self._output_matrix = self._C[..., 0] + 1j * self._C[..., 1]

        self._D = hk.get_parameter(
            'D',
            [d_model,],
            init=hk.initializers.RandomNormal(stddev=1.0)
        )

        self._delta_t = hk.get_parameter(
            'delta_T',
            [_state_size, 1],
            init=init_log_steps
        )
        timescale = step_rescale * jnp.exp(self._delta_t[:, 0])

        # We could also use the bilinear discretization method, but we'll just stick to zoh for now.
        self._state_matrix, self._input_matrix = discretize(self._A, B, timescale)


    def __call__(self, inputs):
        # Note that this is the exact same function as presented above just with alternate procedures
        # depending on the bidirectional and conjugate symmetry arguments

        A_elements = self._state_matrix * jnp.ones((inputs.shape[0], self._state_matrix.shape[0]))
        Bu_elements = jax.vmap(lambda u: self._input_matrix @ u)(inputs)

        _, xs = jax.lax.associative_scan(binary_operator, (A_elements, Bu_elements))

        if self.bidirectional:
            _, xs2 = jax.lax.associative_scan(binary_operator,
                                          (A_elements, Bu_elements),
                                          reverse=True)
            xs = jnp.concatenate((xs, xs2), axis=-1)

        if self.conj_sym:
            ys = jax.vmap(lambda x: 2*(self._output_matrix @ x).real)(xs)
        else:
            ys = jax.vmap(lambda x: (self._output_matrix @ x).real)(xs)

        Du = jax.vmap(lambda u: self._D * u)(inputs)

        return ys + Du

There we have it, a complete S5 layer! Now let's form a block around it using a structure very similar to a transformer block with a Gated Linear Unit (GLU).

In [10]:
import dataclasses

@dataclasses.dataclass
class S5Block(hk.Module):
    ssm: S5
    d_model: int
    dropout_rate: float
    prenorm: bool
    istraining: bool = True
    name: Optional[str] = None

    def __post_init__(self):
        super(S5Block, self).__post_init__()
        # We could use either layer norm or batch norm.
        self._norm = hk.LayerNorm(axis=-1, create_scale=True, create_offset=True)
        self._linear = hk.Linear(self.d_model)

    def __call__(self, x):
        skip = x
        if self.prenorm:
            x = self._norm(x)

        x = self.ssm(x)
        # There are a couple of other GLU patterns we could use here, but once again I have chosen
        # one semi-arbitrarily to avoid cluttering our module with if statements.
        x1 = hk.dropout(hk.next_rng_key(), self.dropout_rate, jax.nn.gelu(x))
        x = x * jax.nn.sigmoid(self._linear(x1))
        x = hk.dropout(hk.next_rng_key(), self.dropout_rate, x)

        x = skip + x
        if not self.prenorm:
            x = self._norm(x)

        return x

Now let's make a stack of these blocks:

In [32]:
@dataclasses.dataclass
class S5Stack(hk.Module):
    ssm: S5
    d_model: int
    n_layers: int
    dropout_rate: float
    prenorm: bool
    istraining: bool = True
    name: Optional[str] = None

    def __post_init__(self):
        super(S5Stack, self).__post_init__(name=self.name)
        self._encoder = hk.Linear(self.d_model)
        self._layers = [
            S5Block(
                ssm=self.ssm,
                d_model=self.d_model,
                dropout_rate=self.dropout_rate,
                istraining=self.istraining,
                prenorm=self.prenorm,
            )
            for _ in range(self.n_layers)
        ]

    def __call__(self, x):
        x = self._encoder(x)
        for layer in self._layers:
            x = layer(x)
        return x

# Implementing Regression in S5

Create a new dataset class to hold our stock data

In [47]:
# --- New Data Section ---

# This custom Dataset class will hold our sliding windows
class StockWindowDataset(TorchDataset):
    def __init__(self, data, labels):
        # Explicitly convert JAX array -> NumPy array -> Torch tensor
        self.data = torch.tensor(np.array(data), dtype=torch.float32)
        self.labels = torch.tensor(np.array(labels), dtype=torch.float32)

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        # We need to add an extra dimension for the input feature
        return self.data[idx][..., None], self.labels[idx]

def create_stock_dataset(batch_size=64, seq_length=7):
    print("[*] Generating Stock Price Dataset...")

    # 1. Download data
    # Let's use the S&P 500 index as an example
    ticker = "^GSPC"
    data = yf.download(ticker, start="2000-01-01", end="2023-12-31")

    # We'll use 'Close' as the representative "average" price
    prices = data[['Close']].values

    # 2. Normalise the data
    # We fit the scaler ONLY on the training data to prevent data leakage
    train_split_idx = int(len(prices) * 0.8)
    scaler = MinMaxScaler(feature_range=(0, 1))

    # Fit on training data
    scaler.fit(prices[:train_split_idx])

    # Transform all data
    prices_scaled = scaler.transform(prices).flatten() # Flatten to 1D array

    # 3. Create sliding windows
    X, y = [], []
    for i in range(len(prices_scaled) - seq_length):
        X.append(prices_scaled[i : i + seq_length])
        y.append(prices_scaled[i + seq_length])

    X = jnp.array(X)
    y = jnp.array(y)

    # 4. Split into train and test sets
    X_train, X_test, y_train, y_test = train_test_split(
        X, y, test_size=0.2, shuffle=False # Time-series data should not be shuffled
    )

    print(f"[*] Train samples: {len(X_train)}, Test samples: {len(X_test)}")

    # 5. Create DataLoaders
    train_dataset = StockWindowDataset(X_train, y_train)
    test_dataset = StockWindowDataset(X_test, y_test)

    trainloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    testloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

    # N_CLASSES is now 1 (we predict one value)
    # SEQ_LENGTH is our window size (7)
    # IN_DIM is 1 (just the price)
    N_CLASSES = 1
    IN_DIM = 1

    # We also return the scaler so we can invert the predictions later
    return trainloader, testloader, N_CLASSES, seq_length, IN_DIM, scaler

# Replace the Datasets dict and create_dataset function
Datasets = {
    "stock-prediction": create_stock_dataset,
}

# Simple class for storing our dataset parameters
class Dataset(NamedTuple):
    trainloader: DataLoader
    testloader: DataLoader
    n_classes: int  # This will be 1
    seq_length: int # This will be 7
    d_input: int    # This will be 1
    scaler: MinMaxScaler # Store the scaler
    classification: bool

def create_dataset(dataset: str, batch_size: int, seq_length: int) -> Dataset:
    classification = 'classification' in dataset
    dataset_init = Datasets[dataset]

    # Pass seq_length to the function
    trainloader, testloader, n_classes, seq_len, d_input, scaler = dataset_init(
        batch_size=batch_size, seq_length=seq_length
    )

    return Dataset(
        trainloader,
        testloader,
        n_classes,
        seq_len,
        d_input,
        scaler, # Add scaler here
        classification
    )

# Creating our Regressor

In [51]:
# Replace S5Classifier with S5Regressor

@dataclasses.dataclass
class S5Regressor(hk.Module):
    ssm: S5
    d_model: int
    d_output: int  # This will be 1 for our task
    n_layers: int
    dropout_rate: float
    mode: str = 'last'  # 'last' is more common for time-series prediction
    prenorm: bool = True
    istraining: bool = True
    name: Optional[str] = None

    def __post_init__(self):
        super(S5Regressor, self).__post_init__(name=self.name)
        self._encoder = S5Stack(
            ssm=self.ssm,
            d_model=self.d_model,
            n_layers=self.n_layers,
            dropout_rate=self.dropout_rate,
            istraining=self.istraining,
            prenorm=self.prenorm,
        )
        self._decoder = hk.Linear(self.d_output)

    def __call__(self, x):
        x = self._encoder(x)

        # For predicting the next step, using the 'last' hidden state
        # is usually the best approach.
        if self.mode == 'pool':
            x = jnp.mean(x, axis=0)
        elif self.mode == 'last':
            x = x[-1] # Take the last element of the sequence
        else:
            raise NotImplementedError("Mode must be in ['pool', 'last]")

        x = self._decoder(x)

        # CRITICAL: Remove the log_softmax! We want the raw output value.
        # We also squeeze the last dimension to match the label shape (batch_size,)
        return jnp.squeeze(x, axis=-1)

Next we will set some hyperparameters:

In [36]:
# --- Set Hyperparameters ---
# These are a good start, but you will need to tune them for this new task.
STATE_SIZE: int = 64   # Reduced for a simpler 1D input
D_MODEL: int = 32    # Reduced
N_LAYERS: int = 4
N_BLOCKS: int = 4
EPOCHS: int = 50
BATCH_SIZE: int = 64
DROPOUT_RATE: float = 0.1
LEARNING_RATE: float = 1e-3
SEED = 0
SEQ_LENGTH: int = 7  # Your specified 7-day window

To keep things simple for this example, we will just be using a plain adam optimizer, but the results can be highly improved with extra techniques such as cosine annealing, learning rate schedules and weight decay.

In [37]:
import optax

class TrainingState(NamedTuple):
    params: hk.Params
    opt_state: optax.OptState
    rng_key: jnp.ndarray

optim = optax.adam(LEARNING_RATE)


Now for the loss and update functions:

In [38]:
from functools import partial
from typing import Tuple, MutableMapping, Any
_Metrics = MutableMapping[str, Any]

# Replace cross_entropy_loss and compute_accuracy

@partial(jnp.vectorize, signature="(),()->()")
def mse_loss(prediction, target) -> jnp.ndarray:
    """Computes Mean Squared Error loss."""
    return (prediction - target) ** 2

# We no longer need compute_accuracy

# --- Modify the 'update' function ---
@partial(jax.jit, static_argnums=(3, 4))
def update(
        state: TrainingState,
        inputs: jnp.ndarray,
        targets: jnp.ndarray,
        model: hk.transform,
        classification: bool,) -> Tuple[TrainingState, _Metrics]:

    rng_key, next_rng_key = jax.random.split(state.rng_key)

    def loss_fn(params):
        # 'predictions' instead of 'logits'
        predictions = model.apply(params, rng_key, inputs)

        # Use MSE loss
        _loss = jnp.mean(mse_loss(predictions, targets))

        # We return loss as the primary value and as an aux value
        return _loss, _loss

    # This 'if' is no longer needed, but harmless if left.
    # if not classification:
    #     targets = inputs[:, :, 0]

    grad_fn = jax.value_and_grad(loss_fn, has_aux=True)

    # (loss, loss_metric), gradients
    (loss, _), gradients = grad_fn(state.params)

    updates, new_opt_state = optim.update(gradients, state.opt_state, state.params)
    new_params = optax.apply_updates(state.params, updates)

    new_state = TrainingState(
        params=new_params,
        opt_state=new_opt_state,
        rng_key=next_rng_key,
    )
    metrics = {
        'loss': loss, # We only track loss now
    }

    return new_state, metrics

# --- Modify the 'evaluate' function ---
@partial(jax.jit, static_argnums=(3, 4))
def evaluate(
        state: TrainingState,
        inputs: jnp.ndarray,
        targets: jnp.ndarray,
        model: hk.transform,
        classification,) -> _Metrics:

    rng_key, _ = jax.random.split(state.rng_key, 2)

    # if not classification:
    #     targets = inputs[:, :, 0]

    predictions = model.apply(state.params, rng_key, inputs)
    loss = jnp.mean(mse_loss(predictions, targets))

    metrics = {
        'loss': loss, # We only track loss
    }

    return metrics


Now we call these update and evaluate functions in their respective epochs:

In [39]:
from tqdm import tqdm

# --- Modify 'training_epoch' ---
def training_epoch(
        state: TrainingState,
        trainloader: DataLoader,
        model: hk.transform,
        classification: bool = False,) -> Tuple[TrainingState, jnp.ndarray]: # Removed accuracy output

    batch_losses = [] # Removed batch_accuracies
    for batch_idx, (inputs, targets) in enumerate(tqdm(trainloader)):
        inputs = jnp.array(inputs.numpy())
        targets = jnp.array(targets.numpy())
        state, metrics = update(
            state, inputs, targets,
            model, classification
        )
        batch_losses.append(metrics['loss'])
        # Removed accuracy

    return (
        state,
        jnp.mean(jnp.array(batch_losses)),
        # Removed accuracy
    )

# --- Modify 'validation_epoch' ---
def validation_epoch(
    state: TrainingState,
    testloader: DataLoader,
    model: hk.transform,
    classification: bool = True,) -> jnp.ndarray: # Removed accuracy output

    losses = [] # Removed accuracies
    for batch_idx, (inputs, targets) in enumerate(tqdm(testloader)):
        inputs = jnp.array(inputs.numpy())
        targets = jnp.array(targets.numpy())
        metrics = evaluate(
            state, inputs, targets,
            model, classification
        )
        losses.append(metrics['loss'])
        # Removed accuracy

    return jnp.mean(jnp.array(losses)) # Removed accuracy

In [40]:
import torch
# Set random number generators
torch.random.manual_seed(SEED)
key = jax.random.PRNGKey(SEED)
rng, init_rng = jax.random.split(key)

In [48]:
# Note: We now pass SEQ_LENGTH
ds = create_dataset(
    'stock-prediction',
    BATCH_SIZE,
    seq_length=SEQ_LENGTH
)
init_data = jnp.array(next(iter(ds.trainloader))[0].numpy())

  data = yf.download(ticker, start="2000-01-01", end="2023-12-31")
[*********************100%***********************]  1 of 1 completed

[*] Generating Stock Price Dataset...
[*] Train samples: 4824, Test samples: 1206





In [49]:
# In Haiku, we have to call our model inside a transformed function using hk.transform for it to become
# functionally pure and compatible with essential JAX functions like jax.grad(). Here we are using hk.vmap
# instead of jax.vmap because we are calling it from within a hk.transform.
@hk.transform
def forward(x) -> hk.transform:
    # Use S5Regressor
    neural_net = S5Regressor(
        S5(
            STATE_SIZE,
            D_MODEL,
            N_BLOCKS,
        ),
        D_MODEL,
        ds.n_classes,  # This will be 1
        N_LAYERS,
        DROPOUT_RATE,
    )
    return hk.vmap(neural_net, split_rng=False)(x)

In [52]:
# Set state
initial_params = forward.init(init_rng, init_data)
initial_opt_state = optim.init(initial_params)

state = TrainingState(
    params=initial_params,
    opt_state=initial_opt_state,
    rng_key=rng)

  return _convert_element_type(operand, new_dtype, weak_type=False)  # type: ignore[unused-ignore,bad-return-type]


# Training and Plotting

In [53]:
for epoch in range(EPOCHS):
    print(f"[*] Training Epoch {epoch + 1}...")
    # Update function signatures
    state, training_loss = training_epoch(
        state,
        ds.trainloader,
        forward,
        ds.classification
    )
    print(f"[*] Running Epoch {epoch + 1} Validation...")
    # Update function signatures
    test_loss = validation_epoch(
        state,
        ds.testloader,
        forward,
        ds.classification
    )

    # Updated print statement
    print(f"\n=>> Epoch {epoch + 1} Metrics ===")
    print(
        f"\tTrain Loss (MSE): {training_loss:.5f}\n\t Test Loss (MSE): {test_loss:.5f}"
    )

[*] Training Epoch 1...


  x_bar = _convert_element_type(x_bar, x.aval.dtype, x.aval.weak_type)
100%|██████████| 76/76 [00:18<00:00,  4.10it/s]


[*] Running Epoch 1 Validation...


100%|██████████| 19/19 [00:09<00:00,  2.07it/s]



=>> Epoch 1 Metrics ===
	Train Loss (MSE): 2.65943
	 Test Loss (MSE): 1.46223
[*] Training Epoch 2...


100%|██████████| 76/76 [00:00<00:00, 330.85it/s]


[*] Running Epoch 2 Validation...


100%|██████████| 19/19 [00:00<00:00, 559.63it/s]



=>> Epoch 2 Metrics ===
	Train Loss (MSE): 0.96576
	 Test Loss (MSE): 0.26262
[*] Training Epoch 3...


100%|██████████| 76/76 [00:00<00:00, 365.42it/s]


[*] Running Epoch 3 Validation...


100%|██████████| 19/19 [00:00<00:00, 575.11it/s]



=>> Epoch 3 Metrics ===
	Train Loss (MSE): 0.65894
	 Test Loss (MSE): 1.07399
[*] Training Epoch 4...


100%|██████████| 76/76 [00:00<00:00, 377.82it/s]


[*] Running Epoch 4 Validation...


100%|██████████| 19/19 [00:00<00:00, 567.09it/s]



=>> Epoch 4 Metrics ===
	Train Loss (MSE): 0.46738
	 Test Loss (MSE): 0.03111
[*] Training Epoch 5...


100%|██████████| 76/76 [00:00<00:00, 372.24it/s]


[*] Running Epoch 5 Validation...


100%|██████████| 19/19 [00:00<00:00, 579.72it/s]



=>> Epoch 5 Metrics ===
	Train Loss (MSE): 0.43795
	 Test Loss (MSE): 0.13936
[*] Training Epoch 6...


100%|██████████| 76/76 [00:00<00:00, 336.36it/s]


[*] Running Epoch 6 Validation...


100%|██████████| 19/19 [00:00<00:00, 555.60it/s]



=>> Epoch 6 Metrics ===
	Train Loss (MSE): 0.27025
	 Test Loss (MSE): 0.13029
[*] Training Epoch 7...


100%|██████████| 76/76 [00:00<00:00, 356.19it/s]


[*] Running Epoch 7 Validation...


100%|██████████| 19/19 [00:00<00:00, 556.29it/s]



=>> Epoch 7 Metrics ===
	Train Loss (MSE): 0.27289
	 Test Loss (MSE): 0.61472
[*] Training Epoch 8...


100%|██████████| 76/76 [00:00<00:00, 364.48it/s]


[*] Running Epoch 8 Validation...


100%|██████████| 19/19 [00:00<00:00, 530.63it/s]



=>> Epoch 8 Metrics ===
	Train Loss (MSE): 0.14729
	 Test Loss (MSE): 0.05707
[*] Training Epoch 9...


100%|██████████| 76/76 [00:00<00:00, 359.11it/s]


[*] Running Epoch 9 Validation...


100%|██████████| 19/19 [00:00<00:00, 414.86it/s]



=>> Epoch 9 Metrics ===
	Train Loss (MSE): 0.11345
	 Test Loss (MSE): 0.60465
[*] Training Epoch 10...


100%|██████████| 76/76 [00:00<00:00, 364.17it/s]


[*] Running Epoch 10 Validation...


100%|██████████| 19/19 [00:00<00:00, 581.10it/s]



=>> Epoch 10 Metrics ===
	Train Loss (MSE): 0.10805
	 Test Loss (MSE): 0.16906
[*] Training Epoch 11...


100%|██████████| 76/76 [00:00<00:00, 369.56it/s]


[*] Running Epoch 11 Validation...


100%|██████████| 19/19 [00:00<00:00, 560.17it/s]



=>> Epoch 11 Metrics ===
	Train Loss (MSE): 0.12195
	 Test Loss (MSE): 0.10933
[*] Training Epoch 12...


100%|██████████| 76/76 [00:00<00:00, 348.43it/s]


[*] Running Epoch 12 Validation...


100%|██████████| 19/19 [00:00<00:00, 549.76it/s]



=>> Epoch 12 Metrics ===
	Train Loss (MSE): 0.09899
	 Test Loss (MSE): 0.07827
[*] Training Epoch 13...


100%|██████████| 76/76 [00:00<00:00, 337.74it/s]


[*] Running Epoch 13 Validation...


100%|██████████| 19/19 [00:00<00:00, 551.26it/s]



=>> Epoch 13 Metrics ===
	Train Loss (MSE): 0.08241
	 Test Loss (MSE): 0.00561
[*] Training Epoch 14...


100%|██████████| 76/76 [00:00<00:00, 349.81it/s]


[*] Running Epoch 14 Validation...


100%|██████████| 19/19 [00:00<00:00, 578.08it/s]



=>> Epoch 14 Metrics ===
	Train Loss (MSE): 0.05938
	 Test Loss (MSE): 0.01048
[*] Training Epoch 15...


100%|██████████| 76/76 [00:00<00:00, 361.29it/s]


[*] Running Epoch 15 Validation...


100%|██████████| 19/19 [00:00<00:00, 534.28it/s]



=>> Epoch 15 Metrics ===
	Train Loss (MSE): 0.06950
	 Test Loss (MSE): 0.24101
[*] Training Epoch 16...


100%|██████████| 76/76 [00:00<00:00, 351.79it/s]


[*] Running Epoch 16 Validation...


100%|██████████| 19/19 [00:00<00:00, 574.56it/s]



=>> Epoch 16 Metrics ===
	Train Loss (MSE): 0.05916
	 Test Loss (MSE): 0.11834
[*] Training Epoch 17...


100%|██████████| 76/76 [00:00<00:00, 340.31it/s]


[*] Running Epoch 17 Validation...


100%|██████████| 19/19 [00:00<00:00, 560.97it/s]



=>> Epoch 17 Metrics ===
	Train Loss (MSE): 0.06413
	 Test Loss (MSE): 0.18363
[*] Training Epoch 18...


100%|██████████| 76/76 [00:00<00:00, 277.38it/s]


[*] Running Epoch 18 Validation...


100%|██████████| 19/19 [00:00<00:00, 414.11it/s]



=>> Epoch 18 Metrics ===
	Train Loss (MSE): 0.03331
	 Test Loss (MSE): 0.00855
[*] Training Epoch 19...


100%|██████████| 76/76 [00:00<00:00, 246.86it/s]


[*] Running Epoch 19 Validation...


100%|██████████| 19/19 [00:00<00:00, 382.80it/s]



=>> Epoch 19 Metrics ===
	Train Loss (MSE): 0.04843
	 Test Loss (MSE): 0.00143
[*] Training Epoch 20...


100%|██████████| 76/76 [00:00<00:00, 260.76it/s]


[*] Running Epoch 20 Validation...


100%|██████████| 19/19 [00:00<00:00, 419.90it/s]



=>> Epoch 20 Metrics ===
	Train Loss (MSE): 0.04799
	 Test Loss (MSE): 0.07406
[*] Training Epoch 21...


100%|██████████| 76/76 [00:00<00:00, 261.59it/s]


[*] Running Epoch 21 Validation...


100%|██████████| 19/19 [00:00<00:00, 413.58it/s]



=>> Epoch 21 Metrics ===
	Train Loss (MSE): 0.04352
	 Test Loss (MSE): 0.03603
[*] Training Epoch 22...


100%|██████████| 76/76 [00:00<00:00, 257.22it/s]


[*] Running Epoch 22 Validation...


100%|██████████| 19/19 [00:00<00:00, 383.72it/s]



=>> Epoch 22 Metrics ===
	Train Loss (MSE): 0.04145
	 Test Loss (MSE): 0.09056
[*] Training Epoch 23...


100%|██████████| 76/76 [00:00<00:00, 233.93it/s]


[*] Running Epoch 23 Validation...


100%|██████████| 19/19 [00:00<00:00, 375.88it/s]



=>> Epoch 23 Metrics ===
	Train Loss (MSE): 0.03797
	 Test Loss (MSE): 0.04545
[*] Training Epoch 24...


100%|██████████| 76/76 [00:00<00:00, 337.30it/s]


[*] Running Epoch 24 Validation...


100%|██████████| 19/19 [00:00<00:00, 545.85it/s]



=>> Epoch 24 Metrics ===
	Train Loss (MSE): 0.03318
	 Test Loss (MSE): 0.18809
[*] Training Epoch 25...


100%|██████████| 76/76 [00:00<00:00, 367.87it/s]


[*] Running Epoch 25 Validation...


100%|██████████| 19/19 [00:00<00:00, 566.44it/s]



=>> Epoch 25 Metrics ===
	Train Loss (MSE): 0.03785
	 Test Loss (MSE): 0.00729
[*] Training Epoch 26...


100%|██████████| 76/76 [00:00<00:00, 341.91it/s]


[*] Running Epoch 26 Validation...


100%|██████████| 19/19 [00:00<00:00, 529.64it/s]



=>> Epoch 26 Metrics ===
	Train Loss (MSE): 0.02569
	 Test Loss (MSE): 0.07025
[*] Training Epoch 27...


100%|██████████| 76/76 [00:00<00:00, 347.76it/s]


[*] Running Epoch 27 Validation...


100%|██████████| 19/19 [00:00<00:00, 533.02it/s]



=>> Epoch 27 Metrics ===
	Train Loss (MSE): 0.01904
	 Test Loss (MSE): 0.00063
[*] Training Epoch 28...


100%|██████████| 76/76 [00:00<00:00, 350.31it/s]


[*] Running Epoch 28 Validation...


100%|██████████| 19/19 [00:00<00:00, 563.43it/s]



=>> Epoch 28 Metrics ===
	Train Loss (MSE): 0.02409
	 Test Loss (MSE): 0.03122
[*] Training Epoch 29...


100%|██████████| 76/76 [00:00<00:00, 367.59it/s]


[*] Running Epoch 29 Validation...


100%|██████████| 19/19 [00:00<00:00, 546.13it/s]



=>> Epoch 29 Metrics ===
	Train Loss (MSE): 0.03161
	 Test Loss (MSE): 0.00120
[*] Training Epoch 30...


100%|██████████| 76/76 [00:00<00:00, 335.70it/s]


[*] Running Epoch 30 Validation...


100%|██████████| 19/19 [00:00<00:00, 531.77it/s]



=>> Epoch 30 Metrics ===
	Train Loss (MSE): 0.01779
	 Test Loss (MSE): 0.01117
[*] Training Epoch 31...


100%|██████████| 76/76 [00:00<00:00, 359.69it/s]


[*] Running Epoch 31 Validation...


100%|██████████| 19/19 [00:00<00:00, 538.28it/s]



=>> Epoch 31 Metrics ===
	Train Loss (MSE): 0.02065
	 Test Loss (MSE): 0.01526
[*] Training Epoch 32...


100%|██████████| 76/76 [00:00<00:00, 357.17it/s]


[*] Running Epoch 32 Validation...


100%|██████████| 19/19 [00:00<00:00, 558.48it/s]



=>> Epoch 32 Metrics ===
	Train Loss (MSE): 0.01950
	 Test Loss (MSE): 0.01762
[*] Training Epoch 33...


100%|██████████| 76/76 [00:00<00:00, 340.68it/s]


[*] Running Epoch 33 Validation...


100%|██████████| 19/19 [00:00<00:00, 546.33it/s]



=>> Epoch 33 Metrics ===
	Train Loss (MSE): 0.01507
	 Test Loss (MSE): 0.00111
[*] Training Epoch 34...


100%|██████████| 76/76 [00:00<00:00, 365.97it/s]


[*] Running Epoch 34 Validation...


100%|██████████| 19/19 [00:00<00:00, 536.80it/s]



=>> Epoch 34 Metrics ===
	Train Loss (MSE): 0.02173
	 Test Loss (MSE): 0.02188
[*] Training Epoch 35...


100%|██████████| 76/76 [00:00<00:00, 362.05it/s]


[*] Running Epoch 35 Validation...


100%|██████████| 19/19 [00:00<00:00, 558.59it/s]



=>> Epoch 35 Metrics ===
	Train Loss (MSE): 0.01269
	 Test Loss (MSE): 0.00461
[*] Training Epoch 36...


100%|██████████| 76/76 [00:00<00:00, 366.13it/s]


[*] Running Epoch 36 Validation...


100%|██████████| 19/19 [00:00<00:00, 556.11it/s]



=>> Epoch 36 Metrics ===
	Train Loss (MSE): 0.01207
	 Test Loss (MSE): 0.00287
[*] Training Epoch 37...


100%|██████████| 76/76 [00:00<00:00, 318.33it/s]


[*] Running Epoch 37 Validation...


100%|██████████| 19/19 [00:00<00:00, 554.69it/s]



=>> Epoch 37 Metrics ===
	Train Loss (MSE): 0.01512
	 Test Loss (MSE): 0.00296
[*] Training Epoch 38...


100%|██████████| 76/76 [00:00<00:00, 350.54it/s]


[*] Running Epoch 38 Validation...


100%|██████████| 19/19 [00:00<00:00, 536.39it/s]



=>> Epoch 38 Metrics ===
	Train Loss (MSE): 0.01288
	 Test Loss (MSE): 0.01814
[*] Training Epoch 39...


100%|██████████| 76/76 [00:00<00:00, 352.05it/s]


[*] Running Epoch 39 Validation...


100%|██████████| 19/19 [00:00<00:00, 540.22it/s]



=>> Epoch 39 Metrics ===
	Train Loss (MSE): 0.01319
	 Test Loss (MSE): 0.00710
[*] Training Epoch 40...


100%|██████████| 76/76 [00:00<00:00, 367.42it/s]


[*] Running Epoch 40 Validation...


100%|██████████| 19/19 [00:00<00:00, 546.26it/s]



=>> Epoch 40 Metrics ===
	Train Loss (MSE): 0.01592
	 Test Loss (MSE): 0.01143
[*] Training Epoch 41...


100%|██████████| 76/76 [00:00<00:00, 343.58it/s]


[*] Running Epoch 41 Validation...


100%|██████████| 19/19 [00:00<00:00, 554.62it/s]



=>> Epoch 41 Metrics ===
	Train Loss (MSE): 0.01762
	 Test Loss (MSE): 0.00891
[*] Training Epoch 42...


100%|██████████| 76/76 [00:00<00:00, 352.89it/s]


[*] Running Epoch 42 Validation...


100%|██████████| 19/19 [00:00<00:00, 556.99it/s]



=>> Epoch 42 Metrics ===
	Train Loss (MSE): 0.01210
	 Test Loss (MSE): 0.01255
[*] Training Epoch 43...


100%|██████████| 76/76 [00:00<00:00, 371.53it/s]


[*] Running Epoch 43 Validation...


100%|██████████| 19/19 [00:00<00:00, 558.68it/s]



=>> Epoch 43 Metrics ===
	Train Loss (MSE): 0.00951
	 Test Loss (MSE): 0.00464
[*] Training Epoch 44...


100%|██████████| 76/76 [00:00<00:00, 340.30it/s]


[*] Running Epoch 44 Validation...


100%|██████████| 19/19 [00:00<00:00, 533.72it/s]



=>> Epoch 44 Metrics ===
	Train Loss (MSE): 0.01032
	 Test Loss (MSE): 0.01130
[*] Training Epoch 45...


100%|██████████| 76/76 [00:00<00:00, 350.02it/s]


[*] Running Epoch 45 Validation...


100%|██████████| 19/19 [00:00<00:00, 512.00it/s]



=>> Epoch 45 Metrics ===
	Train Loss (MSE): 0.00813
	 Test Loss (MSE): 0.00094
[*] Training Epoch 46...


100%|██████████| 76/76 [00:00<00:00, 348.83it/s]


[*] Running Epoch 46 Validation...


100%|██████████| 19/19 [00:00<00:00, 538.73it/s]



=>> Epoch 46 Metrics ===
	Train Loss (MSE): 0.00901
	 Test Loss (MSE): 0.01361
[*] Training Epoch 47...


100%|██████████| 76/76 [00:00<00:00, 358.88it/s]


[*] Running Epoch 47 Validation...


100%|██████████| 19/19 [00:00<00:00, 546.89it/s]



=>> Epoch 47 Metrics ===
	Train Loss (MSE): 0.01092
	 Test Loss (MSE): 0.00128
[*] Training Epoch 48...


100%|██████████| 76/76 [00:00<00:00, 336.00it/s]


[*] Running Epoch 48 Validation...


100%|██████████| 19/19 [00:00<00:00, 545.67it/s]



=>> Epoch 48 Metrics ===
	Train Loss (MSE): 0.00868
	 Test Loss (MSE): 0.00792
[*] Training Epoch 49...


100%|██████████| 76/76 [00:00<00:00, 349.23it/s]


[*] Running Epoch 49 Validation...


100%|██████████| 19/19 [00:00<00:00, 533.65it/s]



=>> Epoch 49 Metrics ===
	Train Loss (MSE): 0.01074
	 Test Loss (MSE): 0.06159
[*] Training Epoch 50...


100%|██████████| 76/76 [00:00<00:00, 349.88it/s]


[*] Running Epoch 50 Validation...


100%|██████████| 19/19 [00:00<00:00, 538.86it/s]


=>> Epoch 50 Metrics ===
	Train Loss (MSE): 0.00634
	 Test Loss (MSE): 0.00110



