### imports and setup

In [1]:
# ipython extension to autoreload imported modules so that any changes will be up to date before running code in this nb
%load_ext autoreload 
%autoreload 2

In [58]:
from utils.jraph_data import get_lorenz_graph_tuples, print_graph_fts
from utils.jraph_training import rollout_loss, train_step, train_step_fn, evaluate_model, train_and_evaluate, create_dataset #, rollout_loss_batched, 
from utils.jraph_models import MLPBlock
import optax
from flax.training import train_state
from flax_gnn_example.configs import mlpblock_test

import numpy as np
import jax.numpy as jnp
import jax
from datetime import datetime

In [3]:
# set up logging
import logging
logger = logging.getLogger()
logger.setLevel(logging.INFO)

### test single rollout

In [4]:
def get_sample_data(seed=42):
    sample_dataset = get_lorenz_graph_tuples(
        n_samples=10,
        input_steps=3,
        output_delay=0,
        output_steps=2,
        timestep_duration=1,
        sample_buffer=1,
        time_resolution=100,
        init_buffer_samples=0,
        train_pct=.2,
        val_pct=0.4,
        test_pct=0.4,
        K=36,
        F=8,
        c=10,
        b=10,
        h=1,
        seed=seed,
        normalize=False)
    # input_window = sample_dataset['train']['input'][0]
    # target_window = sample_dataset['train']['targets'][0]
    return sample_dataset 

sample_dataset = get_sample_data()

INFO:jax._src.xla_bridge:Unable to initialize backend 'cuda': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
INFO:jax._src.xla_bridge:Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
INFO:jax._src.xla_bridge:Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: dlopen(libtpu.so, 0x0001): tried: 'libtpu.so' (no such file), '/System/Volumes/Preboot/Cryptexes/OSlibtpu.so' (no such file), '/usr/local/lib/libtpu.so' (no such file), '/System/Volumes/Preboot/Cryptexes/OS/usr/local/lib/libtpu.so' (no such file), '/usr/lib/libtpu.so' (no such file, not in dyld cache), 'libtpu.so' (no such file), '/usr/local/lib/libtpu.so' (no such file), '/usr/lib/libtpu.so' (no such file, not in dyld cache)


In [5]:
config=mlpblock_test.get_config()
sample_dataset = create_dataset(config=config)

In [6]:
sample_input_window = sample_dataset['train']['inputs'][0]
sample_target_window = sample_dataset['train']['targets'][0]

sample_input_graph = sample_input_window[0]
sample_target_graph = sample_target_window[0]

sample_input_batch = sample_dataset['train']['inputs']
sample_target_batch = sample_dataset['train']['targets']

print_graph_fts(sample_input_graph)
print_graph_fts(sample_target_graph)

Number of nodes: 6
Number of edges: 30
Node features shape: (6, 2)
Edge features shape: (30, 1)
Global features shape: (1, 1)
Number of nodes: 6
Number of edges: 30
Node features shape: (6, 2)
Edge features shape: (30, 1)
Global features shape: (1, 1)


In [7]:
print(sample_input_graph.nodes.shape)
print(sample_input_graph.n_node)
print(sample_input_graph.edges.shape)
print(sample_input_graph.n_edge)
print(sample_input_graph.receivers.shape)
print(sample_input_graph.n_node[1])
print(sample_input_graph.n_node.shape)

(6, 2)
[6]
(30, 1)
[30]
(30,)
6
(1,)


In [8]:
# set up state 

hidden_layer_features = {'edge': [16, 8], 
                        'node': [32, 2], 'global': None}
model = MLPBlock(edge_features=hidden_layer_features['edge'],
                node_features=hidden_layer_features['node'],
                global_features=hidden_layer_features['global'])

# set up params
# init_graphs = test_input_graph
rng = jax.random.key(0)
rng, init_rng = jax.random.split(rng)
params = jax.jit(model.init)(init_rng, sample_input_window)

