In [1]:
import os
import sys
module_path = os.path.abspath(os.path.join('../../../'))
if module_path not in sys.path:
    sys.path.append(module_path)

In [2]:
import jax
import jax.ops
import jax.numpy as jnp
from jax.experimental.host_callback import id_print
from jax.tree_util import Partial

import flax
from flax import linen as nn
from flax.linen.recurrent import RNNCellBase

import optax

import numpy as np  # convention: original numpy

from typing import Any, Callable, Sequence, Optional, Tuple, Union
from collections import defaultdict
from functools import partial
import pprint

from src.models.hippo.hippo import HiPPO


In [3]:
pp = pprint.PrettyPrinter(indent=4)

In [4]:
seed = 1701
key = jax.random.PRNGKey(seed)


In [5]:
num_copies = 2
rng, subkey = jax.random.split(key, num=num_copies)


## Helper Functions

In [6]:
def add_batch(nest, batch_size: Optional[int]):
    """Adds a batch dimension at axis 0 to the leaves of a nested structure."""
    broadcast = lambda x: jnp.broadcast_to(x, (batch_size,) + x.shape)

    return jax.tree_map(broadcast, nest)


## RNN Cells

In [7]:
class RNNCell(RNNCellBase):
    hidden_size: int

    # def setup(self):
        # self.dense_h = Partial(
        #     nn.Dense,
        #     features=self.hidden_size,
        #     use_bias=True,
        #     kernel_init=nn.initializers.orthogonal(),
        #     bias_init=nn.initializers.zeros,
        #     dtype=None,
        #     param_dtype=jnp.float32,
        # )

        # self.dense_o = Partial(
        #     nn.Dense,
        #     features=self.hidden_size,
        #     use_bias=False,
        #     kernel_init=nn.initializers.orthogonal(),
        #     bias_init=nn.initializers.zeros,
        #     dtype=None,
        #     param_dtype=jnp.float32,
        # )

    # @partial(
    #     nn.transforms.scan, variable_broadcast="params", split_rngs={"params": False}
    # )
    @nn.compact
    def __call__(self, carry, input):
        """
        Description:
            W_xh = x_{t} @ W_{xh} - multiply the previous hidden state with
            W_hh = H_{t-1} @ W_{hh} + b_{h} - this a linear layer

            H_{t} = f_{w}(H_{t-1}, x)
            H_{t} = tanh(H_{t-1} @ W_{hh}) + (x_{t} @ W_{xh})

        Args:
            hidden_size (int): hidden state size
            carry (jnp.ndarray): hidden state from previous time step
            input (jnp.ndarray): # input vector

        Returns:
            A tuple with the new carry and the output.
        """
        ht_1, _ = carry

        print(f"inside the rnn, input:\n{input.shape}")

        h_t = self.rnn_update(ht_1, input)

        return (h_t, h_t), h_t

    def rnn_update(self, ht_1, input):
        print(f"inside the rnn update, input:\n{input.shape}")
        print(f"inside the rnn update, ht_1:\n{ht_1.shape}")

        # print(f"self.dense_h(name='dense rnn_wxh layer')(ht_1):\n{self.dense_h(name='dense rnn_wxh layer')(ht_1)}")
        # W_hh = self.dense_h(ht_1)

        W_hh = nn.Dense(self.hidden_size)(ht_1)
        print(f"W_hh:\n{W_hh.shape}")
        # id_print(W_hh, what="BLAH BLAH BLAH", tap_with_device=True)
        # print(f"W_hh:\n{W_hh}")
        # print(f"input.shape:\n{input.shape}")
        W_xh = nn.Dense(self.hidden_size)(input)
        print(f"W_xh:\n{W_xh.shape}")
        # W_xh = self.dense_o(name="dense rnn_wxh layer")(input)
        print(f"W_hh shape:\n{W_hh.shape}")
        print(f"W_xh shape:\n{W_xh.shape}")
        h_t = nn.relu(W_hh + W_xh)  # H_{t} = tanh(H_{t-1} @ W_{hh}) + (x_{t} @ W_{xh})
        print(f"inside the rnn update in an rnn, h_t:\n{h_t.shape}")
        
        return h_t

    @staticmethod
    def initialize_carry(rng, batch_size, hidden_size, init_fn=nn.initializers.zeros):
        """Initialize the RNN cell carry.
        Args:
        rng: random number generator passed to the init_fn.
        batch_dims: a tuple providing the shape of the batch dimensions.
        hidden_size: the size or number of features of the memory.
        init_fn: initializer function for the carry.
        Returns:
        An initialized carry for the given RNN cell.
        """
        key1, key2 = jax.random.split(rng)
        mem_shape = batch_size + (hidden_size,)

        return init_fn(key1, mem_shape), init_fn(key2, mem_shape)


