### imports and setup

In [4]:
# ipython extension to autoreload imported modules so that any changes will be up to date before running code in this nb
%load_ext autoreload 
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [22]:
from utils.jraph_data import get_lorenz_graph_tuples, print_graph_fts
from flax_gnn_example.train import rollout_loss, train_step, train_step_fn #, rollout_loss_batched, 
from utils.jraph_models import MLPBlock
import optax
from flax.training import train_state

import numpy as np
import jax.numpy as jnp
import jax

### test single rollout

In [6]:
def get_sample_data(seed=42):
    sample_dataset = get_lorenz_graph_tuples(n_samples=2,
                        input_steps=3,
                        output_delay=0,
                        output_steps=2,
                        timestep_duration=1,
                        sample_buffer=1,
                        time_resolution=100,
                        init_buffer_samples=0,
                        train_pct=1.0,
                        val_pct=0,
                        test_pct=0,
                        K=36,
                        F=8,
                        c=10,
                        b=10,
                        h=1,
                        seed=seed,
                        normalize=False)
    # input_window = sample_dataset['train']['input'][0]
    # target_window = sample_dataset['train']['targets'][0]
    return sample_dataset 

sample_dataset = get_sample_data()

In [9]:
sample_input_window = sample_dataset['train']['inputs'][0]
sample_target_window = sample_dataset['train']['targets'][0]

sample_input_graph = sample_input_window[0]
sample_target_graph = sample_target_window[0]

sample_input_batch = sample_dataset['train']['inputs']
sample_target_batch = sample_dataset['train']['targets']

print_graph_fts(sample_input_graph)
print_graph_fts(sample_target_graph)

Number of nodes: 36
Number of edges: 180
Node features shape: (36, 2)
Edge features shape: (180, 1)
Global features shape: (1, 1)
Number of nodes: 36
Number of edges: 180
Node features shape: (36, 2)
Edge features shape: (180, 1)
Global features shape: (1, 1)


In [10]:
print(sample_input_graph.nodes.shape)
print(sample_input_graph.n_node)
print(sample_input_graph.edges.shape)
print(sample_input_graph.n_edge)
print(sample_input_graph.receivers.shape)
print(sample_input_graph.n_node[1])
print(sample_input_graph.n_node.shape)

(36, 2)
[36]
(180, 1)
[180]
(180,)
36
(1,)


In [12]:
# set up state 

hidden_layer_features = {'edge': [16, 8], 
                        'node': [32, 2], 'global': None}
model = MLPBlock(edge_features=hidden_layer_features['edge'],
                node_features=hidden_layer_features['node'],
                global_features=hidden_layer_features['global'])

# set up params
# init_graphs = test_input_graph
rng = jax.random.key(0)
rng, init_rng = jax.random.split(rng)
params = jax.jit(model.init)(init_rng, sample_input_window)

# set up optimizer (needed for the state even if we aren't training)
learning_rate = 0.001  # default learning rate for adam in keras
tx = optax.adam(learning_rate=learning_rate)

# set up state object, which helps us keep track of the model, params, and optimizer
state = train_state.TrainState.create(apply_fn=model.apply,
                                        params=params,
                                        tx=tx)

In [15]:
# test single rollout 
avg_loss, pred_nodes = rollout_loss(
    state=state, 
    n_steps=len(sample_target_window),
    input_window_graphs=sample_input_window,
    target_window_graphs=sample_target_window,
    rngs=None,
    )

In [16]:
print(avg_loss)
print(type(pred_nodes))
print(len(pred_nodes))

print(type(pred_nodes[0]))
print(pred_nodes[0].shape)
print(pred_nodes[0])
print(pred_nodes[1])

0.0023148963
<class 'list'>
2
<class 'jaxlib.xla_extension.ArrayImpl'>
(36,)
[8.      8.      8.      8.      8.      8.      8.      8.      8.
 8.      8.      8.      8.      8.      8.      8.      8.00025 8.
 8.      8.      8.      8.      8.      8.      8.      8.      8.
 8.      8.      8.      8.      8.      8.      8.      8.      8.     ]
[7.9902935 7.9929175 7.990142  7.9900727 7.9901466 7.990148  7.9901466
 7.990146  7.9901466 7.9901466 7.9901466 7.9901466 7.9901466 7.9901466
 7.990147  7.990166  7.990394  7.9901447 7.9901266 7.9901466 7.990147
 7.9901466 7.990146  7.9901466 7.9901466 7.9901466 7.9901466 7.9901466
 7.9901466 7.9901466 7.9901466 7.9901466 7.9901447 7.9900723 7.9873753
 7.9207993]


### test rollout loss batched

ok, we're getting an issue that i don't know how to immediately fix and batching isn't our top priority right now so i'm going to leave this loose end hanging. TODO later 

the problem: when we treat a list of GraphsTuples as a jax pytree, for some reason, it treats each attribute of the named tuple as a leaf in the pytree?? so we have n_windows * n_elements in the graphtuple number of leaves. 

what we'd need to do to fix it is to treat each GraphsTuple as a unique leaf. not sure how to set this. 

