In [155]:
import numpy as np
from matplotlib import pyplot as plt
from ipywidgets import interact

import torch
from torch import nn
from torchsde import BrownianInterval, sdeint

import dgl
import dgl.function as fn

import networkx as nx

import os
import gc
import pickle

# デバイス設定
device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')


In [156]:
def printNPZ(npz):
    for kw in npz.files:
        print(kw, npz[kw])

In [157]:
dirName = './HiraiwaModel_chem20221102_020319/'

homeName = './'

params = np.load(dirName+'params.npz')
#traj = np.load(dirName+'result.npz')

printNPZ(params)
#printNPZ(traj)

v0 1.0
r 1.0
D 0.1
A 0.0
L 20
rho 1.0
beta 1.0
A_CFs [0.9 0.1]
A_CIL 0.0
cellType_ratio [0.7 0.3]
quiv_colors ['k' 'r']
kappa 1.0
A_Macdonalds [2. 2.]
batch_size 400
state_size 3
brownian_size 1
periodic True
t_max 200
methodSDE heun
isIto False
stepSDE 0.01


In [158]:
subdir_list = [f.path for f in os.scandir(dirName) if f.is_dir()]

print(subdir_list)

datadir_list = [f for f in subdir_list if 'result.npz' in [ff.name for ff in os.scandir(f)]]

print(datadir_list)

