In [5]:
#!pip install dgl-cu113 dglgo -f https://data.dgl.ai/wheels/repo.html

In [6]:
import os
os.environ['DGLBACKEND'] = 'pytorch'

In [72]:
import torch
from torch import nn

import dgl
import dgl.function as fn

import networkx as nx
from matplotlib import pyplot as plt
import numpy as np

import os
import gc

# デバイス設定
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


In [8]:
def printNPZ(npz):
    for kw in npz.files:
        print(kw, npz[kw])

In [29]:
dirName = './HiraiwaModel_chem20220916_150816/'
savedirName = dirName + 'ActiveNet_vp_rotsym/'
os.makedirs(savedirName, exist_ok=True)

params = np.load(dirName+'params.npz')
traj = np.load(dirName+'result.npz')

printNPZ(params)
printNPZ(traj)

v0 1.0
r 1.0
D 0.1
A 0.0
L 20
rho 1.0
beta 1.0
A_CFs [0.1 0.9]
A_CIL 0.0
cellType_ratio [0.5 0.5]
quiv_colors ['k' 'r']
kappa 0.5
A_Macdonalds [2.  0.2]
batch_size 400
state_size 3
brownian_size 1
periodic True
t_max 1000
methodSDE heun
isIto False
stepSDE 0.01
xy [[[17.648756    7.6790533 ]
  [12.942514   10.766076  ]
  [ 1.1944778  10.673973  ]
  ...
  [11.097479   10.400983  ]
  [ 6.7026143   6.4549184 ]
  [ 3.938607   13.604414  ]]

 [[18.273907    7.9960933 ]
  [13.940152   10.937266  ]
  [ 2.0769567  10.777703  ]
  ...
  [11.048374    9.450381  ]
  [ 6.242977    5.886885  ]
  [ 4.581452   13.308872  ]]

 [[18.163706    8.013234  ]
  [14.596919   11.515686  ]
  [ 2.7184641  11.058035  ]
  ...
  [10.640358    8.482995  ]
  [ 5.569696    5.964412  ]
  [ 4.982173   12.816619  ]]

 ...

 [[19.538105    8.743137  ]
  [ 0.2186718   5.0863724 ]
  [18.22847     4.743004  ]
  ...
  [ 0.6903497   2.5728188 ]
  [ 2.156601    7.8962593 ]
  [ 1.3712387   7.0354767 ]]

 [[19.575407    8.669537 

In [47]:
if params['periodic']:
    L = torch.tensor(params['L'])
    def calc_dr(r1, r2):
        dr = torch.remainder((r1 - r2), L)
        dr[dr > L/2] = dr[dr > L/2] - L
        return dr
else:
    def calc_dr(r1, r2):
        return r1 - r2
    
def makeGraph(x_data, r_thresh):
        Ndata = x_data.size(0)
        dx = calc_dr(torch.unsqueeze(x_data, 0), torch.unsqueeze(x_data, 1))
        dx = torch.sum(dx**2, dim=2)
        edges = torch.argwhere(dx < r_thresh/2)
        return dgl.graph((edges[:,0], edges[:,1]), num_nodes=Ndata)

In [48]:
dr_thresh = 4
dt = 1
batch_size = 8

N_data = params['t_max']

TR_VA_rate = np.array([0.6, 0.2])

TR_last = np.int_(np.ceil(TR_VA_rate[0] * N_data))
VA_last = np.int_(np.ceil((TR_VA_rate[0]+TR_VA_rate[1]) * N_data))

shuffle_inds = np.arange(N_data, dtype=int)
np.random.shuffle(shuffle_inds)

train_inds = shuffle_inds[:TR_last]
valid_inds = shuffle_inds[TR_last:VA_last]
test_inds = shuffle_inds[VA_last:]

celltypes = torch.tensor(traj['celltype_label']).view(-1,1)

xy_t = torch.tensor(traj['xy'][:-1,:,:])
v_t = calc_dr(torch.tensor(traj['xy'][1:,:,:]), torch.tensor(traj['xy'][:-1,:,:])) / dt
p_t = torch.unsqueeze(torch.tensor(traj['theta'][:-1,:]), dim=2)
w_t = torch.unsqueeze(torch.tensor((traj['theta'][1:,:]-traj['theta'][:-1,:])%(2*np.pi)/dt), dim=2)
#dr = (calc_dr(xy_t[:,:1], xy_t[:,:1].T)**2 + calc_dr(xy_t[:,1:2], xy_t[:,1:2].T)**2)**0.5
#edge_inds = torch.nonzero(torch.logical_and(dr <= dr_thresh, dr>0)).T

train_dataset = torch.utils.data.TensorDataset(
    torch.concat((xy_t[train_inds], p_t[train_inds]), -1),
    torch.concat((v_t[train_inds], w_t[train_inds]), -1))

valid_dataset = torch.utils.data.TensorDataset(
    torch.concat((xy_t[valid_inds], p_t[valid_inds]), -1),
    torch.concat((v_t[valid_inds], w_t[valid_inds]), -1))

test_dataset = torch.utils.data.TensorDataset(
    torch.concat((xy_t[test_inds], p_t[test_inds]), -1),
    torch.concat((v_t[test_inds], w_t[test_inds]), -1))


train_data = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, pin_memory=True)
valid_data = torch.utils.data.DataLoader(valid_dataset, batch_size=batch_size, pin_memory=True)
test_data = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, pin_memory=True)

del train_dataset, valid_dataset, test_dataset#, xy_t, v_t, p_t, w_t, traj
gc.collect()

#print(data)
#print(data.num_graphs)
#print(data.x)
#print(data.y)
#print(data.edge_index)

55170

In [49]:
print(train_inds)

[894 886   0 912 738 252 976 740 706 893  77  65 632 988 994 741 281 107
 209 380 167  36 306 465 780 495 805 331 580 724 813 225 714 384 825 244
 214 980 311 294 713 129 884 151 677 357 471 180 761 148 139 478 618 337
 523 312 823 873 559 549 274 895 930 695  71 206  12 801 902 491 907 984
 546 177  67 275 233 649 844 159  20 897 972 513 842  42 947 304 621  41
 106 635  58 663 485 919 182 323 819 387 112 981 585 817 608  81 918 676
 455 703 675 325  30 969 411 310 155 882 355 735 709 545 305  98 183 747
 496 221  64 407 377 992 689 434 732 836 584 512 739 885  28 948 648 314
 218 787 226 694 273 505 997 757 335  22 259 808 558 750 864 567  62 846
 718 719  61 963 303 667 923 629  94 596 480 292 989  66 650 540 843 669
 746 647 944 804 101 899 279  29 520 749 230 205 157 514 861 163 800 379
 334 392 852 204 889 430 625 726 865 662 451 553 126 258 481 140 474 942
 791 535 581 468 658 913 644 720 615 727 701 797 122   4 792 196 315 869
  15 822 493 194 100  31 360 807 166 270   2 810 36

In [50]:
dir(train_data)

['_DataLoader__initialized',
 '_DataLoader__multiprocessing_context',
 '_IterableDataset_len_called',
 '__annotations__',
 '__class__',
 '__class_getitem__',
 '__delattr__',
 '__dict__',
 '__dir__',
 '__doc__',
 '__eq__',
 '__format__',
 '__ge__',
 '__getattribute__',
 '__gt__',
 '__hash__',
 '__init__',
 '__init_subclass__',
 '__iter__',
 '__le__',
 '__len__',
 '__lt__',
 '__module__',
 '__ne__',
 '__new__',
 '__orig_bases__',
 '__parameters__',
 '__reduce__',
 '__reduce_ex__',
 '__repr__',
 '__setattr__',
 '__sizeof__',
 '__slots__',
 '__str__',
 '__subclasshook__',
 '__weakref__',
 '_auto_collation',
 '_dataset_kind',
 '_get_iterator',
 '_get_shared_seed',
 '_index_sampler',
 '_is_protocol',
 '_iterator',
 'batch_sampler',
 'batch_size',
 'check_worker_number_rationality',
 'collate_fn',
 'dataset',
 'drop_last',
 'generator',
 'multiprocessing_context',
 'num_workers',
 'persistent_workers',
 'pin_memory',
 'pin_memory_device',
 'prefetch_factor',
 'sampler',
 'timeout',
 'worker_i

In [51]:
def plotGraph(data):

    # networkxのグラフに変換
    nxg = dgl.to_networkx(data)

    # 可視化のためのページランク計算
    pr = nx.pagerank(nxg)
    pr_max = np.array(list(pr.values())).max()

    # 可視化する際のノード位置
    draw_pos = nx.spring_layout(nxg, seed=0) 

    # ノードの色設定
    cmap = plt.get_cmap('tab10')
    labels = data.y.numpy()
    colors = [cmap(l) for l in labels]

    # 図のサイズ
    fig0 = plt.figure(figsize=(10, 10))

    # 描画
    nx.draw_networkx_nodes(nxg, 
                          draw_pos,
                          node_size=[v / pr_max * 1000 for v in pr.values()])#,
                          #node_color=colors, alpha=0.5)
    nx.draw_networkx_edges(nxg, draw_pos, arrowstyle='-', alpha=0.2)
    nx.draw_networkx_labels(nxg, draw_pos, font_size=10)

    #plt.title('KarateClub')
    plt.show()

    return fig0

In [81]:
class NeuralNet(nn.Module):
    def __init__(self, in_channels, out_channels, Nchannels, flgBias=False):
        super(NeuralNet, self).__init__()

        self.layer1 = nn.Linear(in_channels, Nchannels, bias=flgBias)
        self.layer2 = nn.Linear(Nchannels, Nchannels, bias=flgBias)
        self.layer3 = nn.Linear(Nchannels, Nchannels, bias=flgBias)
        self.layer4 = nn.Linear(Nchannels, out_channels, bias=flgBias)

        self.activation = nn.ReLU()

    def reset_parameters(self):
        self.layer1.reset_parameters()
        self.layer2.reset_parameters()
        self.layer3.reset_parameters()
        self.layer4.reset_parameters()
        #nn.init.zeros_(self.layer1.weight)
        #nn.init.zeros_(self.layer2.weight)
        #nn.init.zeros_(self.layer3.weight)
        #nn.init.zeros_(self.layer4.weight)
        
    def forward(self, x):
        out = self.activation(self.layer1(x))
        out = self.activation(self.layer2(out))
        out = self.activation(self.layer3(out))
        out = self.layer4(out)

        return out

class ActiveNet(nn.Module):
    def __init__(self, xy_dim, r, Nchannels=128):
        super().__init__()

        self.interactNN = NeuralNet(xy_dim*2 + 2, xy_dim, Nchannels, False)

        self.thetaDotNN = NeuralNet(xy_dim*2 + 2, 1, Nchannels, False)
        
        self.selfpropel = nn.Parameter(torch.tensor(0.0, requires_grad=True, device=device))

        #self.Normalizer = nn.Softmax(dim=1)

        self.xy_dim = xy_dim
        
        self.r = r

        self.reset_parameters()

    def reset_parameters(self):
        self.interactNN.reset_parameters()

        self.thetaDotNN.reset_parameters()
        
        nn.init.uniform_(self.selfpropel)

        #self.bias.data.zero_()
        
    def load_celltypes(self, celltype):
        self.celltype = celltype

    def calc_message(self, edges):
        dx = calc_dr(edges.dst['x'], edges.src['x'])

        costheta = torch.cos(edges.dst['theta'])
        sintheta = torch.sin(edges.dst['theta'])

        dx_para = costheta * dx[..., :1] + sintheta * dx[..., 1:]
        dx_perp = costheta * dx[..., 1:] - sintheta * dx[..., :1]

        p_para_src = torch.cos(edges.src['theta'] - edges.dst['theta'])
        p_perp_src = torch.sin(edges.src['theta'] - edges.dst['theta'])

        rot_m_v = self.interactNN(torch.concat((dx_para, dx_perp, 
                                                p_para_src, p_perp_src,
                                                edges.dst['type'], edges.src['type']), -1))

        m_v = torch.concat((costheta * rot_m_v[..., :1] - sintheta * rot_m_v[..., 1:],
                            costheta * rot_m_v[..., 1:] + sintheta * rot_m_v[..., :1]), -1)

        m_theta = self.thetaDotNN(torch.concat((dx_para, dx_perp, 
                                                p_para_src, p_perp_src, 
                                                edges.dst['type'], edges.src['type']), -1))
        
        return {'m': torch.concat((m_v, m_theta), -1)}
        
    def forward(self, xv):
        r_g = makeGraph(xv[..., :self.xy_dim], self.r/2)
        r_g.ndata['x'] = xv[..., :self.xy_dim]
        r_g.ndata['theta'] = xv[..., self.xy_dim:(self.xy_dim+1)]
        r_g.ndata['type'] = self.celltype
        r_g.update_all(self.calc_message, fn.sum('m', 'a'))
        r_g.ndata['a'][..., :self.xy_dim] = r_g.ndata['a'][..., :self.xy_dim] + self.selfpropel * torch.concat((torch.cos(r_g.ndata['theta']), torch.sin(r_g.ndata['theta'])), -1)
        
        return r_g.ndata['a']



In [99]:
def myLoss(out, target):
    dv = torch.sum(torch.square(out[..., :xy_dim] - target[..., :xy_dim]), dim=-1)
    dcos = 1 - torch.cos(out[..., xy_dim] - target[..., xy_dim])
    return torch.mean(dv + dcos)

In [102]:
# モデルのインスタンス生成
xy_dim = 2

model = ActiveNet(xy_dim, dr_thresh, Nchannels=128).to(device)
model.load_celltypes(celltypes.to(device))
# input data
#data = dataset[0]

model.train()

# optimizer
optimizer = torch.optim.SGD(model.parameters(), lr=1e-2, momentum=0.9)
#optimizer = torch.optim.Adam(model.parameters(), lr=1e-2, weight_decay=5e-4)

val_loss_log = []

val_loss_min = np.Inf

# learnig loop
for epoch in range(300):
    for batch_x, batch_y in train_data:
        optimizer.zero_grad()
        loss = 0
        for ib in range(batch_x.size(0)):
            out = model(batch_x[ib].to(device))
            loss = loss + myLoss(out, batch_y[ib].to(device))
        loss = loss / batch_x.size(0)
        loss.backward()
        optimizer.step()
    val_loss = 0
    val_count = 0
    with torch.no_grad():
        for batch_x, batch_y in valid_data:
            for ib in range(batch_x.size(0)):
                val_out = model(batch_x[ib].to(device))
                val_loss = val_loss + myLoss(val_out, batch_y[ib].to(device))
            val_count = val_count + batch_x.size(0)
    val_loss = val_loss/val_count
    print('Epoch %d | train Loss: %.4f | valid Loss: %.4f' % (epoch, loss.item(), val_loss.item()))
    val_loss_log.append(val_loss.cpu().item())
    if val_loss.item() < val_loss_min:
        stored_model = model
        val_loss_min = val_loss.item()

Epoch 0 | train Loss: 0.1606 | valid Loss: 0.1879
Epoch 1 | train Loss: 0.1495 | valid Loss: 0.1706
Epoch 2 | train Loss: 0.1474 | valid Loss: 0.1638
Epoch 3 | train Loss: 0.1438 | valid Loss: 0.1617
Epoch 4 | train Loss: 0.1412 | valid Loss: 0.1593
Epoch 5 | train Loss: 0.1395 | valid Loss: 0.1575
Epoch 6 | train Loss: 0.1381 | valid Loss: 0.1561
Epoch 7 | train Loss: 0.1370 | valid Loss: 0.1550
Epoch 8 | train Loss: 0.1359 | valid Loss: 0.1539
Epoch 9 | train Loss: 0.1353 | valid Loss: 0.1530
Epoch 10 | train Loss: 0.1346 | valid Loss: 0.1521
Epoch 11 | train Loss: 0.1341 | valid Loss: 0.1507
Epoch 12 | train Loss: 0.1337 | valid Loss: 0.1491
Epoch 13 | train Loss: 0.1331 | valid Loss: 0.1473
Epoch 14 | train Loss: 0.1322 | valid Loss: 0.1460
Epoch 15 | train Loss: 0.1314 | valid Loss: 0.1452
Epoch 16 | train Loss: 0.1305 | valid Loss: 0.1446
Epoch 17 | train Loss: 0.1296 | valid Loss: 0.1439
Epoch 18 | train Loss: 0.1288 | valid Loss: 0.1434
Epoch 19 | train Loss: 0.1279 | valid Los

Epoch 160 | train Loss: 0.0966 | valid Loss: 0.1138
Epoch 161 | train Loss: 0.0948 | valid Loss: 0.1125
Epoch 162 | train Loss: 0.0945 | valid Loss: 0.1126
Epoch 163 | train Loss: 0.0952 | valid Loss: 0.1128
Epoch 164 | train Loss: 0.0953 | valid Loss: 0.1131
Epoch 165 | train Loss: 0.0946 | valid Loss: 0.1125
Epoch 166 | train Loss: 0.0942 | valid Loss: 0.1119
Epoch 167 | train Loss: 0.0948 | valid Loss: 0.1129
Epoch 168 | train Loss: 0.0942 | valid Loss: 0.1125
Epoch 169 | train Loss: 0.0954 | valid Loss: 0.1133
Epoch 170 | train Loss: 0.0943 | valid Loss: 0.1124
Epoch 171 | train Loss: 0.0949 | valid Loss: 0.1128
Epoch 172 | train Loss: 0.0951 | valid Loss: 0.1130
Epoch 173 | train Loss: 0.0935 | valid Loss: 0.1118
Epoch 174 | train Loss: 0.0940 | valid Loss: 0.1123
Epoch 175 | train Loss: 0.0944 | valid Loss: 0.1125
Epoch 176 | train Loss: 0.0944 | valid Loss: 0.1126
Epoch 177 | train Loss: 0.0935 | valid Loss: 0.1116
Epoch 178 | train Loss: 0.0934 | valid Loss: 0.1121
Epoch 179 | 

In [103]:
model.state_dict()

OrderedDict([('selfpropel', tensor(0.6257, device='cuda:0')),
             ('interactNN.layer1.weight',
              tensor([[ 4.6246e-01,  4.2752e-01,  8.4108e-02, -2.6066e-01, -4.0486e-02,
                       -1.8207e-01],
                      [-2.4736e-01,  4.8428e-01, -7.8616e-03, -1.3916e-01, -1.1201e-01,
                        1.8069e-01],
                      [ 3.4992e-01,  1.7218e-01, -5.1868e-02,  1.9980e-01, -1.1329e-01,
                        2.9223e-01],
                      [ 2.7807e-01,  1.5620e-02, -2.3227e-01,  3.4111e-01,  3.3095e-01,
                       -9.9264e-02],
                      [ 3.3405e-01,  4.7301e-01,  8.9246e-02,  2.4959e-01,  8.5286e-02,
                       -1.0712e-01],
                      [-3.1115e-01,  2.7157e-01, -2.8529e-01,  6.1973e-02, -3.7968e-01,
                       -2.5930e-01],
                      [ 3.6479e-01, -1.2185e-01,  1.1302e-01,  1.6711e-01, -1.1096e-01,
                       -3.1224e-01],
                     

In [104]:
# モデルを評価モードに設定
stored_model.eval()

# 推論
test_loss = 0
test_count = 0
with torch.no_grad():
    for batch in test_data:
        for ib in range(batch_x.size(0)):
            test_out = model(batch_x[ib].to(device))
            test_loss = test_loss + myLoss(test_out, batch_y[ib].to(device))
        test_count = test_count + batch_x.size(0)
test_loss = test_loss/test_count
print('test Loss: %.4f' % (test_loss.item()))

test Loss: 0.1060


In [105]:
import pickle
import datetime

now = datetime.datetime.now()
nowstr = now.strftime('%Y%m%d_%H%M%S')

In [106]:
stored_model.selfpropel.detach()

tensor(0.6257, device='cuda:0')

In [107]:
stored_model = stored_model.to('cpu')

filename1 = savedirName + nowstr + '_Model.pkl'
with open(filename1, "wb") as f:
    pickle.dump(stored_model, f)

filename1_2 = savedirName + nowstr + '_Model.pt'
torch.save(stored_model, filename1_2)

filename2 = savedirName + nowstr
torch.save(stored_model.interactNN.state_dict(), filename2 + '_interactNN.pkl')
torch.save(stored_model.thetaDotNN.state_dict(), filename2 + '_thetaDotNN.pkl')
torch.save(stored_model.selfpropel.detach(), filename2 + '_selfpropel.pkl')

filename3 = savedirName + nowstr + '_Separation.npz'
np.savez(filename3, dr_thresh=dr_thresh, batch_size=batch_size,
         train_inds=train_inds, valid_inds=valid_inds, test_inds=test_inds, 
         val_loss_log=val_loss_log, test_loss=test_loss.cpu())

filename4 = savedirName + nowstr + '_optimizer.pt'
torch.save(optimizer, filename4)