In [71]:
%load_ext autoreload
%autoreload 2

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as pltcolors

import torch as t
import torch_geometric as ptgeo

from torch_geometric.utils import to_networkx, from_networkx
import networkx as nx

from gninvert.functions import run_GN, gn_time_series, make_color_scale,\
gdisplay, run_and_draw, generate_training_data,\
generate_graphs_from_connections, generate_grid_edge_index
from gninvert.gns import MultiDiffusionGN, EquationGN, FullActInhGN
from gninvert.gnns import LinearGNN
from gninvert.graph_compare import graph_compare, model_compare, model_steps_compare
from gninvert.data_generation import get_TrainingData
from gninvert.gnns import GNN_3Layer
from gninvert.training import fit
from gninvert.hyperparamsearch import hpsearch
import gninvert

import ipywidgets as widgets

plt.rcParams['figure.figsize'] = [10, 4] # set plot size below

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [114]:
gn = FullActInhGN(
    spatial_const = 10,
    temporal_const = 0.01,
    growth_alpha = 10,
    growth_rho = 1,
    growth_scale = 0.05,
    reaction_const = 0.2,
    reference_conc = 2
)
gdata = generate_graphs_from_connections(
    generate_grid_edge_index(6),
    node_feature_num=3,
    num=1)[0]
# make the cell sizes homogenous
gdata.x[:, 0] = 0.1
# activator/inhibitor all start in one corner
gdata.x[:, 1] = 0
gdata.x[:, 2] = 0
gdata.x[0, 1] = 1
gdata.x[1, 1] = 0.5
gdata.x[1, 2] = 0
gdata.x[2, 2] = 0
time_series = gn_time_series(gn, 300, gdata)

In [121]:
def f(x):
    gdisplay(
        time_series[x],
        color_scales = [
            make_color_scale(0, 3, plt.cm.plasma),
            make_color_scale(0, 1, plt.cm.BuGn),
            make_color_scale(0, 1, plt.cm.Reds)
        ]
    )
    print(x)

widgets.interact(
    f,
    x=widgets.IntSlider(min=0, max=len(time_series)-1, step=1, value=0)
)

interactive(children=(IntSlider(value=0, description='x', max=300), Output()), _dom_classes=('widget-interact'…

<function __main__.f(x)>

In [120]:
time_series[55].x

tensor([[1.3002e+00, 8.0193e+00, 3.3766e+00],
        [1.4115e+00, 9.3922e+00, 4.4328e+01],
        [1.4627e+00, 8.1790e+00, 9.8059e+01],
        [1.7303e+00, 3.9275e+00, 1.4093e+01],
        [1.4762e+00, 4.8246e+00, 3.0750e+01],
        [1.1934e+00, 9.3632e+00, 4.8286e+01],
        [1.3698e+00, 6.5719e+00, 3.2597e+01],
        [1.3996e+00, 9.6389e+00, 1.2086e+02],
        [1.2936e+00, 1.9511e+01, 1.2942e+03],
        [1.4484e+00, 8.3313e+00, 7.5066e+01],
        [1.2596e+00, 9.9029e+00, 1.0675e+02],
        [1.0565e+00, 1.8305e+01, 8.0378e+01],
        [1.4564e+00, 5.8130e+00, 1.1678e+02],
        [1.5647e+00, 5.7004e+00, 4.6897e+01],
        [1.3487e+00, 1.0736e+01, 1.4659e+02],
        [1.2174e+00, 1.1970e+01, 1.2751e+02],
        [1.0879e+00, 2.0438e+01, 1.4253e+02],
        [8.7526e-01, 1.1110e+01, 1.5366e+00],
        [1.4438e+00, 7.0980e+00, 1.0086e+02],
        [1.3864e+00, 7.5982e+00, 9.1879e+01],
        [1.1904e+00, 1.5023e+01, 2.2218e+02],
        [1.0461e+00, 2.4416e+01, 1