In [172]:
import numpy as np
from matplotlib import pyplot as plt
from ipywidgets import interact

import torch
from torch import nn
from torchsde import BrownianInterval, sdeint

import dgl
import dgl.function as fn

import networkx as nx

import os
import gc
import pickle

# デバイス設定
device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')


In [173]:
def printNPZ(npz):
    for kw in npz.files:
        print(kw, npz[kw])

In [174]:
dirName = '/home/uwamichi/jupyter/HiraiwaModel_chem20220922_180005/'

homeName = '/home/uwamichi/jupyter/'

params = np.load(dirName+'params.npz')
#traj = np.load(dirName+'result.npz')

printNPZ(params)
#printNPZ(traj)

v0 1.0
r 1.0
D 0.1
A 0.0
L 20
rho 1.0
beta 1.0
A_CFs [0.9 0.5]
A_CIL 0.0
cellType_ratio [0.7 0.3]
quiv_colors ['k' 'r']
kappa 0.5
A_Macdonalds [0.5 0.5]
batch_size 400
state_size 3
brownian_size 1
periodic True
t_max 1000
methodSDE heun
isIto False
stepSDE 0.01


In [175]:
subdir_list = [f.path for f in os.scandir(dirName) if f.is_dir()]

print(subdir_list)

datadir_list = [f for f in subdir_list if 'result.npz' in [ff.name for ff in os.scandir(f)]]

print(datadir_list)

