This notebook contains some sample data to test out different metric implementations. 

### imports and setup

In [1]:
# ipython extension to autoreload imported modules so that any changes will be up to date before running code in this nb
%load_ext autoreload 
%autoreload 2

In [2]:
from utils.jraph_training import create_dataset, create_model, create_optimizer
from utils.jraph_data import print_graph_fts
from flax.training import train_state
import jax 
import jax.numpy as jnp
import jax.scipy as jscipy 
import ml_collections

In [3]:
# set up logging
import logging
logger = logging.getLogger()
logger.setLevel(logging.WARNING)

### get some sample data 

The below function returns a config for a small dataset, which we can then use to create the actual dataset. 

(Don't worry too much about what is happening here, I just pasted it here for transparency. The key points are that we will get a dataset with 100 total samples – 70 in train, 20 in val, and 10 in test. )

In [4]:
def get_data_config():
    config = ml_collections.ConfigDict()

    config.n_samples=100
    config.input_steps=1
    config.output_delay=8 # predict 24 hrs into the future 
    config.output_steps=4
    config.time_resolution=120 # the number of 
                # raw data points generated per time unit, equivalent to the 
                # number of data points generated per 5 days in the simulation
    config.timestep_duration=3 # equivalent to 3 hours
    # note a 3 hour timestep resolution would be 5*24/3=40
    # if the time_resolution is 120, then a sampling frequency of 3 would achieve a 3 hour timestep 
    config.sample_buffer = (
        -1 * (
            config.input_steps + 
            config.output_delay + 
            config.output_steps - 1)
        ) 
        # number of timesteps strictly between the end of one full sample and the start of the next sample
        # we want a negative buffer so that our sample input are continuous (i.e. the first sample would overlap a bit with consecutive samples) 
    config.init_buffer_samples=100
    config.train_pct=0.7
    config.val_pct=0.2
    config.test_pct=0.1
    config.K=36
    config.F=8
    config.c=10
    config.b=10
    config.h=1
    config.seed=42
    config.normalize=True
    config.fully_connected_edges=False

    return config

now we generate the dataset.

In [5]:
dataset_config = get_data_config()
datasets = create_dataset(dataset_config)

the dataset is a nested dictionary with the following structure: 
{
'train': {
    'inputs': list of windows of graphtuples
    'targets': list of windows of graphtuples},
'val': {
    'inputs': list of windows of graphtuples,
    'targets': list of windows of graphtuples},
'test': {
    'inputs': list of windows of graphtuples,
    'targets': list of windows of graphtuples},
}

A "window" is a time series of graphs, here representing either the input sequence or output sequence of data. This are represented as lists, so really each value in the nested dict is a list of list of GraphsTuple objects. 

Below we demonstrate how to navigate the datasets.

In [6]:
print(datasets.keys())
print(datasets['train'].keys())

input_train_data = datasets['train']['inputs']
first_input_window = input_train_data[0]
first_input_graph = first_input_window[0]
first_target_graph = datasets['train']['targets'][0][0]

print("type(input_train_data)", type(input_train_data))
print("type(first_input_window)", type(first_input_window))
print("type(first_input_graph)", type(first_input_graph))

dict_keys(['train', 'val', 'test'])
dict_keys(['inputs', 'targets'])
type(input_train_data) <class 'list'>
type(first_input_window) <class 'list'>
type(first_input_graph) <class 'jraph._src.graph.GraphsTuple'>


In [7]:
# look at basic properties of the graph
print_graph_fts(first_input_graph)

Number of nodes: 36
Number of edges: 180
Node features shape: (36, 2)
Edge features shape: (180, 1)
Global features shape: (1, 1)


each GraphsTuple contains a bunch of data, but what will likely be of most interest are the node features, which is what we're trying to predict. The nodes features have shape (36, 2), corresponding to the K=36 nodes which each have an X1 and X2 feature. 

In [8]:
first_input_graph.nodes

array([[-1.3649824 ,  0.81143504],
       [ 0.4459807 ,  0.5750941 ],
       [ 0.96717197, -1.0274767 ],
       [-1.8197978 , -1.5187225 ],
       [-0.33013806, -0.45958167],
       [-0.12215422, -0.46869078],
       [ 0.16023503,  0.18824354],
       [ 0.39943892,  1.0756559 ],
       [ 0.94749486, -1.2535653 ],
       [-0.4739116 , -0.6680645 ],
       [-2.2530704 , -2.022228  ],
       [ 0.33636376, -0.7878205 ],
       [ 0.1396998 ,  0.31165692],
       [-0.7822607 , -0.6472315 ],
       [ 0.03536604, -0.9326305 ],
       [ 0.95444596, -0.43482503],
       [ 0.582315  ,  1.6892598 ],
       [ 0.32959026,  0.12167044],
       [ 0.5314968 , -1.1229218 ],
       [ 0.30201438,  0.635556  ],
       [-0.69298774,  0.7974206 ],
       [-1.2602797 , -0.9461577 ],
       [-0.12644374,  0.8608444 ],
       [ 2.0222785 ,  1.5772247 ],
       [-0.06871947, -0.8737022 ],
       [-1.5709201 , -0.31845146],
       [-1.0807154 , -1.23525   ],
       [-1.2676263 , -0.05490148],
       [-0.59919816,

#### create a small model to get predictions

The details of this don't really matter for the purpose of this notebook, so I'll just import a basic GNBlock config and setup. 

Note that this model is NOT trained, so the predictions will be random – it'll be used just for testing metric functions. 

Don't worry about any other details here. 

In [9]:
from experiments.configs import GNBlock_baseline
model_config = GNBlock_baseline.get_config()

rng = jax.random.key(0)
rng, init_rng = jax.random.split(rng)
net = create_model(model_config, deterministic=True)
params = jax.jit(net.init)(init_rng, first_input_window)

# Create the optimizer.
tx = create_optimizer(model_config)

# Create the training state.
state = train_state.TrainState.create(
    apply_fn=net.apply, params=params, tx=tx
)

Now, we can get predictions using the following function call. 

pred_graphs_list is a list of predicted graphs (in this case it is only has single element)

In [10]:
pred_graphs_list = state.apply_fn(state.params, first_input_window) 
pred_graph = pred_graphs_list[0]
print(type(pred_graph))

<class 'jraph._src.graph.GraphsTuple'>


We can look at the nodes

In [11]:
pred_graph.nodes

Array([[0.        , 0.6415052 ],
       [0.        , 0.28648597],
       [0.        , 0.36244932],
       [0.10148586, 0.1256536 ],
       [0.        , 0.0408808 ],
       [0.        , 0.        ],
       [0.        , 0.        ],
       [0.        , 0.39947745],
       [0.        , 0.47231084],
       [0.        , 0.05931716],
       [0.41900653, 0.36059526],
       [0.        , 0.12634347],
       [0.        , 0.04503629],
       [0.        , 0.06245288],
       [0.        , 0.10422921],
       [0.        , 0.1124042 ],
       [0.        , 0.694765  ],
       [0.        , 0.04666549],
       [0.        , 0.4672794 ],
       [0.        , 0.17842755],
       [0.        , 0.46139765],
       [0.        , 0.        ],
       [0.        , 0.34355032],
       [0.        , 0.36689094],
       [0.        , 0.27120814],
       [0.        , 0.05685274],
       [0.        , 0.15074416],
       [0.        , 0.04457913],
       [0.        , 0.        ],
       [0.        , 0.3694534 ],
       [0.

### test metrics

This is a generalized implementation of MSE that uses jax numpy, which is pretty easy to substitute for regular numpy. The same should apply for jax scipy (imported above as jscipy). 

In this case, I treat the target and preds variables as arrays (the type doesn't matter that much, they can be jnp or np arrays, possibly even nested lists). So we can't pass a raw GraphsTuple into this, we have to specifically select the nodes attribute. 

In [12]:
def MSE(targets, preds):
    mse = jnp.mean(jnp.square(preds - targets))
    return mse 

We can test the metrics like so: 

In [13]:
MSE(first_target_graph.nodes, pred_graph.nodes)

Array(1.0594673, dtype=float32)