# Recurrent Neural Networks

### How is an RNN trained, with each token or each sequence?

In [4]:
from typing import List

import numpy
import jax
import jax.numpy as jnp
import flax.linen as nn
from jax.nn.initializers import lecun_normal

In [5]:
class RecurrentLayer(nn.Module):

    @nn.compact
    def __call__(self, hidden_state, x):
        hidden_dim = hidden_state.shape[0]
        hidden_transformation = nn.Dense(features=hidden_dim, kernel_init = lecun_normal())(hidden_state)
        x_transformation = nn.Dense(features = hidden_dim, kernel_init = lecun_normal())(x)
        updated_hidden_state = hidden_state + x
        x = nn.tanh(updated_hidden_state)
        return updated_hidden_state, x

class RNN(nn.Module):
    linear_sizes : List[int]

    @nn.compact 
    def __call__(self, hidden_states, x):
        updated_hidden_states = []
        for hidden_state in hidden_states:
            updated_hidden_state, x = RecurrentLayer()(hidden_state, x)
            updated_hidden_states.append(updated_hidden_state)

        for dim in range(self.linear_sizes[:-1]):
            x = nn.Dense(features = dim, kernel_init = lecun_normal())(x)
            x = nn.relu(x)
        x = nn.Dense(features = self.linear_sizes[:-1], kernel_init=lecun_normal())(x)
        x = nn.softmax(x)
        
        return jnp.array(updated_hidden_states), x


In [7]:
layers = [128, 128, 3]
model = RNN(linear_sizes=layers)
rng = jax.random.PRNGKey(42)
rng, hidden_init_rng = jax.random.split(rng, 2)
hidden_states = lecun_normal()(hidden_init_rng, shape = ((3, 128)), dtype=jnp.float64)

batch_in = jnp.ones((10, 10))
rng, init_rng = jax.random.split(rng, 2)
params = model.init(init_rng, hidden_states, batch_in)

# LSTMs

In [None]:
class LSTMCell(nn.Module):

    @nn.compact
    def __call__(self, hidden_state, long_term_memory, x):
        concat = jnp.concatenate([hidden_state, x], axis=1)
        forget_gate = nn.sigmoid(nn.Dense(features=long_term_memory.shape[0], kernel_init=lecun_normal())(concat))
        intermediate_c = long_term_memory * forget_gate
        
        input_gate = nn.sigmoid(nn.Dense(features=long_term_memory.shape[0], kernel_init=lecun_normal())(concat))
        candidate_gate = nn.tanh(nn.Dense(features=long_term_memory.shape[0], kernel_init=lecun_normal())(concat))
        c_t = input_gate * candidate_gate + intermediate_c

        # How much of long term memory should be added to hidden state
        output_gate = nn.sigmoid(nn.Dense(features=long_term_memory.shape[0], kernel_init=lecun_normal())(concat))
        
        h_t = nn.tanh(nn.Dense(features=hidden_state.shape[0], kernel_init=lecun_normal())(output_gate * nn.tanh(c_t)))
        return h_t, c_t