['/home/uwamichi/jupyter/HiraiwaModel_chem20220922_180005/20221026_220625', '/home/uwamichi/jupyter/HiraiwaModel_chem20220922_180005/20221027_020808', '/home/uwamichi/jupyter/HiraiwaModel_chem20220922_180005/20221027_061543', '/home/uwamichi/jupyter/HiraiwaModel_chem20220922_180005/20221027_102254', '/home/uwamichi/jupyter/HiraiwaModel_chem20220922_180005/20221027_144800', '/home/uwamichi/jupyter/HiraiwaModel_chem20220922_180005/ActiveNet_vp_rotsym_batchNorm', '/home/uwamichi/jupyter/HiraiwaModel_chem20220922_180005/20221027_190650', '/home/uwamichi/jupyter/HiraiwaModel_chem20220922_180005/ActiveNet_vp_rotsym_multiStep_transfer_batchNorm', '/home/uwamichi/jupyter/HiraiwaModel_chem20220922_180005/ActiveNet_vp_rotsym_multiStep_fineTuning_batchNorm', '/home/uwamichi/jupyter/HiraiwaModel_chem20220922_180005/20221114_163735', '/home/uwamichi/jupyter/HiraiwaModel_chem20220922_180005/20221114_205206', '/home/uwamichi/jupyter/HiraiwaModel_chem20220922_180005/20221115_004702', '/home/uwamichi/j

In [176]:
modeldirName = dirName + 'ActiveNet_vp_rotsym_noSelfLoop_multiStep_fineTuning_batchNorm/'

#datename = '20220921_160055'
i_model = -1

model_files_list = [[os.path.join(c, ff) for ff in f if ff.endswith('_Model.pt')]
               for c, s, f in os.walk(modeldirName)]
print(model_files_list)

model_files = []
for i in range(len(model_files_list)):
    for j in range(len(model_files_list[i])):
        model_files.append(model_files_list[i][j])
print(model_files)

model_dir = os.path.dirname(model_files[i_model])

model_name = model_files[i_model].replace('_Model.pt', '').replace(modeldirName, '')
print(model_name)

[[], ['/home/uwamichi/jupyter/HiraiwaModel_chem20220922_180005/ActiveNet_vp_rotsym_noSelfLoop_multiStep_fineTuning_batchNorm/20230304_132719/20230304_132719_Model.pt']]
['/home/uwamichi/jupyter/HiraiwaModel_chem20220922_180005/ActiveNet_vp_rotsym_noSelfLoop_multiStep_fineTuning_batchNorm/20230304_132719/20230304_132719_Model.pt']
20230304_132719/20230304_132719


In [177]:
load_initState = True
initTime = 0

if load_initState:
    i_truth = 0
    traj = np.load(datadir_list[i_truth]+'/result.npz')
    printNPZ(traj)
    savedirStr = os.path.join(model_dir, datadir_list[i_truth].replace(homeName, '')+'/')
else:
    savedirStr = model_dir+'/'
    
print(savedirStr)

xy [[[ 2.7825236  10.17232   ]
  [ 3.5684276   8.793808  ]
  [ 3.642801   17.762491  ]
  ...
  [15.283232    8.7073345 ]
  [ 1.0766518  18.73103   ]
  [ 0.5558777   4.839964  ]]

 [[ 2.5694928   9.920237  ]
  [ 2.9387658   7.848478  ]
  [ 3.368579   17.33183   ]
  ...
  [16.0392      9.724034  ]
  [ 1.6472214  18.467205  ]
  [ 0.456154    5.886875  ]]

 [[ 2.095088    9.80876   ]
  [ 2.0751727   7.4935527 ]
  [ 3.0540552  16.295012  ]
  ...
  [16.967674   10.063215  ]
  [ 2.24791    18.004068  ]
  [ 0.5111869   6.8831882 ]]

 ...

 [[ 0.9594116  14.447151  ]
  [16.365234    7.797249  ]
  [ 8.001495   18.281395  ]
  ...
  [ 4.8424225  15.38649   ]
  [ 4.554535   14.569389  ]
  [17.788956    6.647293  ]]

 [[ 0.67907715 14.403763  ]
  [16.566132    7.063217  ]
  [ 8.017166   17.700539  ]
  ...
  [ 4.755966   15.080368  ]
  [ 4.269821   14.435646  ]
  [18.349945    5.7295227 ]]

 [[ 0.64634705 14.523315  ]
  [16.50682     6.1477737 ]
  [ 7.859291   17.04631   ]
  ...
  [ 4.547867   14.831

In [178]:
if params['periodic']:
    L = torch.tensor(params['L'])
    def calc_dr(r1, r2):
        dr = torch.remainder((r1 - r2), L)
        dr[dr > L/2] = dr[dr > L/2] - L
        return dr
else:
    def calc_dr(r1, r2):
        return r1 - r2
    
def makeGraph(x_data, r_thresh):
        Ndata = x_data.size(0)
        dx = calc_dr(torch.unsqueeze(x_data, 0), torch.unsqueeze(x_data, 1))
        dx = torch.sum(dx**2, dim=2)
        edges = torch.argwhere(torch.logical_and(dx > 0, dx < r_thresh/2))
        return dgl.graph((edges[:,0], edges[:,1]), num_nodes=Ndata)

In [179]:
class NeuralNet(nn.Module):
    def __init__(self, in_channels, out_channels, Nchannels, dropout=0, batchN=False, flgBias=False):
        super(NeuralNet, self).__init__()

        if dropout:
            self.dropout = nn.Dropout(p=dropout)
        else:
            self.dropout = 0
            
        if batchN:
            self.bNorm1 = nn.BatchNorm1d(Nchannels)
            self.bNorm2 = nn.BatchNorm1d(Nchannels)
            self.bNorm3 = nn.BatchNorm1d(Nchannels)
            
        self.batchN=batchN
        
        self.layer1 = nn.Linear(in_channels, Nchannels, bias=flgBias)
        self.layer2 = nn.Linear(Nchannels, Nchannels, bias=flgBias)
        self.layer3 = nn.Linear(Nchannels, Nchannels, bias=flgBias)
        self.layer4 = nn.Linear(Nchannels, out_channels, bias=flgBias)

        self.activation = nn.ReLU()

    def reset_parameters(self):
        self.layer1.reset_parameters()
        self.layer2.reset_parameters()
        self.layer3.reset_parameters()
        self.layer4.reset_parameters()
        #nn.init.zeros_(self.layer1.weight)
        #nn.init.zeros_(self.layer2.weight)
        #nn.init.zeros_(self.layer3.weight)
        #nn.init.zeros_(self.layer4.weight)
        
    def forward(self, x):
        out = self.activation(self.layer1(x))
        if self.batchN:
            out = self.bNorm1(out)
        if self.dropout:
            out = self.dropout(out)
        
        out = self.activation(self.layer2(out))
        if self.batchN:
            out = self.bNorm2(out)
        if self.dropout:
            out = self.dropout(out)
        
        out = self.activation(self.layer3(out))
        if self.batchN:
            out = self.bNorm3(out)
        if self.dropout:
            out = self.dropout(out)
        
        out = self.layer4(out)

        return out

class ActiveNet(nn.Module):
    def __init__(self, xy_dim, r, dropout=0, batchN=False, bias=False, Nchannels=128):
        super().__init__()

        self.interactNN = NeuralNet(xy_dim*2 + 2, xy_dim, Nchannels, dropout, batchN, bias)

        self.thetaDotNN = NeuralNet(xy_dim*2 + 2, 1, Nchannels, dropout, batchN, bias)
        
        self.selfpropel = nn.Parameter(torch.tensor(0.0, requires_grad=True, device=device))

        #self.Normalizer = nn.Softmax(dim=1)

        self.xy_dim = xy_dim
        
        self.r = r

        self.reset_parameters()

    def reset_parameters(self):
        self.interactNN.reset_parameters()

        self.thetaDotNN.reset_parameters()
        
        nn.init.uniform_(self.selfpropel)

        #self.bias.data.zero_()
        
    def load_celltypes(self, celltype):
        self.celltype = celltype

    def calc_message(self, edges):
        dx = calc_dr(edges.dst['x'], edges.src['x'])

        costheta = torch.cos(edges.dst['theta'])
        sintheta = torch.sin(edges.dst['theta'])

        dx_para = costheta * dx[..., :1] + sintheta * dx[..., 1:]
        dx_perp = costheta * dx[..., 1:] - sintheta * dx[..., :1]

        p_para_src = torch.cos(edges.src['theta'] - edges.dst['theta'])
        p_perp_src = torch.sin(edges.src['theta'] - edges.dst['theta'])

        rot_m_v = self.interactNN(torch.concat((dx_para, dx_perp, 
                                                p_para_src, p_perp_src,
                                                edges.dst['type'], edges.src['type']), -1))

        m_v = torch.concat((costheta * rot_m_v[..., :1] - sintheta * rot_m_v[..., 1:],
                            costheta * rot_m_v[..., 1:] + sintheta * rot_m_v[..., :1]), -1)

        m_theta = self.thetaDotNN(torch.concat((dx_para, dx_perp, 
                                                p_para_src, p_perp_src, 
                                                edges.dst['type'], edges.src['type']), -1))
        
        return {'m': torch.concat((m_v, m_theta), -1)}
        
    def forward(self, xv):
        r_g = makeGraph(xv[..., :self.xy_dim], self.r/2)
        r_g.ndata['x'] = xv[..., :self.xy_dim]
        r_g.ndata['theta'] = xv[..., self.xy_dim:(self.xy_dim+1)]
        r_g.ndata['type'] = self.celltype
        r_g.update_all(self.calc_message, fn.sum('m', 'a'))
        r_g.ndata['a'][..., :self.xy_dim] = r_g.ndata['a'][..., :self.xy_dim] + self.selfpropel * torch.concat((torch.cos(r_g.ndata['theta']), torch.sin(r_g.ndata['theta'])), -1)
        
        return r_g.ndata['a']



In [180]:
filename1 = modeldirName + model_name + '_Model.pt'
gnn_model = torch.load(filename1, map_location=torch.device(device)) #pickle.load(open(filename1, 'rb'))
gnn_model.eval()

filename3 = modeldirName + model_name + '_Separation.npz'
learn_params = np.load(filename3, allow_pickle=True)

printNPZ(learn_params)

dr_thresh 4
T_pred 5
batch_size 8
train_inds [2 9 6 4 3]
valid_inds [7 1]
test_inds [8 5 0]
val_loss_log [[1.63403279 0.11493566]
 [1.59570481 0.1151024 ]
 [1.5764672  0.11543358]
 [1.55472892 0.11587476]
 [1.54151713 0.11615335]
 [1.53351366 0.11645507]
 [1.52429478 0.11667096]
 [1.5201491  0.11672764]
 [1.51461845 0.11697252]
 [1.51065312 0.11719412]
 [1.50759546 0.11734923]
 [1.50427528 0.11750269]
 [1.50065303 0.11732401]
 [1.50009504 0.11725311]
 [1.49723444 0.11737162]
 [1.49572535 0.11729963]
 [1.4941661  0.11728718]
 [1.49268577 0.11731092]
 [1.4917414  0.11721653]
 [1.49045308 0.11708841]
 [1.48883917 0.11705846]
 [1.48662562 0.11690125]
 [1.48532288 0.11683401]
 [1.48446726 0.1165512 ]
 [1.48351788 0.1164945 ]
 [1.48163097 0.11641084]
 [1.48036438 0.11629311]
 [1.47842579 0.11617087]
 [1.4779427  0.11590116]
 [1.47635261 0.11571354]
 [1.47503093 0.1156854 ]
 [1.47337949 0.11555678]
 [1.47259492 0.11564068]
 [1.47153208 0.11548702]
 [1.47073437 0.11525437]
 [1.46819007 0.11497

In [181]:
dir(gnn_model)

['T_destination',
 '__annotations__',
 '__call__',
 '__class__',
 '__delattr__',
 '__dict__',
 '__dir__',
 '__doc__',
 '__eq__',
 '__format__',
 '__ge__',
 '__getattr__',
 '__getattribute__',
 '__gt__',
 '__hash__',
 '__init__',
 '__init_subclass__',
 '__le__',
 '__lt__',
 '__module__',
 '__ne__',
 '__new__',
 '__reduce__',
 '__reduce_ex__',
 '__repr__',
 '__setattr__',
 '__setstate__',
 '__sizeof__',
 '__str__',
 '__subclasshook__',
 '__weakref__',
 '_apply',
 '_backward_hooks',
 '_buffers',
 '_call_impl',
 '_forward_hooks',
 '_forward_pre_hooks',
 '_get_backward_hooks',
 '_get_name',
 '_is_full_backward_hook',
 '_load_from_state_dict',
 '_load_state_dict_post_hooks',
 '_load_state_dict_pre_hooks',
 '_maybe_warn_non_full_backward_hook',
 '_modules',
 '_named_members',
 '_non_persistent_buffers_set',
 '_parameters',
 '_register_load_state_dict_pre_hook',
 '_register_state_dict_hook',
 '_replicate_for_data_parallel',
 '_save_to_state_dict',
 '_slow_forward',
 '_state_dict_hooks',
 '_ver

In [182]:
# threshold distance for making edges
dr_thresh = learn_params['dr_thresh'].item()

# fixed noise strength
D = 0.0 #params['D']

# boundary conditions (L=100 in article)
#L = params['L']
periodic = params['periodic']

# cell type ratio
cellType_ratio = params['cellType_ratio']
quiv_colors = params['quiv_colors']

# numbers of cells, freedom, and noise source
batch_size, state_size, brownian_size = params['batch_size'], 3, 1

# duration of simulation (6400 in the article)
t_max =  params['t_max'] - initTime

# method to solve SDE
#methodSDE = 'euler'
#isIto = True

# time step to simulate
stepSDE = 1


In [183]:
gnn_model.r = dr_thresh

In [184]:
gnn_model.r

4

In [185]:
print(gnn_model.state_dict())

OrderedDict([('selfpropel', tensor(0.9251, device='cuda:1')), ('interactNN.layer1.weight', tensor([[ 0.2185,  0.2012,  0.0633, -0.1603, -0.2546,  0.2063],
        [ 0.0789,  0.0199, -0.2185, -0.0179,  0.0370,  0.2715],
        [-0.0823,  0.1029, -0.3984, -0.0192,  0.3186, -0.0922],
        [ 0.0451,  0.2858, -0.0804, -0.3068,  0.0925,  0.0947],
        [-0.4581,  0.0992, -0.2529, -0.3365,  0.3901, -0.0395],
        [-0.0803,  0.0635, -0.0571,  0.2846, -0.2998,  0.3691],
        [ 0.4822,  0.2574,  0.0750,  0.1314,  0.0157, -0.3035],
        [ 0.4956, -0.4538, -0.0259, -0.0735, -0.0517, -0.0602],
        [-0.5238, -0.5016, -0.0672, -0.1873, -0.1848, -0.1180],
        [ 0.3227, -0.3686, -0.2994,  0.1840, -0.3364,  0.3206],
        [-0.3876, -0.3320, -0.1246, -0.0276, -0.3060, -0.2636],
        [ 0.0414, -0.2237,  0.0751, -0.2037,  0.3727, -0.0216],
        [-0.0592, -0.5208, -0.0320,  0.0850, -0.0405, -0.1056],
        [ 0.1455, -0.2446, -0.3514, -0.0719,  0.2618,  0.1284],
        [-0.0

In [186]:
celltype_label = torch.zeros([batch_size, 1], device=device)
quiv_label = np.full([batch_size], '')
count = 0
for ict, ctr in enumerate(cellType_ratio):
    Nct = int(np.round(batch_size*ctr)) + count
    celltype_label[count:Nct, 0] = ict
    quiv_label[count:Nct] = quiv_colors[ict]
    count = Nct
    
gnn_model.load_celltypes(celltype_label.to(device))

In [187]:
sigma = np.sqrt(2*D)

bm = BrownianInterval(t0=0, 
                      t1=float(t_max), 
                      size=(batch_size, brownian_size),
                      dt=stepSDE,
                      device=device)

def calc_onestep(x, t):

    dv_dtheta = gnn_model(x)
    
    out = x + dv_dtheta
    out[..., -1:] = out[..., -1:] + sigma * bm(t, t+stepSDE)

    return out


In [188]:
# Initial state y0, the SDE is solved over the interval [ts[0], ts[-1]].
# ys will have shape (t_size, batch_size, state_size)

if load_initState:
    theta0 = torch.tensor(traj['theta'][initTime], device=device).view(-1,1)
    y0 = torch.concat((torch.tensor(traj['xy'][initTime], device=device), 
                       theta0), 1)

else:
    theta0 = torch.rand((batch_size, 1), device=device) * (2 * np.pi)
    y0 = torch.concat((torch.rand((batch_size, state_size-1), device=device) * L,
                       theta0), 1)

print(y0)

ts = np.linspace(0, t_max, t_max+1)

ys = np.zeros((len(ts), batch_size, state_size))
ys[0] = y0.cpu().detach().numpy()

print(ys[0])

tensor([[ 2.7825, 10.1723,  3.5049],
        [ 3.5684,  8.7938,  4.1048],
        [ 3.6428, 17.7625,  4.2182],
        ...,
        [15.2832,  8.7073,  0.2160],
        [ 1.0767, 18.7310,  5.8443],
        [ 0.5559,  4.8400,  1.5180]], device='cuda:1')
[[ 2.78252363 10.17232037  3.50493574]
 [ 3.56842756  8.79380798  4.10477257]
 [ 3.64280105 17.76249123  4.21824741]
 ...
 [15.28323174  8.70733452  0.21601154]
 [ 1.07665181 18.73102951  5.84431362]
 [ 0.55587769  4.83996391  1.51804745]]


In [189]:
y = y0.clone()

with torch.no_grad():
    for it, t in enumerate(ts[:-1]):
        y = calc_onestep(y, t)
        ys[it+1] = y.cpu().detach().numpy()

In [190]:
plt.rcParams['font.size']=18
plt.rcParams['figure.figsize'] = [8,6]
plt.rcParams['savefig.bbox'] = 'tight'

plt.rcParams['lines.markersize'] = 11.0

plt.rcParams['lines.linewidth'] = 3.5

In [191]:
cm = plt.cm.hsv

@interact(t_plot=(0, t_max-1))
def f(t_plot):

  x = ys[t_plot,:,0] % L.item()
  y = ys[t_plot,:,1] % L.item()
  px = np.cos(ys[t_plot,:,2])
  py = np.sin(ys[t_plot,:,2])

  fig, ax = plt.subplots()
  ax.set_aspect('equal')
  im = ax.quiver(x, y, px, py, color=quiv_label, scale = 1, scale_units='xy')
  plt.xlim(0, L)
  plt.ylim(0, L)
  #fig.colorbar(im)
  im.set_clim(0, 2*np.pi)


interactive(children=(IntSlider(value=499, description='t_plot', max=999), Output()), _dom_classes=('widget-in…

In [192]:
import pickle
import datetime
import codecs

now = datetime.datetime.now()
nowstr = now.strftime('%Y%m%d_%H%M%S')

savedirName = savedirStr + 'simulate_D'+str(D)+'_init' +str(initTime)+ '/' + nowstr + '/'
os.makedirs(savedirName, exist_ok=True)

np.savez(savedirName + 'results.npz', ys=ys, ts=ts, 
         celltype_label=celltype_label.cpu().detach().numpy(), quiv_label=quiv_label)

#torch.save(gnn_model.to('cpu'), savedirName + nowstr + '_Model.pt')

np.savez(savedirName + nowstr + '_Separation.npz', dr_thresh=dr_thresh, batch_size=batch_size)

np.savez(os.path.join(savedirName, 'params.npz'),
         D = D, dr_thresh = dr_thresh,
         cellType_ratio = cellType_ratio, quiv_colors = quiv_colors,
         batch_size = batch_size, state_size = state_size, brownian_size = brownian_size,
         t_max = t_max, stepSDE = stepSDE)

loaddir = savedirName
params = np.load(loaddir+'/params.npz', allow_pickle=True)

txtstring = []
for k in params.files:
    txtstring.append("{} = {}".format(k, params[k]))

print(*txtstring, sep="\n", file=codecs.open(loaddir+'/params.txt', 'w', 'utf-8'))

for t_plot in range(t_max):

  x = ys[t_plot,:,0] % L.item()
  y = ys[t_plot,:,1] % L.item()
  px = np.cos(ys[t_plot,:,2])
  py = np.sin(ys[t_plot,:,2])

  fig, ax = plt.subplots()
  ax.set_aspect('equal')
  im = ax.quiver(x, y, px, py, color=quiv_label, scale = 1, scale_units='xy')
  plt.xlim(0, L)
  plt.ylim(0, L)
  #fig.colorbar(im)
  im.set_clim(0, 2*np.pi)

  fig.savefig(savedirName + 't'+format(t_plot, '04d')+'.png')

  plt.clf()
  plt.close()


In [193]:
print(savedirStr)

/home/uwamichi/jupyter/HiraiwaModel_chem20220922_180005/ActiveNet_vp_rotsym_noSelfLoop_multiStep_fineTuning_batchNorm/20230304_132719/HiraiwaModel_chem20220922_180005/20221026_220625/
