in this notebook we try to learn how to use flax and test an implementation of unrolling/recurrence. 

### imports and setup

In [350]:
# ipython extension to autoreload imported modules so that any changes will be up to date before running code in this nb
%load_ext autoreload 
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [351]:
import numpy as np
import jax.numpy as jnp
from jax import grad, jit, vmap, value_and_grad
import jax.random
import flax
import jraph
import optax
from flax.training import train_state # Simple train state for the common case with a single Optax optimizer.

from utils.jraph_models import MLPBlock, MLPGraphNetwork
from utils.jraph_training import train_step, one_step_loss

from typing import Optional, Any
import time


### python for-loops

In [352]:
def count_to_x(x):
    for i in range(x):
        print(i)

In [353]:
%time count_to_x(10)

0
1
2
3
4
5
6
7
8
9
CPU times: user 333 µs, sys: 210 µs, total: 543 µs
Wall time: 427 µs


In [354]:
jit_count_to_x = jit(count_to_x)
# jit_count_to_x(10) # this will fail

In [355]:
jit_count_to_10 = jit(count_to_x, static_argnums=(10,))

In [356]:
%time jit_count_to_10

CPU times: user 2 µs, sys: 1 µs, total: 3 µs
Wall time: 6.2 µs


<PjitFunction of <function count_to_x at 0x14d0b3c10>>

In [357]:
%time jit_count_to_10

CPU times: user 2 µs, sys: 1 µs, total: 3 µs
Wall time: 6.2 µs


<PjitFunction of <function count_to_x at 0x14d0b3c10>>

In [394]:
jit_count_to_x = jit(count_to_x, static_argnames=["x"])

In [395]:
%time jit_count_to_x(5)

0
1
2
3
4
CPU times: user 37.2 ms, sys: 1.95 ms, total: 39.1 ms
Wall time: 39 ms


In [396]:
%time jit_count_to_x(5)

CPU times: user 37 µs, sys: 4 µs, total: 41 µs
Wall time: 43.9 µs


yay, so we've demonstrated that jitting a function can make it much faster after it compiles. we also see that if we want to jit a for loop then we'll have to set the static argnum. 

### flax RNNs

now let's try working with the flax RNN submodule. 

In [358]:
def get_dummy_graphtuple(seed=42):
    K = 36
    n_fts = 2
    rng = np.random.default_rng(seed=seed)
    dummy_data = rng.random((K, n_fts))   # array of shape (K, num_fts=2)

    # define edges
    receivers = []
    senders = []
    edge_fts = []

    for i in range(K):
        senders += [i] * 5
        receivers += [i, (i + 1) % K, (i + 2) % K, (i - 1) % K, (i - 2) % K]

        # edge features = length + direction of edge
        edge_fts += [
            [0],  # self edge
            [1],  # receiver is 1 node to the right
            [2],  # receiver is 2 nodes to the right
            [-1],  # receiver is 1 node to the left
            [-2]  # receiver is 2 nodes to the left
        ]

    return jraph.GraphsTuple(
        globals=jnp.array([[1.]]),  # placeholder global features for now (was an empty array and None both causing errors down the line?)
        # globals=jnp.array([]),  # no global features for now
        # globals=None,  # no global features for now
        nodes=jnp.array(
            dummy_data),  # node features = state values. shape of (K, 2)
        edges=jnp.array(edge_fts, dtype=float),
        receivers=jnp.array(receivers),
        senders=jnp.array(senders),
        n_node=jnp.array([K]),
        n_edge=jnp.array([K * 5]))

In [359]:
test_input_graph = get_dummy_graphtuple()
test_target_graphs = [get_dummy_graphtuple()]

set up params and state for a single prediction step

In [360]:
net = MLPBlock()
rng = jax.random.key(0)
rng, init_rng = jax.random.split(rng)

params = jax.jit(net.init)(init_rng, test_input_graph)

learning_rate = 0.001 # default learning rate for adam in keras 
tx = optax.adam(learning_rate=learning_rate)

state = train_state.TrainState.create(
    apply_fn=net.apply, params=params, tx=tx
)


In [361]:
params