In [8]:
class LSTMCell(RNNCellBase):
    hidden_size: int

    @partial(
        nn.transforms.scan, variable_broadcast="params", split_rngs={"params": False}
    )
    @nn.compact
    def __call__(self, carry, input):
        """
        Description:
            i_{t} = sigmoid((W_{ii} @ x_{t} + b_{ii}) + (W_{hi} @ h_{t-1} + b_{hi}))
            f_{t} = sigmoid((W_{if} @ x_{t} + b_{if}) + (W_{hf} @ h_{t-1} + b_{hf}))
            g_{t} = tanh((W_{ig} @ x_{t} + b_{ig}) + (W_{hg} @ h_{t-1} + b_{hg}))
            o_{t} = sigmoid((W_{io} @ x_{t} + b_{io}) + (W_{ho} @ h_{t-1} + b_{ho}))
            c_{t} = f_{t} * c_{t-1} + i_{t} * g_{t}
            h_{t} = o_{t} * tanh(c_{t})

        Args:
            hidden_size (int): hidden state size
            carry (jnp.ndarray): hidden state from previous time step
            input (jnp.ndarray): # input vector

        Returns:
            A tuple with the new carry and the output.
        """
        print(f"inside the LSTMCell, input:\n{input.shape}")
        print(f"inside the LSTMCell, input type:\n{type(input)}")
        
        print(f"inside the LSTMCell, carry:\n{carry.shape}")
        print(f"inside the LSTMCell, carry type:\n{type(carry)}")
        ht_1, ct_1 = carry
        print(f"carry split:\n{ht_1.shape}\n{ct_1.shape}")

        c_t, h_t = self.rnn_update(input, ht_1, ct_1)
        print(f"inside the LSTMCell, c_t:\n{c_t.shape}")
        print(f"inside the LSTMCell, h_t:\n{h_t.shape}")

        return (h_t, c_t), h_t

    # @partial(
    #     nn.transforms.scan, variable_broadcast="params", split_rngs={"params": False}
    # )
    def rnn_update(self, input, ht_1, ct_1):
        print(f"inside the LSTMCell rnn_update, input:\n{input.shape}")
        print(f"inside the LSTMCell rnn_update, ht_1:\n{ht_1.shape}")
        print(f"inside the LSTMCell rnn_update, ct_1:\n{ct_1.shape}")
        # i_ta = partial(
        #     nn.Dense,
        #     features=ht_1.shape()[0],
        #     use_bias=False,
        #     kernel_init=self.recurrent_kernel_init,
        #     bias_init=self.bias_init,
        # )

        i_ta = nn.Dense(features=ht_1.shape()[0])(input)
        i_tb = nn.Dense(features=self.hidden_size)(ht_1)
        i_t = nn.sigmoid(i_ta + i_tb)  # input gate
        print(f"inside the LSTMCell, input gate output:\n{i_t.shape}")

        o_ta = nn.Dense(self.hidden_size)(input)
        o_tb = nn.Dense(self.hidden_size)(ht_1)
        o_t = nn.sigmoid(o_ta + o_tb)  # output gate
        print(f"inside the LSTMCell, output gate output:\n{o_t.shape}")

        f_ia = nn.Dense(self.hidden_size)(
            input
        )  # b^{f}_{i} + \sum\limits_{j} U^{f}_{i, j} x^{t}_{j}
        f_ib = nn.Dense(self.hidden_size)(
            ht_1
        )  # \sum\limits_{j} W^{f}_{i, j} h^{(t-1)}_{j}
        f_i = nn.sigmoid(f_ia + f_ib)  # forget gate
        print(f"inside the LSTMCell, forget gate output:\n{f_i.shape}")

        g_ia = nn.Dense(self.hidden_size)(
            input
        )  # b^{g}_{i} + \sum\limits_{j} U^{g}_{i, j} x^{t}_{j}
        g_ib = nn.Dense(self.hidden_size)(
            ht_1
        )  # \sum\limits_{j} W^{g}_{i, j} h^{(t-1)}_{j}
        g_i = nn.tanh(g_ia + g_ib)  # (external) input gate
        print(f"inside the LSTMCell, (external) input gate output:\n{g_i.shape}")

        c_t = (f_i * ct_1) + (i_t * g_i)  # internal cell state update
        print(f"inside the LSTMCell, cell state output:\n{c_t.shape}")

        h_t = o_t * nn.tanh(c_t)  # hidden state update
        print(f"inside the LSTMCell, hidden state output:\n{h_t.shape}")

        return h_t, c_t

    @staticmethod
    def initialize_carry(rng, batch_size, hidden_size, init_fn=nn.initializers.zeros):
        """Initialize the RNN cell carry.
        Args:
        rng: random number generator passed to the init_fn.
        batch_dims: a tuple providing the shape of the batch dimensions.
        hidden_size: the size or number of features of the memory.
        init_fn: initializer function for the carry.
        Returns:
        An initialized carry for the given RNN cell.
        """
        key1, key2 = jax.random.split(rng)
        mem_shape = batch_size + (hidden_size,)

        return init_fn(key1, mem_shape), init_fn(key2, mem_shape)


