In [1]:
%load_ext autoreload
%autoreload 2

In [5]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as pltcolors

import torch as t
import torch_geometric as ptgeo

from torch_geometric.utils import to_networkx, from_networkx
import networkx as nx

from gninvert.functions import run_GN, gn_time_series, make_color_scale,\
gdisplay, run_and_draw, generate_training_data,\
generate_graphs_from_connections, generate_grid_edge_index
from gninvert.gns import SingleDiffusionGN, MultiDiffusionGN, EquationGN, ActivatorInhibitorGN, FullActInhGN
from gninvert.gnns import LinearGNN, GNN_full
from gninvert.graph_compare import graph_compare, model_compare, model_steps_compare
from gninvert.data_generation import get_TrainingData
from gninvert.gnns import GNN_3Layer
from gninvert.training import fit
from gninvert.hyperparamsearch import hpsearch
import gninvert

import ipywidgets as widgets

plt.rcParams['figure.figsize'] = [10, 4] # set plot size below


In [7]:
gn_d1 = SingleDiffusionGN(diffusion_constant=0.1)
gn_d2 = MultiDiffusionGN(diffusion_constants=[0.1, 0.1])
gn_d3 = MultiDiffusionGN(diffusion_constants=[0.1, 0.1, 0.1])
gn_d3v = MultiDiffusionGN(diffusion_constants=[0.15, 0.1, 0.05]) 
gn_as = ActivatorInhibitorGN(act_diff_const=0.1,
                             inh_diff_const=0.05,
                             growth_const=0.05)
gn_af = FullActInhGN(
    spatial_const = 10,
    temporal_const = 0.01,
    growth_alpha = 10,
    growth_rho = 1,
    growth_scale = 0.05,
    reaction_const = 0.2,
    reference_conc = 2
)

# Select GN:

In [17]:
gn = gn_as

---
... and then run everything below:

In [22]:
gdata = generate_graphs_from_connections(
    generate_grid_edge_index(6),
    node_feature_num=gn.node_features,
    num=1)[0]

is_act_inh_gn = gn.node_features == 3 and not hasattr(gn, 'diffusion_constants')

if is_act_inh_gn and 1 == 2:
    # ^ then assume this is an activator-inhibitor model
    # make the cell sizes homogenous:
    gdata.x[:, 0] = 0.1
    # activator/inhibitor all start in one corner:
    gdata.x[:, 1] = 0
    gdata.x[:, 2] = 0
    gdata.x[0, 1] = 1
    gdata.x[1, 1] = 0.5
    gdata.x[1, 2] = 0.4
    gdata.x[2, 2] = 0.4
time_series = gn_time_series(gn, 50, gdata)

In [23]:
def f(x):
    gdisplay(
        time_series[x],
        color_scales = [
            make_color_scale(0, 3, plt.cm.plasma),
            make_color_scale(0, 1, plt.cm.BuGn),
            make_color_scale(0, 1, plt.cm.Reds)
        ] if is_act_inh_gn else [make_color_scale(0, 1, plt.cm.cividis) for _ in range(gn.node_features)]
    )
    print(time_series[x].x)

widgets.interact(
    f,
    x=widgets.IntSlider(min=0, max=len(time_series)-1, step=1, value=0)
)

interactive(children=(IntSlider(value=0, description='x', max=50), Output()), _dom_classes=('widget-interact',â€¦

<function __main__.f(x)>