{'params': {'MLP_0': {'Dense_0': {'bias': Array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],      dtype=float32),
    'kernel': Array([[ 0.12708823, -0.12388758,  0.34716293,  0.7363299 , -0.05217113,
             0.2559718 ,  0.1492992 , -0.06294186, -0.3936614 ,  0.39498648,
            -0.03390806,  0.42035106,  0.43463182, -0.6656627 , -0.04596043,
             0.5623761 ],
           [ 0.36956677, -0.09241027,  0.64535147, -0.57547694,  0.13745499,
            -0.02492646,  0.08678089, -0.2533039 , -0.01573062,  0.07404543,
            -0.21427661, -0.00971881, -0.7193296 , -0.37759158, -0.44224206,
            -0.2494645 ],
           [ 0.5481207 , -0.1552033 , -0.10531062, -0.7850843 ,  0.14456064,
            -0.35050648, -0.11576676,  0.40200943, -0.28664398,  0.07857388,
            -0.14365411, -0.15011843,  0.72559404,  0.48221382,  0.5161432 ,
             0.59971863],
           [ 0.14962323, -0.10533044,  0.09056184,  0.3435282 , -0.57027584,
       

In [362]:
print(net.tabulate(jax.random.key(0), test_input_graph)) # visualize a table of the model layers by passing an RNG key and template image input


[3m                                MLPBlock Summary                                [0m
┏━━━━━━━━━━━━━━━━┳━━━━━━━━━━┳━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━┓
┃[1m [0m[1mpath          [0m[1m [0m┃[1m [0m[1mmodule  [0m[1m [0m┃[1m [0m[1minputs        [0m[1m [0m┃[1m [0m[1moutputs       [0m[1m [0m┃[1m [0m[1mparams        [0m[1m [0m┃
┡━━━━━━━━━━━━━━━━╇━━━━━━━━━━╇━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━┩
│                │ MLPBlock │ -              │ -              │                │
│                │          │ [2mfloat32[0m[36,2]  │ [2mfloat32[0m[36,2]  │                │
│                │          │ -              │ -              │                │
│                │          │ [2mfloat32[0m[180,1] │ [2mfloat32[0m[180,1] │                │
│                │          │ - [2mint32[0m[180]   │ - [2mint32[0m[180]   │                │
│                │          │ - [2mint32[0m[180]   │ - [2mint32[0m[180]   │               

single prediction step

In [363]:
pred_graph = state.apply_fn(state.params, test_input_graph) 
preds = pred_graph.nodes
targets = test_target_graphs[0].nodes

In [364]:
pred_graph.edges

Array([[ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.],
       [ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.],
       [ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.],
       [ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.],
       [ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.],
       [ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.],
       [ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.],
       [ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.],
       [ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.],
       [ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.],
       [ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.],
       [ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.],
       [ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.],
       [ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.],
       [ 0.],
      

In [365]:
for ft in range(2):
    print(f"feature {ft}")
    for node in range(36):
        print(f"pred: {preds[node][ft]}, target: {targets[node][ft]}")

feature 0
pred: 0.7739560604095459, target: 0.7739560604095459
pred: 0.8585979342460632, target: 0.8585979342460632
pred: 0.09417735040187836, target: 0.09417735040187836
pred: 0.7611396908760071, target: 0.7611396908760071
pred: 0.12811362743377686, target: 0.12811362743377686
pred: 0.3707980215549469, target: 0.3707980215549469
pred: 0.6438651084899902, target: 0.6438651084899902
pred: 0.44341421127319336, target: 0.44341421127319336
pred: 0.554584801197052, target: 0.554584801197052
pred: 0.8276311755180359, target: 0.8276311755180359
pred: 0.7580877542495728, target: 0.7580877542495728
pred: 0.9706979990005493, target: 0.9706979990005493
pred: 0.7854065895080566, target: 0.7783834934234619
pred: 0.4667209982872009, target: 0.4667209982872009
pred: 0.15428949892520905, target: 0.15428949892520905
pred: 0.7447621822357178, target: 0.7447621822357178
pred: 0.32582536339759827, target: 0.32582536339759827
pred: 0.4695558249950409, target: 0.4695558249950409
pred: 0.12992151081562042, t

single train step

In [369]:
new_state, metrics_update, pred_nodes = train_step(
    state, test_input_graph, test_target_graphs
)

In [None]:
state.step

0

In [378]:
assert (jnp.array(1) == 1)

In [374]:
float(new_state.step) == 1
new_state.step

1

In [None]:
params['params']['MLP_0']['Dense_0'] 

{'bias': Array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],      dtype=float32),
 'kernel': Array([[ 0.12708823, -0.12388758,  0.34716293,  0.7363299 , -0.05217113,
          0.2559718 ,  0.1492992 , -0.06294186, -0.3936614 ,  0.39498648,
         -0.03390806,  0.42035106,  0.43463182, -0.6656627 , -0.04596043,
          0.5623761 ],
        [ 0.36956677, -0.09241027,  0.64535147, -0.57547694,  0.13745499,
         -0.02492646,  0.08678089, -0.2533039 , -0.01573062,  0.07404543,
         -0.21427661, -0.00971881, -0.7193296 , -0.37759158, -0.44224206,
         -0.2494645 ],
        [ 0.5481207 , -0.1552033 , -0.10531062, -0.7850843 ,  0.14456064,
         -0.35050648, -0.11576676,  0.40200943, -0.28664398,  0.07857388,
         -0.14365411, -0.15011843,  0.72559404,  0.48221382,  0.5161432 ,
          0.59971863],
        [ 0.14962323, -0.10533044,  0.09056184,  0.3435282 , -0.57027584,
          0.7743562 ,  0.48538765, -0.05884578,  0.91562   ,  0.64872515,
     

In [None]:
new_state.params['params']['MLP_0'].keys()

dict_keys(['Dense_0', 'Dense_1'])

In [None]:
state

TrainState(step=0, apply_fn=<bound method Module.apply of MLPBlock(
    # attributes
    dropout_rate = 0
    skip_connections = True
    layer_norm = False
    deterministic = True
)>, params={'params': {'MLP_0': {'Dense_0': {'bias': Array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],      dtype=float32), 'kernel': Array([[ 0.12708823, -0.12388758,  0.34716293,  0.7363299 , -0.05217113,
         0.2559718 ,  0.1492992 , -0.06294186, -0.3936614 ,  0.39498648,
        -0.03390806,  0.42035106,  0.43463182, -0.6656627 , -0.04596043,
         0.5623761 ],
       [ 0.36956677, -0.09241027,  0.64535147, -0.57547694,  0.13745499,
        -0.02492646,  0.08678089, -0.2533039 , -0.01573062,  0.07404543,
        -0.21427661, -0.00971881, -0.7193296 , -0.37759158, -0.44224206,
        -0.2494645 ],
       [ 0.5481207 , -0.1552033 , -0.10531062, -0.7850843 ,  0.14456064,
        -0.35050648, -0.11576676,  0.40200943, -0.28664398,  0.07857388,
        -0.14365411, -0.15011843, 

In [None]:
float(metrics_update.loss.total)

0.08238545060157776

In [373]:
int(metrics_update.loss.count)

1

In [None]:
# compute loss by hand to compare? 
mse = jnp.mean(jnp.square(preds - targets))
mse

Array(0.08238545, dtype=float32)

test multi-step rollout

In [392]:
n_steps = 5
test_target_graphs = [get_dummy_graphtuple(seed=i) for i in range(n_steps)]
new_state, metrics_update, pred_nodes = train_step(
    state, test_input_graph, test_target_graphs, n_steps=n_steps
)

In [393]:
pred_nodes

[Array([[0.77395606, 0.8975426 ],
        [0.85859793, 1.0016221 ],
        [0.09417735, 1.2731246 ],
        [0.7611397 , 1.125962  ],
        [0.12811363, 0.99429697],
        [0.37079802, 1.2111586 ],
        [0.6438651 , 1.1443826 ],
        [0.4434142 , 0.78128695],
        [0.5545848 , 0.5819175 ],
        [0.8276312 , 0.92344344],
        [0.75808775, 0.7155212 ],
        [0.970698  , 1.1413416 ],
        [0.7854066 , 0.6675924 ],
        [0.466721  , 0.5153479 ],
        [0.1542895 , 0.94776165],
        [0.7447622 , 1.1977992 ],
        [0.32582536, 0.88596725],
        [0.46955582, 0.71777076],
        [0.12992151, 0.94309694],
        [0.22690935, 1.0196851 ],
        [0.4371519 , 1.1684438 ],
        [0.7002651 , 0.8094662 ],
        [0.8322598 , 1.1067413 ],
        [0.38747838, 0.6576911 ],
        [0.6824955 , 0.63843   ],
        [0.1999082 , 0.49439633],
        [0.78692436, 0.9510118 ],
        [0.7051654 , 1.056159  ],
        [0.45891577, 0.9395046 ],
        [0.139

In [380]:
new_state
# TODO check if this still works with jit 

TrainState(step=1, apply_fn=<bound method Module.apply of MLPBlock(
    # attributes
    dropout_rate = 0
    skip_connections = True
    layer_norm = False
    deterministic = True
)>, params={'params': {'MLP_0': {'Dense_0': {'bias': Array([ 0.00099999,  0.00099996, -0.00099999, -0.00099999, -0.00099999,
       -0.00099999,  0.00099999,  0.00099999, -0.00099999,  0.00099999,
        0.        , -0.00099999, -0.00099999,  0.00099999, -0.00099999,
       -0.00099999], dtype=float32), 'kernel': Array([[ 0.12808822, -0.12488751,  0.34616295,  0.7353299 , -0.05117113,
         0.2549718 ,  0.15029919, -0.06194187, -0.39266142,  0.3939865 ,
        -0.03390806,  0.41935107,  0.43363184, -0.6646627 , -0.04496043,
         0.5613761 ],
       [ 0.37056676, -0.09341023,  0.6443515 , -0.57647693,  0.136455  ,
        -0.02592645,  0.08778089, -0.2523039 , -0.01673061,  0.07504543,
        -0.21427661, -0.0107188 , -0.7203296 , -0.3765916 , -0.44324204,
        -0.2504645 ],
       [ 0.54912066,