In [11]:
import jax
import jax.ops
import jax.numpy as jnp

import flax
from flax import linen as nn
from flax import optim
from flax.linen.recurrent import RNNCellBase

import optax

import numpy as np  # convention: original numpy

from typing import Any, Callable, Sequence, Optional, Tuple, Union
from collections import defaultdict


In [12]:
seed = 1701
key = jax.random.PRNGKey(seed)


In [13]:
num_copies = 5
rng, key2, key3, key4, key5 = jax.random.split(key, num=num_copies)


## RNN Cells

In [None]:
class RNNCell(RNNCellBase):
    @nn.compact()
    def __call__(self, carry, input):
        """
        Description:
            W_xh = x_{t} @ W_{xh} - multiply the previous hidden state with
            W_hh = H_{t-1} @ W_{hh} + b_{h} - this a linear layer

            H_{t} = f_{w}(H_{t-1}, x)
            H_{t} = tanh(H_{t-1} @ W_{hh}) + (x_{t} @ W_{xh})

        Args:
            carry (jnp.ndarray): hidden state from previous time step
            input (jnp.ndarray): # input vector

        Returns:
            A tuple with the new carry and the output.
        """
        ht_1 = carry

        h_t = self.rnn_update(input, ht_1)

        return h_t, h_t

    def rnn_update(self, input, ht_1):

        W_hh = nn.Dense(ht_1.shape[0])(ht_1)
        W_xh = nn.Dense(input.shape[0])(input)
        h_t = jnp.relu(W_hh + W_xh)  # H_{t} = tanh(H_{t-1} @ W_{hh}) + (x_{t} @ W_{xh})

        return h_t


In [None]:
class LSTMCell(RNNCellBase):
    @nn.compact()
    def __call__(self, carry, input):
        """
        Description:
            i_{t} = sigmoid((W_{ii} @ x_{t} + b_{ii}) + (W_{hi} @ h_{t-1} + b_{hi}))
            f_{t} = sigmoid((W_{if} @ x_{t} + b_{if}) + (W_{hf} @ h_{t-1} + b_{hf}))
            g_{t} = tanh((W_{ig} @ x_{t} + b_{ig}) + (W_{hg} @ h_{t-1} + b_{hg}))
            o_{t} = sigmoid((W_{io} @ x_{t} + b_{io}) + (W_{ho} @ h_{t-1} + b_{ho}))
            c_{t} = f_{t} * c_{t-1} + i_{t} * g_{t}
            h_{t} = o_{t} * tanh(c_{t})

        Args:
            carry (jnp.ndarray): hidden state from previous time step
            input (jnp.ndarray): # input vector

        Returns:
            A tuple with the new carry and the output.
        """
        ht_1, ct_1 = carry

        c_t, h_t = self.rnn_update(input, ht_1, ct_1)

        return (h_t, c_t), h_t

    def rnn_update(self, input, ht_1, ct_1):

        i_ta = nn.Dense(input.shape[0])(input)
        i_tb = nn.Dense(ht_1.shape[0])(ht_1)
        i_t = jnp.sigmoid(i_ta + i_tb)  # input gate

        o_ta = nn.Dense(input.shape[0])(input)
        o_tb = nn.Dense(ht_1.shape[0])(ht_1)
        o_t = jnp.sigmoid(o_ta + o_tb)  # output gate

        f_ia = nn.Dense(input.shape[0])(
            input
        )  # b^{f}_{i} + \sum\limits_{j} U^{f}_{i, j} x^{t}_{j}
        f_ib = nn.Dense(ht_1.shape[0])(
            ht_1
        )  # \sum\limits_{j} W^{f}_{i, j} h^{(t-1)}_{j}
        f_i = jnp.sigmoid(f_ia + f_ib)  # forget gate

        g_ia = nn.Dense(input.shape[0])(
            input
        )  # b^{g}_{i} + \sum\limits_{j} U^{g}_{i, j} x^{t}_{j}
        g_ib = nn.Dense(ht_1.shape[0])(
            ht_1
        )  # \sum\limits_{j} W^{g}_{i, j} h^{(t-1)}_{j}
        g_i = jnp.tanh(g_ia + g_ib)  # (external) input gate

        c_t = (f_i * ct_1) + (i_t * g_i)  # internal cell state update

        h_t = o_t * jnp.tanh(c_t)  # hidden state update

        return h_t, c_t


