# Prediction of time series with recurrent neural networks

https://github.com/google/flax/blob/main/examples/seq2seq/models.py

* We don't need vocabulary size
* We don't need one-hot encoding
* We don't need to sample from the output of the decoder

In [2]:
import jax
import jax.numpy as jnp
import flax.linen as nn
from flax.training import train_state
import optax
import numpy as np

Array = jax.Array
PRNGKey = jax.Array
CellCarry = tuple[Array, Array]


# Set random seed for reproducibility
root_key = jax.random.PRNGKey(0)


# Generate multidimensional sequential data
def generate_multidim_data(input_seq_length, output_seq_length):
    num_points = input_seq_length + output_seq_length
    t = np.linspace(0, 8 * jnp.pi, num_points)
    X = np.empty((num_points, input_dim))

    # Generate different patterns for each dimension
    X[:, 0] = np.sin(t)
    X[:, 1] = np.cos(t)
    X[:, 2] = np.sin(2 * t)

    # let's have one point overlap between input and output
    Y = X[input_seq_length-1:, 0] + X[input_seq_length-1:, 1] + X[input_seq_length:-1, 2]

    return jnp.array(X[:input_seq_length]), jnp.array(Y)


class DecoderCell(nn.RNNCellBase):
    features: int

    @nn.compact
    def __call__(self, carry, x):
        state, last_prediction = carry
        


class Seq2seq(nn.Module):
    """Sequence-to-sequence class using encoder/decoder architecture."""

    encoder_size: int
    decoder_size: int

    @nn.compact
    def __call__(self, X: Array, y: Array) -> tuple[Array, Array]:
        """Applies the seq2seq model."""
        # X shape (batch size, input length, input dims)
        # y shape (batch size, output length)
        encoder = nn.RNN(
            nn.GRUCell(self.encoder_size), return_carry=True, name="encoder"
        )
        decoder = nn.RNN(nn.GRUCell(self.decoder_size), name="decoder")

        state, _ = encoder(X)
        y = decoder(state, y)
        return y


# Create and initialize the model
def create_train_state(key, encoder_size, decoder_size, learning_rate, input_shape, output_size):
    model = Seq2seq(encoder_size, decoder_size)
    params = model.init(key, jnp.ones(input_shape), output_size)["params"]
    tx = optax.nadamw(learning_rate)
    return train_state.TrainState.create(apply_fn=model.apply, params=params, tx=tx)


# Define loss function
def mse_loss(params, model, X, y):
    pred = model.apply(params, X, y[:, 0], y.shape[-1])
    return optax.squared_error(pred, y[:, 1:]).mean()


# Training step
@jax.jit
def train_step(state, X, y):
    loss, grads = jax.value_and_grad(mse_loss)(state.params, state.apply_fn, X, y)
    state = state.apply_gradients(grads=grads)
    return state, loss


# Main training loop
def train_lstm(input_dim, input_seq_length, output_seq_length):
    # Hyperparameters
    hidden_size = 64
    learning_rate = 0.01
    num_epochs = 1000
    batch_size = 32

    # Generate data
    X, y = generate_multidim_data(2000, input_dim, input_seq_length, output_seq_length)

    # Create and initialize model
    state = create_train_state(
        key, hidden_size, input_dim, output_seq_length, learning_rate, X.shape[1:]
    )

    # Training loop
    for epoch in range(num_epochs):
        for i in range(0, len(X), batch_size):
            batch_X = X[i : i + batch_size]
            batch_y = y[i : i + batch_size]
            state, loss = train_step(state, batch_X, batch_y)

        if epoch % 100 == 0:
            print(f"Epoch {epoch}, Loss: {loss}")

    return state


# Run training
input_dim = 5  # Number of input dimensions
input_seq_length = 20  # Length of input sequence (K)
output_seq_length = 10  # Length of output sequence (N)
trained_state = train_lstm(input_dim, input_seq_length, output_seq_length)


# Evaluate the model
def evaluate_model(state, X, y):
    predictions = state.apply_fn({"params": state.params}, X)
    mse = jnp.mean((predictions - y) ** 2)
    print(f"Mean Squared Error: {mse}")

    # Plot results for the first dimension
    import matplotlib.pyplot as plt

    plt.figure(figsize=(12, 6))
    plt.plot(y[0, :, 0], label="True")
    plt.plot(predictions[0, :, 0], label="Predicted")
    plt.legend()
    plt.title("LSTM Multidimensional Sequence-to-Sequence Prediction")
    plt.xlabel("Time Steps")
    plt.ylabel("Value")
    plt.show()


# Generate test data and evaluate
X_test, y_test = generate_multidim_data(
    200, input_dim, input_seq_length, output_seq_length
)

evaluate_model(trained_state, X_test, y_test)

TypeError: LSTMCell.__init__() missing 1 required positional argument: 'features'