In [23]:
import numpy as np
from matplotlib import pyplot as plt
from ipywidgets import interact

import torch
from torch import nn
from torchsde import BrownianInterval, sdeint

import dgl
import dgl.function as fn

import networkx as nx

import os
import gc
import pickle

# デバイス設定
device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')


In [24]:
def printNPZ(npz):
    for kw in npz.files:
        print(kw, npz[kw])

In [25]:
dirName = './HiraiwaModel_chem20221102_020319/'

homeName = './'

params = np.load(dirName+'params.npz')
#traj = np.load(dirName+'result.npz')

printNPZ(params)
#printNPZ(traj)

v0 1.0
r 1.0
D 0.1
A 0.0
L 20
rho 1.0
beta 1.0
A_CFs [0.9 0.1]
A_CIL 0.0
cellType_ratio [0.7 0.3]
quiv_colors ['k' 'r']
kappa 1.0
A_Macdonalds [2. 2.]
batch_size 400
state_size 3
brownian_size 1
periodic True
t_max 200
methodSDE heun
isIto False
stepSDE 0.01


In [26]:
subdir_list = [f.path for f in os.scandir(dirName) if f.is_dir()]

print(subdir_list)

datadir_list = [f for f in subdir_list if 'result.npz' in [ff.name for ff in os.scandir(f)]]

print(datadir_list)

