# This notebook focuses on flax.linen.scan function, key for any recurrent model.


In this notebook I show: 
- Two ways of implementing a custom initialization for a Flax model.
- How to use the `nn.scan` function to implement a simple RNN.
- How to use the `tabulate` function to display the model's parameters.
- How to compute the flop of this model. 
- Maximum FLOPs achievable by the GPU through JAX.

We implement: 

$
h_t = \tanh(W_{hh} h_{t-1} + (W_{xh} x_t + b_h))
$

where $h_t$ is the hidden state at time $t$, $x_t$ is the input at time $t$, $W_{hh}$, $W_{xh}$ and $b_h$ are the weights and bias of the RNN.

- $W_{hh} h_{t-1}$ is computed via an explicit matrix multiplication (`jnp.dot`, actually we compute $h_{t-1} W_{hh}$).
- $(W_{xh} x_t + b_h)$ is computed via a `nn.Dense` layer.

In [2]:
import flax.linen as nn
import jax.numpy as jnp
import jax
import os 
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

As far as I go, `nn.scan` can only be used inside a __call__ method on a function also defined inside the __call__ method. 

Also, the scanned function must use the signature `fn(self, carry, x)`

In [2]:
def custom_w_init():
    def init(rng, shape):
        return jax.random.uniform(rng, shape, minval=-0.1, maxval=0.1)
    return init

def another_custom_w_init(rng, shape, dtype=jnp.float32):
    return jax.random.uniform(rng, shape, minval=-0.1, maxval=0.1, dtype=dtype)

class RNNCell(nn.Module):
    hidden_dim: int = 10

    @nn.compact
    def __call__(self, x):
        Wh = self.param('W', custom_w_init(), (self.hidden_dim, self.hidden_dim))
        dense_in = nn.Dense(features=self.hidden_dim, kernel_init=another_custom_w_init)
        h = jnp.zeros((self.hidden_dim,))

        def update(self, h, x):
            h = jnp.tanh(jnp.dot(h, Wh) + dense_in(x))
            return h, h # Return the new carry and the output   
        
        scan_update = nn.scan(
            update,
            variable_broadcast='params',
            in_axes=0,
            out_axes=0
        )
        
        return scan_update(self, h, x)



# Define inputs
x = jnp.ones((20, 100))  # 5 timesteps, input size 10
HIDDEN_DIM = 10
# Initialize
key = jax.random.PRNGKey(0)
keys = jax.random.split(key)
model = RNNCell(HIDDEN_DIM)
params = model.init(key, x)
out, hist = model.apply(params, x)
print(out.shape, hist.shape)  # (10,) (20, 10)


(10,) (20, 10)


In [3]:
# Define inputs
HIDDEN_DIM = 100
x = jnp.ones((200000, HIDDEN_DIM))  # 5 timesteps, input size 10

# Initialize
key = jax.random.PRNGKey(0)
model = RNNCell(hidden_dim=HIDDEN_DIM)
params = model.init(key, x)
out, hist = model.apply(params, x)
print(out.shape, hist.shape)  # (5, 10) (5, 10)


(100,) (200000, 100)


# compute the flops of this RNN
- $W_h h_{t-1}$ is a matrix multiplication of size 100x100 -> 100x(100 mults + 99 adds) = 19'900 flops
- $W_x x_t$ is a matrix multiplication of size 100x100 -> 100x(100 mults + 99 adds) = 19'900 flops
- $W_h h_{t-1} + W_x x_t$ is an addition of two vectors of size 100 -> 100 adds = 100 flops
- The activation function is 100 flops --> 100 flop
- The total flop per time step is 19'900 + 19'900 + 100 + 100 = 39'000 flop
- We do 200'000 time steps --> 200'000 * 39'000 = 7'800'000'000 flop
- The runtime is 1.95s --> 7'800'000'000 flop / 1.95s = 4'000'000'000 flop/s = 4 GFLOP/s
- The jit runtime is 1.85s --> 7'800'000'000 flop / 1.85s = 4'216'216'216 flop/s = 4.2 GFLOP/s

In [None]:
%timeit model.apply(params, x) 

1.89 s ± 57.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [None]:
jit_model = jax.jit(model.apply)
%timeit jit_model(params, x) 

1.83 s ± 108 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


# Print the model params

