In [11]:
import jax
import jax.ops
import jax.numpy as jnp

import flax
from flax import linen as nn
from flax import optim
from flax.linen.recurrent import RNNCellBase

import optax

import numpy as np  # convention: original numpy

from typing import Any, Callable, Sequence, Optional, Tuple, Union
from collections import defaultdict


In [12]:
seed = 1701
key = jax.random.PRNGKey(seed)


In [13]:
num_copies = 5
rng, key2, key3, key4, key5 = jax.random.split(key, num=num_copies)


## Helper Functions

In [None]:
def add_batch(nest, batch_size: Optional[int]):
    """Adds a batch dimension at axis 0 to the leaves of a nested structure."""
    broadcast = lambda x: jnp.broadcast_to(x, (batch_size,) + x.shape)

    return jax.tree_map(broadcast, nest)


## RNN Cells

In [None]:
class RNNCell(RNNCellBase):
    hidden_size: int
    output_size: int

    @nn.compact()
    def __call__(self, carry, input):
        """
        Description:
            W_xh = x_{t} @ W_{xh} - multiply the previous hidden state with
            W_hh = H_{t-1} @ W_{hh} + b_{h} - this a linear layer

            H_{t} = f_{w}(H_{t-1}, x)
            H_{t} = tanh(H_{t-1} @ W_{hh}) + (x_{t} @ W_{xh})

        Args:
            carry (jnp.ndarray): hidden state from previous time step
            input (jnp.ndarray): # input vector

        Returns:
            A tuple with the new carry and the output.
        """
        ht_1, _ = carry

        h_t = self.rnn_update(input, ht_1)

        return (h_t, h_t), h_t

    def rnn_update(self, input, ht_1):

        W_hh = nn.Dense(self.hidden_size)(ht_1)
        W_xh = nn.Dense(self.hidden_size)(input)
        h_t = nn.relu(W_hh + W_xh)  # H_{t} = tanh(H_{t-1} @ W_{hh}) + (x_{t} @ W_{xh})

        return h_t

    def initial_state(self, batch_size: Optional[int]):
        state = jnp.zeros([self.hidden_size])
        if batch_size is not None:
            state = add_batch(state, batch_size)
        return state


In [None]:
class LSTMCell(RNNCellBase):
    hidden_size: int
    output_size: int

    @nn.compact()
    def __call__(self, carry, input):
        """
        Description:
            i_{t} = sigmoid((W_{ii} @ x_{t} + b_{ii}) + (W_{hi} @ h_{t-1} + b_{hi}))
            f_{t} = sigmoid((W_{if} @ x_{t} + b_{if}) + (W_{hf} @ h_{t-1} + b_{hf}))
            g_{t} = tanh((W_{ig} @ x_{t} + b_{ig}) + (W_{hg} @ h_{t-1} + b_{hg}))
            o_{t} = sigmoid((W_{io} @ x_{t} + b_{io}) + (W_{ho} @ h_{t-1} + b_{ho}))
            c_{t} = f_{t} * c_{t-1} + i_{t} * g_{t}
            h_{t} = o_{t} * tanh(c_{t})

        Args:
            carry (jnp.ndarray): hidden state from previous time step
            input (jnp.ndarray): # input vector

        Returns:
            A tuple with the new carry and the output.
        """
        ht_1, ct_1 = carry

        c_t, h_t = self.rnn_update(input, ht_1, ct_1)

        return (h_t, c_t), h_t

    def rnn_update(self, input, ht_1, ct_1):

        i_ta = nn.Dense(self.hidden_size)(input)
        i_tb = nn.Dense(self.hidden_size)(ht_1)
        i_t = nn.sigmoid(i_ta + i_tb)  # input gate

        o_ta = nn.Dense(self.hidden_size)(input)
        o_tb = nn.Dense(self.hidden_size)(ht_1)
        o_t = nn.sigmoid(o_ta + o_tb)  # output gate

        f_ia = nn.Dense(self.hidden_size)(
            input
        )  # b^{f}_{i} + \sum\limits_{j} U^{f}_{i, j} x^{t}_{j}
        f_ib = nn.Dense(self.hidden_size)(
            ht_1
        )  # \sum\limits_{j} W^{f}_{i, j} h^{(t-1)}_{j}
        f_i = nn.sigmoid(f_ia + f_ib)  # forget gate

        g_ia = nn.Dense(self.hidden_size)(
            input
        )  # b^{g}_{i} + \sum\limits_{j} U^{g}_{i, j} x^{t}_{j}
        g_ib = nn.Dense(self.hidden_size)(
            ht_1
        )  # \sum\limits_{j} W^{g}_{i, j} h^{(t-1)}_{j}
        g_i = nn.tanh(g_ia + g_ib)  # (external) input gate

        c_t = (f_i * ct_1) + (i_t * g_i)  # internal cell state update

        h_t = o_t * nn.tanh(c_t)  # hidden state update

        return h_t, c_t

    def initial_state(self, batch_size: Optional[int]):
        state = jnp.zeros([self.hidden_size])
        if batch_size is not None:
            state = add_batch(state, batch_size)
        return state


