### imports and setup

In [17]:
# ipython extension to autoreload imported modules so that any changes will be up to date before running code in this nb
%load_ext autoreload 
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [18]:
from utils.jraph_data import get_lorenz_graph_tuples, print_graph_fts
from utils.lorenz import load_lorenz96_2coupled

import numpy as np
import jax.numpy as jnp

### test the graphtuples

In [19]:
n_samples = 100
input_steps = 6
output_delay = 0
output_steps = 4
timestep_duration = 3
sample_buffer = 2  # buffer between consequetive samples
init_buffer_samples = 100  # buffer at the beginning of the dataset to allow for the system to settle
time_resolution = 100

# generate desired dataset with train/val split and subsampled windows
graph_tuple_dict = get_lorenz_graph_tuples(
    n_samples=n_samples,
    input_steps=input_steps,
    output_delay=output_delay,
    output_steps=output_steps,
    timestep_duration=timestep_duration,
    sample_buffer=sample_buffer,
    time_resolution=time_resolution,
    init_buffer_samples=init_buffer_samples,
    train_pct=0.7,
    val_pct=0.3,
    test_pct=0.0,
    K=36,
    F=8,
    c=10,
    b=10,
    h=1,
    seed=42)

In [20]:
print_graph_fts(graph_tuple_dict['train']['inputs'][0])

# check shapes of the data
print(type(graph_tuple_dict['train']['inputs']))
# list of sample windows
print(len(graph_tuple_dict['train']['inputs']))
print(len(graph_tuple_dict['train']['targets']))
print(len(graph_tuple_dict['val']['inputs']))

# single window containing list of data points 
print(type(graph_tuple_dict['train']['inputs'][0]))
print(len(graph_tuple_dict['train']['inputs'][0]))
print(len(graph_tuple_dict['train']['targets'][0]))

# single data point in a window
print(graph_tuple_dict['train']['inputs'][0].n_node.shape[0])
print(graph_tuple_dict['train']['inputs'][0].n_node)
print(type(graph_tuple_dict['train']['inputs'][0].n_node))

print(np.array([36] * 6))

print(np.array_equal(graph_tuple_dict['train']['inputs'][0].n_node, [36] * 6))

Number of graphs: 6
Number of nodes: [36 36 36 36 36 36]
Number of edges: [180 180 180 180 180 180]
Node features (total) shape: (216, 2)
Edge features (total) shape: (1080, 1)
Global features shape: (6, 1)
<class 'list'>
70
70
30
<class 'jraph._src.graph.GraphsTuple'>
7
7
6
[36 36 36 36 36 36]
<class 'jaxlib.xla_extension.ArrayImpl'>
[36 36 36 36 36 36]
True


In [21]:
graph_tuple_dict['train']['inputs'][0].nodes.shape

(216, 2)

In [22]:
_, raw_data = load_lorenz96_2coupled("/Users/h.lu/Documents/_code/_research lorenz code/lorenzGNN/data/test.npz")
raw_data[3600, :36]
raw_data[list(range(3600, 3615 + 1, 3)), :36]