In [6]:
# use the tabulate function to see the number of parameters
x = jnp.ones((5, 10))  # 5 timesteps, input size 10
tabulate_fn = nn.tabulate(RNNCell(), jax.random.PRNGKey(0))
print(tabulate_fn(x))
print(jax.tree_map(lambda x: x.shape, params))


[3m                                RNNCell Summary                                 [0m
┏━━━━━━━━━┳━━━━━━━━━┳━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┓
┃[1m [0m[1mpath   [0m[1m [0m┃[1m [0m[1mmodule [0m[1m [0m┃[1m [0m[1minputs       [0m[1m [0m┃[1m [0m[1moutputs        [0m[1m [0m┃[1m [0m[1mparams                [0m[1m [0m┃
┡━━━━━━━━━╇━━━━━━━━━╇━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━┩
│         │ RNNCell │ [2mfloat32[0m[5,10] │ - [2mfloat32[0m[10]   │ W: [2mfloat32[0m[10,10]      │
│         │         │               │ - [2mfloat32[0m[5,10] │                        │
│         │         │               │                 │ [1m100 [0m[1;2m(400 B)[0m            │
├─────────┼─────────┼───────────────┼─────────────────┼────────────────────────┤
│ Dense_0 │ Dense   │ [2mfloat32[0m[10]   │ [2mfloat32[0m[10]     │ bias: [2mfloat32[0m[10]      │
│         │         │               │                 │ kernel: [2m

# Matrix multiplication: 
Assume $N$ is the size of the matrix. Let `bs` be the batch size.
Then we have: $ bs \text{ matrix multiplications } * N \text{ rows } * (N \text{ multiplications } + (N-1) \text{ additions}) = bs * N * (2N-1) \text{ flop} $

When we use vmap, we actually do `bs` matrix multiplications in parallel which means that the memory usage passes from $N^2 * 3$ to $(2*bs + 1) N^2$ because the batched tensor is dispatched, every dispatch produces a result while the common tensor is shared between all the dispatches.

| N     | bs    | Flop          | Time     | Flops        | Memory       |
|-------|-------|---------------|----------|--------------|--------------|
| 10000 | 1     |   199'990'000 |    24 ms |  8.33 GFLOPS |  1.12 GB     |
| 5000  | 1     |    49'995'000 |  3.39 ms | 14.75 GFLOPS |   286 MB     |
| 2048  | 1     |     8'386'560 |   210 us | 39.94 GFLOPS |    48 MB     |
| 2000  | 1     |     7'998'000 |   205 us | 39.02 GFLOPS |    46 MB     |
| 1500  | 1     |     4'498'500 |   160 us | 28.13 GFLOPS |    26 MB     |
| 1024  | 1     |     2'096'128 |    37 us | 56.65 GFLOPS |    12 MB     |
| 1000  | 1     |     1'999'000 |  34.5 us | 58.03 GFLOPS |    11 MB     |
| 500   | 1     |       500'000 |    45 us | 11.11 GFLOPS |     3 MB     |
| 1000  | 100   |   199'900'000 |  2.53 ms | 79.1  GFLOPS |   766 MB     |
| 1000  | 1000  | 1'999'000'000 |  24.5 ms | 81.63 GFLOPS |  7.45 GB     |
| 1024  | 1024  | 2'146'435'072 |  26.1 ms | 82.24 GFLOPS |     8 GB     |

In [68]:
def flop_compute(N, bs):
    return bs * N * (2*N - 1)
def memory(N, bs):
    return ((2*bs + 1) * N**2) * 4 
def flops_compute(N, bs, time_in_s):
    a = flop_compute(N, bs)
    b = memory(N, bs)
    print(f'{a} - {a/time_in_s:.4e} - {b/(1024**3):.2f} GB - {b/(1024**2):.2f} MB - {b/(1024):.2f} KB')

flops_compute(10000, 1, 0.024)
flops_compute(5000, 1, 0.00339)
flops_compute(2048, 1, 0.000210)
flops_compute(2000, 1, 0.000205)
flops_compute(1500, 1, 0.000160)
flops_compute(1024, 1, 0.000037)
flops_compute(1000, 1, 0.000034)
flops_compute(500, 1, 0.000045)
flops_compute(1000, 100, 0.00253)
flops_compute(1000, 1000, 0.0245)
flops_compute(1024, 1024, 0.0261)


199990000 - 8.3329e+09 - 1.12 GB - 1144.41 MB - 1171875.00 KB
49995000 - 1.4748e+10 - 0.28 GB - 286.10 MB - 292968.75 KB
8386560 - 3.9936e+10 - 0.05 GB - 48.00 MB - 49152.00 KB
7998000 - 3.9015e+10 - 0.04 GB - 45.78 MB - 46875.00 KB
4498500 - 2.8116e+10 - 0.03 GB - 25.75 MB - 26367.19 KB
2096128 - 5.6652e+10 - 0.01 GB - 12.00 MB - 12288.00 KB
1999000 - 5.8794e+10 - 0.01 GB - 11.44 MB - 11718.75 KB
499500 - 1.1100e+10 - 0.00 GB - 2.86 MB - 2929.69 KB
199900000 - 7.9012e+10 - 0.75 GB - 766.75 MB - 785156.25 KB
1999000000 - 8.1592e+10 - 7.45 GB - 7633.21 MB - 7816406.25 KB
2146435072 - 8.2239e+10 - 8.00 GB - 8196.00 MB - 8392704.00 KB


In [52]:
N = 1024
A = jnp.ones((N, N))
B = jnp.ones((N, N))
def matmul(A, B):
    return jnp.dot(A, B)
jit_matmul = jax.jit(matmul)

In [53]:
_ = jit_matmul(A, B)  # warmup

In [54]:
%timeit jit_matmul(A, B)  

36.3 μs ± 40.9 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [55]:
bs = 1024
C = jnp.ones((bs, N, N))
jit_vmap = jax.jit(jax.vmap(matmul, in_axes=(0, None), out_axes=0))
_ = jit_vmap(C, B)  # warmup

In [56]:
%timeit jit_vmap(C, B)

26.1 ms ± 1.32 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [57]:
vmap_jit = jax.vmap(jax.jit(matmul), in_axes=(0, None), out_axes=0)
_ = vmap_jit(C, B)  # warmup

In [58]:
%timeit vmap_jit(C, B)  

26.1 ms ± 26.6 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [19]:
def mvm(A, b):
    return jnp.dot(A, b)

jit_mvm = jax.jit(mvm)
vmap_jit_mvm = jax.vmap(jit_mvm, in_axes=(None, 0), out_axes=0)
_ = vmap_jit_mvm(A, B)  # warmup

In [20]:
%timeit vmap_jit_mvm(A, B)

The slowest run took 37.37 times longer than the fastest. This could mean that an intermediate result is being cached.
3.31 ms ± 6.79 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [13]:
jit_vmap_mvm = jax.jit(jax.vmap(mvm, in_axes=(None, 0), out_axes=0))
_ = jit_vmap_mvm(A, B)  # warmup

In [14]:
%timeit jit_vmap_mvm(A, B)  

46.1 μs ± 8.06 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


# Benchmarking the softmax function of a vector
- x of size N
- exp(x) = N flop
- sum(exp(x)) = N-1 flop (can be highly optimized)
- exp(x) / sum(exp(x)) = N flop
- Total flop = 3N flop

In [51]:
jit_softmax = jax.jit(jax.nn.softmax)
x = jax.random.normal(jax.random.PRNGKey(0), (500000,))
a = jit_softmax(x)  # warmup
a.shape

(500000,)

In [52]:
%timeit jax.nn.softmax(x)
%timeit jit_softmax(x)  # 1.5 s

227 μs ± 45.4 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [10]:
x = jax.random.normal(jax.random.PRNGKey(0), (500000,))
cpu_softmax = jax.jit(jax.nn.softmax, device=jax.devices("cpu")[0])
result = cpu_softmax(x)
gpu_softmax = jax.jit(jax.nn.softmax, device=jax.devices("gpu")[0])
result = gpu_softmax(x)


In [11]:
%timeit cpu_softmax(x)  # 
%timeit gpu_softmax(x)  # 

645 μs ± 60.5 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
30.9 μs ± 1.54 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [12]:
def hand_softmax(x):
    return jnp.exp(x) / jnp.sum(jnp.exp(x))
jit_hand_softmax = jax.jit(hand_softmax)
result = jit_hand_softmax(x)  # warmup

In [13]:
%timeit hand_softmax(x)  
%timeit jit_hand_softmax(x)  # always 10% faster than jax.nn.softmax, surprisingly

98.5 μs ± 2.77 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
27.9 μs ± 1.89 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


: 