['./HiraiwaModel_chem20221102_020319/20221102_194956', './HiraiwaModel_chem20221102_020319/20221102_203311', './HiraiwaModel_chem20221102_020319/20221102_211612', './HiraiwaModel_chem20221102_020319/20221102_215953', './HiraiwaModel_chem20221102_020319/20221102_224239', './HiraiwaModel_chem20221102_020319/20221102_232519', './HiraiwaModel_chem20221102_020319/20221103_000843', './HiraiwaModel_chem20221102_020319/20221103_005036', './HiraiwaModel_chem20221102_020319/20221103_013303', './HiraiwaModel_chem20221102_020319/20221103_021518', './HiraiwaModel_chem20221102_020319/20221103_025758', './HiraiwaModel_chem20221102_020319/20221103_033955', './HiraiwaModel_chem20221102_020319/20221103_042225', './HiraiwaModel_chem20221102_020319/20221103_050424', './HiraiwaModel_chem20221102_020319/20221103_054634', './HiraiwaModel_chem20221102_020319/20221103_062842', './HiraiwaModel_chem20221102_020319/20221103_071105', './HiraiwaModel_chem20221102_020319/20221103_075340', './HiraiwaModel_chem2022110

In [27]:
modeldirName = dirName + 'ActiveNet_vp_rotsym_noSelfLoop_multiStep_fineTuning_batchNorm/'

#datename = '20220921_160055'
i_model = -1

model_files_list = [[os.path.join(c, ff) for ff in f if ff.endswith('_Model.pt')]
               for c, s, f in os.walk(modeldirName)]
print(model_files_list)

model_files = []
for i in range(len(model_files_list)):
    for j in range(len(model_files_list[i])):
        model_files.append(model_files_list[i][j])
print(model_files)

model_dir = os.path.dirname(model_files[i_model])

model_name = model_files[i_model].replace('_Model.pt', '').replace(modeldirName, '')
print(model_name)

[[], ['./HiraiwaModel_chem20221102_020319/ActiveNet_vp_rotsym_noSelfLoop_multiStep_fineTuning_batchNorm/20221130_164456/20221130_164456_Model.pt'], [], [], [], []]
['./HiraiwaModel_chem20221102_020319/ActiveNet_vp_rotsym_noSelfLoop_multiStep_fineTuning_batchNorm/20221130_164456/20221130_164456_Model.pt']
20221130_164456/20221130_164456


In [28]:
load_initState = False
initTime = 0

if load_initState:
    i_truth = 0
    traj = np.load(datadir_list[i_truth]+'/result.npz')
    printNPZ(traj)
    savedirStr = os.path.join(model_dir, datadir_list[i_truth].replace(homeName, '')+'/')
else:
    savedirStr = model_dir+'/'
    

In [29]:
if params['periodic']:
    L = torch.tensor(params['L'])
    def calc_dr(r1, r2):
        dr = torch.remainder((r1 - r2), L)
        dr[dr > L/2] = dr[dr > L/2] - L
        return dr
else:
    def calc_dr(r1, r2):
        return r1 - r2
    
def makeGraph(x_data, r_thresh):
        Ndata = x_data.size(0)
        dx = calc_dr(torch.unsqueeze(x_data, 0), torch.unsqueeze(x_data, 1))
        dx = torch.sum(dx**2, dim=2)
        edges = torch.argwhere(torch.logical_and(dx > 0, dx < r_thresh/2))
        return dgl.graph((edges[:,0], edges[:,1]), num_nodes=Ndata)

In [30]:
class NeuralNet(nn.Module):
    def __init__(self, in_channels, out_channels, Nchannels, dropout=0, batchN=False, flgBias=False):
        super(NeuralNet, self).__init__()

        if dropout:
            self.dropout = nn.Dropout(p=dropout)
        else:
            self.dropout = 0
            
        if batchN:
            self.bNorm1 = nn.BatchNorm1d(Nchannels)
            self.bNorm2 = nn.BatchNorm1d(Nchannels)
            self.bNorm3 = nn.BatchNorm1d(Nchannels)
            
        self.batchN=batchN
        
        self.layer1 = nn.Linear(in_channels, Nchannels, bias=flgBias)
        self.layer2 = nn.Linear(Nchannels, Nchannels, bias=flgBias)
        self.layer3 = nn.Linear(Nchannels, Nchannels, bias=flgBias)
        self.layer4 = nn.Linear(Nchannels, out_channels, bias=flgBias)

        self.activation = nn.ReLU()

    def reset_parameters(self):
        self.layer1.reset_parameters()
        self.layer2.reset_parameters()
        self.layer3.reset_parameters()
        self.layer4.reset_parameters()
        #nn.init.zeros_(self.layer1.weight)
        #nn.init.zeros_(self.layer2.weight)
        #nn.init.zeros_(self.layer3.weight)
        #nn.init.zeros_(self.layer4.weight)
        
    def forward(self, x):
        out = self.activation(self.layer1(x))
        if self.batchN:
            out = self.bNorm1(out)
        if self.dropout:
            out = self.dropout(out)
        
        out = self.activation(self.layer2(out))
        if self.batchN:
            out = self.bNorm2(out)
        if self.dropout:
            out = self.dropout(out)
        
        out = self.activation(self.layer3(out))
        if self.batchN:
            out = self.bNorm3(out)
        if self.dropout:
            out = self.dropout(out)
        
        out = self.layer4(out)

        return out

class ActiveNet(nn.Module):
    def __init__(self, xy_dim, r, dropout=0, batchN=False, bias=False, Nchannels=128):
        super().__init__()

        self.interactNN = NeuralNet(xy_dim*2 + 2, xy_dim, Nchannels, dropout, batchN, bias)

        self.thetaDotNN = NeuralNet(xy_dim*2 + 2, 1, Nchannels, dropout, batchN, bias)
        
        self.selfpropel = nn.Parameter(torch.tensor(0.0, requires_grad=True, device=device))

        #self.Normalizer = nn.Softmax(dim=1)

        self.xy_dim = xy_dim
        
        self.r = r

        self.reset_parameters()

    def reset_parameters(self):
        self.interactNN.reset_parameters()

        self.thetaDotNN.reset_parameters()
        
        nn.init.uniform_(self.selfpropel)

        #self.bias.data.zero_()
        
    def load_celltypes(self, celltype):
        self.celltype = celltype

    def calc_message(self, edges):
        dx = calc_dr(edges.dst['x'], edges.src['x'])

        costheta = torch.cos(edges.dst['theta'])
        sintheta = torch.sin(edges.dst['theta'])

        dx_para = costheta * dx[..., :1] + sintheta * dx[..., 1:]
        dx_perp = costheta * dx[..., 1:] - sintheta * dx[..., :1]

        p_para_src = torch.cos(edges.src['theta'] - edges.dst['theta'])
        p_perp_src = torch.sin(edges.src['theta'] - edges.dst['theta'])

        rot_m_v = self.interactNN(torch.concat((dx_para, dx_perp, 
                                                p_para_src, p_perp_src,
                                                edges.dst['type'], edges.src['type']), -1))

        m_v = torch.concat((costheta * rot_m_v[..., :1] - sintheta * rot_m_v[..., 1:],
                            costheta * rot_m_v[..., 1:] + sintheta * rot_m_v[..., :1]), -1)

        m_theta = self.thetaDotNN(torch.concat((dx_para, dx_perp, 
                                                p_para_src, p_perp_src, 
                                                edges.dst['type'], edges.src['type']), -1))
        
        return {'m': torch.concat((m_v, m_theta), -1)}
        
    def forward(self, xv):
        r_g = makeGraph(xv[..., :self.xy_dim], self.r/2)
        r_g.ndata['x'] = xv[..., :self.xy_dim]
        r_g.ndata['theta'] = xv[..., self.xy_dim:(self.xy_dim+1)]
        r_g.ndata['type'] = self.celltype
        r_g.update_all(self.calc_message, fn.sum('m', 'a'))
        r_g.ndata['a'][..., :self.xy_dim] = r_g.ndata['a'][..., :self.xy_dim] + self.selfpropel * torch.concat((torch.cos(r_g.ndata['theta']), torch.sin(r_g.ndata['theta'])), -1)
        
        return r_g.ndata['a']



In [31]:
filename1 = modeldirName + model_name + '_Model.pt'
gnn_model = torch.load(filename1, map_location=torch.device(device)) #pickle.load(open(filename1, 'rb'))
gnn_model.eval()

filename3 = modeldirName + model_name + '_Separation.npz'
learn_params = np.load(filename3, allow_pickle=True)

printNPZ(learn_params)

dr_thresh 4
T_pred 5
batch_size 8
train_inds [ 9  4 16  1 20 11  2 17  3 13  5  7 19 23 24  6  0 21 18  8]
valid_inds [22 15 10]
test_inds [14 12]
val_loss_log [[1.4807345  0.22311947]
 [1.47467708 0.22288243]
 [1.46873724 0.22261202]
 [1.46383406 0.22252834]
 [1.46241756 0.22251541]
 [1.45830144 0.22229846]
 [1.4527992  0.222151  ]
 [1.44936371 0.22212393]
 [1.44625042 0.22192037]
 [1.44326783 0.22176205]
 [1.44102412 0.22160732]
 [1.43928129 0.22138297]
 [1.43684123 0.22136499]
 [1.43371219 0.221305  ]
 [1.43149202 0.22114536]
 [1.42898182 0.22118869]
 [1.42733962 0.22114763]
 [1.42507449 0.22106917]
 [1.42353522 0.22113154]
 [1.42252574 0.22101138]
 [1.42104922 0.22096432]
 [1.42078752 0.22097922]
 [1.41982086 0.22111379]
 [1.41859626 0.22115344]
 [1.41787878 0.22097809]
 [1.41752318 0.22083972]
 [1.41657339 0.22084716]
 [1.41595671 0.22076702]
 [1.41476222 0.22079086]
 [1.41263541 0.22074121]
 [1.41147006 0.22065643]
 [1.40921505 0.22059962]
 [1.40953807 0.22070963]
 [1.4093303  0.

In [32]:
dir(gnn_model)

['T_destination',
 '__annotations__',
 '__call__',
 '__class__',
 '__delattr__',
 '__dict__',
 '__dir__',
 '__doc__',
 '__eq__',
 '__format__',
 '__ge__',
 '__getattr__',
 '__getattribute__',
 '__gt__',
 '__hash__',
 '__init__',
 '__init_subclass__',
 '__le__',
 '__lt__',
 '__module__',
 '__ne__',
 '__new__',
 '__reduce__',
 '__reduce_ex__',
 '__repr__',
 '__setattr__',
 '__setstate__',
 '__sizeof__',
 '__str__',
 '__subclasshook__',
 '__weakref__',
 '_apply',
 '_backward_hooks',
 '_buffers',
 '_call_impl',
 '_forward_hooks',
 '_forward_pre_hooks',
 '_get_backward_hooks',
 '_get_name',
 '_is_full_backward_hook',
 '_load_from_state_dict',
 '_load_state_dict_post_hooks',
 '_load_state_dict_pre_hooks',
 '_maybe_warn_non_full_backward_hook',
 '_modules',
 '_named_members',
 '_non_persistent_buffers_set',
 '_parameters',
 '_register_load_state_dict_pre_hook',
 '_register_state_dict_hook',
 '_replicate_for_data_parallel',
 '_save_to_state_dict',
 '_slow_forward',
 '_state_dict_hooks',
 '_ver

In [33]:
# threshold distance for making edges
dr_thresh = learn_params['dr_thresh'].item()

# fixed noise strength
D = 0.0 #params['D']

# boundary conditions (L=100 in article)
#L = params['L']
periodic = params['periodic']

# cell type ratio
cellType_ratio = params['cellType_ratio']
quiv_colors = params['quiv_colors']

# numbers of cells, freedom, and noise source
batch_size, state_size, brownian_size = params['batch_size'], 3, 1

# duration of simulation (6400 in the article)
t_max =  params['t_max'] - initTime

# method to solve SDE
#methodSDE = 'euler'
#isIto = True

# time step to simulate
stepSDE = 1


In [34]:
gnn_model.r = dr_thresh

In [35]:
gnn_model.r

4

In [36]:
print(gnn_model.state_dict())

OrderedDict([('selfpropel', tensor(0.9046, device='cuda:1')), ('interactNN.layer1.weight', tensor([[ 6.8693e-02,  2.2930e-02,  2.8067e-01, -3.1570e-01, -2.2064e-01,
         -4.7282e-01],
        [ 3.6042e-02,  1.7907e-01, -4.5696e-01, -3.4711e-01,  3.2970e-01,
         -7.4623e-02],
        [ 3.3213e-01, -1.0251e-01, -1.4388e-01,  3.2349e-02, -1.4225e-02,
          1.8594e-01],
        [ 8.1909e-02, -5.9736e-01, -6.5860e-02, -6.6977e-02, -9.3137e-02,
         -3.7936e-01],
        [ 1.7986e-01, -2.9950e-01, -1.8117e-01,  2.3075e-01, -9.3827e-02,
          7.1142e-04],
        [ 1.7699e-01,  1.2099e-01, -4.0220e-01, -1.5253e-01, -3.9307e-01,
          7.2420e-02],
        [ 2.8259e-01,  1.4268e-01, -2.3052e-01, -7.4569e-02,  2.1027e-01,
          1.1247e-01],
        [ 5.5769e-01,  4.0678e-01, -5.3094e-02,  8.4206e-02,  6.0408e-02,
          2.0130e-02],
        [-8.2915e-02,  2.6401e-01,  1.1098e-01,  2.4484e-02,  1.4117e-01,
         -2.9952e-01],
        [ 4.4458e-01, -6.7364e-02, -

In [37]:
celltype_label = torch.zeros([batch_size, 1], device=device)
quiv_label = np.full([batch_size], '')
count = 0
for ict, ctr in enumerate(cellType_ratio):
    Nct = int(np.round(batch_size*ctr)) + count
    celltype_label[count:Nct, 0] = ict
    quiv_label[count:Nct] = quiv_colors[ict]
    count = Nct
    
gnn_model.load_celltypes(celltype_label.to(device))

In [38]:
sigma = np.sqrt(2*D)

bm = BrownianInterval(t0=0, 
                      t1=float(t_max), 
                      size=(batch_size, brownian_size),
                      dt=stepSDE,
                      device=device)

def calc_onestep(x, t):

    dv_dtheta = gnn_model(x)
    
    out = x + dv_dtheta
    out[..., -1:] = out[..., -1:] + sigma * bm(t, t+stepSDE)

    return out


In [39]:
# Initial state y0, the SDE is solved over the interval [ts[0], ts[-1]].
# ys will have shape (t_size, batch_size, state_size)

if load_initState:
    theta0 = torch.tensor(traj['theta'][initTime], device=device).view(-1,1)
    y0 = torch.concat((torch.tensor(traj['xy'][initTime], device=device), 
                       theta0), 1)

else:
    theta0 = torch.rand((batch_size, 1), device=device) * (2 * np.pi)
    y0 = torch.concat((torch.rand((batch_size, state_size-1), device=device) * L,
                       theta0), 1)

print(y0)

ts = np.linspace(0, t_max, t_max+1)

ys = np.zeros((len(ts), batch_size, state_size))
ys[0] = y0.cpu().detach().numpy()

print(ys[0])

tensor([[ 7.2510, 14.7978,  1.4797],
        [11.2607,  2.9181,  4.1369],
        [18.8038,  4.5726,  3.5098],
        ...,
        [ 9.3847,  2.1352,  4.2550],
        [16.6394,  8.4456,  3.5421],
        [13.5632, 10.4455,  4.0421]], device='cuda:1')
[[ 7.25104904 14.797822    1.47971272]
 [11.26065159  2.91810226  4.13685608]
 [18.80384254  4.5726161   3.50983477]
 ...
 [ 9.38468552  2.13523912  4.25496912]
 [16.63944054  8.44555855  3.54212236]
 [13.56318855 10.44546318  4.04212189]]


In [40]:
y = y0.clone()

with torch.no_grad():
    for it, t in enumerate(ts[:-1]):
        y = calc_onestep(y, t)
        ys[it+1] = y.cpu().detach().numpy()

In [41]:
plt.rcParams['font.size']=18
plt.rcParams['figure.figsize'] = [8,6]
plt.rcParams['savefig.bbox'] = 'tight'

plt.rcParams['lines.markersize'] = 11.0

plt.rcParams['lines.linewidth'] = 3.5

In [42]:
cm = plt.cm.hsv

@interact(t_plot=(0, t_max-1))
def f(t_plot):

  x = ys[t_plot,:,0] % L.item()
  y = ys[t_plot,:,1] % L.item()
  px = np.cos(ys[t_plot,:,2])
  py = np.sin(ys[t_plot,:,2])

  fig, ax = plt.subplots()
  ax.set_aspect('equal')
  im = ax.quiver(x, y, px, py, color=quiv_label, scale = 1, scale_units='xy')
  plt.xlim(0, L)
  plt.ylim(0, L)
  #fig.colorbar(im)
  im.set_clim(0, 2*np.pi)


interactive(children=(IntSlider(value=99, description='t_plot', max=199), Output()), _dom_classes=('widget-int…

In [43]:
import pickle
import datetime
import codecs

now = datetime.datetime.now()
nowstr = now.strftime('%Y%m%d_%H%M%S')

savedirName = savedirStr + 'simulate_D'+str(D)+'_init' +str(initTime)+ '/' + nowstr + '/'
os.makedirs(savedirName, exist_ok=True)

np.savez(savedirName + 'results.npz', ys=ys, ts=ts, 
         celltype_label=celltype_label.cpu().detach().numpy(), quiv_label=quiv_label)

#torch.save(gnn_model.to('cpu'), savedirName + nowstr + '_Model.pt')

np.savez(savedirName + nowstr + '_Separation.npz', dr_thresh=dr_thresh, batch_size=batch_size)

np.savez(os.path.join(savedirName, 'params.npz'),
         D = D, dr_thresh = dr_thresh,
         cellType_ratio = cellType_ratio, quiv_colors = quiv_colors,
         batch_size = batch_size, state_size = state_size, brownian_size = brownian_size,
         t_max = t_max, stepSDE = stepSDE)

loaddir = savedirName
params = np.load(loaddir+'/params.npz', allow_pickle=True)

txtstring = []
for k in params.files:
    txtstring.append("{} = {}".format(k, params[k]))

print(*txtstring, sep="\n", file=codecs.open(loaddir+'/params.txt', 'w', 'utf-8'))

for t_plot in range(t_max):

  x = ys[t_plot,:,0] % L.item()
  y = ys[t_plot,:,1] % L.item()
  px = np.cos(ys[t_plot,:,2])
  py = np.sin(ys[t_plot,:,2])

  fig, ax = plt.subplots()
  ax.set_aspect('equal')
  im = ax.quiver(x, y, px, py, color=quiv_label, scale = 1, scale_units='xy')
  plt.xlim(0, L)
  plt.ylim(0, L)
  #fig.colorbar(im)
  im.set_clim(0, 2*np.pi)

  fig.savefig(savedirName + 't'+format(t_plot, '04d')+'.png')

  plt.clf()
  plt.close()


In [44]:
print(savedirStr)

./HiraiwaModel_chem20221102_020319/ActiveNet_vp_rotsym_noSelfLoop_multiStep_fineTuning_batchNorm/20221130_164456/