In [9]:
class GRUCell(RNNCellBase):
    hidden_size: int

    @partial(
        nn.transforms.scan, variable_broadcast="params", split_rngs={"params": False}
    )
    @nn.compact
    def __call__(self, carry, input):
        """
        Description:
            z_t = sigmoid((W_{iz} @ x_{t} + b_{iz}) + (W_{hz} @ h_{t-1} + b_{hz}))
            r_t = sigmoid((W_{ir} @ x_{t} + b_{ir}) + (W_{hr} @ h_{t-1} + b_{hr}))
            g_t = tanh(((W_{ig} @ x_{t} + b_{ig}) + r_t) * (W_{hg} @ h_{t-1} + b_{hg}))
            h_t = (z_t * h_{t-1}) + ((1 - z_t) * g_i)

        Args:
            hidden_size (int): hidden state size
            carry (jnp.ndarray): hidden state from previous time step
            input (jnp.ndarray): # input vector

        Returns:
            A tuple with the new carry and the output.
        """
        ht_1 = carry

        h_t = self.rnn_update(input, ht_1)

        return (h_t, h_t), h_t

    # @partial(
    #     nn.transforms.scan, variable_broadcast="params", split_rngs={"params": False}
    # )
    def rnn_update(self, input, ht_1):

        z_ta = nn.Dense(self.hidden_size)(input)
        z_tb = nn.Dense(self.hidden_size)(ht_1)
        z_t = nn.sigmoid(z_ta + z_tb)  # reset gate

        r_ta = nn.Dense(self.hidden_size)(input)
        r_tb = nn.Dense(self.hidden_size)(ht_1)
        r_t = nn.sigmoid(r_ta + r_tb)  # update gate

        g_ta = nn.Dense(self.hidden_size)(input)
        g_tb = nn.Dense(self.hidden_size)(ht_1)
        g_t = nn.tanh((g_ta + r_t) * g_tb)  # (external) input gate

        h_t = ((1 - z_t) * ht_1) + (z_t * g_t)  # internal cell state update

        return h_t

    @staticmethod
    def initialize_carry(rng, batch_size, hidden_size, init_fn=nn.initializers.zeros):
        """Initialize the RNN cell carry.
        Args:
        rng: random number generator passed to the init_fn.
        batch_dims: a tuple providing the shape of the batch dimensions.
        hidden_size: the size or number of features of the memory.
        init_fn: initializer function for the carry.
        Returns:
        An initialized carry for the given RNN cell.
        """
        key1, key2 = jax.random.split(rng)
        mem_shape = batch_size + (hidden_size,)

        return init_fn(key1, mem_shape), init_fn(key2, mem_shape)