array([[ 9.20309163,  0.10097562,  2.70563981,  2.2889173 ,  4.93039924,
        11.30840708, -1.38595368,  2.63419645,  7.95192663,  1.43438903,
        -1.42082268,  2.30244857,  3.42416195,  4.87855049,  4.51646427,
        -1.07710273,  3.00946656,  6.99104909,  6.12433051, -0.27011117,
         3.89495546,  5.66476719,  1.99112679, -3.06823388,  1.9413291 ,
         0.01839206,  4.35759736,  6.51211141, -0.78984769,  4.12645985,
         9.07101638, -2.57900585,  3.88142137,  1.34250433,  1.96287859,
         4.03334224],
       [ 8.90840078, -0.07021375,  2.84319918,  2.88652768,  5.6435903 ,
        10.52825977, -1.92097817,  2.92823958,  8.17242816,  0.63536023,
        -1.31975251,  2.35411458,  3.98855673,  5.17456602,  3.91327855,
        -1.05094979,  3.05779617,  7.63500519,  5.41440325, -0.60555415,
         4.00364149,  5.95317358,  0.98025528, -2.89512759,  2.21806075,
         0.71339013,  4.5026242 ,  6.41297943, -0.58423494,  4.17772833,
         8.82043128, -2.24875

In [23]:
print(jnp.reshape(graph_tuple_dict['train']['inputs'][0].nodes[:, 0], (6, 36)))

[[ 9.203092    0.10097562  2.7056398   2.2889173   4.9303994  11.308407
  -1.3859537   2.6341965   7.9519267   1.434389   -1.4208226   2.3024485
   3.424162    4.8785505   4.516464   -1.0771028   3.0094666   6.9910493
   6.1243305  -0.27011117  3.8949554   5.6647673   1.9911268  -3.068234
   1.9413291   0.01839206  4.3575974   6.512111   -0.7898477   4.12646
   9.071016   -2.579006    3.8814213   1.3425044   1.9628786   4.0333424 ]
 [ 8.908401   -0.07021375  2.8431993   2.8865278   5.6435905  10.52826
  -1.9209782   2.9282396   8.172428    0.63536024 -1.3197525   2.3541145
   3.9885566   5.174566    3.9132786  -1.0509498   3.0577962   7.635005
   5.4144034  -0.60555416  4.0036416   5.9531736   0.9802553  -2.8951275
   2.2180607   0.7133901   4.502624    6.4129796  -0.58423495  4.177728
   8.820432   -2.2487593   4.514402    2.0992928   2.147029    4.4703975 ]
 [ 8.54527    -0.28519565  3.008693    3.5687807   6.387768    9.465805
  -2.5100567   3.182503    8.379879   -0.19135566 -1.076

In [26]:
# look at normalized data 
normed_graph_tuple_dict = get_lorenz_graph_tuples(
    n_samples=n_samples,
    input_steps=input_steps,
    output_delay=output_delay,
    output_steps=output_steps,
    timestep_duration=timestep_duration,
    sample_buffer=sample_buffer,
    time_resolution=time_resolution,
    init_buffer_samples=init_buffer_samples,
    train_pct=0.7,
    val_pct=0.3,
    test_pct=0.0,
    K=36,
    F=8,
    c=10,
    b=10,
    h=1,
    seed=42,
    normalize=True)

In [27]:
sample_graph = normed_graph_tuple_dict['train']['inputs'][0]
sample_graph.nodes.shape
print(jnp.reshape(sample_graph.nodes[:, 0], (6, 36)))


[[ 1.94605434e+00 -6.42786026e-01  9.80373099e-02 -2.04876680e-02
   7.30807483e-01  2.54485178e+00 -1.06570113e+00  7.77172744e-02
   1.59019578e+00 -2.63534158e-01 -1.07561862e+00 -1.66390985e-02
   3.02400678e-01  7.16060519e-01  6.13075256e-01 -9.77857172e-01
   1.84452280e-01  1.31690121e+00  1.07038748e+00 -7.48331189e-01
   4.36304599e-01  9.39677715e-01 -1.05185792e-01 -1.54417837e+00
  -1.19349331e-01 -6.66274607e-01  5.67890048e-01  1.18068087e+00
  -8.96155596e-01  5.02149582e-01  1.90848923e+00 -1.40503120e+00
   4.32455212e-01 -2.89668143e-01 -1.13220192e-01  4.75664884e-01]
 [ 1.86223781e+00 -6.91476047e-01  1.37162209e-01  1.49485782e-01
   9.33654547e-01  2.32296133e+00 -1.21787381e+00  1.61349535e-01
   1.65291119e+00 -4.90795374e-01 -1.04687214e+00 -1.94415415e-03
   4.62926835e-01  8.00253749e-01  4.41516131e-01 -9.70418751e-01
   1.98198274e-01  1.50005627e+00  8.68468761e-01 -8.43738496e-01
   4.67217326e-01  1.02170682e+00 -3.92699689e-01 -1.49494314e+00
  -4.0640

### set up graph tuple batches