# This notebook focuses on flax.linen.scan function, key for any recurrent model.


In this notebook I show: 
- Two ways of implementing a custom initialization for a Flax model.
- How to use the `nn.scan` function to implement a simple RNN.
- How to use the `tabulate` function to display the model's parameters.
- How to compute the flop of this model. 
- Maximum FLOPs achievable by the GPU through JAX.

We implement: 

$
h_t = \tanh(W_{hh} h_{t-1} + (W_{xh} x_t + b_h))
$

where $h_t$ is the hidden state at time $t$, $x_t$ is the input at time $t$, $W_{hh}$, $W_{xh}$ and $b_h$ are the weights and bias of the RNN.

- $W_{hh} h_{t-1}$ is computed via an explicit matrix multiplication (`jnp.dot`, actually we compute $h_{t-1} W_{hh}$).
- $(W_{xh} x_t + b_h)$ is computed via a `nn.Dense` layer.

In [2]:
import flax.linen as nn
import jax.numpy as jnp
import jax
import os 
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

As far as I go, `nn.scan` can only be used inside a __call__ method on a function also defined inside the __call__ method. 

Also, the scanned function must use the signature `fn(self, carry, x)`

In [3]:
def custom_w_init():
    def init(rng, shape):
        return jax.random.uniform(rng, shape, minval=-0.1, maxval=0.1)
    return init

def another_custom_w_init(rng, shape, dtype=jnp.float32):
    return jax.random.uniform(rng, shape, minval=-0.1, maxval=0.1, dtype=dtype)

class RNNCell(nn.Module):
    hidden_dim: int = 10

    @nn.compact
    def __call__(self, x):
        Wh = self.param('W', custom_w_init(), (self.hidden_dim, self.hidden_dim))
        dense_in = nn.Dense(features=self.hidden_dim, kernel_init=another_custom_w_init)
        h = jnp.zeros((self.hidden_dim,))

        def update(self, h, x):
            h = jnp.tanh(jnp.dot(h, Wh) + dense_in(x))
            return h, h # Return the new carry and the output   
        
        scan_update = nn.scan(
            update,
            variable_broadcast='params',
            in_axes=0,
            out_axes=0
        )
        
        return scan_update(self, h, x)



# Define inputs
x = jnp.ones((20, 100))  # 5 timesteps, input size 10
HIDDEN_DIM = 10
# Initialize
key = jax.random.PRNGKey(0)
keys = jax.random.split(key)
model = RNNCell(HIDDEN_DIM)
params = model.init(key, x)
out, hist = model.apply(params, x)
print(out.shape, hist.shape)  # (10,) (20, 10)


(10,) (20, 10)


In [4]:
# Define inputs
HIDDEN_DIM = 100
x = jnp.ones((200000, HIDDEN_DIM))  # 5 timesteps, input size 10

# Initialize
key = jax.random.PRNGKey(0)
model = RNNCell(hidden_dim=HIDDEN_DIM)
params = model.init(key, x)
out, hist = model.apply(params, x)
print(out.shape, hist.shape)  # (5, 10) (5, 10)


(100,) (200000, 100)


# compute the flops of this RNN
- $W_h h_{t-1}$ is a matrix multiplication of size 100x100 -> 100x(100 mults + 99 adds) = 19'900 flops
- $W_x x_t$ is a matrix multiplication of size 100x100 -> 100x(100 mults + 99 adds) = 19'900 flops
- $W_h h_{t-1} + W_x x_t$ is an addition of two vectors of size 100 -> 100 adds = 100 flops
- The activation function is 100 flops --> 100 flop
- The total flop per time step is 19'900 + 19'900 + 100 + 100 = 39'000 flop
- We do 200'000 time steps --> 200'000 * 39'000 = 7'800'000'000 flop
- The runtime is 1.95s --> 7'800'000'000 flop / 1.95s = 4'000'000'000 flop/s = 4 GFLOP/s
- The jit runtime is 1.85s --> 7'800'000'000 flop / 1.85s = 4'216'216'216 flop/s = 4.2 GFLOP/s

In [6]:
%timeit model.apply(params, x)  # 1.5 s

