In this notebook, I will use torch_geometric to predict the developpement of a graph of positions through time

In [1]:
import torch_geometric
import torch

from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts

import numpy as np

from torch_geometric.utils import to_networkx, from_networkx
from torch_geometric.nn import radius_graph
from torch_geometric.data import Data

import networkx as nx

import glob

import pickle
import lzma

import time

import sys
import os

sys.path.append("/home/nstillman/0_miscellaneous/sbi_model1")
sys.path.append('/home/nstillman/1_sbi_activematter/cpp_model')
#import allium

  from .autonotebook import tqdm as notebook_tqdm


The data is a graph of cells having their own positions and velocity.

In the graph, we will first start by connecting all the edges, then maybe later make radius_graphs to reduce the cost of the pass through the model

In [2]:
import os.path as osp

from torch_geometric.data import Dataset

#find /scratch/users/nstillman/data-cpp/train/ -name "*fast4p.p" -type f  | head | xargs du
"""
27804   ./p000_r004_sb015_2148010_b039_fast4p.p                                                                         
33133   ./p013_r000_sb009_2147623_b071_fast4p.p                                                                         
9628    ./p056_r000_sb040_2147623_b132_fast4p.p                                                                         
23679   ./p029_r003_sb006_2148010_b038_fast4p.p                                                                         
16676   ./p030_r000_sb038_2147623_b119_fast4p.p                                                                         
39793   ./p019_r001_sb013_2147250_b028_fast4p.p                                                                         
21402   ./p023_r000_sb018_2147623_b131_fast4p.p                                                                         
22863   ./p001_r002_sb001_2148010_b026_fast4p.p                                                                         
38547   ./p040_r000_sb000_2147623_b082_fast4p.p                                                                         
23422   ./p058_r000_sb020_2147623_b134_fast4p.p
"""

class CellGraphDataset(Dataset):
    def __init__(self, root, max_size, transform=None, pre_transform=None, pre_filter=None):
        super().__init__(root, transform, pre_transform, pre_filter)
        
        self.paths = self.each_path(root, max_size)
        
    def _download(self):
        pass

    def _process(self):
        pass

    class each_path():
        def __init__(self, root, max_size):
            self.root = root
            self.max_size = max_size
            self._each_path = self.read()
            
        def read(self):
            relative = [s.replace('\\', '/') for s in glob.glob(str(self.root) + '/*fast4p.p*')]
            max_size = max(self.max_size, len(relative))
            relative = relative[:max_size]
            absolute = [osp.abspath(s) for s in relative]
            return absolute

        def fset(self, value):
            self._each_path = value
        
        def fget(self):
            return self._each_path
        
    def process_file(self, path):
        print(path)
        if path.endswith(".p") :
            with open(path, 'rb') as f:
                x = pickle.load(f)
                
        if path.endswith(".pz") :
            with lzma.open(path, 'rb') as f:
                x = pickle.load(f)
                
        else :
            raise ValueError("File type not supported for path: " + path)
            
        # Parameters of interest: 
        #Attraction force: 
        epsilon = x.param.pairatt[0][0]
        # Persistence timescale 
        tau = x.param.tau[0]
        # Active force
        v0 = x.param.factive[0]

        #cutoff distance defines the interaction radius. You can assume below:
        cutoff = 2*(x.param.cutoffZ + 2*x.param.pairatt[0][0])
        #Get position data
        rval = x.rval
        #Get time and number of cells from shape of position data
        T = rval.shape[0]
        N = rval.shape[1]


        #ideally we would like to only have those connections but radius_graph doesn't work on GPU for some reason
        edge_index = torch.tensor([radius_graph(rval[t], r=cutoff, batch=None, loop=False, max_num_neighbors=100) for t in range(T)])
        
        #though for now we will make a completely connected graph
        #edges = torch.tensor([[i,j] for i in range(N) for j in range(N) if i!=j]).t()
        
        #distance between nodes where rval[t, i, :] is the position of cell i at time t
        edges_attr =  torch.tensor([torch.norm(rval[t, edge_index[t, 0, :], :] - rval[t, edge_index[t, 1, :], :], dim=1) for t in range(T)])
        
        #we will add to rval dx and dy to get the velocity
        rval = torch.cat((rval, torch.zeros((T, N, 2))), dim=2)
        for t in range(T-1):
            rval[t+1, :, 2:4] = rval[t+1, :, :2] - rval[t, :, :2]

        return rval, edge_index, edges_attr

        #construction of the graph
        geom = [Data(x=torch.tensor(rval[t, :, :]), edge_index=edge_index, edge_attr=edges_attr[t, :]) for t in range(T)]

        #save the list of graphs
        return geom
    
    def len(self):
        return len(self.paths.fget())

    def get(self, idx):
        data = self.process_file(self.paths.fget()[idx])
        
    def dump_source(self, path):
        with open(path, 'wb') as f:
            pickle.dump(self.paths.fget(), f)
            
    def overwrite_source(self, path):
        with open(path, 'rb') as f:
            #set the property to the new list of paths
            self.paths.fset(pickle.load(f))

In [3]:
path = "data/" #local
#path = "/scratch/users/nstillman/data-cpp/" #remote

data_train = CellGraphDataset(root=path + 'train', max_size=36)
print("Training data length : ", data_train.len())

data_test = CellGraphDataset(root=path + 'test', max_size=10)
print("Test data length : ", data_test.len())
 
data_val = CellGraphDataset(root=path + 'valid', max_size=10)
print("Validation data length : ", data_val.len())

Training data length :  36
Test data length :  10
Validation data length :  9


In [4]:
if "sources" not in os.listdir():
    os.mkdir("sources")