In [None]:
class GRUCell(RNNCellBase):
    hidden_size: int

    @nn.compact()
    def __call__(self, carry, input):
        """
        Description:
            z_t = sigmoid((W_{iz} @ x_{t} + b_{iz}) + (W_{hz} @ h_{t-1} + b_{hz}))
            r_t = sigmoid((W_{ir} @ x_{t} + b_{ir}) + (W_{hr} @ h_{t-1} + b_{hr}))
            g_t = tanh(((W_{ig} @ x_{t} + b_{ig}) + r_t) * (W_{hg} @ h_{t-1} + b_{hg}))
            h_t = (z_t * h_{t-1}) + ((1 - z_t) * g_i)

        Args:
            carry (jnp.ndarray): hidden state from previous time step
            input (jnp.ndarray): # input vector

        Returns:
            A tuple with the new carry and the output.
        """
        ht_1 = carry

        h_t = self.rnn_update(input, ht_1)

        return (h_t, h_t), h_t

    def rnn_update(self, input, ht_1):

        z_ta = nn.Dense(self.hidden_size)(input)
        z_tb = nn.Dense(self.hidden_size)(ht_1)
        z_t = nn.sigmoid(z_ta + z_tb)  # reset gate

        r_ta = nn.Dense(self.hidden_size)(input)
        r_tb = nn.Dense(self.hidden_size)(ht_1)
        r_t = nn.sigmoid(r_ta + r_tb)  # update gate

        g_ta = nn.Dense(self.hidden_size)(input)
        g_tb = nn.Dense(self.hidden_size)(ht_1)
        g_t = nn.tanh((g_ta + r_t) * g_tb)  # (external) input gate

        h_t = ((1 - z_t) * ht_1) + (z_t * g_t)  # internal cell state update

        return h_t

    def initial_state(self, batch_size: Optional[int]):
        state = jnp.zeros([self.hidden_size])
        if batch_size is not None:
            state = add_batch(state, batch_size)
        return state


In [None]:
class HiPPOCell(nn.module):
    """
    Description:
        RNN update function
        τ(h, x) = (1 - g(h, x)) ◦ h + g(h, x) ◦ tanh(Lτ (h, x))
        g(h, x) = σ(Lg(h,x))

    Args:
        hippo (HiPPO): # hippo model object
        cell (RNNCellBase): choice of RNN cell object
            - RNNCell
            - LSTMCell
            - GRUCell
    """

    hidden_size: int
    output_size: int

    hippo: HiPPO
    cell: RNNCellBase

    @nn.compact()
    def __call__(self, carry, input):
        """
        Description:
            RNN update function
            τ(h, x) = (1 - g(h, x)) ◦ h + g(h, x) ◦ tanh(Lτ (h, x))
            g(h, x) = σ(Lg(h,x))

        Args:
            carry (jnp.ndarray): hidden state from previous time step
            input (jnp.ndarray): # input vector

        Returns:
            A tuple with the new carry and the output.
        """

        _, h_t = self.cell(carry, input)

        y_t = nn.Dense(self.output_size)(h_t)  # f_t in the paper

        c_t = self.hippo(y_t, init_state=None, kernel=False)

        return (h_t, c_t), h_t

    def initial_state(self, batch_size: Optional[int]):
        state = jnp.zeros([self.hidden_size])
        if batch_size is not None:
            state = add_batch(state, batch_size)
        return state


## Deep RNN 