ScopeParamShapeError: Initializer expected to generate shape (100, 100) but got shape (10, 100) instead for parameter "kernel" in "/Dense_0". (https://flax.readthedocs.io/en/latest/api_reference/flax.errors.html#flax.errors.ScopeParamShapeError)

In [12]:
jit_model = jax.jit(model.apply)
%timeit jit_model(params, x)  # 1.5 s

1.86 s ± 124 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


# Print the model params

In [5]:
# use the tabulate function to see the number of parameters
x = jnp.ones((5, 10))  # 5 timesteps, input size 10
tabulate_fn = nn.tabulate(RNNCell(), jax.random.PRNGKey(0))
print(tabulate_fn(x))
print(jax.tree_map(lambda x: x.shape, params))


[3m                                RNNCell Summary                                 [0m
┏━━━━━━━━━┳━━━━━━━━━┳━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┓
┃[1m [0m[1mpath   [0m[1m [0m┃[1m [0m[1mmodule [0m[1m [0m┃[1m [0m[1minputs       [0m[1m [0m┃[1m [0m[1moutputs        [0m[1m [0m┃[1m [0m[1mparams                [0m[1m [0m┃
┡━━━━━━━━━╇━━━━━━━━━╇━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━┩
│         │ RNNCell │ [2mfloat32[0m[5,10] │ - [2mfloat32[0m[10]   │ W: [2mfloat32[0m[10,10]      │
│         │         │               │ - [2mfloat32[0m[5,10] │                        │
│         │         │               │                 │ [1m100 [0m[1;2m(400 B)[0m            │
├─────────┼─────────┼───────────────┼─────────────────┼────────────────────────┤
│ Dense_0 │ Dense   │ [2mfloat32[0m[10]   │ [2mfloat32[0m[10]     │ bias: [2mfloat32[0m[10]      │
│         │         │               │                 │ kernel: [2m

- Matrix multiplication: 
- 10000 x (10000 multiplications + 9999 additions) = 100,000,000 multiplications + 99,990,000 additions
- 200M operations - 24ms = 200M / 24ms = 8.33 GFLOPS
- memory: 10000x10000 x 4 bytes * 3 matrices = 1.2GB
- 5000x(5000 multiplications + 4999 additions) = 25,000,000 multiplications + 24,995,000 additions
- 50M / 3.39ms = 14.75 GFLOPS
- memory: 5000x5000 x 4 bytes * 3 matrices = 600MB
- 2048x(2048 multiplications + 2047 additions) = 4'194'304 multiplications + 4'192'256 additions 
- 8'386'560 / 210us = 39.94 GFLOPS
- 2000x(2000 multiplications + 1999 additions) = 4,000,000 multiplications + 3,998,000 additions
- 8M / 205us = 39.02 GFLOPS
- 1500x(1500 multiplications + 1499 additions) = 2,250,000 multiplications + 2,248,500 additions
- 4.5M / 160us = 28.13 GFLOPS
- 1024x(1024 multiplications + 1023 additions) = 1,048,576 multiplications + 1,047,552 additions
- 2'096'128 / 37us = 56.65 GFLOPS 
- memory: 1024x1024 x 4 bytes * 3 matrices = 12MB 
- 1000 x (1000 multiplications + 999 additions) = 1,000,000 multiplications + 999,000 additions
- 1'999'000 / 34.5us = 58.03 GFLOPS
- 500x(500 multiplications + 499 additions) = 250,000 multiplications + 249,500 additions
- 500k / 45us = 11.11 GFLOPS

In [19]:
A = jnp.ones((10000, 10000))
B = jnp.ones((10000, 10000))
def matmul(A, B):
    return jnp.dot(A, B)
jit_matmul = jax.jit(matmul)

In [20]:
_ = jit_matmul(A, B)  # warmup

In [None]:
%timeit jit_matmul(A, B)  # 1.5 s

24.1 ms ± 21.3 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


: 

# Benchmarking the softmax function of a vector
- x of size N
- exp(x) = N flop
- sum(exp(x)) = N-1 flop (can be highly optimized)
- exp(x) / sum(exp(x)) = N flop
- Total flop = 3N flop

In [51]:
jit_softmax = jax.jit(jax.nn.softmax)
x = jax.random.normal(jax.random.PRNGKey(0), (500000,))
a = jit_softmax(x)  # warmup
a.shape

(500000,)

In [52]:
%timeit jax.nn.softmax(x)
%timeit jit_softmax(x)  # 1.5 s

227 μs ± 45.4 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [10]:
x = jax.random.normal(jax.random.PRNGKey(0), (500000,))
cpu_softmax = jax.jit(jax.nn.softmax, device=jax.devices("cpu")[0])
result = cpu_softmax(x)
gpu_softmax = jax.jit(jax.nn.softmax, device=jax.devices("gpu")[0])
result = gpu_softmax(x)


In [11]:
%timeit cpu_softmax(x)  # 
%timeit gpu_softmax(x)  # 

645 μs ± 60.5 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
30.9 μs ± 1.54 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [12]:
def hand_softmax(x):
    return jnp.exp(x) / jnp.sum(jnp.exp(x))
jit_hand_softmax = jax.jit(hand_softmax)
result = jit_hand_softmax(x)  # warmup

In [13]:
%timeit hand_softmax(x)  
%timeit jit_hand_softmax(x)  # always 10% faster than jax.nn.softmax, surprisingly

98.5 μs ± 2.77 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
27.9 μs ± 1.89 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


: 