### imports and setup

In [17]:
# ipython extension to autoreload imported modules so that any changes will be up to date before running code in this nb
%load_ext autoreload 
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [18]:
from utils.jraph_data import get_lorenz_graph_tuples, print_graph_fts
from utils.jraph_training import rollout_loss, train_step, train_step_fn, evaluate_model, train_and_evaluate, create_dataset #, rollout_loss_batched, 
from utils.jraph_models import MLPBlock
import optax
from flax.training import train_state
from flax_gnn_example.configs import mlpblock_test

import numpy as np
import jax.numpy as jnp
import jax
from datetime import datetime

In [19]:
# set up logging
import logging
logger = logging.getLogger()
logger.setLevel(logging.INFO)

### test single rollout

In [20]:
def get_sample_data(seed=42):
    sample_dataset = get_lorenz_graph_tuples(
        n_samples=10,
        input_steps=3,
        output_delay=0,
        output_steps=2,
        timestep_duration=1,
        sample_buffer=1,
        time_resolution=100,
        init_buffer_samples=0,
        train_pct=.2,
        val_pct=0.4,
        test_pct=0.4,
        K=36,
        F=8,
        c=10,
        b=10,
        h=1,
        seed=seed,
        normalize=False)
    # input_window = sample_dataset['train']['input'][0]
    # target_window = sample_dataset['train']['targets'][0]
    return sample_dataset 

sample_dataset = get_sample_data()

In [21]:
config=mlpblock_test.get_config()
sample_dataset = create_dataset(config=config)

In [22]:
sample_input_window = sample_dataset['train']['inputs'][0]
sample_target_window = sample_dataset['train']['targets'][0]

sample_input_graph = sample_input_window[0]
sample_target_graph = sample_target_window[0]

sample_input_batch = sample_dataset['train']['inputs']
sample_target_batch = sample_dataset['train']['targets']

print_graph_fts(sample_input_graph)
print_graph_fts(sample_target_graph)

Number of nodes: 36
Number of edges: 180
Node features shape: (36, 2)
Edge features shape: (1296, 1)
Global features shape: (1, 1)
Number of nodes: 36
Number of edges: 180
Node features shape: (36, 2)
Edge features shape: (1296, 1)
Global features shape: (1, 1)


In [23]:
print(sample_input_graph.nodes.shape)
print(sample_input_graph.n_node)
print(sample_input_graph.edges.shape)
print(sample_input_graph.n_edge)
print(sample_input_graph.receivers.shape)
print(sample_input_graph.n_node[1])
print(sample_input_graph.n_node.shape)

(36, 2)
[36]
(1296, 1)
[180]
(1296,)
36
(1,)


In [24]:
print(sample_input_window[0].nodes.shape)
print(sample_input_window[0].edges.shape)
print(sample_input_window[0].globals.shape)
print(sample_input_window[0].receivers.shape)
sample_input_window[0].receivers

(36, 2)
(1296, 1)
(1, 1)
(1296,)


Array([ 0,  1,  2, ..., 33, 34, 35], dtype=int32)

In [25]:
# set up state 

hidden_layer_features = {'edge': [16, 8], 
                        'node': [32, 2], 'global': None}
model = MLPBlock(edge_features=hidden_layer_features['edge'],
                node_features=hidden_layer_features['node'],
                global_features=hidden_layer_features['global'])

# set up params
# init_graphs = test_input_graph
rng = jax.random.key(0)
rng, init_rng = jax.random.split(rng)
params = jax.jit(model.init)(init_rng, sample_input_window)

# set up optimizer (needed for the state even if we aren't training)
learning_rate = 0.001  # default learning rate for adam in keras
tx = optax.adam(learning_rate=learning_rate)

# set up state object, which helps us keep track of the model, params, and optimizer
state = train_state.TrainState.create(apply_fn=model.apply,
                                        params=params,
                                        tx=tx)

In [26]:
# test single rollout 
avg_loss, pred_nodes = rollout_loss(
    state=state, 
    n_rollout_steps=len(sample_target_window),
    input_window_graphs=sample_input_window,
    target_window_graphs=sample_target_window,
    rngs=None,
    )

In [27]:
print(avg_loss)
print(type(pred_nodes))
print(len(pred_nodes))

print(type(pred_nodes[0]))
print(pred_nodes[0].shape)
print(pred_nodes[0])
print(pred_nodes[1])