In [None]:
# TODO: refer to https://github.com/deepmind/dm-haiku/blob/main/haiku/_src/recurrent.py#L714-L762
# also refer to https://dm-haiku.readthedocs.io/en/latest/api.html?highlight=DeepRNN#deeprnn
class _DeepRNN(RNNCellBase):
    layers: Sequence[Any]
    skip_connections: bool
    hidden_to_output_layer: bool
    name: Optional[str]

    def setup(self):
        if self.skip_connections:
            for layer in self.layers:
                if not (isinstance(layer, RNNCellBase) or isinstance(layer, HiPPOCell)):
                    raise ValueError(
                        f"{self.name} layer {layer} is not a RNNCellBase or HiPPOCell"
                    )

    def __call__(self, carry, inputs):
        current_carry = carry
        next_states = []
        h_t_outputs = []
        c_t_outputs = []
        state_idx = 0
        h_t, c_t = current_carry  # c_t may actually be h_t in which case dont use it
        (
            h_t_copy,
            c_t_copy,
        ) = current_carry  # c_t may actually be h_t in which case dont use it
        concat = lambda *args: jnp.concatenate(args, axis=-1)
        for idx, layer in enumerate(self.layers):
            if self.skip_connections and idx > 0:
                skip_h_t = jax.tree_map(concat, h_t, h_t_copy)
                skip_c_t = jax.tree_map(concat, c_t, c_t_copy)
                current_carry = (skip_h_t, skip_c_t)

            if isinstance(layer, RNNCellBase) or isinstance(layer, HiPPOCell):
                current_carry, next_state = layer(current_carry, inputs[state_idx])
                if self.hidden_to_output_layer:
                    next_state = nn.Dense(next_state.shape[0])(next_state)

                h_t, c_t = current_carry
                h_t_outputs.append(h_t)
                c_t_outputs.append(c_t)
                next_states.append(next_state)
                state_idx += 1

            else:
                current_carry = layer(current_carry)

        if self.skip_connections:
            skip_h_t_out = jax.tree_map(concat, *h_t_outputs)
            skip_c_t_out = jax.tree_map(concat, *c_t_outputs)
            out = (skip_h_t_out, skip_c_t_out)
        else:
            out = current_carry

        return out, tuple(next_states)

    def initial_state(self, batch_size: Optional[int]):
        return tuple(
            layer.initial_state_and_carry(batch_size)
            for layer in self.layers
            if isinstance(layer, RNNCellBase)
        )


In [None]:
class DeepRNN(_DeepRNN):
    r"""Wraps a sequence of cores and callables as a single core.
        >>> deep_rnn = hk.DeepRNN([
        ...     LSTMCell(hidden_size=4),
        ...     jax.nn.relu,
        ...     LSTMCell(hidden_size=2),
        ... ])
    The state of a :class:`DeepRNN` is a tuple with one element per
    :class:`RNNCore`. If no layers are :class:`RNNCore`\ s, the state is an empty
    tuple.
    """

    def __init__(
        self,
        layers: Sequence[Any],
        skip_connections: Optional[bool] = False,
        hidden_to_output_layer: Optional[bool] = False,
        name: Optional[str] = None,
    ):
        super().__init__(
            layers,
            skip_connections=skip_connections,
            hidden_to_output_layer=hidden_to_output_layer,
            name=name,
        )


