In [11]:
import jax
import jax.ops
import jax.numpy as jnp

import flax
from flax import linen as nn
from flax import optim
from flax.linen.recurrent import RNNCellBase

import optax

import numpy as np  # convention: original numpy

from typing import Any, Callable, Sequence, Optional, Tuple, Union


In [12]:
seed = 1701
key = jax.random.PRNGKey(seed)


In [13]:
num_copies = 5
rng, key2, key3, key4, key5 = jax.random.split(key, num=num_copies)


In [None]:
class RNNCell(RNNCellBase):
    @nn.compact()
    def __call__(self, carry, input):
        """
        Description:
            W_xh = x_{t} @ W_{xh} - multiply the previous hidden state with
            W_hh = H_{t-1} @ W_{hh} + b_{h} - this a linear layer

            H_{t} = f_{w}(H_{t-1}, x)
            H_{t} = tanh(H_{t-1} @ W_{hh}) + (x_{t} @ W_{xh})

        Args:
            carry (jnp.ndarray): hidden state from previous time step
            input (jnp.ndarray): # input vector

        Returns:
            A tuple with the new carry and the output.
        """
        ht_1 = carry

        h_t = self.rnn_update(input, ht_1)

        return h_t, h_t

    def rnn_update(self, input, ht_1):

        W_hh = nn.Dense(ht_1.shape[0])(ht_1)
        W_xh = nn.Dense(input.shape[0])(input)
        h_t = jnp.tanh(W_hh + W_xh)  # H_{t} = tanh(H_{t-1} @ W_{hh}) + (x_{t} @ W_{xh})

        return h_t


In [None]:
class LSTMCell(RNNCellBase):
    @nn.compact()
    def __call__(self, carry, input):
        """
        Description:
            i_{t} = sigmoid((W_{ii} @ x_{t} + b_{ii}) + (W_{hi} @ h_{t-1} + b_{hi}))
            f_{t} = sigmoid((W_{if} @ x_{t} + b_{if}) + (W_{hf} @ h_{t-1} + b_{hf}))
            g_{t} = tanh((W_{ig} @ x_{t} + b_{ig}) + (W_{hg} @ h_{t-1} + b_{hg}))
            o_{t} = sigmoid((W_{io} @ x_{t} + b_{io}) + (W_{ho} @ h_{t-1} + b_{ho}))
            c_{t} = f_{t} * c_{t-1} + i_{t} * g_{t}
            h_{t} = o_{t} * tanh(c_{t})

        Args:
            carry (jnp.ndarray): hidden state from previous time step
            input (jnp.ndarray): # input vector

        Returns:
            A tuple with the new carry and the output.
        """
        ht_1, ct_1 = carry

        c_t, h_t = self.rnn_update(input, ht_1, ct_1)

        return (h_t, c_t), h_t

    def rnn_update(self, input, ht_1, ct_1):

        i_ta = nn.Dense(input.shape[0])(input)
        i_tb = nn.Dense(ht_1.shape[0])(ht_1)
        i_t = jnp.sigmoid(i_ta + i_tb)  # input gate

        o_ta = nn.Dense(input.shape[0])(input)
        o_tb = nn.Dense(ht_1.shape[0])(ht_1)
        o_t = jnp.sigmoid(o_ta + o_tb)  # output gate

        f_ia = nn.Dense(input.shape[0])(
            input
        )  # b^{f}_{i} + \sum\limits_{j} U^{f}_{i, j} x^{t}_{j}
        f_ib = nn.Dense(ht_1.shape[0])(
            ht_1
        )  # \sum\limits_{j} W^{f}_{i, j} h^{(t-1)}_{j}
        f_i = jnp.sigmoid(f_ia + f_ib)  # forget gate

        g_ia = nn.Dense(input.shape[0])(
            input
        )  # b^{g}_{i} + \sum\limits_{j} U^{g}_{i, j} x^{t}_{j}
        g_ib = nn.Dense(ht_1.shape[0])(
            ht_1
        )  # \sum\limits_{j} W^{g}_{i, j} h^{(t-1)}_{j}
        g_i = jnp.tanh(g_ia + g_ib)  # (external) input gate

        c_t = (f_i * ct_1) + (i_t * g_i)  # internal cell state update

        h_t = o_t * jnp.tanh(c_t)  # hidden state update

        return h_t, c_t


