In this notebook, I will use torch_geometric to predict the developpement of a graph of positions through time

In [8]:
import torch
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
import torch.nn.functional as F

import numpy as np

import torch_geometric
from torch_geometric.utils import to_networkx, from_networkx
from torch_geometric.nn import radius_graph
from torch_geometric.data import Data

import networkx as nx

import pickle

import time

import sys
import os
from genericpath import exists

import random

model_path = "model.pkl"
loss_path = "loss.pkl"

from cell_dataset import CellGraphDataset
from cell_model import GraphEvolution
from cell_utils import GraphingLoss, make_animation
from cell_training import train, test, run_single, run_single_recursive

import threading
import matplotlib.pyplot as plt

import os, psutil
process = psutil.Process(os.getpid())
print("Using : ", process.memory_info().rss // 1000000)  # in megabytes 
print("Available : ", process.memory_info().vms  // 1000000)  # in megabytes 

print(torch.cuda.is_available())

#https://github.com/clovaai/AdamP
from adamp import AdamP

sys.path.append('/home/nstillman/1_sbi_activematter/cpp_model')
import allium

Using :  2048
Available :  14631
True


The data is a graph of cells having their own positions and velocity.

In the graph, we will first start by connecting all the edges, then maybe later make radius_graphs to reduce the cost of the pass through the model

In [9]:
#path = "data/" #local
path = "/scratch/users/nstillman/data-cpp/" #remote

data_train = CellGraphDataset(root=path + 'train', max_size=200, rdts=True, inmemory=True, bg_load=True)
print("Training data length : ", data_train.len())

data_test = CellGraphDataset(root=path + 'test', max_size=50, inmemory=True, bg_load=True)
print("Test data length : ", data_test.len())
 
data_val = CellGraphDataset(root=path + 'valid', max_size=50, inmemory=True, bg_load=True)
print("Validation data length : ", data_val.len())

Training data length :  200
Test data length :  50
Validation data length :  50


In [10]:
override = True #make this true to always use the same ones

if override :
    data_train.save_or_load_if_exists("train_paths.pkl")
    data_test.save_or_load_if_exists("test_paths.pkl")
    data_val.save_or_load_if_exists("val_paths.pkl")


In [11]:
#INFO : if bg_load is True, this starts the loading, if skipped, bg_loading will take place as soon as a get is called
rval, edge_index, edge_attr, batch_edge, norm_and_std = data_train.get(0)

In [12]:
print(rval.mean(dim=(0,1)))
print(rval.var(dim=(0,1)))

tensor([ 0.0531, -0.0232, -0.0011, -0.0007])
tensor([1.4129, 1.4271, 0.0219, 0.0194])


Next we need to define the model that will be used :
    > input 
        (1) Graph at a particular time t (nodes having x,y,dx,dy as attributes)
        (2) Graphs up to a particular time [t-a, t] (nodes having x,y as attributes)
    > output
        (a) Graph at the immediate next time step t+1
        (b) Graph [t, t+b]
        (c) Graph at t+b
    > graph size
        (x) Fixed graph size to the most nodes possible (or above)
        (y) Unbounded graph size
            >> idea : graph walks
            >> idea : sampler

The following model will do (1ax)

In [4]:
global losses
losses = []

In [5]:
def start(model, optimizer, scheduler, data_train, data_test, device, epoch, offset, save=0, early_stop=False):
    for e in range(offset, offset + epoch):
        model = train(model, optimizer, scheduler, data_train, device, e, process)
            
        test_loss = test(model, data_test, device)
        
        if (e%10 == 0) :
            print("Epoch : ", e, "Test loss : ", test_loss, "                                                         ")
        
        losses.append(test_loss)
        
        if early_stop and len(losses) > 30 :
            recent_losses = min(len(losses), 30)
            y = losses[-recent_losses:]
            
            axis = np.arange(recent_losses)
            A = np.vstack([axis, np.ones(len(axis))]).T
            
            a = np.linalg.lstsq(A, y, rcond=None)
            
            if a[0][0] >= -0.002 :
                print("Early stopping : recent slope at ", a[0][0])
                if (save) :
                    with open(model_path, 'wb') as f:
                        pickle.dump(model, f)
                return
            else : 
                print("Early stopping passed : current slope at ", a[0][0])
        
        if (save and (e%save == 0 or e == epoch-1)) :
            with open(model_path, 'wb') as f:
                pickle.dump(model, f)
            with open(loss_path, 'wb') as f:
                pickle.dump(losses, f)

In [6]:
load = True
if (load and exists(model_path)) :
    with open(model_path, 'rb') as f:
        model = pickle.load(f)
    with open(loss_path, 'rb') as f:
        losses = pickle.load(f)
else :
    model = GraphEvolution(in_channels=4, out_channels=4, hidden_channels=32, heads=7, dropout=0.01, edge_dim=len(data_train.attributes))
    losses = []
    
assert isinstance(model, GraphEvolution)

#might want to investigate AdamP 
optimizer = AdamP(model.parameters(), lr=2e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=5e-3, delta=0.1, wd_ratio=0.1, nesterov=True)
scheduler = CosineAnnealingWarmRestarts(optimizer=optimizer, T_0=10, T_mult=2, eta_min=1e-5)

In [None]:
epochs = 501
grapher = GraphingLoss(losses)
scheduler.step(len(losses))
try :
    grapher.gstart(20)
    start(model, optimizer, scheduler, data_train, data_test, "cuda" if torch.cuda.is_available() else "cpu", epochs, len(losses), save=50, early_stop = False)
finally :
    grapher.gstop()

Epoch :  40 Test loss :  0.039442543145269154                                                               
Epoch :  50 Test loss :  0.03852382559096441                                                                
Epoch :  60 Test loss :  0.0352633522124961                                                                 
Epoch :  70 Test loss :  0.038478864203207196                                                               
Epoch :  80 Test loss :  0.035006512058898806                                                               
Epoch :  90 Test loss :  0.034492358285933736                                                               
Epoch :  100 Test loss :  0.033827737369574604                                                              
Epoch :  110 Test loss :  0.03475434453226626                                                               
Epoch :  120 Test loss :  0.03620127455797047                                                               
Epoch :  130 Test l

In [None]:
#things to do :
    #make sure this model doesn't use past times (normally should not since the graph is disconnected and it's based on message passing)