# set up optimizer (needed for the state even if we aren't training)
learning_rate = 0.001  # default learning rate for adam in keras
tx = optax.adam(learning_rate=learning_rate)

# set up state object, which helps us keep track of the model, params, and optimizer
state = train_state.TrainState.create(apply_fn=model.apply,
                                        params=params,
                                        tx=tx)

input_graph
Number of nodes: Traced<ShapedArray(int32[])>with<DynamicJaxprTrace(level=1/0)>
Number of edges: Traced<ShapedArray(int32[])>with<DynamicJaxprTrace(level=1/0)>
Node features shape: (6, 2)
Edge features shape: (30, 1)
Global features shape: (1, 1)
> [0;32m/Users/h.lu/Documents/_code/_research lorenz code/lorenzGNN/utils/jraph_models.py[0m(112)[0;36m__call__[0;34m()[0m
[0;32m    110 [0;31m        [0mprint_graph_fts[0m[0;34m([0m[0minput_graph[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    111 [0;31m        [0mpdb[0m[0;34m.[0m[0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m--> 112 [0;31m        [0mprocessed_graphs[0m [0;34m=[0m [0mgraph_net[0m[0;34m([0m[0minput_graph[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    113 [0;31m        [0mprint[0m[0;34m([0m[0;34m'processed_graphs'[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    114 [0;31m        [0mprint_graph_fts[0m[0;34m([0m[0mproces

In [9]:
# test single rollout 
avg_loss, pred_nodes = rollout_loss(
    state=state, 
    n_rollout_steps=len(sample_target_window),
    input_window_graphs=sample_input_window,
    target_window_graphs=sample_target_window,
    rngs=None,
    )

input_graph
Number of nodes: 6
Number of edges: 30
Node features shape: (6, 2)
Edge features shape: (30, 1)
Global features shape: (1, 1)
> [0;32m/Users/h.lu/Documents/_code/_research lorenz code/lorenzGNN/utils/jraph_models.py[0m(112)[0;36m__call__[0;34m()[0m
[0;32m    110 [0;31m        [0mprint_graph_fts[0m[0;34m([0m[0minput_graph[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    111 [0;31m        [0mpdb[0m[0;34m.[0m[0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m--> 112 [0;31m        [0mprocessed_graphs[0m [0;34m=[0m [0mgraph_net[0m[0;34m([0m[0minput_graph[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    113 [0;31m        [0mprint[0m[0;34m([0m[0;34m'processed_graphs'[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    114 [0;31m        [0mprint_graph_fts[0m[0;34m([0m[0mprocessed_graphs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m
processed_graphs
Number of nodes: 6
Number of edges: 30
Node fe

In [10]:
print(avg_loss)
print(type(pred_nodes))
print(len(pred_nodes))

print(type(pred_nodes[0]))
print(pred_nodes[0].shape)
print(pred_nodes[0])
print(pred_nodes[1])

1.0748875
<class 'list'>
2
<class 'jaxlib.xla_extension.ArrayImpl'>
(6, 2)
[[0.55302817 1.4878951 ]
 [0.55302817 1.4880304 ]
 [0.55302817 1.4881315 ]
 [0.55302817 1.4880161 ]
 [0.56065965 1.4864204 ]
 [0.55302817 1.4879614 ]]
[[0.4858688  1.202163  ]
 [0.5044668  1.2232045 ]
 [0.48546994 1.202471  ]
 [0.4850514  1.2159793 ]
 [0.47401622 1.2174321 ]
 [0.00990715 1.2830077 ]]


### test rollout loss batched

ok, we're getting an issue that i don't know how to immediately fix and batching isn't our top priority right now so i'm going to leave this loose end hanging. TODO later 

the problem: when we treat a list of GraphsTuples as a jax pytree, for some reason, it treats each attribute of the named tuple as a leaf in the pytree?? so we have n_windows * n_elements in the graphtuple number of leaves. 

what we'd need to do to fix it is to treat each GraphsTuple as a unique leaf. not sure how to set this. 

In [11]:
# type(sample_input_batch)
# print_graph_fts(sample_input_batch[0])

In [12]:
jax.tree_util.tree_leaves(sample_input_batch)
for leaf in jax.tree_util.tree_leaves(sample_input_batch):
    print(type(leaf))
    print(leaf.shape)

<class 'numpy.ndarray'>
(6, 2)
<class 'jaxlib.xla_extension.ArrayImpl'>
(30, 1)
<class 'jaxlib.xla_extension.ArrayImpl'>
(30,)
<class 'jaxlib.xla_extension.ArrayImpl'>
(30,)
<class 'jaxlib.xla_extension.ArrayImpl'>
(1, 1)
<class 'jaxlib.xla_extension.ArrayImpl'>
(1,)
<class 'jaxlib.xla_extension.ArrayImpl'>
(1,)
<class 'numpy.ndarray'>
(6, 2)
<class 'jaxlib.xla_extension.ArrayImpl'>
(30, 1)
<class 'jaxlib.xla_extension.ArrayImpl'>
(30,)
<class 'jaxlib.xla_extension.ArrayImpl'>
(30,)
<class 'jaxlib.xla_extension.ArrayImpl'>
(1, 1)
<class 'jaxlib.xla_extension.ArrayImpl'>
(1,)
<class 'jaxlib.xla_extension.ArrayImpl'>
(1,)
<class 'numpy.ndarray'>
(6, 2)
<class 'jaxlib.xla_extension.ArrayImpl'>
(30, 1)
<class 'jaxlib.xla_extension.ArrayImpl'>
(30,)
<class 'jaxlib.xla_extension.ArrayImpl'>
(30,)
<class 'jaxlib.xla_extension.ArrayImpl'>
(1, 1)
<class 'jaxlib.xla_extension.ArrayImpl'>
(1,)
<class 'jaxlib.xla_extension.ArrayImpl'>
(1,)
<class 'numpy.ndarray'>
(6, 2)
<class 'jaxlib.xla_extensio

In [13]:
# batch_avg_loss, batch_pred_nodes = rollout_loss_batched(state, 
#                  sample_input_batch,
#                  sample_target_batch,
#                  None,
#                  )

### test train_step

In [14]:
# check number of params
print(type(params))
print(params.keys())
print(params['params'].keys())
print(params['params']['MLP_0'].keys())
print(params['params']['MLP_0']['Dense_0'].keys())
print(type(params['params']['MLP_0']['Dense_0']['bias']))
print(params['params']['MLP_0']['Dense_0']['bias'].shape)
print(params['params']['MLP_0']['Dense_0']['kernel'].shape)
print(params['params']['MLP_0']['Dense_1']['bias'].shape)
print(params['params']['MLP_0']['Dense_1']['kernel'].shape)
print(params['params']['MLP_1']['Dense_0']['bias'].shape)
print(params['params']['MLP_1']['Dense_0']['kernel'].shape)
print(params['params']['MLP_1']['Dense_1']['bias'].shape)
print(params['params']['MLP_1']['Dense_1']['kernel'].shape)

<class 'dict'>
dict_keys(['params'])
dict_keys(['MLP_0', 'MLP_1'])
dict_keys(['Dense_0', 'Dense_1'])
dict_keys(['bias', 'kernel'])
<class 'jaxlib.xla_extension.ArrayImpl'>
(16,)
(6, 16)
(8,)
(16, 8)
(32,)
(19, 32)
(2,)
(32, 2)


In [15]:
# run train step
new_state, metrics_update, pred_nodes = train_step_fn(
    state=state,
    n_rollout_steps=len(sample_target_window),
    input_window_graphs=sample_input_window,
    target_window_graphs=sample_target_window,
    rngs={'dropout': rng}
)

input_graph
Number of nodes: 6
Number of edges: 30
Node features shape: (6, 2)
Edge features shape: (30, 1)
Global features shape: (1, 1)
> [0;32m/Users/h.lu/Documents/_code/_research lorenz code/lorenzGNN/utils/jraph_models.py[0m(112)[0;36m__call__[0;34m()[0m
[0;32m    110 [0;31m        [0mprint_graph_fts[0m[0;34m([0m[0minput_graph[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    111 [0;31m        [0mpdb[0m[0;34m.[0m[0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m--> 112 [0;31m        [0mprocessed_graphs[0m [0;34m=[0m [0mgraph_net[0m[0;34m([0m[0minput_graph[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    113 [0;31m        [0mprint[0m[0;34m([0m[0;34m'processed_graphs'[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    114 [0;31m        [0mprint_graph_fts[0m[0;34m([0m[0mprocessed_graphs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m
processed_graphs
Number of nodes: 6
Number of edges: 30
Node fe

### try fixing jit batching issues

In [16]:
# import jraph
# batch = jax.jit(jraph.batch)(sample_input_batch)
# print_graph_fts(batch)

In [17]:
# from flax_gnn_example.train import unbatch_i
# first_graph = jax.jit(unbatch_i)(sample_input_window, 0)
# first_window = jraph.unbatch(sample_input_window)
def func_with_list(l):
   res = 0
   for i in l:
      res += i
   return res

jax.jit(func_with_list)([1,2,3])

Array(6, dtype=int32, weak_type=True)

give up

### test evaluate_model 

In [18]:
eval_metrics = evaluate_model(
    state=state,
    n_rollout_steps=len(sample_target_window),
    datasets=sample_dataset,
    # first key = train/test/val, second key = input/target 
    splits=['val', 'test']
)

input_graph
Number of nodes: Traced<ShapedArray(int32[])>with<DynamicJaxprTrace(level=1/0)>
Number of edges: Traced<ShapedArray(int32[])>with<DynamicJaxprTrace(level=1/0)>
Node features shape: (6, 2)
Edge features shape: (30, 1)
Global features shape: (1, 1)
> [0;32m/Users/h.lu/Documents/_code/_research lorenz code/lorenzGNN/utils/jraph_models.py[0m(112)[0;36m__call__[0;34m()[0m
[0;32m    110 [0;31m        [0mprint_graph_fts[0m[0;34m([0m[0minput_graph[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    111 [0;31m        [0mpdb[0m[0;34m.[0m[0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m--> 112 [0;31m        [0mprocessed_graphs[0m [0;34m=[0m [0mgraph_net[0m[0;34m([0m[0minput_graph[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    113 [0;31m        [0mprint[0m[0;34m([0m[0;34m'processed_graphs'[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    114 [0;31m        [0mprint_graph_fts[0m[0;34m([0m[0mproces

In [19]:
print(eval_metrics)
print(eval_metrics['val'].loss)
print(eval_metrics['val'].loss.total)
print(eval_metrics['val'].loss.count)

{'val': EvalMetrics(_reduction_counter=_ReductionCounter(value=Array(4, dtype=int32, weak_type=True)), loss=Metric.from_output.<locals>.FromOutput(total=Array(685.5943, dtype=float32), count=Array(4., dtype=float32))), 'test': EvalMetrics(_reduction_counter=_ReductionCounter(value=Array(4, dtype=int32, weak_type=True)), loss=Metric.from_output.<locals>.FromOutput(total=Array(597.77374, dtype=float32), count=Array(4., dtype=float32)))}
Metric.from_output.<locals>.FromOutput(total=Array(685.5943, dtype=float32), count=Array(4., dtype=float32))
685.5943
4.0


In [20]:
type(state.step)
state.step

0

### test full training pipeline

TODO PICK UP HERE AND DEBUG

In [62]:
mlp_config = mlpblock_test.get_config()
workdir=f"tests/outputs/train_testing_dir_{datetime.now()}"

trained_state = train_and_evaluate(config=mlp_config, workdir=workdir)

INFO:absl:Hyperparameters: {'F': 8, 'K': 6, 'add_self_loops': True, 'add_undirected_edges': True, 'add_virtual_node': True, 'b': 10, 'batch_size': 3, 'c': 10, 'checkpoint_every_epochs': 1, 'checkpoint_every_steps': 2, 'dropout_rate': 0.1, 'epochs': 4, 'eval_every_epochs': 1, 'eval_every_steps': 1, 'h': 1, 'init_buffer_samples': 0, 'input_steps': 3, 'layer_norm': False, 'learning_rate': 0.001, 'log_every_epochs': 1, 'log_every_steps': 2, 'model': 'MLPBlock', 'n_samples': 10, 'normalize': True, 'optimizer': 'adam', 'output_delay': 0, 'output_steps': 2, 'sample_buffer': 1, 'seed': 42, 'skip_connections': False, 'test_pct': 0.4, 'time_resolution': 100, 'timestep_duration': 1, 'train_pct': 0.2, 'val_pct': 0.4}
INFO:absl:Obtaining datasets.
INFO:absl:Initializing network.
INFO:absl:
+-----------------------------+----------+------+---------+-------+
| Name                        | Shape    | Size | Mean    | Std   |
+-----------------------------+----------+------+---------+-------+
| params

epoch 0
step 0


INFO:absl:[0] train_loss=0.26761993765830994


is_last_step False
num_train_steps 8
(epoch == config.epochs - 1) False



INFO:absl:[0] val_loss=394.91632080078125
INFO:absl:Checkpoint.save() ...
INFO:absl:[0] test_loss=499.5340576171875
INFO:absl:Checkpoint.save() finished after 0.02s.
INFO:absl:[1] val_loss=394.67730712890625
INFO:absl:[1] test_loss=496.6314392089844
INFO:absl:[2] train_loss=4.852296352386475
INFO:absl:Checkpoint.save() ...
INFO:absl:[2] val_loss=394.0526428222656
INFO:absl:[2] test_loss=493.9013671875
INFO:absl:Checkpoint.save() finished after 0.02s.
INFO:absl:[3] val_loss=393.9186096191406
INFO:absl:[3] test_loss=491.12677001953125
INFO:absl:[4] train_loss=4.80951452255249
INFO:absl:Checkpoint.save() ...
INFO:absl:[4] val_loss=393.5780334472656
INFO:absl:[4] test_loss=488.6341247558594
INFO:absl:Checkpoint.save() finished after 0.02s.
INFO:absl:[5] val_loss=393.5848083496094
INFO:absl:[5] test_loss=486.771728515625
INFO:absl:[6] train_loss=4.766618728637695
INFO:absl:Checkpoint.save() ...
INFO:absl:[6] val_loss=393.40545654296875
INFO:absl:[6] test_loss=485.63525390625
INFO:absl:Check

step 1
is_last_step False
num_train_steps 8
(epoch == config.epochs - 1) False

epoch 1
step 2
is_last_step False
num_train_steps 8
(epoch == config.epochs - 1) False

step 3
is_last_step False
num_train_steps 8
(epoch == config.epochs - 1) False

epoch 2
step 4
is_last_step False
num_train_steps 8
(epoch == config.epochs - 1) False

step 5
is_last_step False
num_train_steps 8
(epoch == config.epochs - 1) False

epoch 3
step 6
is_last_step False
num_train_steps 8
(epoch == config.epochs - 1) True

step 7
is_last_step True
num_train_steps 8
(epoch == config.epochs - 1) True



In [67]:
print(type(trained_state))
print(trained_state.step)
print(trained_state.params['params']['MLP_0']['Dense_0']['kernel'].shape)

<class 'flax.training.train_state.TrainState'>
8
(6, 4)