In [None]:
class GRUCell(RNNCellBase):
    @nn.compact()
    def __call__(self, carry, input):
        """
        Description:
            z_t = sigmoid((W_{iz} @ x_{t} + b_{iz}) + (W_{hz} @ h_{t-1} + b_{hz}))
            r_t = sigmoid((W_{ir} @ x_{t} + b_{ir}) + (W_{hr} @ h_{t-1} + b_{hr}))
            g_t = tanh(((W_{ig} @ x_{t} + b_{ig}) + r_t) * (W_{hg} @ h_{t-1} + b_{hg}))
            h_t = (z_t * h_{t-1}) + ((1 - z_t) * g_i)

        Args:
            carry (jnp.ndarray): hidden state from previous time step
            input (jnp.ndarray): # input vector

        Returns:
            A tuple with the new carry and the output.
        """
        ht_1 = carry

        h_t = self.rnn_update(input, ht_1)

        return h_t, h_t

    def rnn_update(self, input, ht_1):

        z_ta = nn.Dense(input.shape[0])(input)
        z_tb = nn.Dense(ht_1.shape[0])(ht_1)
        z_t = jnp.sigmoid(z_ta + z_tb)  # reset gate

        r_ta = nn.Dense(input.shape[0])(input)
        r_tb = nn.Dense(ht_1.shape[0])(ht_1)
        r_t = jnp.sigmoid(r_ta + r_tb)  # update gate

        g_ta = nn.Dense(input.shape[0])(input)
        g_tb = nn.Dense(ht_1.shape[0])(ht_1)
        g_t = jnp.tanh((g_ta + r_t) * g_tb)  # (external) input gate

        h_t = ((1 - z_t) * ht_1) + (z_t * g_t)  # internal cell state update

        return h_t


In [None]:
class HiPPOCell(nn.module):
    '''
        Description:
            RNN update function
            τ(h, x) = (1 - g(h, x)) ◦ h + g(h, x) ◦ tanh(Lτ (h, x)) 
            g(h, x) = σ(Lg(h,x))

        Args:
            hippo (HiPPO): # hippo model object
            cell (RNNCellBase): choice of RNN cell object
                - RNNCell 
                - LSTMCell
                - GRUCell
    '''
    hippo: HiPPO
    cell: RNNCellBase
        
    @nn.compact() 
    def __call__(self, carry, input):
        '''
        Description:
            RNN update function
            τ(h, x) = (1 - g(h, x)) ◦ h + g(h, x) ◦ tanh(Lτ (h, x)) 
            g(h, x) = σ(Lg(h,x))
            
        Args:
            carry (jnp.ndarray): hidden state from previous time step
            input (jnp.ndarray): # input vector
            
        Returns:
            A tuple with the new carry and the output.
        '''
        
        _, h_t = self.cell(carry, input)
        
        y_t = nn.Dense(input.shape[0])(h_t) # f_t in the paper
        
        c_t = self.hippo(y_t, init_state=None, kernel=False)
        
        return (h_t, c_t), h_t
        
        

## Deep RNN 

In [None]:
# TODO: refer to https://github.com/deepmind/dm-haiku/blob/main/haiku/_src/recurrent.py#L714-L762
# also refer to https://dm-haiku.readthedocs.io/en/latest/api.html?highlight=DeepRNN#deeprnn
class _DeepRNN():
    layers: Sequence[Any]
    skip_connections: bool
    name: Optional[str]
    
    def setup(self):
        if self.skip_connections:
            for layer in self.layers:
                assert isinstance(layer, nn.Module), "layer must be a nn.Module or an instance of a subclass of nn.Module"
                
    
    def __call__(self, carry, input):
        current_inputs = input
        ht_1, ct_1 = carry
        next_states = []
        outputs = []
        state_idx = 0
        concat = lambda *args: jnp.concatenate(args, axis=-1)
        for idx, layer in enumerate(self.layers):
        if self.skip_connections and idx > 0:
            current_inputs = jax.tree_map(concat, inputs, current_inputs)

        if isinstance(layer, RNNCore):
            current_inputs, next_state = layer(current_inputs, state[state_idx])
            outputs.append(current_inputs)
            next_states.append(next_state)
            state_idx += 1
        else:
            current_inputs = layer(current_inputs)

        if self.skip_connections:
        out = jax.tree_map(concat, *outputs)
        else:
        out = current_inputs

        return out, tuple(next_states)

    def initial_state(self, batch_size: Optional[int]):
        return tuple(
            layer.initial_state(batch_size)
            for layer in self.layers
            if isinstance(layer, RNNCore))
        
        pass