## RNN Training Types
refer to [this](https://stanford.edu/~shervine/teaching/cs-230/cheatsheet-recurrent-neural-networks)

In [None]:
def train_one2one(params, init_carry, input, sequence_length, cell):
    _, h_t = cell.apply(params, init_carry, input)
    y_t = nn.Dense(input.shape[0])(h_t)  # output
    return y_t


In [None]:
def train_one2many(params, init_carry, input, T_y, cell, tf_bool=True):
    output = []
    if not tf_bool:
        for i in range(len(T_y)):
            if i == 0:
                carry, h_t = cell.apply(params, init_carry, input[i])
            else:
                carry, h_t = cell.apply(params, carry, output[i - 1])

            y_t = nn.Dense(input[i].shape[0])(h_t)  # output
            output.append(y_t)
    else:
        for i in range(len(input)):
            carry, h_t = cell.apply(params, init_carry, input[i])
            y_t = nn.Dense(input[i].shape[0])(h_t)  # output
            output.append(y_t)

    return output


In [None]:
def train_many2one(params, init_carry, input, T_x, cell, tf_bool=True):
    y_t = None
    carry = init_carry
    for i in range(len(T_x)):
        carry, h_t = cell.apply(params, carry, input[i])
        if i == (len(T_x) - 1):
            y_t = nn.Dense(input[i].shape[0])(h_t)  # output

    return y_t


In [None]:
def train_many2many(params, init_carry, input, T_xy, cell, tf_bool=True):
    output = []
    carry = init_carry
    for i in range(len(T_xy)):
        carry, h_t = cell.apply(params, carry, input[i])
        y_t = nn.Dense(input[i].shape[0])(h_t)  # output
        output.append(y_t)

    return output


In [None]:
def train_many2many(params, init_carry, input, T_x, T_y, cell, tf_bool=True):
    assert T_x != T_y, "T_x and T_y must be different"
    output = []
    carry = init_carry
    for i in range(len(T_x)):
        carry, h_t = cell.apply(params, carry, input[i])

    for i in range(len(T_y)):
        carry, h_t = cell.apply(params, carry, input[i])
        y_t = nn.Dense(input[i].shape[0])(h_t)  # output
        output.append(y_t)

    return output


## Data Preprocessing


In [None]:
def map_data_2_id(iterable):
    """
    provides mapping to and from ids
    """
    id_2_data = {}
    data_2_id = {}

    for id, elem in enumerate(iterable):
        id_2_data[id] = elem

    for id, elem in enumerate(iterable):
        data_2_id[elem] = id

    return (id_2_data, data_2_id)


In [None]:
def one_hot(i, n):
    """
    create vector of size n with 1 at index i
    """
    x = defaultdict(lambda: jnp.zeros(n))
    return x[].at[i].set(1)


In [None]:
def encode(char):
    return one_hot(data_2_id[char], len(data_2_id))


In [None]:
def decode(predictions, id_2_data):
    return id_2_data[int(jnp.argmax(predictions))]


## Optimizer Helpers


check [this](https://github.com/deepmind/optax/blob/master/examples/quick_start.ipynb) out

In [None]:
def pick_optimizer_fn(name, starting_learning_rate):
    # refer to https://optax.readthedocs.io/en/latest/api.html#optimizer-schedules
    optim = None

    if name == "sgd":
        optim = optax.sgd(starting_learning_rate)
    elif name == "adam":
        optim = optax.adam(starting_learning_rate)
    elif name == "adagrad":
        optim = optax.adagrad(starting_learning_rate)
    elif name == "rmsprop":
        optim = optax.rmsprop(starting_learning_rate)
    else:
        raise ValueError("optimizer name not recognized")

    return optim


In [None]:
def pick_scheduler_fn(
    start_learning_rate, steps, decay_rate, init_value, end_val, name
):
    # refer to https://optax.readthedocs.io/en/latest/api.html#schedules
    scheduler = None

    if name == "constant":
        scheduler = optax.constant_schedule(init_value)

    elif name == "exp_decay":
        scheduler = optax.exponential_decay(
            init_value=start_learning_rate, transition_steps=1000, decay_rate=0.99
        )
    elif name == "linear":
        scheduler = optax.linear_schedule(init_value=init_value, end_value=end_val)

    else:
        raise ValueError("scheduler name not recognized")

    return scheduler


In [None]:
# TODO: add transformations
# refer to https://optax.readthedocs.io/en/latest/api.html#optax-transformations


In [None]:
#     # A simple update loop.
#     for _ in range(1000):
#     grads = jax.grad(compute_loss)(params, xs, ys)
#     updates, opt_state = gradient_transform.update(grads, opt_state)
#     params = optax.apply_updates(params, updates)

#     assert jnp.allclose(params, target_params), \
#     'Optimization should retrieve the target params used to generate the data.'


In [None]:
# def optimizer_fn(start_learning_rate, params, num_weights, x, y):
#     optimizer = optax.adam(start_learning_rate)
#     # Obtain the `opt_state` that contains statistics for the optimizer.
#     params = {'w': jnp.ones((num_weights,))}
#     opt_state = optimizer.init(params)

#     compute_loss = lambda params, x, y: optax.l2_loss(params['w'].dot(x), y)
#     grads = jax.grad(compute_loss)(params, xs, ys)

#     updates, opt_state = optimizer.update(grads, opt_state)
#     params = optax.apply_updates(params, updates)


## Training Loop


In [None]:
class CharRNN(nn.Module):
    state_size: int
    vocab_size: int
    batch_size: int

    def setup(self):
        L = self.vocab_size
        N = self.state_size

        hippo = HiPPO(
            N=N,
            max_length=L,
            measure="legs",
            step=1.0 / L,
            GBT_alpha=0.5,
            seq_L=L,
            v="v",
            lambda_n=1.0,
            fourier_type="fru",
            alpha=0.0,
            beta=1.0,
        )

        cell1 = LSTMCell()
        cell2 = LSTMCell()

        self.deep_cell = DeepRNN(
            layers=[
                HiPPOCell(hippo=hippo, cell=cell1),
                nn.relu,
                HiPPOCell(hippo=hippo, cell=cell2),
            ],
            skip_connections=True,
            hidden_to_output_layer=True,
            name="CharRNN",
        )

    def __call__(self, carry, i):
        input = one_hot(i, self.vocab_size)
        carries, next_states = self.deep_cell(carry, input)
        predictions = nn.softmax(nn.Dense(self.vocab_size)(next_states[-1]))
        return next_states, predictions


In [None]:
def train():
    pass
