In this notebook, I will use torch_geometric to predict the developpement of a graph of positions through time

In [1]:
"""
I realized I am leaning towards this approach https://doi.org/10.1016/j.trc.2020.102635
"""

'\nI realized I am leaning towards this approach https://doi.org/10.1016/j.trc.2020.102635\n'

In [2]:
import torch
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts

import pickle

import sys
import os
from genericpath import exists

from cell_dataset import CellGraphDataset, load
from cell_model import GraphEvolution, GraphEvolutionDiscr
from cell_utils import GraphingLoss
from cell_training import train, test_single, test_recursive, compute_parameters

import os, psutil
process = psutil.Process(os.getpid())
print("Using : ", process.memory_info().rss // 1000000)  # in megabytes 
print("Available : ", process.memory_info().vms  // 1000000)  # in megabytes 

print(torch.cuda.is_available())

#https://github.com/clovaai/AdamP
from adamp import AdamP

sys.path.append('/home/nstillman/1_sbi_activematter/cpp_model')
try :
    import allium
except :
    print("Could not import allium")

  from .autonotebook import tqdm as notebook_tqdm


Using :  290
Available :  3271
True
Could not import allium


The data is a graph of cells having their own positions and velocity.

In the graph, we will first start by connecting all the edges, then maybe later make radius_graphs to reduce the cost of the pass through the model

In [3]:
load_all =  True #load directly from a pickle
pre_separated = False #if three subfolders already exist for train test and val

override = False #make this true to always use the same ones

extension = "_open_ht_lv"

model_path = "models/model" + extension + "_"
loss_path = "models/loss" + extension + "_"

data_train, data_test, data_val = load(load_all, extension, pre_separated, override)

Validation data not found


In [4]:
#INFO : if bg_load is True, this starts the loading, if skipped, bg_loading will take place as soon as a get is called
rval, edge_index, edge_attr, batch_edge, border, params = data_train.get(0)
rval, edge_index, edge_attr, batch_edge, border, params = data_test.get(0)

print("Is data wrapped ? ", data_train.wrap)

Is data wrapped ?  False


Next we need to define the model that will be used :
    > input 
        (1) Graph at a particular time t (nodes having x,y,dx,dy as attributes)
        (2) Graphs up to a particular time [t-a, t] (nodes having x,y as attributes)
    > output
        (a) Graph at the immediate next time step t+1
        (b) Graph [t, t+b]
        (c) Graph at t+b
    > graph size
        (x) Fixed graph size to the most nodes possible (or above)
        (y) Unbounded graph size
            >> idea : graph walks
            >> idea : sampler

The following model will do (1ax)

In [5]:
def start(model : GraphEvolution, optimizer : torch.optim.Optimizer, scheduler  : torch.optim.lr_scheduler._LRScheduler,\
          data_train : CellGraphDataset, data_test : CellGraphDataset, device : torch.device, epoch : int, offset : int, grapher : GraphingLoss, save=0, save_datasets=True):
    for e in range(offset, offset + epoch):
        
        recursive = e > 10

        model = train(model, optimizer, scheduler, data_train, device, e, process, max_epoch=offset+epoch, recursive=recursive)

        #model.show_gradients()
        
        if(e == 0 and save_datasets) :
            data_train.thread = None
            data_test.thread = None
            with open("data/training" + extension + ".pkl", 'wb') as f:
                pickle.dump(data_train, f)
            with open("data/testing " + extension + ".pkl", 'wb') as f:
                pickle.dump(data_test, f)
            print("Saved datasets")
        

        test_loss_s = test_single(model, data_test, device, duration=16)
        test_loss_r = test_recursive(model, data_test, device, duration=16)

        print("Epoch : ", e, "Test loss : ", test_loss_s, "Test loss recursive : ", test_loss_r)

        grapher.losses.append(test_loss_r)
        grapher.losses.append(test_loss_s)

        grapher.plot_losses()
        
        if (e%10 == 0) :      
            all_params_out, all_params_true = compute_parameters(model, data_test, device, duration=-1)
            grapher.plot_params(all_params_out, all_params_true, e, extension=extension)
        
        if (save and (e%save == 0 or e == epoch-1)) :
            torch.save(model.state_dict(), model_path + str(e) + ".pt")
            with open(loss_path + str(e) + ".pkl", 'wb') as f:
                pickle.dump(grapher.losses, f)

In [6]:
load = True

epoch_to_load = 0

model = GraphEvolution(in_channels=14, out_channels=4, hidden_channels=32, dropout=0.05, edge_dim=2, messages=10, wrap=data_train.wrap)
#model = GraphEvolution(in_channels=9, out_channels=4, hidden_channels=32, dropout=0.01, edge_dim=2, messages=5, wrap=True)
#model = GraphEvolutionDiscr(in_channels=9, out_channels=4, hidden_channels=16, dropout=0.01, edge_dim=2, messages=5, wrap=True)
losses = []

if exists(model_path + str(epoch_to_load) + ".pt") and load :
    with open(loss_path + str(epoch_to_load) + ".pkl", 'rb') as f:
        losses = pickle.load(f)
    model.load_state_dict(torch.load(model_path + str(epoch_to_load) + ".pt"))
    print("Loaded model")

Loaded model


In [7]:
print("Using : ", process.memory_info().rss // 1000000)  # in megabytes
print("Losses : ", len(losses) // 2)
print("Model : ", model)

Using :  651
Losses :  1
Model :  GraphEvolution(
  (encoder_resize): Linear(in_features=14, out_features=32, bias=True)
  (encoder_resize2): Linear(in_features=32, out_features=32, bias=True)
  (transformer_encoder): TransformerEncoder(
    (layers): ModuleList(
      (0): TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=32, out_features=32, bias=True)
        )
        (linear1): Linear(in_features=32, out_features=32, bias=True)
        (dropout): Dropout(p=0.05, inplace=False)
        (linear2): Linear(in_features=32, out_features=32, bias=True)
        (norm1): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
        (dropout1): Dropout(p=0.05, inplace=False)
        (dropout2): Dropout(p=0.05, inplace=False)
      )
      (1): TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamical

In [8]:
#might want to investigate AdamP 
optimizer = AdamP(model.parameters(), lr=5e-4, betas=(0.9, 0.999), eps=1e-8, weight_decay=5e-3, delta=0.1, wd_ratio=0.1, nesterov=True)
scheduler = CosineAnnealingWarmRestarts(optimizer=optimizer, T_0=10, T_mult=2, eta_min=1e-12)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [9]:
epochs = 40
grapher = GraphingLoss(losses)
scheduler.step(len(losses) // 2)

model = model.to(device)
#data_train.to(device)
#data_test.to(device)

#all_params_out, all_params_true = compute_parameters(model.to(device), data_test, device, duration=8)
#grapher.plot_params(all_params_out, all_params_true, epoch_to_load)

In [10]:
start(model, optimizer, scheduler, data_train, data_test, device, \
        epochs, len(losses) // 2, grapher=grapher, save=10, save_datasets=False)

Epoch :  1 Test loss :  -4.114302656650543 Test loss recursive :  0.0024378935096319764
Epoch :  2 Test loss :  -4.064601757526398 Test loss recursive :  0.0028750400798162445
Epoch :  3 Test loss :  -3.99153635263443 Test loss recursive :  0.0033766000042669475
Epoch :  4 Test loss :  -4.1348529410362245 Test loss recursive :  0.00150968996953452
Epoch :  5 Test loss :  -4.1233647918701175 Test loss recursive :  0.001079014661081601
Epoch :  6 Test loss :  -4.170844030380249 Test loss recursive :  0.00036177697693347
Epoch :  7 Test loss :  -4.150682511329651 Test loss recursive :  0.0012480500277888495
Epoch :  8 Test loss :  -4.12133127450943 Test loss recursive :  0.0017036578743136487
Epoch :  9 Test loss :  -4.059589741230011 Test loss recursive :  0.0022280350839719177
Current probability of recursive training :  0
Epoch :  10 Test loss :  -4.0845848965644835 Test loss recursive :  0.0022622124815825373
Epoch :  11 Test loss :  -4.139437003135681 Test loss recursive :  0.0010523

KeyboardInterrupt: 