In [None]:
type(sample_input_batch)
print_graph_fts(sample_input_batch[0])

Number of graphs: 3
Number of nodes: [36 36 36]
Number of edges: [180 180 180]
Node features (total) shape: (108, 2)
Edge features (total) shape: (540, 1)
Global features shape: (3, 1)


In [None]:
jax.tree_util.tree_leaves(sample_input_batch)
for leaf in jax.tree_util.tree_leaves(sample_input_batch):
    print(type(leaf))
    print(leaf.shape)

<class 'jaxlib.xla_extension.ArrayImpl'>
(108, 2)
<class 'jaxlib.xla_extension.ArrayImpl'>
(540, 1)
<class 'jaxlib.xla_extension.ArrayImpl'>
(540,)
<class 'jaxlib.xla_extension.ArrayImpl'>
(540,)
<class 'jaxlib.xla_extension.ArrayImpl'>
(3, 1)
<class 'jaxlib.xla_extension.ArrayImpl'>
(3,)
<class 'jaxlib.xla_extension.ArrayImpl'>
(3,)
<class 'jaxlib.xla_extension.ArrayImpl'>
(108, 2)
<class 'jaxlib.xla_extension.ArrayImpl'>
(540, 1)
<class 'jaxlib.xla_extension.ArrayImpl'>
(540,)
<class 'jaxlib.xla_extension.ArrayImpl'>
(540,)
<class 'jaxlib.xla_extension.ArrayImpl'>
(3, 1)
<class 'jaxlib.xla_extension.ArrayImpl'>
(3,)
<class 'jaxlib.xla_extension.ArrayImpl'>
(3,)


In [None]:
# batch_avg_loss, batch_pred_nodes = rollout_loss_batched(state, 
#                  sample_input_batch,
#                  sample_target_batch,
#                  None,
#                  )

### test train_step

In [17]:
# check number of params
print(type(params))
print(params.keys())
print(params['params'].keys())
print(params['params']['MLP_0'].keys())
print(params['params']['MLP_0']['Dense_0'].keys())
print(type(params['params']['MLP_0']['Dense_0']['bias']))
print(params['params']['MLP_0']['Dense_0']['bias'].shape)
print(params['params']['MLP_0']['Dense_0']['kernel'].shape)
print(params['params']['MLP_0']['Dense_1']['bias'].shape)
print(params['params']['MLP_0']['Dense_1']['kernel'].shape)
print(params['params']['MLP_1']['Dense_0']['bias'].shape)
print(params['params']['MLP_1']['Dense_0']['kernel'].shape)
print(params['params']['MLP_1']['Dense_1']['bias'].shape)
print(params['params']['MLP_1']['Dense_1']['kernel'].shape)

<class 'dict'>
dict_keys(['params'])
dict_keys(['MLP_0', 'MLP_1'])
dict_keys(['Dense_0', 'Dense_1'])
dict_keys(['bias', 'kernel'])
<class 'jaxlib.xla_extension.ArrayImpl'>
(16,)
(6, 16)
(8,)
(16, 8)
(32,)
(19, 32)
(2,)
(32, 2)


In [23]:
# run train step
new_state, metrics_update, pred_nodes = train_step_fn(
    state=state,
    n_steps=len(sample_target_window),
    input_window_graphs=sample_input_window,
    target_window_graphs=sample_target_window,
    rngs={'dropout': rng}
)

> [0;32m/Users/h.lu/Documents/_code/_research lorenz code/lorenzGNN/flax_gnn_example/train.py[0m(233)[0;36mtrain_step_fn[0;34m()[0m
[0;32m    231 [0;31m    [0;31m# print('grads', grads['params']['MLP_1']['Dense_1']['kernel'])[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    232 [0;31m    [0mpdb[0m[0;34m.[0m[0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m--> 233 [0;31m    [0mstate[0m [0;34m=[0m [0mstate[0m[0;34m.[0m[0mapply_gradients[0m[0;34m([0m[0mgrads[0m[0;34m=[0m[0mgrads[0m[0;34m)[0m [0;31m# update params in the state[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    234 [0;31m[0;34m[0m[0m
[0m[0;32m    235 [0;31m    [0mmetrics_update[0m [0;34m=[0m [0mTrainMetrics[0m[0;34m.[0m[0msingle_from_model_output[0m[0;34m([0m[0mloss[0m[0;34m=[0m[0mloss[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m
Array([[0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.]

In [None]:
import jraph
batch = jax.jit(jraph.batch)(sample_input_batch)
print_graph_fts(batch)

Number of graphs: 6
Number of nodes: [36 36 36 36 36 36]
Number of edges: [180 180 180 180 180 180]
Node features (total) shape: (216, 2)
Edge features (total) shape: (1080, 1)
Global features shape: (6, 1)


### try fixing jit issues

In [None]:
# from flax_gnn_example.train import unbatch_i
# first_graph = jax.jit(unbatch_i)(sample_input_window, 0)
# first_window = jraph.unbatch(sample_input_window)
def func_with_list(l):
   res = 0
   for i in l:
      res += i
   return res

jax.jit(func_with_list)([1,2,3])