1758.22
<class 'list'>
2
<class 'jaxlib.xla_extension.ArrayImpl'>
(36, 2)
[[ 0.6568187  58.017986  ]
 [ 0.6568187  58.017982  ]
 [ 0.6568187  58.017986  ]
 [ 0.6568187  58.017742  ]
 [ 0.6568187  58.017822  ]
 [ 0.6568187  58.01782   ]
 [ 0.6568187  58.017822  ]
 [ 0.6568187  58.017826  ]
 [ 0.6568187  58.017826  ]
 [ 0.6568187  58.017822  ]
 [ 0.6568187  58.017822  ]
 [ 0.6568187  58.01777   ]
 [ 0.6568187  58.017757  ]
 [ 0.6568187  58.018368  ]
 [ 0.6568187  58.01841   ]
 [ 0.6568187  58.018414  ]
 [ 0.66054416 58.000698  ]
 [ 0.6568187  58.01778   ]
 [ 0.6568187  58.017986  ]
 [ 0.6568187  58.017982  ]
 [ 0.6568187  58.017986  ]
 [ 0.6568187  58.017742  ]
 [ 0.6568187  58.017822  ]
 [ 0.6568187  58.01782   ]
 [ 0.6568187  58.017822  ]
 [ 0.6568187  58.017826  ]
 [ 0.6568187  58.017826  ]
 [ 0.6568187  58.017822  ]
 [ 0.6568187  58.017822  ]
 [ 0.6568187  58.01777   ]
 [ 0.6568187  58.017757  ]
 [ 0.6568187  58.018368  ]
 [ 0.6568187  58.01841   ]
 [ 0.6568187  58.018414  ]
 [ 0.656

### test rollout loss batched

ok, we're getting an issue that i don't know how to immediately fix and batching isn't our top priority right now so i'm going to leave this loose end hanging. TODO later 

the problem: when we treat a list of GraphsTuples as a jax pytree, for some reason, it treats each attribute of the named tuple as a leaf in the pytree?? so we have n_windows * n_elements in the graphtuple number of leaves. 

what we'd need to do to fix it is to treat each GraphsTuple as a unique leaf. not sure how to set this. 

In [28]:
# type(sample_input_batch)
# print_graph_fts(sample_input_batch[0])

In [29]:
jax.tree_util.tree_leaves(sample_input_batch)
for leaf in jax.tree_util.tree_leaves(sample_input_batch):
    print(type(leaf))
    print(leaf.shape)

<class 'numpy.ndarray'>
(36, 2)
<class 'jaxlib.xla_extension.ArrayImpl'>
(1296, 1)
<class 'jaxlib.xla_extension.ArrayImpl'>
(1296,)
<class 'jaxlib.xla_extension.ArrayImpl'>
(1296,)
<class 'jaxlib.xla_extension.ArrayImpl'>
(1, 1)
<class 'jaxlib.xla_extension.ArrayImpl'>
(1,)
<class 'jaxlib.xla_extension.ArrayImpl'>
(1,)
<class 'numpy.ndarray'>
(36, 2)
<class 'jaxlib.xla_extension.ArrayImpl'>
(1296, 1)
<class 'jaxlib.xla_extension.ArrayImpl'>
(1296,)
<class 'jaxlib.xla_extension.ArrayImpl'>
(1296,)
<class 'jaxlib.xla_extension.ArrayImpl'>
(1, 1)
<class 'jaxlib.xla_extension.ArrayImpl'>
(1,)
<class 'jaxlib.xla_extension.ArrayImpl'>
(1,)
<class 'numpy.ndarray'>
(36, 2)
<class 'jaxlib.xla_extension.ArrayImpl'>
(1296, 1)
<class 'jaxlib.xla_extension.ArrayImpl'>
(1296,)
<class 'jaxlib.xla_extension.ArrayImpl'>
(1296,)
<class 'jaxlib.xla_extension.ArrayImpl'>
(1, 1)
<class 'jaxlib.xla_extension.ArrayImpl'>
(1,)
<class 'jaxlib.xla_extension.ArrayImpl'>
(1,)
<class 'numpy.ndarray'>
(36, 2)
<clas

In [30]:
# batch_avg_loss, batch_pred_nodes = rollout_loss_batched(state, 
#                  sample_input_batch,
#                  sample_target_batch,
#                  None,
#                  )

### test train_step

In [31]:
# check number of params
print(type(params))
print(params.keys())
print(params['params'].keys())
print(params['params']['MLP_0'].keys())
print(params['params']['MLP_0']['Dense_0'].keys())
print(type(params['params']['MLP_0']['Dense_0']['bias']))
print(params['params']['MLP_0']['Dense_0']['bias'].shape)
print(params['params']['MLP_0']['Dense_0']['kernel'].shape)
print(params['params']['MLP_0']['Dense_1']['bias'].shape)
print(params['params']['MLP_0']['Dense_1']['kernel'].shape)
print(params['params']['MLP_1']['Dense_0']['bias'].shape)
print(params['params']['MLP_1']['Dense_0']['kernel'].shape)
print(params['params']['MLP_1']['Dense_1']['bias'].shape)
print(params['params']['MLP_1']['Dense_1']['kernel'].shape)

<class 'dict'>
dict_keys(['params'])
dict_keys(['MLP_0', 'MLP_1'])
dict_keys(['Dense_0', 'Dense_1'])
dict_keys(['bias', 'kernel'])
<class 'jaxlib.xla_extension.ArrayImpl'>
(16,)
(6, 16)
(8,)
(16, 8)
(32,)
(19, 32)
(2,)
(32, 2)


In [32]:
# run train step
new_state, metrics_update, pred_nodes = train_step_fn(
    state=state,
    n_rollout_steps=len(sample_target_window),
    input_window_graphs=sample_input_window,
    target_window_graphs=sample_target_window,
    rngs={'dropout': rng}
)

### try fixing jit batching issues

In [33]:
# import jraph
# batch = jax.jit(jraph.batch)(sample_input_batch)
# print_graph_fts(batch)

In [34]:
# from flax_gnn_example.train import unbatch_i
# first_graph = jax.jit(unbatch_i)(sample_input_window, 0)
# first_window = jraph.unbatch(sample_input_window)
def func_with_list(l):
   res = 0
   for i in l:
      res += i
   return res

jax.jit(func_with_list)([1,2,3])

Array(6, dtype=int32, weak_type=True)

give up

### test evaluate_model 

In [35]:
eval_metrics = evaluate_model(
    state=state,
    n_rollout_steps=len(sample_target_window),
    datasets=sample_dataset,
    # first key = train/test/val, second key = input/target 
    splits=['val', 'test']
)

In [36]:
print(eval_metrics)
print(eval_metrics['val'].loss)
print(eval_metrics['val'].loss.total)
print(eval_metrics['val'].loss.count)

{'val': EvalMetrics(_reduction_counter=_ReductionCounter(value=Array(4, dtype=int32, weak_type=True)), loss=Metric.from_output.<locals>.FromOutput(total=Array(23624.05, dtype=float32), count=Array(4., dtype=float32))), 'test': EvalMetrics(_reduction_counter=_ReductionCounter(value=Array(4, dtype=int32, weak_type=True)), loss=Metric.from_output.<locals>.FromOutput(total=Array(41627.9, dtype=float32), count=Array(4., dtype=float32)))}
Metric.from_output.<locals>.FromOutput(total=Array(23624.05, dtype=float32), count=Array(4., dtype=float32))
23624.05
4.0


In [37]:
type(state.step)
state.step

0

### test full training pipeline

TODO PICK UP HERE AND DEBUG

In [42]:
mlp_config = mlpblock_test.get_config()
workdir=f"tests/outputs/train_testing_dir_{datetime.now()}"

trained_state, train_metrics, eval_metrics_dict = train_and_evaluate(config=mlp_config, workdir=workdir)

INFO:absl:Obtaining datasets.
INFO:absl:Hyperparameters: {'F': 8, 'K': 36, 'add_self_loops': True, 'add_undirected_edges': True, 'add_virtual_node': True, 'b': 10, 'batch_size': 3, 'c': 10, 'checkpoint_every_epochs': 1, 'dropout_rate': 0.1, 'edge_features': (4, 8), 'epochs': 4, 'eval_every_epochs': 1, 'global_features': None, 'h': 1, 'init_buffer_samples': 0, 'input_steps': 3, 'layer_norm': False, 'learning_rate': 0.001, 'log_every_epochs': 1, 'model': 'MLPBlock', 'n_samples': 10, 'node_features': (32, 2), 'normalize': True, 'optimizer': 'adam', 'output_delay': 0, 'output_steps': 2, 'sample_buffer': 1, 'seed': 42, 'skip_connections': False, 'test_pct': 0.4, 'time_resolution': 100, 'timestep_duration': 1, 'train_pct': 0.2, 'val_pct': 0.4}
INFO:absl:Initializing network.
INFO:absl:
+-----------------------------+----------+------+---------+-------+
| Name                        | Shape    | Size | Mean    | Std   |
+-----------------------------+----------+------+---------+-------+
| par

init_epoch 0


INFO:absl:Finished training step 0.
INFO:absl:Finished training step 1.
INFO:absl:[2] train_loss=13.500608444213867
INFO:absl:Checkpoint.save() ...
INFO:absl:[2] val_loss=935.670166015625
INFO:absl:[2] test_loss=1591.06005859375
INFO:absl:Checkpoint.save() finished after 0.04s.
INFO:absl:[4] train_loss=17.09326934814453
INFO:absl:[4] val_loss=872.2733764648438
INFO:absl:Checkpoint.save() ...
INFO:absl:[4] test_loss=1466.304931640625
INFO:absl:Checkpoint.save() finished after 0.03s.
INFO:absl:[6] train_loss=13.328892707824707
INFO:absl:[6] val_loss=811.263916015625
INFO:absl:Checkpoint.save() ...
INFO:absl:[6] test_loss=1344.5447998046875
INFO:absl:Checkpoint.save() finished after 0.02s.
INFO:absl:[8] train_loss=12.038654327392578
INFO:absl:Checkpoint.save() ...
INFO:absl:[8] val_loss=754.405029296875
INFO:absl:[8] test_loss=1231.849609375
INFO:absl:Checkpoint.save() finished after 0.03s.


In [43]:
print(type(trained_state))
print(trained_state.step)
print(trained_state.params['params']['MLP_0']['Dense_0']['kernel'].shape)

<class 'flax.training.train_state.TrainState'>
8
(6, 4)


In [46]:
eval_metrics_dict['val'].loss.total

Array(3017.62, dtype=float32)