['./HiraiwaModel_chem20221102_020319/20221102_194956', './HiraiwaModel_chem20221102_020319/20221102_203311', './HiraiwaModel_chem20221102_020319/20221102_211612', './HiraiwaModel_chem20221102_020319/20221102_215953', './HiraiwaModel_chem20221102_020319/20221102_224239', './HiraiwaModel_chem20221102_020319/20221102_232519', './HiraiwaModel_chem20221102_020319/20221103_000843', './HiraiwaModel_chem20221102_020319/20221103_005036', './HiraiwaModel_chem20221102_020319/20221103_013303', './HiraiwaModel_chem20221102_020319/20221103_021518', './HiraiwaModel_chem20221102_020319/20221103_025758', './HiraiwaModel_chem20221102_020319/20221103_033955', './HiraiwaModel_chem20221102_020319/20221103_042225', './HiraiwaModel_chem20221102_020319/20221103_050424', './HiraiwaModel_chem20221102_020319/20221103_054634', './HiraiwaModel_chem20221102_020319/20221103_062842', './HiraiwaModel_chem20221102_020319/20221103_071105', './HiraiwaModel_chem20221102_020319/20221103_075340', './HiraiwaModel_chem2022110

In [159]:
modeldirName = dirName + 'ActiveNet_vp_rotsym_multiStep_fineTuning_batchNorm/'

#datename = '20220921_160055'
i_model = -1

model_files_list = [[os.path.join(c, ff) for ff in f if ff.endswith('_Model.pt')]
               for c, s, f in os.walk(modeldirName)]
print(model_files_list)

model_files = []
for i in range(len(model_files_list)):
    for j in range(len(model_files_list[i])):
        model_files.append(model_files_list[i][j])
print(model_files)

model_dir = os.path.dirname(model_files[i_model])

model_name = model_files[i_model].replace('_Model.pt', '').replace(modeldirName, '')
print(model_name)

[[], ['./HiraiwaModel_chem20221102_020319/ActiveNet_vp_rotsym_multiStep_fineTuning_batchNorm/20221123_044031/20221123_044031_Model.pt'], [], [], ['./HiraiwaModel_chem20221102_020319/ActiveNet_vp_rotsym_multiStep_fineTuning_batchNorm/20221124_000148/20221124_000148_Model.pt'], [], [], ['./HiraiwaModel_chem20221102_020319/ActiveNet_vp_rotsym_multiStep_fineTuning_batchNorm/20221124_172342/20221124_172342_Model.pt'], [], []]
['./HiraiwaModel_chem20221102_020319/ActiveNet_vp_rotsym_multiStep_fineTuning_batchNorm/20221123_044031/20221123_044031_Model.pt', './HiraiwaModel_chem20221102_020319/ActiveNet_vp_rotsym_multiStep_fineTuning_batchNorm/20221124_000148/20221124_000148_Model.pt', './HiraiwaModel_chem20221102_020319/ActiveNet_vp_rotsym_multiStep_fineTuning_batchNorm/20221124_172342/20221124_172342_Model.pt']
20221124_172342/20221124_172342


In [160]:
load_initState = True
initTime = 100

if load_initState:
    i_truth = 0
    traj = np.load(datadir_list[i_truth]+'/result.npz')
    printNPZ(traj)
    savedirStr = os.path.join(model_dir, datadir_list[i_truth].replace(homeName, '')+'/')
else:
    savedirStr = model_dir+'/'
    

xy [[[ 4.1139326   9.296628  ]
  [ 3.4176195  16.287083  ]
  [ 6.668969    5.7918215 ]
  ...
  [ 2.3774242  14.191948  ]
  [ 4.774294   14.416741  ]
  [ 2.1923792   5.9275446 ]]

 [[ 4.5529084   9.618775  ]
  [ 3.3233438  16.671354  ]
  [ 5.9192305   5.181894  ]
  ...
  [ 2.9622722  14.693307  ]
  [ 5.2210836  14.138584  ]
  [ 2.4284892   5.112486  ]]

 [[ 5.2404866   9.577545  ]
  [ 2.967069   16.444715  ]
  [ 5.8765817   4.613716  ]
  ...
  [ 3.7002106  14.917825  ]
  [ 6.095678   14.030148  ]
  [ 1.9746629   4.689925  ]]

 ...

 [[15.047613    3.320736  ]
  [ 3.3207378  14.982803  ]
  [ 9.36577    12.797707  ]
  ...
  [ 2.278626    2.2870636 ]
  [ 2.4501953   5.7985215 ]
  [19.353088   18.924608  ]]

 [[14.632973    3.4523087 ]
  [ 3.3929787  15.064716  ]
  [ 8.738223   13.414352  ]
  ...
  [ 2.9922526   2.4920807 ]
  [ 1.7580891   6.2111015 ]
  [ 0.05151939 18.215548  ]]

 [[14.365591    3.562416  ]
  [ 3.263052   15.059608  ]
  [ 7.9789047  13.866756  ]
  ...
  [ 3.8677142   2.087

In [161]:
if params['periodic']:
    L = torch.tensor(params['L'])
    def calc_dr(r1, r2):
        dr = torch.remainder((r1 - r2), L)
        dr[dr > L/2] = dr[dr > L/2] - L
        return dr
else:
    def calc_dr(r1, r2):
        return r1 - r2
    
def makeGraph(x_data, r_thresh):
        Ndata = x_data.size(0)
        dx = calc_dr(torch.unsqueeze(x_data, 0), torch.unsqueeze(x_data, 1))
        dx = torch.sum(dx**2, dim=2)
        edges = torch.argwhere(dx < r_thresh/2)
        return dgl.graph((edges[:,0], edges[:,1]), num_nodes=Ndata)

In [162]:
class NeuralNet(nn.Module):
    def __init__(self, in_channels, out_channels, Nchannels, dropout=0, batchN=False, flgBias=False):
        super(NeuralNet, self).__init__()

        if dropout:
            self.dropout = nn.Dropout(p=dropout)
        else:
            self.dropout = 0
            
        if batchN:
            self.bNorm1 = nn.BatchNorm1d(Nchannels)
            self.bNorm2 = nn.BatchNorm1d(Nchannels)
            self.bNorm3 = nn.BatchNorm1d(Nchannels)
            
        self.batchN=batchN
        
        self.layer1 = nn.Linear(in_channels, Nchannels, bias=flgBias)
        self.layer2 = nn.Linear(Nchannels, Nchannels, bias=flgBias)
        self.layer3 = nn.Linear(Nchannels, Nchannels, bias=flgBias)
        self.layer4 = nn.Linear(Nchannels, out_channels, bias=flgBias)

        self.activation = nn.ReLU()

    def reset_parameters(self):
        self.layer1.reset_parameters()
        self.layer2.reset_parameters()
        self.layer3.reset_parameters()
        self.layer4.reset_parameters()
        #nn.init.zeros_(self.layer1.weight)
        #nn.init.zeros_(self.layer2.weight)
        #nn.init.zeros_(self.layer3.weight)
        #nn.init.zeros_(self.layer4.weight)
        
    def forward(self, x):
        out = self.activation(self.layer1(x))
        if self.batchN:
            out = self.bNorm1(out)
        if self.dropout:
            out = self.dropout(out)
        
        out = self.activation(self.layer2(out))
        if self.batchN:
            out = self.bNorm2(out)
        if self.dropout:
            out = self.dropout(out)
        
        out = self.activation(self.layer3(out))
        if self.batchN:
            out = self.bNorm3(out)
        if self.dropout:
            out = self.dropout(out)
        
        out = self.layer4(out)

        return out

class ActiveNet(nn.Module):
    def __init__(self, xy_dim, r, dropout=0, batchN=False, bias=False, Nchannels=128):
        super().__init__()

        self.interactNN = NeuralNet(xy_dim*2 + 2, xy_dim, Nchannels, dropout, batchN, bias)

        self.thetaDotNN = NeuralNet(xy_dim*2 + 2, 1, Nchannels, dropout, batchN, bias)
        
        self.selfpropel = nn.Parameter(torch.tensor(0.0, requires_grad=True, device=device))

        #self.Normalizer = nn.Softmax(dim=1)

        self.xy_dim = xy_dim
        
        self.r = r

        self.reset_parameters()

    def reset_parameters(self):
        self.interactNN.reset_parameters()

        self.thetaDotNN.reset_parameters()
        
        nn.init.uniform_(self.selfpropel)

        #self.bias.data.zero_()
        
    def load_celltypes(self, celltype):
        self.celltype = celltype

    def calc_message(self, edges):
        dx = calc_dr(edges.dst['x'], edges.src['x'])

        costheta = torch.cos(edges.dst['theta'])
        sintheta = torch.sin(edges.dst['theta'])

        dx_para = costheta * dx[..., :1] + sintheta * dx[..., 1:]
        dx_perp = costheta * dx[..., 1:] - sintheta * dx[..., :1]

        p_para_src = torch.cos(edges.src['theta'] - edges.dst['theta'])
        p_perp_src = torch.sin(edges.src['theta'] - edges.dst['theta'])

        rot_m_v = self.interactNN(torch.concat((dx_para, dx_perp, 
                                                p_para_src, p_perp_src,
                                                edges.dst['type'], edges.src['type']), -1))

        m_v = torch.concat((costheta * rot_m_v[..., :1] - sintheta * rot_m_v[..., 1:],
                            costheta * rot_m_v[..., 1:] + sintheta * rot_m_v[..., :1]), -1)

        m_theta = self.thetaDotNN(torch.concat((dx_para, dx_perp, 
                                                p_para_src, p_perp_src, 
                                                edges.dst['type'], edges.src['type']), -1))
        
        return {'m': torch.concat((m_v, m_theta), -1)}
        
    def forward(self, xv):
        r_g = makeGraph(xv[..., :self.xy_dim], self.r/2)
        r_g.ndata['x'] = xv[..., :self.xy_dim]
        r_g.ndata['theta'] = xv[..., self.xy_dim:(self.xy_dim+1)]
        r_g.ndata['type'] = self.celltype
        r_g.update_all(self.calc_message, fn.sum('m', 'a'))
        r_g.ndata['a'][..., :self.xy_dim] = r_g.ndata['a'][..., :self.xy_dim] + self.selfpropel * torch.concat((torch.cos(r_g.ndata['theta']), torch.sin(r_g.ndata['theta'])), -1)
        
        return r_g.ndata['a']



In [163]:
filename1 = modeldirName + model_name + '_Model.pt'
gnn_model = torch.load(filename1, map_location=torch.device(device)) #pickle.load(open(filename1, 'rb'))
gnn_model.eval()

filename3 = modeldirName + model_name + '_Separation.npz'
learn_params = np.load(filename3, allow_pickle=True)

printNPZ(learn_params)

dr_thresh 4
T_pred 5
batch_size 8
train_inds [15  9 23 17  0 19  4 16 20 14 10  5  2  7 21 22  8 12 18 13]
valid_inds [ 6  1 11]
test_inds [24  3]
val_loss_log [[1.45646583 0.21329873]
 [1.45091671 0.21317007]
 [1.44348577 0.21293874]
 [1.43768798 0.21275028]
 [1.43034618 0.21247833]
 [1.42535211 0.2122647 ]
 [1.41899505 0.21234914]
 [1.41271807 0.21203104]
 [1.40926855 0.21196829]
 [1.40479076 0.21166787]
 [1.4011826  0.21154321]
 [1.40106776 0.21162673]
 [1.39587382 0.21153385]
 [1.39185597 0.2112902 ]
 [1.38959325 0.2113577 ]
 [1.38676675 0.21135884]
 [1.38532993 0.21148182]
 [1.38472513 0.21153093]
 [1.38401124 0.21140037]
 [1.3823777  0.21127214]
 [1.37910575 0.21112209]
 [1.37695371 0.21105833]
 [1.37386404 0.2109144 ]
 [1.37339088 0.21081962]
 [1.37087448 0.21086789]
 [1.36986768 0.21087584]
 [1.36893867 0.21076649]
 [1.3660223  0.21081922]
 [1.36364621 0.21073553]
 [1.36388904 0.21065095]
 [1.36212226 0.21064733]
 [1.35886983 0.21081248]
 [1.35815139 0.21072975]
 [1.35609521 0.

In [164]:
dir(gnn_model)

['T_destination',
 '__annotations__',
 '__call__',
 '__class__',
 '__delattr__',
 '__dict__',
 '__dir__',
 '__doc__',
 '__eq__',
 '__format__',
 '__ge__',
 '__getattr__',
 '__getattribute__',
 '__gt__',
 '__hash__',
 '__init__',
 '__init_subclass__',
 '__le__',
 '__lt__',
 '__module__',
 '__ne__',
 '__new__',
 '__reduce__',
 '__reduce_ex__',
 '__repr__',
 '__setattr__',
 '__setstate__',
 '__sizeof__',
 '__str__',
 '__subclasshook__',
 '__weakref__',
 '_apply',
 '_backward_hooks',
 '_buffers',
 '_call_impl',
 '_forward_hooks',
 '_forward_pre_hooks',
 '_get_backward_hooks',
 '_get_name',
 '_is_full_backward_hook',
 '_load_from_state_dict',
 '_load_state_dict_post_hooks',
 '_load_state_dict_pre_hooks',
 '_maybe_warn_non_full_backward_hook',
 '_modules',
 '_named_members',
 '_non_persistent_buffers_set',
 '_parameters',
 '_register_load_state_dict_pre_hook',
 '_register_state_dict_hook',
 '_replicate_for_data_parallel',
 '_save_to_state_dict',
 '_slow_forward',
 '_state_dict_hooks',
 '_ver

In [165]:
# threshold distance for making edges
dr_thresh = learn_params['dr_thresh'].item()

# fixed noise strength
D = 0.0 #params['D']

# boundary conditions (L=100 in article)
#L = params['L']
periodic = params['periodic']

# cell type ratio
cellType_ratio = params['cellType_ratio']
quiv_colors = params['quiv_colors']

# numbers of cells, freedom, and noise source
batch_size, state_size, brownian_size = params['batch_size'], 3, 1

# duration of simulation (6400 in the article)
t_max =  params['t_max'] - initTime

# method to solve SDE
#methodSDE = 'euler'
#isIto = True

# time step to simulate
stepSDE = 1


In [166]:
gnn_model.r = dr_thresh

In [167]:
gnn_model.r

4

In [168]:
print(gnn_model.state_dict())

OrderedDict([('selfpropel', tensor(0.9674, device='cuda:1')), ('interactNN.layer1.weight', tensor([[-1.5789e-01, -1.8519e-01,  1.4166e-01, -4.5040e-02,  1.4097e-01,
         -1.2141e-01],
        [-6.7769e-01,  8.1622e-01,  2.5423e-01, -2.3726e-02, -5.1402e-02,
         -1.5248e-02],
        [-1.4998e-01, -1.1237e-02, -8.8741e-02, -2.6726e-01, -3.6581e-01,
         -9.9635e-03],
        [-1.8666e-01, -2.4193e-01,  3.5165e-01,  1.4185e-01,  1.4809e-01,
          1.5322e-01],
        [ 1.6536e-01, -3.4551e-01, -3.2431e-01,  9.5790e-02,  3.4689e-01,
          4.8828e-02],
        [ 1.4116e-01,  4.9433e-01, -1.6227e-02,  1.4102e-02,  1.4293e-01,
          1.8939e-01],
        [ 1.7721e-01, -1.1081e+00, -2.6408e-01,  2.6522e-01,  5.3711e-02,
          6.9439e-02],
        [ 5.8838e-02, -1.1318e-01,  3.2413e-01,  2.4427e-01, -2.3609e-01,
         -2.5406e-01],
        [-8.2445e-01,  6.6213e-01,  5.1267e-02, -2.2211e-01, -7.7745e-02,
         -3.6832e-02],
        [ 1.7080e-01, -5.7209e-01, -

In [169]:
celltype_label = torch.zeros([batch_size, 1], device=device)
quiv_label = np.full([batch_size], '')
count = 0
for ict, ctr in enumerate(cellType_ratio):
    Nct = int(np.round(batch_size*ctr)) + count
    celltype_label[count:Nct, 0] = ict
    quiv_label[count:Nct] = quiv_colors[ict]
    count = Nct
    
gnn_model.load_celltypes(celltype_label.to(device))

In [170]:
sigma = np.sqrt(2*D)

bm = BrownianInterval(t0=0, 
                      t1=float(t_max), 
                      size=(batch_size, brownian_size),
                      dt=stepSDE,
                      device=device)

def calc_onestep(x, t):

    dv_dtheta = gnn_model(x)
    
    out = x + dv_dtheta
    out[..., -1:] = out[..., -1:] + sigma * bm(t, t+stepSDE)

    return out


In [171]:
# Initial state y0, the SDE is solved over the interval [ts[0], ts[-1]].
# ys will have shape (t_size, batch_size, state_size)

if load_initState:
    theta0 = torch.tensor(traj['theta'][initTime], device=device).view(-1,1)
    y0 = torch.concat((torch.tensor(traj['xy'][initTime], device=device), 
                       theta0), 1)

else:
    theta0 = torch.rand((batch_size, 1), device=device) * (2 * np.pi)
    y0 = torch.concat((torch.rand((batch_size, state_size-1), device=device) * L,
                       theta0), 1)

print(y0)

ts = np.linspace(0, t_max, t_max+1)

ys = np.zeros((len(ts), batch_size, state_size))
ys[0] = y0.cpu().detach().numpy()

print(ys[0])

tensor([[ 3.8594, 10.5606,  3.8949],
        [18.7685, 14.7998,  2.2557],
        [17.9473, 18.5409,  6.2108],
        ...,
        [ 0.8962, 14.0865,  4.0011],
        [ 5.2347,  7.9910,  3.4275],
        [ 5.1084,  7.5699,  3.2847]], device='cuda:1')
[[ 3.85938144 10.56056499  3.89494419]
 [18.76848602 14.7998333   2.25569105]
 [17.94725418 18.54092598  6.21079731]
 ...
 [ 0.89619774 14.08654499  4.00110388]
 [ 5.23468733  7.99098873  3.42754221]
 [ 5.10841942  7.56990242  3.28469181]]


In [172]:
y = y0.clone()

with torch.no_grad():
    for it, t in enumerate(ts[:-1]):
        y = calc_onestep(y, t)
        ys[it+1] = y.cpu().detach().numpy()

In [173]:
plt.rcParams['font.size']=18
plt.rcParams['figure.figsize'] = [8,6]
plt.rcParams['savefig.bbox'] = 'tight'

plt.rcParams['lines.markersize'] = 11.0

plt.rcParams['lines.linewidth'] = 3.5

In [174]:
cm = plt.cm.hsv

@interact(t_plot=(0, t_max-1))
def f(t_plot):

  x = ys[t_plot,:,0] % L.item()
  y = ys[t_plot,:,1] % L.item()
  px = np.cos(ys[t_plot,:,2])
  py = np.sin(ys[t_plot,:,2])

  fig, ax = plt.subplots()
  ax.set_aspect('equal')
  im = ax.quiver(x, y, px, py, color=quiv_label, scale = 1, scale_units='xy')
  plt.xlim(0, L)
  plt.ylim(0, L)
  #fig.colorbar(im)
  im.set_clim(0, 2*np.pi)


interactive(children=(IntSlider(value=49, description='t_plot', max=99), Output()), _dom_classes=('widget-inte…

In [175]:
import pickle
import datetime
import codecs

now = datetime.datetime.now()
nowstr = now.strftime('%Y%m%d_%H%M%S')

savedirName = savedirStr + 'simulate_D'+str(D)+'_init' +str(initTime)+ '/' + nowstr + '/'
os.makedirs(savedirName, exist_ok=True)

np.savez(savedirName + 'results.npz', ys=ys, ts=ts, 
         celltype_label=celltype_label.cpu().detach().numpy(), quiv_label=quiv_label)

#torch.save(gnn_model.to('cpu'), savedirName + nowstr + '_Model.pt')

np.savez(savedirName + nowstr + '_Separation.npz', dr_thresh=dr_thresh, batch_size=batch_size)

np.savez(os.path.join(savedirName, 'params.npz'),
         D = D, dr_thresh = dr_thresh,
         cellType_ratio = cellType_ratio, quiv_colors = quiv_colors,
         batch_size = batch_size, state_size = state_size, brownian_size = brownian_size,
         t_max = t_max, stepSDE = stepSDE)

loaddir = savedirName
params = np.load(loaddir+'/params.npz', allow_pickle=True)

txtstring = []
for k in params.files:
    txtstring.append("{} = {}".format(k, params[k]))

print(*txtstring, sep="\n", file=codecs.open(loaddir+'/params.txt', 'w', 'utf-8'))

for t_plot in range(t_max):

  x = ys[t_plot,:,0] % L.item()
  y = ys[t_plot,:,1] % L.item()
  px = np.cos(ys[t_plot,:,2])
  py = np.sin(ys[t_plot,:,2])

  fig, ax = plt.subplots()
  ax.set_aspect('equal')
  im = ax.quiver(x, y, px, py, color=quiv_label, scale = 1, scale_units='xy')
  plt.xlim(0, L)
  plt.ylim(0, L)
  #fig.colorbar(im)
  im.set_clim(0, 2*np.pi)

  fig.savefig(savedirName + 't'+format(t_plot, '04d')+'.png')

  plt.clf()
  plt.close()


In [176]:
print(savedirStr)

./HiraiwaModel_chem20221102_020319/ActiveNet_vp_rotsym_multiStep_fineTuning_batchNorm/20221124_172342/HiraiwaModel_chem20221102_020319/20221102_194956/