In [10]:
class HiPPOCell(nn.Module):
    """
    Description:
        RNN update function
        τ(h, x) = (1 - g(h, x)) ◦ h + g(h, x) ◦ tanh(Lτ (h, x))
        g(h, x) = σ(Lg(h,x))

    Args:
        hidden_size (int): hidden state size
        output_size (int): output size
        hippo (HiPPO): hippo model object
        cell (RNNCellBase): choice of RNN cell object
            - RNNCell
            - LSTMCell
            - GRUCell
    """

    hidden_size: int
    output_size: int
    hippo: HiPPO
    model: RNNCellBase

    # def setup(self):
    #     self.dense_y = Partial(
    #         nn.Dense,
    #         features=self.output_size,
    #         use_bias=True,
    #         kernel_init=nn.initializers.orthogonal(),
    #         bias_init=nn.initializers.zeros,
    #         dtype=None,
    #         param_dtype=jnp.float32,
    #     )
    #     self.cell = self.model(self.hidden_size)

    # @partial(
    #     nn.transforms.scan, variable_broadcast="params", split_rngs={"params": False}
    # )
    @nn.compact
    def __call__(self, carry, input):
        """
        Description:
            RNN update function
            τ(h, x) = (1 - g(h, x)) ◦ h + g(h, x) ◦ tanh(Lτ (h, x))
            g(h, x) = σ(Lg(h,x))

        Args:
            carry (jnp.ndarray): hidden state from previous time step
            input (jnp.ndarray): # input vector

        Returns:
            A tuple with the new carry and the output.
        """

        print(f"inside hippo cell, input:\n{input}")
        h_t, c_t = carry
        print(f"inside hippo cell, h_t:\n{h_t}\nc_t:\n{c_t}")
        print(f"inside hippo cell, the cell:\n{self.model}")
        _, h_t = self.model(self.hidden_size)(carry, input)
        print(f"inside hippo cell, h_t:\n{h_t.shape}")

        # y_t = nn.Dense(self.output_size)(h_t)  # f_t in the paper
        # print(f"inside hippo cell, y_t: \n{y_t}")

        # c_t = self.hippo(y_t, init_state=None, kernel=False)
        # print(f"inside hippo cell, c_t: \n{c_t}")

        return self.rnn_update(input, h_t)

    # @partial(
    #     nn.transforms.scan, variable_broadcast="params", split_rngs={"params": False}
    # )
    def rnn_update(self, input, h_t):

        # y_t = self.dense_y(name="dense hippo input layer")(h_t)
        print(f"inside hippo cell in the update, h_t:\n{h_t.shape}")
        print(f"inside hippo cell in the update, input:\n{input.shape}")
        y_t = nn.Dense(1)(h_t)  # f_t in the paper
        print(f"inside hippo cell, before reshape y_t: \n{y_t.shape}")
        y_t = jnp.swapaxes(y_t, 1, 0)
        y_t = jnp.swapaxes(y_t, 2, 1)
        print(f"inside hippo cell, y_t: \n{y_t.shape}")
        
        c_t = self.hippo(y_t, init_state=h_t, kernel=False)
        print(f"inside hippo cell, c_t: \n{c_t}")
        print(f"inside hippo cell, c_t: \n{c_t.shape}")

        return (h_t, c_t), h_t

    @staticmethod
    def initialize_carry(rng, batch_size, hidden_size, init_fn=nn.initializers.zeros):
        """Initialize the RNN cell carry.
        Args:
        rng: random number generator passed to the init_fn.
        batch_dims: a tuple providing the shape of the batch dimensions.
        hidden_size: the size or number of features of the memory.
        init_fn: initializer function for the carry.
        Returns:
        An initialized carry for the given RNN cell.
        """
        key1, key2 = jax.random.split(rng)
        mem_shape = batch_size + (hidden_size,)

        return init_fn(key1, mem_shape), init_fn(key2, mem_shape)


## Deep RNN 

