# This notebook focuses on flax.linen.scan function, key for any recurrent model.


In this notebook I show: 
- Two ways of implementing a custom initialization for a Flax model.
- How to use the `nn.scan` function to implement a simple RNN.
- How to use the `tabulate` function to display the model's parameters.
- How to compute the flop of this model. 
- Maximum FLOPs achievable by the GPU through JAX.

We implement: 

$
h_t = \tanh(W_{hh} h_{t-1} + (W_{xh} x_t + b_h))
$

where $h_t$ is the hidden state at time $t$, $x_t$ is the input at time $t$, $W_{hh}$, $W_{xh}$ and $b_h$ are the weights and bias of the RNN.

- $W_{hh} h_{t-1}$ is computed via an explicit matrix multiplication (`jnp.dot`, actually we compute $h_{t-1} W_{hh}$).
- $(W_{xh} x_t + b_h)$ is computed via a `nn.Dense` layer.

In [1]:
import flax.linen as nn
import jax.numpy as jnp
import jax
import os 
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
# jax memory allocation
os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '0.95'

As far as I go, `nn.scan` can only be used inside a __call__ method on a function also defined inside the __call__ method. 

Also, the scanned function must use the signature `fn(self, carry, x)`

In [None]:
def custom_w_init():
    def init(rng, shape):
        return jax.random.uniform(rng, shape, minval=-0.1, maxval=0.1)
    return init

def another_custom_w_init(rng, shape, dtype=jnp.float32):
    return jax.random.uniform(rng, shape, minval=-0.1, maxval=0.1, dtype=dtype)

class RNNCell(nn.Module):
    hidden_dim: int = 10

    @nn.compact
    def __call__(self, x):
        Wh = self.param('W', custom_w_init(), (self.hidden_dim, self.hidden_dim))
        dense_in = nn.Dense(features=self.hidden_dim, kernel_init=another_custom_w_init)
        h = jnp.zeros((self.hidden_dim,))

        def update(self, h, x):
            h = jnp.tanh(jnp.dot(h, Wh) + dense_in(x))
            return h, h # Return the new carry and the output   
        
        scan_update = nn.scan(
            update,
            variable_broadcast='params',
            in_axes=0,
            out_axes=0
        )
        
        return scan_update(self, h, x)



# Define inputs
x = jnp.ones((20, 100))  # 5 timesteps, input size 10
HIDDEN_DIM = 10
# Initialize
key = jax.random.PRNGKey(0)
keys = jax.random.split(key)
model = RNNCell(HIDDEN_DIM)
params = model.init(key, x)
out, hist = model.apply(params, x)
print(out.shape, hist.shape)  # (10,) (20, 10)


In [3]:
# Define inputs
HIDDEN_DIM = 100
x = jnp.ones((200000, HIDDEN_DIM))  # 5 timesteps, input size 10

# Initialize
key = jax.random.PRNGKey(0)
model = RNNCell(hidden_dim=HIDDEN_DIM)
params = model.init(key, x)
out, hist = model.apply(params, x)
print(out.shape, hist.shape)  # (5, 10) (5, 10)


(100,) (200000, 100)


# compute the flops of this RNN
- $W_h h_{t-1}$ is a matrix multiplication of size 100x100 -> 100x(100 mults + 99 adds) = 19'900 flops
- $W_x x_t$ is a matrix multiplication of size 100x100 -> 100x(100 mults + 99 adds) = 19'900 flops
- $W_h h_{t-1} + W_x x_t$ is an addition of two vectors of size 100 -> 100 adds = 100 flops
- The activation function is 100 flops --> 100 flop
- The total flop per time step is 19'900 + 19'900 + 100 + 100 = 39'000 flop
- We do 200'000 time steps --> 200'000 * 39'000 = 7'800'000'000 flop
- The runtime is 1.95s --> 7'800'000'000 flop / 1.95s = 5'650'000'000 flop/s = 5.65 GFLOP/s
- The jit runtime is 1.35s --> 7'800'000'000 flop / 1.85s = 5'820'895'522 flop/s = 5.82 GFLOP/s

In [4]:
%timeit model.apply(params, x) 

1.38 s ± 14 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [5]:
jit_model = jax.jit(model.apply)
%timeit jit_model(params, x) 

1.34 s ± 10.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


# Print the model params

In [6]:
# use the tabulate function to see the number of parameters
x = jnp.ones((5, 10))  # 5 timesteps, input size 10
tabulate_fn = nn.tabulate(RNNCell(), jax.random.PRNGKey(0))
print(tabulate_fn(x))
print(jax.tree_map(lambda x: x.shape, params))


[3m                                RNNCell Summary                                 [0m
┏━━━━━━━━━┳━━━━━━━━━┳━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┓
┃[1m [0m[1mpath   [0m[1m [0m┃[1m [0m[1mmodule [0m[1m [0m┃[1m [0m[1minputs       [0m[1m [0m┃[1m [0m[1moutputs        [0m[1m [0m┃[1m [0m[1mparams                [0m[1m [0m┃
┡━━━━━━━━━╇━━━━━━━━━╇━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━┩
│         │ RNNCell │ [2mfloat32[0m[5,10] │ - [2mfloat32[0m[10]   │ W: [2mfloat32[0m[10,10]      │
│         │         │               │ - [2mfloat32[0m[5,10] │                        │
│         │         │               │                 │ [1m100 [0m[1;2m(400 B)[0m            │
├─────────┼─────────┼───────────────┼─────────────────┼────────────────────────┤
│ Dense_0 │ Dense   │ [2mfloat32[0m[10]   │ [2mfloat32[0m[10]     │ bias: [2mfloat32[0m[10]      │
│         │         │               │                 │ kernel: [2m