In [None]:
class GRUCell(RNNCellBase):
    @nn.compact()
    def __call__(self, carry, input):
        """
        Description:
            z_t = sigmoid((W_{iz} @ x_{t} + b_{iz}) + (W_{hz} @ h_{t-1} + b_{hz}))
            r_t = sigmoid((W_{ir} @ x_{t} + b_{ir}) + (W_{hr} @ h_{t-1} + b_{hr}))
            g_i = tanh(((W_{ig} @ x_{t} + b_{ig}) + r_t) * (W_{hg} @ h_{t-1} + b_{hg}))
            h_t = (z_t * h_{t-1}) + ((1 - z_t) * g_i)

        Args:
            carry (jnp.ndarray): hidden state from previous time step
            input (jnp.ndarray): # input vector

        Returns:
            A tuple with the new carry and the output.
        """
        ht_1 = carry

        h_t = self.rnn_update(input, ht_1)

        return h_t, h_t

    def rnn_update(self, input, ht_1):

        z_ta = nn.Dense(input.shape[0])(input)
        z_tb = nn.Dense(ht_1.shape[0])(ht_1)
        z_t = jnp.sigmoid(z_ta + z_tb)  # update gate

        r_ta = nn.Dense(input.shape[0])(input)
        r_tb = nn.Dense(ht_1.shape[0])(ht_1)
        r_t = jnp.sigmoid(r_ta + r_tb)  # reset gate

        g_ia = nn.Dense(input.shape[0])(input)
        g_ib = nn.Dense(ht_1.shape[0])(ht_1)
        g_i = jnp.tanh((g_ia + r_t) * g_ib)  # (external) input gate

        h_t = (z_t * ht_1) + ((1 - z_t) * g_i)  # internal cell state update

        return h_t


In [None]:
class HiPPOCell(nn.module):
    '''
        Description:
            RNN update function
            τ(h, x) = (1 - g(h, x)) ◦ h + g(h, x) ◦ tanh(Lτ (h, x)) 
            g(h, x) = σ(Lg(h,x))

        Args:
            hippo (HiPPO): # hippo model object
            cell (RNNCellBase): choice of RNN cell object
                - RNNCell 
                - LSTMCell
                - GRUCell
    '''
    hippo: HiPPO
    cell: RNNCellBase
        
    @nn.compact() 
    def __call__(self, carry, input):
        '''
        Description:
            RNN update function
            τ(h, x) = (1 - g(h, x)) ◦ h + g(h, x) ◦ tanh(Lτ (h, x)) 
            g(h, x) = σ(Lg(h,x))
            
        Args:
            carry (jnp.ndarray): hidden state from previous time step
            input (jnp.ndarray): # input vector
            
        Returns:
            A tuple with the new carry and the output.
        '''
        
        _, h_t = self.cell(carry, input)
        
        y_t = nn.Dense(input.shape[0])(h_t) # f_t in the paper
        
        c_t = self.hippo(y_t, init_state=None, kernel=False)
        
        return (h_t, c_t), h_t
        
        

In [None]:
class RNN_o2m(nn.Module):
    """_summary_
        An RNN with one-to-many connections.

    Args:
        cell (_type_): choice of RNN cell object
        sequence_length (int): length of the input sequence
        carry (Tuple[jnp.ndarray, jnp.ndarray]): A tuple with the new carry and the output.

    """

    cell: RNNCellBase
    sequence_length: int
    carry: Tuple[jnp.ndarray, jnp.ndarray]

    @nn.compact()
    def __call__(self, carry, x):
        output = []

        for i in range(self.sequence_length):
            if i == 0:
                carry, h_t = self.cell(self.carry, x)
            else:
                carry, h_t = self.cell(carry, output[i - 1])

            y_t = nn.Dense(x.shape[0])(h_t)  # output
            output.append(y_t)

        return output