In [11]:
# TODO: refer to https://github.com/deepmind/dm-haiku/blob/main/haiku/_src/recurrent.py#L714-L762
# also refer to https://dm-haiku.readthedocs.io/en/latest/api.html?highlight=DeepRNN#deeprnn
class _DeepRNN(RNNCellBase):
    hidden_size: int
    layers: Sequence[Any]
    skip_connections: bool
    hidden_to_output_layer: bool
    layer_name: Optional[str]

    def setup(self):
        if self.skip_connections:
            for layer in self.layers:
                if not (isinstance(layer, RNNCellBase) or isinstance(layer, HiPPOCell)):
                    raise ValueError(
                        "skip_connections requires for all layers to be "
                        "`hk.RNNCore`s. Layers is: {}".format(self.layers)
                    )
                    # raise ValueError(
                    #     f"{self.layer_name} layer {layer} is not a RNNCellBase or HiPPOCell"
                    # )

    def __call__(self, carry, inputs):
        current_carry = carry
        next_states = []
        h_t_outputs = []
        c_t_outputs = []
        state_idx = 0
        # print(f"inside deep rnn, inputs:\n{inputs.shape}")
        # print(f"inside deep rnn, carry:\n{carry.shape}")
        h_t, c_t = carry  # c_t may actually be h_t in which case dont use it
        (
            h_t_copy,
            c_t_copy,
        ) = current_carry  # c_t may actually be h_t in which case dont use it
        concat = lambda *args: jnp.concatenate(args, axis=-1)
        print(f"before main loop")
        for idx, layer in enumerate(self.layers):
            print(f"inside deep rnn cell, h_t:\n{h_t.shape}\nc_t:\n{c_t.shape}")
            print(f"inside deep rnn cell, THE INPUTS:\n{inputs[idx].shape}")
            if self.skip_connections and idx > 0:
                skip_h_t = jax.tree_map(concat, h_t, h_t_copy)
                skip_c_t = jax.tree_map(concat, c_t, c_t_copy)
                current_carry = tuple([skip_h_t, skip_c_t])

            if isinstance(layer, RNNCellBase) or isinstance(layer, HiPPOCell):
                # print(f"inside deep rnn, inputs:\n{inputs}")
                # print(f"inside deep rnn, state_idx:\n{state_idx}")
                print(f"inside deep rnn, inputs[state_idx]:\n{inputs[state_idx].shape}")
                h_t, c_t, next_state = layer(
                    current_carry, inputs[state_idx]
                )  # problem line
                h_t_outputs.append(h_t)
                c_t_outputs.append(c_t)
                next_states.append(next_state)
                state_idx += 1

            else:
                print(f"current_carry before layer: {current_carry.shape}")
                print(f"layer: {layer}")
                current_carry = layer(current_carry)
                print(f"current_carry:\n {current_carry.shape}")

        print(f"third conditional")
        if self.skip_connections:
            skip_h_t_out = jax.tree_map(concat, *h_t_outputs)
            skip_c_t_out = jax.tree_map(concat, *c_t_outputs)
            next_carry = (skip_h_t_out, skip_c_t_out)
        else:
            next_carry = current_carry

        print(f"next_states before tuple:\n", next_states)
        pp.pprint(layer)
        print(f"carry before return B:\n", next_carry)

        return next_carry, next_states

    @staticmethod
    def initialize_state(
        num_layers,
        rng,
        batch_size: tuple,
        output_size: int,
        init_fn=nn.initializers.zeros,
    ):
        states = []
        for i in range(num_layers):
            print(f"Layer: {i}\n")
            states.append(
                _DeepRNN.init_state(
                    rng=rng,
                    batch_size=batch_size,
                    output_size=output_size,
                    init_fn=init_fn,
                )
            )

        return states

    @staticmethod
    def init_state(
        rng, batch_size: tuple, output_size: int, init_fn=nn.initializers.zeros
    ):
        print(f"batch_size: {(batch_size,)}")
        print(f"output_size: {(output_size,)}")
        mem_shape = (batch_size,) + (output_size,)
        print(f"state mem_shape: {mem_shape}")

        return init_fn(rng, mem_shape)

    @staticmethod
    def initialize_carry(
        rng, batch_size: tuple, hidden_size: int, init_fn=nn.initializers.zeros
    ):
        print(f"batch_size: {batch_size}")
        print(f"hidden_size: {(hidden_size,)}")
        mem_shape = batch_size + (1, hidden_size)
        print(f"carry mem_shape: {mem_shape}")

        return init_fn(rng, mem_shape), init_fn(rng, mem_shape)


In [12]:
class DeepRNN(_DeepRNN):
    r"""Wraps a sequence of cores and callables as a single core.
        >>> deep_rnn = hk.DeepRNN([
        ...     LSTMCell(hidden_size=4),
        ...     jax.nn.relu,
        ...     LSTMCell(hidden_size=2),
        ... ])
    The state of a :class:`DeepRNN` is a tuple with one element per
    :class:`RNNCore`. If no layers are :class:`RNNCore`\ s, the state is an empty
    tuple.
    """

    def __init__(
        self,
        hidden_size: int,
        layers: Sequence[Any],
        skip_connections: Optional[bool] = False,
        hidden_to_output_layer: Optional[bool] = False,
        name: Optional[str] = None,
    ):
        super().__init__(
            hidden_size=hidden_size,
            layers=layers,
            skip_connections=skip_connections,
            hidden_to_output_layer=hidden_to_output_layer,
            layer_name=name,
        )