## RNN Training Types
refer to [this](https://stanford.edu/~shervine/teaching/cs-230/cheatsheet-recurrent-neural-networks)

In [None]:
def train_one2one(params, init_carry, input, sequence_length, cell):
    _, h_t = cell.apply(params, init_carry, input)
    y_t = nn.Dense(input.shape[0])(h_t)  # output
    return y_t


In [None]:
def train_one2many(params, init_carry, input, T_y, cell, tf_bool=True):
    output = []
    if not tf_bool:
        for i in range(len(T_y)):
            if i == 0:
                carry, h_t = cell.apply(params, init_carry, input[i])
            else:
                carry, h_t = cell.apply(params, carry, output[i - 1])

            y_t = nn.Dense(input[i].shape[0])(h_t)  # output
            output.append(y_t)
    else:
        for i in range(len(input)):
            carry, h_t = cell.apply(params, init_carry, input[i])
            y_t = nn.Dense(input[i].shape[0])(h_t)  # output
            output.append(y_t)

    return output


In [None]:
def train_many2one(params, init_carry, input, T_x, cell, tf_bool=True):
    y_t = None
    carry = init_carry
    for i in range(len(T_x)):
        carry, h_t = cell.apply(params, carry, input[i])
        if i == (len(T_x) - 1):
            y_t = nn.Dense(input[i].shape[0])(h_t)  # output

    return y_t


In [None]:
def train_many2many(params, init_carry, input, T_xy, cell, tf_bool=True):
    output = []
    carry = init_carry
    for i in range(len(T_xy)):
        carry, h_t = cell.apply(params, carry, input[i])
        y_t = nn.Dense(input[i].shape[0])(h_t)  # output
        output.append(y_t)

    return output


In [None]:
def train_many2many(params, init_carry, input, T_x, T_y, cell, tf_bool=True):
    assert T_x != T_y, "T_x and T_y must be different"
    output = []
    carry = init_carry
    for i in range(len(T_x)):
        carry, h_t = cell.apply(params, carry, input[i])
        
    for i in range(len(T_y)):
        carry, h_t = cell.apply(params, carry, input[i])
        y_t = nn.Dense(input[i].shape[0])(h_t)  # output
        output.append(y_t)
    
    return output
    

## Data Preprocessing


In [None]:
def map_data_2_id(iterable):
    """
    provides mapping to and from ids
    """
    id_2_data = {}
    data_2_id = {}

    for id, elem in enumerate(iterable):
        id_2_data[id] = elem

    for id, elem in enumerate(iterable):
        data_2_id[elem] = id

    return (id_2_data, data_2_id)


In [None]:
def one_hot(i, n):
    """
    create vector of size n with 1 at index i
    """
    x = defaultdict(lambda: jnp.zeros(n))
    return x[].at[i].set(1)


In [None]:
def encode(char):
    return one_hot(data_2_id[char], len(data_2_id))


In [None]:
def decode(predictions, id_2_data):
    return id_2_data[int(jnp.argmax(predictions))]


## Optimizer Helpers


check [this](https://github.com/deepmind/optax/blob/master/examples/quick_start.ipynb) out

In [None]:
def pick_optimizer_fn(name, starting_learning_rate):
    # refer to https://optax.readthedocs.io/en/latest/api.html#optimizer-schedules
    optim = None

    if name == "sgd":
        optim = optax.sgd(starting_learning_rate)
    elif name == "adam":
        optim = optax.adam(starting_learning_rate)
    elif name == "adagrad":
        optim = optax.adagrad(starting_learning_rate)
    elif name == "rmsprop":
        optim = optax.rmsprop(starting_learning_rate)
    else:
        raise ValueError("optimizer name not recognized")

    return optim


In [None]:
def pick_scheduler_fn(start_learning_rate, steps, decay_rate, init_value, end_val):
    # refer to https://optax.readthedocs.io/en/latest/api.html#schedules
    scheduler = None

    if name == "constant":
        scheduler = optax.constant_schedule(value)

    elif name == "exp_decay":
        scheduler = optax.exponential_decay(
            init_value=start_learning_rate, transition_steps=1000, decay_rate=0.99
        )
    elif name == "linear":
        scheduler = optax.linear_schedule(
            init_value=init_value, end_value = end_val, transition_steps
        )
    
    return scheduler


In [None]:
# TODO: add transformations
# refer to https://optax.readthedocs.io/en/latest/api.html#optax-transformations


In [None]:
#     # A simple update loop.
#     for _ in range(1000):
#     grads = jax.grad(compute_loss)(params, xs, ys)
#     updates, opt_state = gradient_transform.update(grads, opt_state)
#     params = optax.apply_updates(params, updates)

#     assert jnp.allclose(params, target_params), \
#     'Optimization should retrieve the target params used to generate the data.'


In [None]:
# def optimizer_fn(start_learning_rate, params, num_weights, x, y):
#     optimizer = optax.adam(start_learning_rate)
#     # Obtain the `opt_state` that contains statistics for the optimizer.
#     params = {'w': jnp.ones((num_weights,))}
#     opt_state = optimizer.init(params)
    
#     compute_loss = lambda params, x, y: optax.l2_loss(params['w'].dot(x), y)
#     grads = jax.grad(compute_loss)(params, xs, ys)
    
#     updates, opt_state = optimizer.update(grads, opt_state)
#     params = optax.apply_updates(params, updates)
    
    

## Training Loop


In [None]:
def train():
    pass
