In this notebook, I will use torch_geometric to predict the developpement of a graph of positions through time

In [66]:
import torch_geometric
import torch

import numpy as np

from torch_geometric.utils import to_networkx, from_networkx
from torch_geometric.nn import radius_graph
from torch_geometric.data import Data

import networkx as nx

import glob

import pickle
import lzma

import time

import sys

sys.path.append('/home/nstillman/1_sbi_activematter/cpp_model')
import allium

The data is a graph of cells having their own positions and velocity.

In the graph, we will first start by connecting all the edges, then maybe later make radius_graphs to reduce the cost of the pass through the model

In [67]:
import os.path as osp

from torch_geometric.data import Dataset

#find /scratch/users/nstillman/data-cpp/train/ -name "*fast4p.p" -type f  | head | xargs du
"""
27804   ./p000_r004_sb015_2148010_b039_fast4p.p                                                                         
33133   ./p013_r000_sb009_2147623_b071_fast4p.p                                                                         
9628    ./p056_r000_sb040_2147623_b132_fast4p.p                                                                         
23679   ./p029_r003_sb006_2148010_b038_fast4p.p                                                                         
16676   ./p030_r000_sb038_2147623_b119_fast4p.p                                                                         
39793   ./p019_r001_sb013_2147250_b028_fast4p.p                                                                         
21402   ./p023_r000_sb018_2147623_b131_fast4p.p                                                                         
22863   ./p001_r002_sb001_2148010_b026_fast4p.p                                                                         
38547   ./p040_r000_sb000_2147623_b082_fast4p.p                                                                         
23422   ./p058_r000_sb020_2147623_b134_fast4p.p
"""

class CellGraphDataset(Dataset):
    def __init__(self, root, transform=None, pre_transform=None, pre_filter=None):
        super().__init__(root, transform, pre_transform, pre_filter)
        
    def _download(self):
        pass

    def _process(self):
        pass

    @property
    def each_path(self):
        relative = [s.replace('\\', '/') for s in glob.glob(str(self.root) + '/*fast4p.p*')]
        absolute = [osp.abspath(s) for s in relative]
        return absolute
        
    def process_file(self, path):
        print(path)
        if path.endswith(".p") :
            with open(path, 'rb') as f:
                x = pickle.load(f)
                
        if path.endswith(".pz") :
            with lzma.open(path, 'rb') as f:
                x = pickle.load(f)
                
        else :
            raise ValueError("File type not supported for path: " + path)
            
        # Parameters of interest: 
        #Attraction force: 
        epsilon = x.param.pairatt[0][0]
        # Persistence timescale 
        tau = x.param.tau[0]
        # Active force
        v0 = x.param.factive[0]

        #cutoff distance defines the interaction radius. You can assume below:
        cutoff = 2*(x.param.cutoffZ + 2*x.param.pairatt[0][0])
        #Get position data
        rval = x.rval
        #Get time and number of cells from shape of position data
        T = rval.shape[0]
        N = rval.shape[1]


        #ideally we would like to only have those connections but radius_graph doesn't work on GPU for some reason
        #edges = radius_graph(pos, r=cutoff, batch=None, loop=False, max_num_neighbors=100)
        
        #though for now we will make a completely connected graph
        edges = torch.tensor([[i,j] for i in range(N) for j in range(N) if i!=j]).t()
        
        #distance between nodes where rval[t, i, :] is the position of cell i at time t
        edges_attr =  torch.tensor([np.linalg.norm(rval[t, i, :] - rval[t, j, :]) for t in range(T) for i in range(N) for j in range(N) if i!=j]).reshape(T, -1).t()

        #construction of the graph
        geom = [Data(x=torch.tensor(rval[t, :, :]), edge_index=edges, edge_attr=edges_attr[t, :]) for t in range(T)]

        #save the list of graphs
        return geom
    
    def len(self):
        return len(self.each_path)

    def get(self, idx):
        data = self.process_file(self.each_path[idx])

In [68]:
#data = CellGraphDataset(root='/scratch/users/nstillman/data-cpp/train/')
data_train = CellGraphDataset(root='../data/train/')
print("Training data length : ", data_train.len())

data_test = CellGraphDataset(root='../data/test/')
print("Test data length : ", data_test.len())
 
data_val = CellGraphDataset(root='../data/valid/')
print("Validation data length : ", data_val.len())

Training data length :  36
Test data length :  10
Validation data length :  9


In [69]:
data = data_train.get(0)


c:\Users\gille\Desktop\data\train\p000_r000_sb000_2147955_b045_fast4p.p


ModuleNotFoundError: No module named 'allium'