## RNN Training Types
refer to [this](https://stanford.edu/~shervine/teaching/cs-230/cheatsheet-recurrent-neural-networks)

## Data Preprocessing


In [13]:
def map_data_2_id(iterable):
    """
    provides mapping to and from ids
    """
    
    id_2_data = {}
    data_2_id = {}

    for id, elem in enumerate(iterable):
        id_2_data[id] = elem

    for id, elem in enumerate(iterable):
        data_2_id[elem] = id

    return (id_2_data, data_2_id)


In [14]:
def one_hot(i, n):
    """
    create vector of size n with 1 at index i
    """
    x = jnp.zeros(n, dtype=int)
    print(f"i:\n{i}")
    print(f"type(i):\n{type(i)}")
    print(f"n:\n{n}")
    print(f"type(n):\n{type(n)}")
    x = x.at[i].set(1) #jax.vmap(x.at[i].set(1), in_axes=(0, 0), out_axes=0)(i, x)

    print(f"i:\n{i}")
    print(f"type(x):\n{type(x)}")
    print(f"x:\n{x}")
    # array = x.at[i].set(1)
    # print(f"array:\n{array}")
    # x = x[i].at[i].set(1)
    # print(x)
    return x


In [15]:
def get_text(fname):
  with open(fname, "r") as reader:
    data = reader.read()
  return data

In [16]:
def prep_data(data):
  chars = list(set(data))
  vocab_size = len(chars)
  char_to_id, id_to_char = map_data_2_id(chars)
  char_to_id = {value:key for key, value in char_to_id.items()}
  # data converted to ids
  # data_id = [char_to_id[char] for char in data]
  data_id = [char_to_id[char] for char in data]
  return data_id, char_to_id, id_to_char

In [17]:
data = "abcd...abcd..."
data_id, char_to_id, id_to_char = prep_data(data)
data_id[:10]


[2, 3, 4, 0, 1, 1, 1, 2, 3, 4]

In [18]:
def encode(char):
    return one_hot(data_2_id[char], len(data_2_id))

In [19]:
def decode(predictions, id_2_data):
    return id_2_data[int(jnp.argmax(predictions))]


## Optimizer Helpers


check [this](https://github.com/deepmind/optax/blob/master/examples/quick_start.ipynb) out

In [20]:
def pick_optimizer_fn(starting_learning_rate, name="adam"):
    # refer to https://optax.readthedocs.io/en/latest/api.html#optimizer-schedules
    optim = None

    if name == "sgd":
        optim = optax.sgd(starting_learning_rate)
    elif name == "adam":
        optim = optax.adam(
            starting_learning_rate,
        )
    elif name == "adagrad":
        optim = optax.adagrad(starting_learning_rate)
    elif name == "rmsprop":
        optim = optax.rmsprop(starting_learning_rate)
    else:
        raise ValueError("optimizer name not recognized")

    return optim


In [21]:
def pick_scheduler_fn(
    start_learning_rate, steps, decay_rate, init_value, end_val, name
):
    # refer to https://optax.readthedocs.io/en/latest/api.html#schedules
    scheduler = None

    if name == "constant":
        scheduler = optax.constant_schedule(init_value)

    elif name == "exp_decay":
        scheduler = optax.exponential_decay(
            init_value=start_learning_rate, transition_steps=1000, decay_rate=0.99
        )
    elif name == "linear":
        scheduler = optax.linear_schedule(init_value=init_value, end_value=end_val)

    else:
        raise ValueError("scheduler name not recognized")

    return scheduler


In [22]:
# TODO: add transformations
# refer to https://optax.readthedocs.io/en/latest/api.html#optax-transformations


In [23]:
#     # A simple update loop.
#     for _ in range(1000):
#     grads = jax.grad(compute_loss)(params, xs, ys)
#     updates, opt_state = gradient_transform.update(grads, opt_state)
#     params = optax.apply_updates(params, updates)

#     assert jnp.allclose(params, target_params), \
#     'Optimization should retrieve the target params used to generate the data.'


In [24]:
# def optimizer_fn(start_learning_rate, params, num_weights, x, y):
#     optimizer = optax.adam(start_learning_rate)
#     # Obtain the `opt_state` that contains statistics for the optimizer.
#     params = {'w': jnp.ones((num_weights,))}
#     opt_state = optimizer.init(params)

#     compute_loss = lambda params, x, y: optax.l2_loss(params['w'].dot(x), y)
#     grads = jax.grad(compute_loss)(params, xs, ys)

#     updates, opt_state = optimizer.update(grads, opt_state)
#     params = optax.apply_updates(params, updates)


## Training Loop


In [25]:
class RNNLM(nn.Module):
    state_size: int
    vocab_size: int
    hidden_size: int
    hippo_order_N: int
    batch_size: int

    def setup(self):
        L = self.vocab_size

        hippo = HiPPO(
            N=self.hippo_order_N,
            max_length=L,
            measure="legs",
            step=1.0 / L,
            GBT_alpha=0.5,
            seq_L=L,
            v="v",
            lambda_n=1.0,
            fourier_type="fru",
            alpha=0.0,
            beta=1.0,
        )

        cell1 = RNNCell
        cell2 = RNNCell
        cell3 = RNNCell

        input_layers = [
            HiPPOCell(
                hippo=hippo, model=cell1, hidden_size=8, output_size=L
            ),
            HiPPOCell(
                hippo=hippo, model=cell2, hidden_size=64, output_size=L
            ),
            HiPPOCell(
                hippo=hippo, model=cell3, hidden_size=512, output_size=L
            ),
        ]

        self.deep_cell = DeepRNN(
            hidden_size=self.hidden_size,
            layers=input_layers,
            skip_connections=True,
            hidden_to_output_layer=False,
            name="RNNLM",
        )

    # @partial(
    #     nn.transforms.scan, variable_broadcast="params", split_rngs={"params": False}
    # )
    def __call__(self, carry, input):
        # print(f"i before one_hot:\n{input}")
        onehot = lambda x: nn.one_hot(x, self.vocab_size)
        # leaves, treedef = jax.tree_util.tree_flatten(input)
        # print(f"inside call function of RNNLM, leaves:\n{leaves}")

        input = jax.tree_map(onehot, input)

        # input = one_hot(input, self.vocab_size)
        # print(f"inside call function of RNNLM, input:\n{input[0].shape}")
        # print(f"inside call function of RNNLM, carry:\n{carry.shape}")
        # deep_cell = lambda x: self.deep_cell(x, input)
        print(f"the deep_cell:\n{self.deep_cell}")
        # carries, next_states = jax.tree_map(deep_cell, input)
        carries, next_states = self.deep_cell(carry, input)
        print(f"inside call function of RNNLM, next_states:\n{next_states}")
        print(f"inside call function of RNNLM, carries:\n{carries}")
        predictions = nn.softmax(nn.Dense(self.vocab_size)(next_states[-1]))
        print(f"inside call function of RNNLM, predictions:\n{predictions}")
        return carries, next_states, predictions

    # def __call__(self, carry, input):

    #     print(f"inside call function of RNNLM, input:\n{input}")
    #     print(f"inside call function of RNNLM, carry:\n{carry}")
    #     onehot = lambda x: nn.one_hot(x, self.vocab_size)
    #     carries, next_states = jax.pmap(self.deep_cell, axis_name=(1, None))(
    #         carry, jax.tree_map(onehot, input)
    #     )
    #     print(f"inside call function of RNNLM, next_states:\n{next_states}")
    #     print(f"inside call function of RNNLM, carries:\n{carries}")
    #     predictions = nn.softmax(nn.Dense(self.vocab_size)(next_states[-1]))
    #     print(f"inside call function of RNNLM, predictions:\n{predictions}")
    #     return carries, next_states, predictions

    @staticmethod
    def initialize_carry(
        rng, batch_size: tuple, hidden_size: int, init_fn=nn.initializers.zeros
    ):
        return DeepRNN.initialize_carry(
            rng=rng,
            batch_size=batch_size,
            hidden_size=hidden_size,
            init_fn=init_fn,
        )

    @staticmethod
    def initialize_state(
        num_layers,
        rng,
        batch_size: tuple,
        output_size: int,
        init_fn=nn.initializers.zeros,
    ):
        return DeepRNN.initialize_state(
            num_layers=num_layers,
            rng=rng,
            batch_size=batch_size,
            output_size=output_size,
            init_fn=init_fn,
        )


In [26]:
def sample(model, params, bridge, initial="", max_length=100):
    char_to_id, id_to_char = bridge
    state = model.init_state()
    output = initial
    if len(initial) > 0:
        for char in initial[:-1]:
            _, state, _ = model.apply(params, char_to_id[char], state)

    next_char = initial[-1]
    for i in range(max_length):
        state, predictions = model.apply(params, state, char_to_id[next_char], state)
        next_char = decode(predictions, id_to_char)
        output += next_char

    return output


Refer to [this](https://github.com/manifest/flax-extra/blob/48efe1f1515893289b44646977bf5049a340b6c8/docs/notebooks/combinators.ipynb), [this](https://github.com/romanak/pyprobml/blob/65c82b9b43d2100cbc7c59e766161ee801c0f85f/notebooks/book1/15/rnn_jax.ipynb), [this](https://github.com/probml/pyprobml/blob/71d98dcdd3798525353eb1bfb9851b47e9d64bde/notebooks/book1/15/rnn_jax.ipynb) and [this](https://github.com/probml/probml-notebooks/blob/36cb173afce3f4a07a7b475cf8a7937025a60465/notebooks-d2l/rnn_jax.ipynb)

In [27]:
# randomness is handled using explicit keys in Jax
key, subkey = jax.random.split(subkey)

num_layers = 3
state_size = 8
hidden_size = 4
output_size = len(char_to_id)
print(f"output_size:\n{output_size}")
#hippo_order_N = 16
batch_size = 10

print(f"Create Model: RNNLM")
# model = CharRNN(state_size, len(char_to_id))
model = RNNLM(
    state_size=state_size,
    vocab_size=output_size,
    hidden_size=hidden_size,
    hippo_order_N=output_size,
    batch_size=batch_size,
)
print(
    f"carry init:\n{model.initialize_carry(rng=subkey, batch_size=(batch_size,), hidden_size=hidden_size, init_fn=nn.initializers.zeros)}"
)
print(
    f"state init:\n{model.initialize_state(num_layers=num_layers,rng=rng,batch_size=batch_size,output_size=len(char_to_id), init_fn=nn.initializers.zeros,)}"
)

print(f"Init Model: RNNLM")
params = model.init(
    subkey,
    model.initialize_carry(
        rng=key,
        batch_size=(batch_size,),
        hidden_size=hidden_size,
        init_fn=nn.initializers.zeros,
    ),
    model.initialize_state(
        num_layers=num_layers,
        rng=subkey,
        batch_size=batch_size,
        output_size=output_size,
        init_fn=nn.initializers.zeros,
    ),
)

print(f"Model state size: {model.state_size}, vocab size: {model.vocab_size}")
# output: Model state size: 8, vocab size: 5

# run a single example through the model to test that it works
new_state, predictions = model.apply(params, model.initial_carry(), 0)
assert predictions.shape[0] == model.vocab_size

# calling sample on random model leads to random output
sample(model, params, (char_to_id, id_to_char), "abc", max_length=10)
# output: 'abcadbaadbadd'


output_size:
5
Create Model: RNNLM
batch_size: (10,)
hidden_size: (4,)
carry mem_shape: (10, 1, 4)
carry init:
(DeviceArray([[[0., 0., 0., 0.]],

             [[0., 0., 0., 0.]],

             [[0., 0., 0., 0.]],

             [[0., 0., 0., 0.]],

             [[0., 0., 0., 0.]],

             [[0., 0., 0., 0.]],

             [[0., 0., 0., 0.]],

             [[0., 0., 0., 0.]],

             [[0., 0., 0., 0.]],

             [[0., 0., 0., 0.]]], dtype=float32), DeviceArray([[[0., 0., 0., 0.]],

             [[0., 0., 0., 0.]],

             [[0., 0., 0., 0.]],

             [[0., 0., 0., 0.]],

             [[0., 0., 0., 0.]],

             [[0., 0., 0., 0.]],

             [[0., 0., 0., 0.]],

             [[0., 0., 0., 0.]],

             [[0., 0., 0., 0.]],

             [[0., 0., 0., 0.]]], dtype=float32))
Layer: 0

batch_size: (10,)
output_size: (5,)
state mem_shape: (10, 5)
Layer: 1

batch_size: (10,)
output_size: (5,)
state mem_shape: (10, 5)
Layer: 2

batch_size: (10,)
output

TypeError: __init__() got an unexpected keyword argument 'v'

In [None]:
def chunker(seq, size):
    """
    chunks a sequences into two subsequences
    one for inputs, another for targets, by
    shifting the input by 1
    """
    n = len(seq)
    p = 0
    while p + 1 <= n:
        # ensure the last chunk is of equal size
        yield seq[p : min(n - 1, p + size)], seq[(p + 1) : (p + size + 1)]
        p += size


: 

In [None]:
def rnn_loss(params, model, carries, inputs, targets):
    # use lax.scan to efficiently generate a loop over the inputs
    # this function returns the final state, and predictions for every step
    # note: scan input array needs have shape [length, 1]
    final_state, predictions = jax.lax.scan(
        lambda carry, input: model.apply(params, carry, input), carries, np.array([inputs]).T
    )
    loss = np.mean(jax.vmap(optax.softmax_cross_entropy)(predictions, np.array([targets]).T))
    return loss, final_state


# we want both the loss an gradient, we set has_aux because rnn_loss also return final state
# use static_argnums=1 to indicate that the model is static;
# a different model input will require recomplication
# finally, we jit the function to improve runtime
rnn_loss_grad = jax.jit(jax.value_and_grad(rnn_loss, has_aux=True), static_argnums=1)


: 

In [None]:
def batch_step(model, optimizer, state, inputs, targets):
    (loss, state), grad = rnn_loss_grad(optimizer.target, model, state, inputs, targets)
    new_optimizer = optimizer.apply_gradient(grad)
    return new_optimizer, loss, state


def epoch_step(model, optimizer, data, batch_size):
    state = model.init_state()
    total_loss = 0
    for n, (inputs, targets) in enumerate(chunker(data, batch_size)):
        optimizer, loss, state = batch_step(model, optimizer, state, inputs, targets)

        total_loss += loss
    return optimizer, total_loss / (n + 1)


: 

In [None]:
def train():
    pass


: 