if "train_paths.pkl" not in os.listdir("sources"):
    #first time running, dump the paths to a pickle file
    data_train.dump_source("sources/train_paths.pkl")
else :
    #overwrite the paths to the previous configuration
    data_train.overwrite_source("sources/train_paths.pkl")
    
if "test_paths.pkl" not in os.listdir("sources"):
    data_test.dump_source("sources/test_paths.pkl")
else :
    data_test.overwrite_source("sources/test_paths.pkl")
    
if "val_paths.pkl" not in os.listdir("sources"):
    data_val.dump_source("sources/val_paths.pkl")
else :
    data_val.overwrite_source("sources/val_paths.pkl")

In [6]:
data = data_train.get(0)


c:\Users\gille\Desktop\graph-displacement\data\train\p000_r000_sb000_2147955_b045_fast4p.p


ModuleNotFoundError: No module named 'allium'

Next we need to define the model that will be used :
    > input 
        (1) Graph at a particular time t (nodes having x,y,dx,dy as attributes)
        (2) Graphs up to a particular time [t-a, t] (nodes having x,y as attributes)
    > output
        (a) Graph at the immediate next time step t+1
        (b) Graph [t, t+b]
        (c) Graph at t+b
    > graph size
        (x) Fixed graph size to the most nodes possible (or above)
        (y) Unbounded graph size
            >> idea : graph walks
            >> idea : sampler

The following model will do (1ax)

In [None]:
from torch_geometric.nn import GATv2Conv
import torch.nn.functional as F

class GraphEvolution(torch.nn.Module):
    def __init__(self, in_channels, out_channels, hidden_channels, heads, dropout=0.0):
        super().__init__()
        
        #we want the channels to be (x, y, vx, vy)
        if (in_channels != out_channels):
            raise ValueError("in_channels must be equal to out_channels")
        
        self.conv  = GATv2Conv(in_channels, hidden_channels, heads=heads, dropout=dropout)
        self.post = torch.nn.Linear(heads*hidden_channels, out_channels)

    def forward(self, x, edge_index, edge_attr):
        #x is a tensor of shape (T, N, in_channels)
        #for each time step, we will predict the next time step
        #the last one is empty and will be filled with the prediction
        
        #the edge_attr is a tensor of shape (T, E)
        #where E is the number of edges
        
        #edge_index is a tensor of shape (T, 2, E)
        
        #if this worked it would be great
        #y = self.conv(x, edge_index, edge_attr)

        #here T is treated as a batch dimension
        for i in range(0, x.shape[0]):
            # x[i] is a tensor of shape (N, in_channels)
            y = self.conv(x[i], edge_index[i], edge_attr[i])
            # y is a tensor of shape (N, heads*hidden_channels)
            y = F.elu(y)
            # different nodes will be considered as batches
            y = self.post(y)
            # y is a tensor of shape (N, out_channels)
            x[i] = y + x[i]
            
        return x

In [None]:
global losses
losses = []

In [None]:
import random

def test(model, data, device) :
    model.eval()
    model = model.to(device)
    with torch.no_grad():
        loss_sum = 0
        for i in range(data.len()):
            x, edge_index, edge_attr = data.get(i)
            
            random_number = random.randint(1, x.shape[0]-2)
            
            xshape = x.shape
            
            y = x[random_number+1].to(device)
            x = x[random_number].to(device)
            ei = edge_index[random_number].to(device)
            ea = edge_attr[random_number].to(device)
            
            out = model(x.unsqueeze(0), ei.unsqueeze(0), ea.unsqueeze(0))
            
            loss = F.mse_loss(out, y.unsqueeze(0))
            
            loss_sum = loss_sum + loss.item()
            
        return loss_sum / data.len()

In [None]:
def train(model, optimizer, scheduler, data, device) :
    model.train()
    model = model.to(device)
    for i in range(data.len()):
        optimizer.zero_grad()
        
        x, edge_index, edge_attr = data.get(i)
        
        x = x.to(device)
        edge_index = edge_index.to(device)
        edge_attr = edge_attr.to(device)
        
        #we don't want to predict the last step since we wouldn't have the data for the loss
        #and for the first point we don't have the velocity
        out = model(x[1:-1], edge_index[1:-1], edge_attr[1:-1])
        
        loss = F.mse_loss(out, x[2:])
        loss.backward()
        
        optimizer.step()
        
        scheduler.step(i)
        
    return model

In [None]:
def start(model, optimizer, scheduler, data_train, data_test, device, epoch):
    for e in range(epoch):
        model = train(model, optimizer, scheduler, data_train, device)
            
        test_loss = test(model, data_test, device)
        
        print("Epoch : ", e, "Test loss : ", test_loss)
        
        losses.append(test_loss)

In [None]:
model = GraphEvolution(4, 4, 32, 8, 0.0)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
scheduler = CosineAnnealingWarmRestarts(optimizer=optimizer, T_0=10, T_mult=2)

In [None]:
import threading
import matplotlib.pyplot as plt

class GraphingLoss():
    def __init__(self):
        self.losses = []
        self.stop = False
        self.timer = 0

    def plot_and_reschedule(self):
        plt.plot(losses)
        plt.show()
        if not self.stop:
            threading.Timer(self.timer, self.plot_and_reschedule).start()
            
    def gstop(self):
        self.stop = True
        
    def gstart(self, timer=10):
        self.timer = timer
        if (not self.timer or self.timer != int(self.timer)):
            raise ValueError("timer must be a positive integer")
        
        threading.Timer(self.timer, self.plot_and_reschedule).start()

In [None]:
epochs = 100
grapher = GraphingLoss()
gstart = grapher.gstart(10)
start(model, optimizer, scheduler, data_train, data_test, "cuda" if torch.cuda.is_available() else "cpu", epochs)