In [11]:
import jax
import jax.ops
import jax.numpy as jnp

import flax
from flax import linen as nn
from flax import optim

import optax

import numpy as np  # convention: original numpy

from typing import Any, Callable, Sequence, Optional, Tuple, Union

In [12]:
seed = 1701
key = jax.random.PRNGKey(seed)

In [13]:
num_copies = 5
rng, key2, key3, key4, key5 = jax.random.split(key, num=num_copies)

In [14]:
class RNNCell(nn.Module):
  @nn.compact
  def __call__(self, state, x):
    # Wh @ h + Wx @ x + b can be efficiently computed
    # by concatenating the vectors and then having a single dense layer
    x = jnp.concatenate([state, x])
    new_state = jnp.tanh(nn.Dense(state.shape[0])(x))
    return new_state

In [15]:
L = 100

x = jax.random.normal(rng, shape=(L, 1))

params = RNNCell.init(key2, x)
    
c_legs = RNNCell.apply(params, x)
    

AttributeError: 'DeviceArray' object has no attribute 'init_with_output'

In [None]:
class RNN(nn.Module):
    """
    Description:
     Wh @ ht_1 + Wx @ x + b

    Args:
        W_xh (jnp.ndarray): # weights associated with the input x_{i} and the hidden state h_{t}
        W_hh (jnp.ndarray): # weights associated with the hidden state h_{t} and the hidden state h_{t-1}
        W_hy (jnp.ndarray): # weights associated with the hidden state h_{t} and the output y_{t}
    """
    W_xh: jnp.ndarray
    W_hh: jnp.ndarray
    W_hy: jnp.ndarray
    
    
    
    def setup(self):
        
        pass
    
    def __call__(self, ht_1, input):
        '''
        Description:
            W_xh = Wx @ x + b - this a linear layer
            W_hh = Wh @ ht_1 - multiply the previous hidden state with 
            
            h_{t} = f_{w}(Wh @ ht_1 + Wx @ x + b)
            h_{t} = Wh @ ht_1 + Wx @ x + b

        Args:
            ht_1 (jnp.ndarray): hidden state from previous time step (aka the current cell state)
            input (jnp.ndarray): # input vector
            
        Returns:
            ht (jnp.ndarray): hidden state from current time step
        '''
        
        self.W_xh = nn.Dense(ht_1.shape[0])(input)
        self.W_hh = self.W_hh @ ht_1
        x = ht_1 @ x
        h_t = jnp.tanh(nn.Dense(ht_1.shape[0])(x))
        return h_t
    
class RNNCell(nn.Module):
    """
    Description:
     Wh @ ht_1 + Wx @ x + b

    Args:
        W_xh (jnp.ndarray): weights associated with the input x_{i} and the hidden state h_{t}
        W_hh (jnp.ndarray): weights associated with the hidden state h_{t} and the hidden state h_{t-1}
        W_hy (jnp.ndarray): weights associated with the hidden state h_{t} and the output y_{t}
        hidden_dim (int): dimension of the hidden state
        io_dim (int): dimension of the input and output
    """
    W_xh: jnp.ndarray
    W_hh: jnp.ndarray
    W_hy: jnp.ndarray
    hidden_dim: int
    io_dim: int
    
    def setup(self):
        # self.Wxh = np.random.randn(hidden_size, vocab_size)*0.01 # input to hidden
        # self.Whh = np.random.randn(hidden_size, hidden_size)*0.01 # hidden to hidden
        # self.Why = np.random.randn(vocab_size, hidden_size)*0.01 # hidden to output
        
        self.W_xh = jnp.ndarray(np.empty((self.hidden_dim, self.io_dim), dtype=jnp.float64))
        self.W_hh = jnp.ndarray(np.empty((self.hidden_dim, self.hidden_dim), dtype=jnp.float64))
        
        
    def __call__(self, ht_1, input):
        '''
        Description:
            W_xh = x_{t} @ W_{xh} - multiply the previous hidden state with 
            W_hh = H_{t-1} @ W_{hh} + b_{h} - this a linear layer 
            
            H_{t} = f_{w}(H_{t-1}, x)
            H_{t} = tanh(H_{t-1} @ W_{hh}) + (x_{t} @ W_{xh})

        Args:
            ht_1 (jnp.ndarray): hidden state from previous time step (aka the current cell state)
            input (jnp.ndarray): # input vector
            
        Returns:
            h_t (jnp.ndarray): hidden state from current time step
        '''
        W_hh = nn.Dense(ht_1.shape[0])(ht_1) 
        W_xh = input @ self.W_xh
        h_t = jnp.tanh(W_hh + W_xh) # H_{t} = tanh(H_{t-1} @ W_{hh}) + (x_{t} @ W_{xh})
        y_t = nn.Dense(h_t.shape[0])(h_t) # \hat{y}_{t} = W^{T}_{hy} @ h_{t} + b^{T}_{hy}
        
        return h_t, y_t
    
    

In [None]:
class RNN(nn.Module):
    """
    Description:
    

    Args:
        nn (_type_): _description_
    """
    cell: Callable[jnp.ndarray, jnp.ndarray]
    hidden_dim: int
    output_dim: int
    
    def setup(self):
        self.rnn_cell = self.cell(W_hh=self., hidden_dim=self.hidden_dim)
        pass
    
    def __call__(self, ht_1, input):
        h_t = self.rnn_cell(ht_1, input)
        y_t = nn.Dense(h_t.shape[0])(h_t) # \hat{y}_{t} = W^{T}_{hy} @ h_{t}
        return h_t, y_t