In [None]:
def train_one2one(init_carry, input, sequence_length, cell, tf_bool=True):

    if not tf_bool:
        output = []

        for i in range(len(sequence_length)):
            if i == 0:
                carry, h_t = cell(init_carry, input)
            else:
                carry, h_t = cell(carry, output[i - 1])

            y_t = nn.Dense(x.shape[0])(h_t)  # output
            output.append(y_t)
    else:
        output = []

        for i in range(sequence_length):
            if i == 0:
                carry, h_t = self.cell(init_carry, input)
            else:
                carry, h_t = self.cell(carry, output[i - 1])

            y_t = nn.Dense(x.shape[0])(h_t)  # output
            output.append(y_t)

        return output


In [None]:
class RNN_m2o(nn.Module):
    """_summary_
        An RNN with many-to-one connections.

    Args:
        cell (_type_): choice of RNN cell object
        sequence_length (int): length of the input sequence
        carry (Tuple[jnp.ndarray, jnp.ndarray]): A tuple with the new carry and the output.

    """

    cell: RNNCellBase
    sequence_length: int
    carry: Tuple[jnp.ndarray, jnp.ndarray]

    @nn.compact()
    def __call__(self, x):
        output = []

        for i in range(self.sequence_length):
            if i == 0:
                carry, h_t = self.cell(self.carry, x)
            else:
                carry, h_t = self.cell(carry, output[i - 1])

            y_t = nn.Dense(x.shape[0])(h_t)  # output
            output.append(y_t)

        return output


In [None]:
class RNN(nn.Module):
    """
    Description:


    Args:
        lstm_size (int): number of units in one LSTM layer.
        num_layers (int): number of stacked LSTM layers.
        keep_prob (float): percentage of cell units to keep in the dropout operation.
        init_learning_rate (float): the learning rate to start with.
        learning_rate_decay (float): decay ratio in later training epochs.
        init_epoch (int): number of epochs using the constant init_learning_rate.
        max_epoch (int): total number of epochs in training
        input_size (int): size of the sliding window / one training data point
        batch_size (int): number of data points to use in one mini-batch.
    """

    input_size = 1
    num_steps = 30
    lstm_size = 128
    num_layers = 1
    keep_prob = 0.8
    batch_size = 64
    init_learning_rate = 0.001
    learning_rate_decay = 0.99
    init_epoch = 5
    max_epoch = 50

    rnn_type: str
    arch: Callable[jnp.ndarray, jnp.ndarray]

    def setup(self):
        self.choose_rnn_type()  # choose rnn type
        cell = self.choose_cell(cell_type="lstm")  # choose cell type

        pass

    def __call__(self, ht_1, input):

        y_t = nn.Dense(input.shape[0])(h_t)  # output
        return h_t, y_t

    def choose_cell(self, cell_type="lstm"):
        cell = None
        if cell_type == "regular":
            cell = RNNCell

        elif cell_type == "hippo":
            cell = HiPPOCell

        elif cell_type == "lstm":
            cell = LSTMCell

        elif cell_type == "gru":
            cell = GRUCell

        else:
            raise ValueError("Invalid cell type")

        return cell

    def choose_rnn_type(self):
        if self.rnn_type == "one2one":
            self.arch = self.one_to_one

        elif self.rnn_type == "one2many":
            self.arch = self.one_to_many

        elif self.rnn_type == "many2one":
            self.arch = self.many_to_one

        elif self.rnn_type == "many2many":
            self.arch = self.many_to_many

        else:
            raise ValueError("Invalid cell type")

    def one_to_one(self, ht_1, input, cell_init):
        _, y_t = cell_init(ht_1, input)

        return y_t

    def one_to_many(self, ht_1, input, cell_init):
        _, y_t = cell_init(ht_1, input)

        return y_t

    def many_to_one(self, ht_1, input, cell_init, output_bool=False):
        output = None

        if output_bool:
            h_t, _ = cell_init(ht_1, input)
            output = h_t.copy()
        else:
            _, y_t = cell_init(ht_1, input)
            output = y_t.copy()

        return output

    def many_to_many(self, ht_1, input, cell_init):
        h_t, y_t = cell_init(ht_1, input)

        return h_t, y_t
