In [1]:
#!pip install dgl-cu113 dglgo -f https://data.dgl.ai/wheels/repo.html

In [2]:
import os
os.environ['DGLBACKEND'] = 'pytorch'

In [3]:
import torch
from torch import nn

import dgl
import dgl.function as fn

import networkx as nx
from matplotlib import pyplot as plt
import numpy as np

import os
import gc

# デバイス設定
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


In [4]:
def printNPZ(npz):
    for kw in npz.files:
        print(kw, npz[kw])

In [10]:
dirName = './HiraiwaModel_chem20220922_180005/'
savedirName = dirName + 'ActiveNet_vp_rotsym/'
os.makedirs(savedirName, exist_ok=True)

params = np.load(dirName+'params.npz')
#traj = np.load(dirName+'result.npz')

printNPZ(params)
#printNPZ(traj)

v0 1.0
r 1.0
D 0.1
A 0.0
L 20
rho 1.0
beta 1.0
A_CFs [0.1 0.9]
A_CIL 0.0
cellType_ratio [0.5 0.5]
quiv_colors ['k' 'r']
kappa 0.5
A_Macdonalds [2.  0.2]
batch_size 400
state_size 3
brownian_size 1
periodic True
t_max 1000
methodSDE heun
isIto False
stepSDE 0.01


In [11]:
if params['periodic']:
    L = torch.tensor(params['L'])
    def calc_dr(r1, r2):
        dr = torch.remainder((r1 - r2), L)
        dr[dr > L/2] = dr[dr > L/2] - L
        return dr
else:
    def calc_dr(r1, r2):
        return r1 - r2
    
def makeGraph(x_data, r_thresh):
        Ndata = x_data.size(0)
        dx = calc_dr(torch.unsqueeze(x_data, 0), torch.unsqueeze(x_data, 1))
        dx = torch.sum(dx**2, dim=2)
        edges = torch.argwhere(dx < r_thresh/2)
        return dgl.graph((edges[:,0], edges[:,1]), num_nodes=Ndata)

In [12]:
subdir_list = [f.path for f in os.scandir(dirName) if f.is_dir()]

print(subdir_list)

datadir_list = [f for f in subdir_list if 'result.npz' in [ff.name for ff in os.scandir(f)]]

print(datadir_list)

['./HiraiwaModel_chem20220916_150816/20221013_130509', './HiraiwaModel_chem20220916_150816/20221013_222508', './HiraiwaModel_chem20220916_150816/20221014_020927', './HiraiwaModel_chem20220916_150816/20221014_055334', './HiraiwaModel_chem20220916_150816/20221014_093729', './HiraiwaModel_chem20220916_150816/ActiveNet_celltypes', './HiraiwaModel_chem20220916_150816/ActiveNet_vp_rotsym']
['./HiraiwaModel_chem20220916_150816/20221013_130509', './HiraiwaModel_chem20220916_150816/20221013_222508', './HiraiwaModel_chem20220916_150816/20221014_020927', './HiraiwaModel_chem20220916_150816/20221014_055334', './HiraiwaModel_chem20220916_150816/20221014_093729']


In [39]:
dr_thresh = 7
dt = 1
batch_size = 8

N_data = len(datadir_list)

#TR_VA_rate = np.array([0.6, 0.2])

TR_last = 3
VA_last = 4

shuffle_inds = np.arange(N_data, dtype=int)
np.random.shuffle(shuffle_inds)

train_inds = shuffle_inds[:TR_last]
valid_inds = shuffle_inds[TR_last:VA_last]
test_inds = shuffle_inds[VA_last:]

celltype_lst = []

train_x = []
valid_x = []
test_x = []

train_y = []
valid_y = []
test_y = []

train_i_dir = []
valid_i_dir = []
test_i_dir = []

for i_dir, subdirName in enumerate(datadir_list):
    
    traj = np.load(subdirName+'/result.npz')
    
    celltype_lst.append(torch.tensor(traj['celltype_label']).view(-1,1))

    xy_t = torch.tensor(traj['xy'][:-1,:,:])
    v_t = calc_dr(torch.tensor(traj['xy'][1:,:,:]), torch.tensor(traj['xy'][:-1,:,:])) / dt
    p_t = torch.unsqueeze(torch.tensor(traj['theta'][:-1,:]), dim=2)
    w_t = torch.unsqueeze(torch.tensor((traj['theta'][1:,:]-traj['theta'][:-1,:])%(2*np.pi)/dt), dim=2)
    
    if i_dir in train_inds:
        train_x.append(torch.concat((xy_t, p_t), -1))
        train_y.append(torch.concat((v_t, w_t), -1))
        train_i_dir.append(torch.ones([xy_t.size(0)])*i_dir)

    if i_dir in valid_inds:
        valid_x.append(torch.concat((xy_t, p_t), -1))
        valid_y.append(torch.concat((v_t, w_t), -1))
        valid_i_dir.append(torch.ones([xy_t.size(0)])*i_dir)
        
    if i_dir in test_inds:
        test_x.append(torch.concat((xy_t, p_t), -1))
        test_y.append(torch.concat((v_t, w_t), -1))
        test_i_dir.append(torch.ones([xy_t.size(0)])*i_dir)
    
train_dataset = torch.utils.data.TensorDataset(
    torch.concat(train_x, 0), 
    torch.concat(train_y, 0), 
    torch.concat(train_i_dir, 0))

valid_dataset = torch.utils.data.TensorDataset(
    torch.concat(valid_x, 0), 
    torch.concat(valid_y, 0), 
    torch.concat(valid_i_dir, 0))

test_dataset = torch.utils.data.TensorDataset(
    torch.concat(test_x, 0), 
    torch.concat(test_y, 0), 
    torch.concat(test_i_dir, 0))

train_data = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, pin_memory=True)
valid_data = torch.utils.data.DataLoader(valid_dataset, batch_size=batch_size, pin_memory=True)
test_data = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, pin_memory=True)

del train_x, train_y, train_i_dir, train_dataset
del valid_x, valid_y, valid_i_dir, valid_dataset
del test_x, test_y, test_i_dir, test_dataset
gc.collect()

#print(data)
#print(data.num_graphs)
#print(data.x)
#print(data.y)
#print(data.edge_index)

3256

In [40]:
print(train_inds)

[1 2 4]


In [41]:
dir(train_data)

['_DataLoader__initialized',
 '_DataLoader__multiprocessing_context',
 '_IterableDataset_len_called',
 '__annotations__',
 '__class__',
 '__class_getitem__',
 '__delattr__',
 '__dict__',
 '__dir__',
 '__doc__',
 '__eq__',
 '__format__',
 '__ge__',
 '__getattribute__',
 '__gt__',
 '__hash__',
 '__init__',
 '__init_subclass__',
 '__iter__',
 '__le__',
 '__len__',
 '__lt__',
 '__module__',
 '__ne__',
 '__new__',
 '__orig_bases__',
 '__parameters__',
 '__reduce__',
 '__reduce_ex__',
 '__repr__',
 '__setattr__',
 '__sizeof__',
 '__slots__',
 '__str__',
 '__subclasshook__',
 '__weakref__',
 '_auto_collation',
 '_dataset_kind',
 '_get_iterator',
 '_get_shared_seed',
 '_index_sampler',
 '_is_protocol',
 '_iterator',
 'batch_sampler',
 'batch_size',
 'check_worker_number_rationality',
 'collate_fn',
 'dataset',
 'drop_last',
 'generator',
 'multiprocessing_context',
 'num_workers',
 'persistent_workers',
 'pin_memory',
 'pin_memory_device',
 'prefetch_factor',
 'sampler',
 'timeout',
 'worker_i

In [42]:
def plotGraph(data):

    # networkxのグラフに変換
    nxg = dgl.to_networkx(data)

    # 可視化のためのページランク計算
    pr = nx.pagerank(nxg)
    pr_max = np.array(list(pr.values())).max()

    # 可視化する際のノード位置
    draw_pos = nx.spring_layout(nxg, seed=0) 

    # ノードの色設定
    cmap = plt.get_cmap('tab10')
    labels = data.y.numpy()
    colors = [cmap(l) for l in labels]

    # 図のサイズ
    fig0 = plt.figure(figsize=(10, 10))

    # 描画
    nx.draw_networkx_nodes(nxg, 
                          draw_pos,
                          node_size=[v / pr_max * 1000 for v in pr.values()])#,
                          #node_color=colors, alpha=0.5)
    nx.draw_networkx_edges(nxg, draw_pos, arrowstyle='-', alpha=0.2)
    nx.draw_networkx_labels(nxg, draw_pos, font_size=10)

    #plt.title('KarateClub')
    plt.show()

    return fig0

In [43]:
class NeuralNet(nn.Module):
    def __init__(self, in_channels, out_channels, Nchannels, flgBias=False):
        super(NeuralNet, self).__init__()

        self.layer1 = nn.Linear(in_channels, Nchannels, bias=flgBias)
        self.layer2 = nn.Linear(Nchannels, Nchannels, bias=flgBias)
        self.layer3 = nn.Linear(Nchannels, Nchannels, bias=flgBias)
        self.layer4 = nn.Linear(Nchannels, out_channels, bias=flgBias)

        self.activation = nn.ReLU()

    def reset_parameters(self):
        self.layer1.reset_parameters()
        self.layer2.reset_parameters()
        self.layer3.reset_parameters()
        self.layer4.reset_parameters()
        #nn.init.zeros_(self.layer1.weight)
        #nn.init.zeros_(self.layer2.weight)
        #nn.init.zeros_(self.layer3.weight)
        #nn.init.zeros_(self.layer4.weight)
        
    def forward(self, x):
        out = self.activation(self.layer1(x))
        out = self.activation(self.layer2(out))
        out = self.activation(self.layer3(out))
        out = self.layer4(out)

        return out

class ActiveNet(nn.Module):
    def __init__(self, xy_dim, r, Nchannels=128):
        super().__init__()

        self.interactNN = NeuralNet(xy_dim*2 + 2, xy_dim, Nchannels, False)

        self.thetaDotNN = NeuralNet(xy_dim*2 + 2, 1, Nchannels, False)
        
        self.selfpropel = nn.Parameter(torch.tensor(0.0, requires_grad=True, device=device))

        #self.Normalizer = nn.Softmax(dim=1)

        self.xy_dim = xy_dim
        
        self.r = r

        self.reset_parameters()

    def reset_parameters(self):
        self.interactNN.reset_parameters()

        self.thetaDotNN.reset_parameters()
        
        nn.init.uniform_(self.selfpropel)

        #self.bias.data.zero_()
        
    def load_celltypes(self, celltype):
        self.celltype = celltype

    def calc_message(self, edges):
        dx = calc_dr(edges.dst['x'], edges.src['x'])

        costheta = torch.cos(edges.dst['theta'])
        sintheta = torch.sin(edges.dst['theta'])

        dx_para = costheta * dx[..., :1] + sintheta * dx[..., 1:]
        dx_perp = costheta * dx[..., 1:] - sintheta * dx[..., :1]

        p_para_src = torch.cos(edges.src['theta'] - edges.dst['theta'])
        p_perp_src = torch.sin(edges.src['theta'] - edges.dst['theta'])

        rot_m_v = self.interactNN(torch.concat((dx_para, dx_perp, 
                                                p_para_src, p_perp_src,
                                                edges.dst['type'], edges.src['type']), -1))

        m_v = torch.concat((costheta * rot_m_v[..., :1] - sintheta * rot_m_v[..., 1:],
                            costheta * rot_m_v[..., 1:] + sintheta * rot_m_v[..., :1]), -1)

        m_theta = self.thetaDotNN(torch.concat((dx_para, dx_perp, 
                                                p_para_src, p_perp_src, 
                                                edges.dst['type'], edges.src['type']), -1))
        
        return {'m': torch.concat((m_v, m_theta), -1)}
        
    def forward(self, xv):
        r_g = makeGraph(xv[..., :self.xy_dim], self.r/2)
        r_g.ndata['x'] = xv[..., :self.xy_dim]
        r_g.ndata['theta'] = xv[..., self.xy_dim:(self.xy_dim+1)]
        r_g.ndata['type'] = self.celltype
        r_g.update_all(self.calc_message, fn.sum('m', 'a'))
        r_g.ndata['a'][..., :self.xy_dim] = r_g.ndata['a'][..., :self.xy_dim] + self.selfpropel * torch.concat((torch.cos(r_g.ndata['theta']), torch.sin(r_g.ndata['theta'])), -1)
        
        return r_g.ndata['a']



In [44]:
def myLoss(out, target):
    dv = torch.sum(torch.square(out[..., :xy_dim] - target[..., :xy_dim]), dim=-1)
    dcos = 1 - torch.cos(out[..., xy_dim] - target[..., xy_dim])
    return torch.mean(dv + dcos)

In [54]:
# モデルのインスタンス生成
xy_dim = 2

model = ActiveNet(xy_dim, dr_thresh, Nchannels=128).to(device)
# input data
#data = dataset[0]

model.train()

# optimizer
optimizer = torch.optim.SGD(model.parameters(), lr=1e-1, momentum=0.9)
#optimizer = torch.optim.Adam(model.parameters(), lr=1e-2, weight_decay=5e-4)

val_loss_log = []

val_loss_min = np.Inf

# learnig loop
for epoch in range(3000):
    for batch_x, batch_y, batch_i_dir in train_data:
        optimizer.zero_grad()
        loss = 0
        for ib in range(batch_x.size(0)):
            model.load_celltypes(celltype_lst[int(batch_i_dir[ib])].to(device))
            out = model(batch_x[ib].to(device))
            loss = loss + myLoss(out, batch_y[ib].to(device))
        loss = loss / batch_x.size(0)
        loss.backward()
        optimizer.step()
    val_loss = 0
    val_count = 0
    with torch.no_grad():
        for batch_x, batch_y, batch_i_dir in valid_data:
            for ib in range(batch_x.size(0)):
                model.load_celltypes(celltype_lst[int(batch_i_dir[ib])].to(device))
                val_out = model(batch_x[ib].to(device))
                val_loss = val_loss + myLoss(val_out, batch_y[ib].to(device))
            val_count = val_count + batch_x.size(0)
    val_loss = val_loss/val_count
    print('Epoch %d | train Loss: %.4f | valid Loss: %.4f' % (epoch, loss.item(), val_loss.item()))
    val_loss_log.append(val_loss.cpu().item())
    if val_loss.item() < val_loss_min:
        stored_model = model
        val_loss_min = val_loss.item()

Epoch 0 | train Loss: 0.1808 | valid Loss: 0.2232
Epoch 1 | train Loss: 0.1700 | valid Loss: 0.2032
Epoch 2 | train Loss: 0.1687 | valid Loss: 0.2071
Epoch 3 | train Loss: 0.1668 | valid Loss: 0.1995
Epoch 4 | train Loss: 0.1639 | valid Loss: 0.1966
Epoch 5 | train Loss: 0.1619 | valid Loss: 0.1925
Epoch 6 | train Loss: 0.1608 | valid Loss: 0.1895
Epoch 7 | train Loss: 0.1597 | valid Loss: 0.1887
Epoch 8 | train Loss: 0.1599 | valid Loss: 0.1860
Epoch 9 | train Loss: 0.1607 | valid Loss: 0.1822
Epoch 10 | train Loss: 0.1599 | valid Loss: 0.1778
Epoch 11 | train Loss: 0.1594 | valid Loss: 0.1773
Epoch 12 | train Loss: 0.1592 | valid Loss: 0.1786
Epoch 13 | train Loss: 0.1574 | valid Loss: 0.1727
Epoch 14 | train Loss: 0.1573 | valid Loss: 0.1767
Epoch 15 | train Loss: 0.1567 | valid Loss: 0.1746
Epoch 16 | train Loss: 0.1550 | valid Loss: 0.1735
Epoch 17 | train Loss: 0.1541 | valid Loss: 0.1736
Epoch 18 | train Loss: 0.1529 | valid Loss: 0.1713
Epoch 19 | train Loss: 0.1521 | valid Los

Epoch 160 | train Loss: 0.1115 | valid Loss: 0.1139
Epoch 161 | train Loss: 0.1115 | valid Loss: 0.1137
Epoch 162 | train Loss: 0.1113 | valid Loss: 0.1137
Epoch 163 | train Loss: 0.1114 | valid Loss: 0.1139
Epoch 164 | train Loss: 0.1113 | valid Loss: 0.1137
Epoch 165 | train Loss: 0.1112 | valid Loss: 0.1135
Epoch 166 | train Loss: 0.1110 | valid Loss: 0.1133
Epoch 167 | train Loss: 0.1114 | valid Loss: 0.1134
Epoch 168 | train Loss: 0.1110 | valid Loss: 0.1124
Epoch 169 | train Loss: 0.1110 | valid Loss: 0.1129
Epoch 170 | train Loss: 0.1116 | valid Loss: 0.1132
Epoch 171 | train Loss: 0.1113 | valid Loss: 0.1131
Epoch 172 | train Loss: 0.1115 | valid Loss: 0.1132
Epoch 173 | train Loss: 0.1110 | valid Loss: 0.1130
Epoch 174 | train Loss: 0.1116 | valid Loss: 0.1131
Epoch 175 | train Loss: 0.1116 | valid Loss: 0.1130
Epoch 176 | train Loss: 0.1112 | valid Loss: 0.1128
Epoch 177 | train Loss: 0.1111 | valid Loss: 0.1126
Epoch 178 | train Loss: 0.1118 | valid Loss: 0.1125
Epoch 179 | 

Epoch 318 | train Loss: 0.1019 | valid Loss: 0.1080
Epoch 319 | train Loss: 0.1020 | valid Loss: 0.1079
Epoch 320 | train Loss: 0.1019 | valid Loss: 0.1079
Epoch 321 | train Loss: 0.1019 | valid Loss: 0.1079
Epoch 322 | train Loss: 0.1019 | valid Loss: 0.1080
Epoch 323 | train Loss: 0.1018 | valid Loss: 0.1079
Epoch 324 | train Loss: 0.1018 | valid Loss: 0.1079
Epoch 325 | train Loss: 0.1018 | valid Loss: 0.1079
Epoch 326 | train Loss: 0.1018 | valid Loss: 0.1080
Epoch 327 | train Loss: 0.1018 | valid Loss: 0.1079
Epoch 328 | train Loss: 0.1018 | valid Loss: 0.1079
Epoch 329 | train Loss: 0.1018 | valid Loss: 0.1078
Epoch 330 | train Loss: 0.1018 | valid Loss: 0.1079
Epoch 331 | train Loss: 0.1019 | valid Loss: 0.1079
Epoch 332 | train Loss: 0.1019 | valid Loss: 0.1078
Epoch 333 | train Loss: 0.1018 | valid Loss: 0.1078
Epoch 334 | train Loss: 0.1018 | valid Loss: 0.1078
Epoch 335 | train Loss: 0.1018 | valid Loss: 0.1078
Epoch 336 | train Loss: 0.1016 | valid Loss: 0.1078
Epoch 337 | 

Epoch 476 | train Loss: 0.1007 | valid Loss: 0.1067
Epoch 477 | train Loss: 0.1007 | valid Loss: 0.1067
Epoch 478 | train Loss: 0.1007 | valid Loss: 0.1067
Epoch 479 | train Loss: 0.1008 | valid Loss: 0.1069
Epoch 480 | train Loss: 0.1007 | valid Loss: 0.1067
Epoch 481 | train Loss: 0.1006 | valid Loss: 0.1067
Epoch 482 | train Loss: 0.1006 | valid Loss: 0.1066
Epoch 483 | train Loss: 0.1007 | valid Loss: 0.1066
Epoch 484 | train Loss: 0.1007 | valid Loss: 0.1067
Epoch 485 | train Loss: 0.1007 | valid Loss: 0.1068
Epoch 486 | train Loss: 0.1007 | valid Loss: 0.1067
Epoch 487 | train Loss: 0.1009 | valid Loss: 0.1069
Epoch 488 | train Loss: 0.1007 | valid Loss: 0.1066
Epoch 489 | train Loss: 0.1006 | valid Loss: 0.1065
Epoch 490 | train Loss: 0.1006 | valid Loss: 0.1065
Epoch 491 | train Loss: 0.1006 | valid Loss: 0.1067
Epoch 492 | train Loss: 0.1006 | valid Loss: 0.1067
Epoch 493 | train Loss: 0.1007 | valid Loss: 0.1066
Epoch 494 | train Loss: 0.1005 | valid Loss: 0.1066
Epoch 495 | 

Epoch 634 | train Loss: 0.0997 | valid Loss: 0.1060
Epoch 635 | train Loss: 0.0997 | valid Loss: 0.1059
Epoch 636 | train Loss: 0.0996 | valid Loss: 0.1059
Epoch 637 | train Loss: 0.0996 | valid Loss: 0.1060
Epoch 638 | train Loss: 0.0996 | valid Loss: 0.1059
Epoch 639 | train Loss: 0.0997 | valid Loss: 0.1059
Epoch 640 | train Loss: 0.0997 | valid Loss: 0.1059
Epoch 641 | train Loss: 0.0997 | valid Loss: 0.1059
Epoch 642 | train Loss: 0.0996 | valid Loss: 0.1059
Epoch 643 | train Loss: 0.0996 | valid Loss: 0.1059
Epoch 644 | train Loss: 0.0996 | valid Loss: 0.1058
Epoch 645 | train Loss: 0.0996 | valid Loss: 0.1058
Epoch 646 | train Loss: 0.0999 | valid Loss: 0.1060
Epoch 647 | train Loss: 0.0995 | valid Loss: 0.1058
Epoch 648 | train Loss: 0.0999 | valid Loss: 0.1060
Epoch 649 | train Loss: 0.0995 | valid Loss: 0.1058
Epoch 650 | train Loss: 0.0996 | valid Loss: 0.1058
Epoch 651 | train Loss: 0.0994 | valid Loss: 0.1058
Epoch 652 | train Loss: 0.0995 | valid Loss: 0.1059
Epoch 653 | 

Epoch 792 | train Loss: 0.0986 | valid Loss: 0.1054
Epoch 793 | train Loss: 0.0986 | valid Loss: 0.1055
Epoch 794 | train Loss: 0.0986 | valid Loss: 0.1054
Epoch 795 | train Loss: 0.0985 | valid Loss: 0.1055
Epoch 796 | train Loss: 0.0985 | valid Loss: 0.1053
Epoch 797 | train Loss: 0.0985 | valid Loss: 0.1055
Epoch 798 | train Loss: 0.0986 | valid Loss: 0.1054
Epoch 799 | train Loss: 0.0988 | valid Loss: 0.1053
Epoch 800 | train Loss: 0.0985 | valid Loss: 0.1054
Epoch 801 | train Loss: 0.0985 | valid Loss: 0.1054
Epoch 802 | train Loss: 0.0985 | valid Loss: 0.1054
Epoch 803 | train Loss: 0.0985 | valid Loss: 0.1055
Epoch 804 | train Loss: 0.0985 | valid Loss: 0.1054
Epoch 805 | train Loss: 0.0986 | valid Loss: 0.1055
Epoch 806 | train Loss: 0.0985 | valid Loss: 0.1055
Epoch 807 | train Loss: 0.0985 | valid Loss: 0.1053
Epoch 808 | train Loss: 0.0985 | valid Loss: 0.1054
Epoch 809 | train Loss: 0.0985 | valid Loss: 0.1055
Epoch 810 | train Loss: 0.0984 | valid Loss: 0.1054
Epoch 811 | 

Epoch 950 | train Loss: 0.0980 | valid Loss: 0.1053
Epoch 951 | train Loss: 0.0980 | valid Loss: 0.1053
Epoch 952 | train Loss: 0.0981 | valid Loss: 0.1053
Epoch 953 | train Loss: 0.0983 | valid Loss: 0.1053
Epoch 954 | train Loss: 0.0981 | valid Loss: 0.1053
Epoch 955 | train Loss: 0.0981 | valid Loss: 0.1052
Epoch 956 | train Loss: 0.0987 | valid Loss: 0.1052
Epoch 957 | train Loss: 0.0981 | valid Loss: 0.1053
Epoch 958 | train Loss: 0.0981 | valid Loss: 0.1052
Epoch 959 | train Loss: 0.0985 | valid Loss: 0.1056
Epoch 960 | train Loss: 0.0981 | valid Loss: 0.1053
Epoch 961 | train Loss: 0.0981 | valid Loss: 0.1053
Epoch 962 | train Loss: 0.0982 | valid Loss: 0.1052
Epoch 963 | train Loss: 0.0981 | valid Loss: 0.1053
Epoch 964 | train Loss: 0.0981 | valid Loss: 0.1052
Epoch 965 | train Loss: 0.0985 | valid Loss: 0.1052
Epoch 966 | train Loss: 0.0982 | valid Loss: 0.1052
Epoch 967 | train Loss: 0.0982 | valid Loss: 0.1052
Epoch 968 | train Loss: 0.0981 | valid Loss: 0.1053
Epoch 969 | 

Epoch 1106 | train Loss: 0.0979 | valid Loss: 0.1052
Epoch 1107 | train Loss: 0.0978 | valid Loss: 0.1051
Epoch 1108 | train Loss: 0.0977 | valid Loss: 0.1051
Epoch 1109 | train Loss: 0.0977 | valid Loss: 0.1051
Epoch 1110 | train Loss: 0.0976 | valid Loss: 0.1051
Epoch 1111 | train Loss: 0.0979 | valid Loss: 0.1052
Epoch 1112 | train Loss: 0.0977 | valid Loss: 0.1051
Epoch 1113 | train Loss: 0.0978 | valid Loss: 0.1051
Epoch 1114 | train Loss: 0.0982 | valid Loss: 0.1051
Epoch 1115 | train Loss: 0.0978 | valid Loss: 0.1050
Epoch 1116 | train Loss: 0.0977 | valid Loss: 0.1051
Epoch 1117 | train Loss: 0.0977 | valid Loss: 0.1051
Epoch 1118 | train Loss: 0.0979 | valid Loss: 0.1050
Epoch 1119 | train Loss: 0.0984 | valid Loss: 0.1051
Epoch 1120 | train Loss: 0.0982 | valid Loss: 0.1051
Epoch 1121 | train Loss: 0.0979 | valid Loss: 0.1051
Epoch 1122 | train Loss: 0.0982 | valid Loss: 0.1051
Epoch 1123 | train Loss: 0.0978 | valid Loss: 0.1050
Epoch 1124 | train Loss: 0.0982 | valid Loss: 

Epoch 1261 | train Loss: 0.0976 | valid Loss: 0.1050
Epoch 1262 | train Loss: 0.0976 | valid Loss: 0.1049
Epoch 1263 | train Loss: 0.0974 | valid Loss: 0.1050
Epoch 1264 | train Loss: 0.0981 | valid Loss: 0.1049
Epoch 1265 | train Loss: 0.0979 | valid Loss: 0.1048
Epoch 1266 | train Loss: 0.0979 | valid Loss: 0.1050
Epoch 1267 | train Loss: 0.0980 | valid Loss: 0.1049
Epoch 1268 | train Loss: 0.0976 | valid Loss: 0.1049
Epoch 1269 | train Loss: 0.0975 | valid Loss: 0.1049
Epoch 1270 | train Loss: 0.0975 | valid Loss: 0.1049
Epoch 1271 | train Loss: 0.0975 | valid Loss: 0.1050
Epoch 1272 | train Loss: 0.0974 | valid Loss: 0.1049
Epoch 1273 | train Loss: 0.0975 | valid Loss: 0.1048
Epoch 1274 | train Loss: 0.0975 | valid Loss: 0.1050
Epoch 1275 | train Loss: 0.0975 | valid Loss: 0.1049
Epoch 1276 | train Loss: 0.0983 | valid Loss: 0.1049
Epoch 1277 | train Loss: 0.0975 | valid Loss: 0.1049
Epoch 1278 | train Loss: 0.0982 | valid Loss: 0.1049
Epoch 1279 | train Loss: 0.0974 | valid Loss: 

Epoch 1416 | train Loss: 0.0971 | valid Loss: 0.1047
Epoch 1417 | train Loss: 0.0976 | valid Loss: 0.1047
Epoch 1418 | train Loss: 0.0971 | valid Loss: 0.1047
Epoch 1419 | train Loss: 0.0970 | valid Loss: 0.1048
Epoch 1420 | train Loss: 0.0971 | valid Loss: 0.1047
Epoch 1421 | train Loss: 0.0971 | valid Loss: 0.1047
Epoch 1422 | train Loss: 0.0970 | valid Loss: 0.1047
Epoch 1423 | train Loss: 0.0977 | valid Loss: 0.1047
Epoch 1424 | train Loss: 0.0972 | valid Loss: 0.1047
Epoch 1425 | train Loss: 0.0972 | valid Loss: 0.1047
Epoch 1426 | train Loss: 0.0972 | valid Loss: 0.1048
Epoch 1427 | train Loss: 0.0977 | valid Loss: 0.1047
Epoch 1428 | train Loss: 0.0971 | valid Loss: 0.1047
Epoch 1429 | train Loss: 0.0971 | valid Loss: 0.1047
Epoch 1430 | train Loss: 0.0971 | valid Loss: 0.1047
Epoch 1431 | train Loss: 0.0972 | valid Loss: 0.1047
Epoch 1432 | train Loss: 0.0972 | valid Loss: 0.1047
Epoch 1433 | train Loss: 0.0978 | valid Loss: 0.1047
Epoch 1434 | train Loss: 0.0971 | valid Loss: 

Epoch 1571 | train Loss: 0.0974 | valid Loss: 0.1046
Epoch 1572 | train Loss: 0.0973 | valid Loss: 0.1046
Epoch 1573 | train Loss: 0.0967 | valid Loss: 0.1046
Epoch 1574 | train Loss: 0.0974 | valid Loss: 0.1046
Epoch 1575 | train Loss: 0.0968 | valid Loss: 0.1046
Epoch 1576 | train Loss: 0.0969 | valid Loss: 0.1045
Epoch 1577 | train Loss: 0.0973 | valid Loss: 0.1044
Epoch 1578 | train Loss: 0.0973 | valid Loss: 0.1047
Epoch 1579 | train Loss: 0.0968 | valid Loss: 0.1046
Epoch 1580 | train Loss: 0.0968 | valid Loss: 0.1046
Epoch 1581 | train Loss: 0.0968 | valid Loss: 0.1046
Epoch 1582 | train Loss: 0.0970 | valid Loss: 0.1045
Epoch 1583 | train Loss: 0.0974 | valid Loss: 0.1047
Epoch 1584 | train Loss: 0.0974 | valid Loss: 0.1045
Epoch 1585 | train Loss: 0.0973 | valid Loss: 0.1045
Epoch 1586 | train Loss: 0.0967 | valid Loss: 0.1046
Epoch 1587 | train Loss: 0.0968 | valid Loss: 0.1046
Epoch 1588 | train Loss: 0.0968 | valid Loss: 0.1046
Epoch 1589 | train Loss: 0.0974 | valid Loss: 

Epoch 1726 | train Loss: 0.0966 | valid Loss: 0.1045
Epoch 1727 | train Loss: 0.0967 | valid Loss: 0.1045
Epoch 1728 | train Loss: 0.0967 | valid Loss: 0.1045
Epoch 1729 | train Loss: 0.0966 | valid Loss: 0.1044
Epoch 1730 | train Loss: 0.0967 | valid Loss: 0.1045
Epoch 1731 | train Loss: 0.0967 | valid Loss: 0.1045
Epoch 1732 | train Loss: 0.0967 | valid Loss: 0.1045
Epoch 1733 | train Loss: 0.0966 | valid Loss: 0.1044
Epoch 1734 | train Loss: 0.0966 | valid Loss: 0.1045
Epoch 1735 | train Loss: 0.0967 | valid Loss: 0.1044
Epoch 1736 | train Loss: 0.0967 | valid Loss: 0.1045
Epoch 1737 | train Loss: 0.0967 | valid Loss: 0.1044
Epoch 1738 | train Loss: 0.0967 | valid Loss: 0.1045
Epoch 1739 | train Loss: 0.0966 | valid Loss: 0.1044
Epoch 1740 | train Loss: 0.0967 | valid Loss: 0.1044
Epoch 1741 | train Loss: 0.0967 | valid Loss: 0.1044
Epoch 1742 | train Loss: 0.0967 | valid Loss: 0.1044
Epoch 1743 | train Loss: 0.0966 | valid Loss: 0.1044
Epoch 1744 | train Loss: 0.0966 | valid Loss: 

Epoch 1881 | train Loss: 0.0964 | valid Loss: 0.1045
Epoch 1882 | train Loss: 0.0966 | valid Loss: 0.1045
Epoch 1883 | train Loss: 0.0965 | valid Loss: 0.1045
Epoch 1884 | train Loss: 0.0964 | valid Loss: 0.1045
Epoch 1885 | train Loss: 0.0964 | valid Loss: 0.1045
Epoch 1886 | train Loss: 0.0963 | valid Loss: 0.1047
Epoch 1887 | train Loss: 0.0964 | valid Loss: 0.1045
Epoch 1888 | train Loss: 0.0964 | valid Loss: 0.1045
Epoch 1889 | train Loss: 0.0965 | valid Loss: 0.1046
Epoch 1890 | train Loss: 0.0970 | valid Loss: 0.1046
Epoch 1891 | train Loss: 0.0963 | valid Loss: 0.1045
Epoch 1892 | train Loss: 0.0965 | valid Loss: 0.1046
Epoch 1893 | train Loss: 0.0964 | valid Loss: 0.1045
Epoch 1894 | train Loss: 0.0964 | valid Loss: 0.1045
Epoch 1895 | train Loss: 0.0964 | valid Loss: 0.1045
Epoch 1896 | train Loss: 0.0963 | valid Loss: 0.1046
Epoch 1897 | train Loss: 0.0964 | valid Loss: 0.1045
Epoch 1898 | train Loss: 0.0963 | valid Loss: 0.1045
Epoch 1899 | train Loss: 0.0971 | valid Loss: 

Epoch 2036 | train Loss: 0.0961 | valid Loss: 0.1045
Epoch 2037 | train Loss: 0.0962 | valid Loss: 0.1045
Epoch 2038 | train Loss: 0.0962 | valid Loss: 0.1044
Epoch 2039 | train Loss: 0.0962 | valid Loss: 0.1044
Epoch 2040 | train Loss: 0.0963 | valid Loss: 0.1044
Epoch 2041 | train Loss: 0.0963 | valid Loss: 0.1045
Epoch 2042 | train Loss: 0.0963 | valid Loss: 0.1045
Epoch 2043 | train Loss: 0.0962 | valid Loss: 0.1044
Epoch 2044 | train Loss: 0.0963 | valid Loss: 0.1044
Epoch 2045 | train Loss: 0.0962 | valid Loss: 0.1045
Epoch 2046 | train Loss: 0.0963 | valid Loss: 0.1045
Epoch 2047 | train Loss: 0.0962 | valid Loss: 0.1044
Epoch 2048 | train Loss: 0.0962 | valid Loss: 0.1044
Epoch 2049 | train Loss: 0.0962 | valid Loss: 0.1044
Epoch 2050 | train Loss: 0.0961 | valid Loss: 0.1045
Epoch 2051 | train Loss: 0.0961 | valid Loss: 0.1044
Epoch 2052 | train Loss: 0.0961 | valid Loss: 0.1044
Epoch 2053 | train Loss: 0.0962 | valid Loss: 0.1045
Epoch 2054 | train Loss: 0.0961 | valid Loss: 

Epoch 2191 | train Loss: 0.0962 | valid Loss: 0.1042
Epoch 2192 | train Loss: 0.0963 | valid Loss: 0.1042
Epoch 2193 | train Loss: 0.0961 | valid Loss: 0.1043
Epoch 2194 | train Loss: 0.0962 | valid Loss: 0.1042
Epoch 2195 | train Loss: 0.0962 | valid Loss: 0.1042
Epoch 2196 | train Loss: 0.0967 | valid Loss: 0.1042
Epoch 2197 | train Loss: 0.0964 | valid Loss: 0.1042
Epoch 2198 | train Loss: 0.0962 | valid Loss: 0.1043
Epoch 2199 | train Loss: 0.0961 | valid Loss: 0.1042
Epoch 2200 | train Loss: 0.0964 | valid Loss: 0.1042
Epoch 2201 | train Loss: 0.0956 | valid Loss: 0.1050
Epoch 2202 | train Loss: 0.0964 | valid Loss: 0.1042
Epoch 2203 | train Loss: 0.0962 | valid Loss: 0.1043
Epoch 2204 | train Loss: 0.0963 | valid Loss: 0.1042
Epoch 2205 | train Loss: 0.0962 | valid Loss: 0.1043
Epoch 2206 | train Loss: 0.0962 | valid Loss: 0.1042
Epoch 2207 | train Loss: 0.0963 | valid Loss: 0.1042
Epoch 2208 | train Loss: 0.0962 | valid Loss: 0.1042
Epoch 2209 | train Loss: 0.0965 | valid Loss: 

Epoch 2346 | train Loss: 0.0963 | valid Loss: 0.1043
Epoch 2347 | train Loss: 0.0962 | valid Loss: 0.1043
Epoch 2348 | train Loss: 0.0961 | valid Loss: 0.1043
Epoch 2349 | train Loss: 0.0963 | valid Loss: 0.1043
Epoch 2350 | train Loss: 0.0961 | valid Loss: 0.1042
Epoch 2351 | train Loss: 0.0961 | valid Loss: 0.1043
Epoch 2352 | train Loss: 0.0962 | valid Loss: 0.1042
Epoch 2353 | train Loss: 0.0967 | valid Loss: 0.1042
Epoch 2354 | train Loss: 0.0963 | valid Loss: 0.1042
Epoch 2355 | train Loss: 0.0962 | valid Loss: 0.1042
Epoch 2356 | train Loss: 0.0962 | valid Loss: 0.1042
Epoch 2357 | train Loss: 0.0962 | valid Loss: 0.1042
Epoch 2358 | train Loss: 0.0968 | valid Loss: 0.1043
Epoch 2359 | train Loss: 0.0962 | valid Loss: 0.1042
Epoch 2360 | train Loss: 0.0963 | valid Loss: 0.1042
Epoch 2361 | train Loss: 0.0963 | valid Loss: 0.1042
Epoch 2362 | train Loss: 0.0961 | valid Loss: 0.1042
Epoch 2363 | train Loss: 0.0962 | valid Loss: 0.1042
Epoch 2364 | train Loss: 0.0960 | valid Loss: 

Epoch 2501 | train Loss: 0.0959 | valid Loss: 0.1042
Epoch 2502 | train Loss: 0.0960 | valid Loss: 0.1043
Epoch 2503 | train Loss: 0.0960 | valid Loss: 0.1043
Epoch 2504 | train Loss: 0.0960 | valid Loss: 0.1043
Epoch 2505 | train Loss: 0.0958 | valid Loss: 0.1043
Epoch 2506 | train Loss: 0.0960 | valid Loss: 0.1043
Epoch 2507 | train Loss: 0.0959 | valid Loss: 0.1043
Epoch 2508 | train Loss: 0.0959 | valid Loss: 0.1043
Epoch 2509 | train Loss: 0.0960 | valid Loss: 0.1043
Epoch 2510 | train Loss: 0.0959 | valid Loss: 0.1043
Epoch 2511 | train Loss: 0.0959 | valid Loss: 0.1043
Epoch 2512 | train Loss: 0.0959 | valid Loss: 0.1043
Epoch 2513 | train Loss: 0.0962 | valid Loss: 0.1043
Epoch 2514 | train Loss: 0.0960 | valid Loss: 0.1043
Epoch 2515 | train Loss: 0.0960 | valid Loss: 0.1043
Epoch 2516 | train Loss: 0.0960 | valid Loss: 0.1043
Epoch 2517 | train Loss: 0.0959 | valid Loss: 0.1043
Epoch 2518 | train Loss: 0.0966 | valid Loss: 0.1044
Epoch 2519 | train Loss: 0.0960 | valid Loss: 

Epoch 2656 | train Loss: 0.0957 | valid Loss: 0.1044
Epoch 2657 | train Loss: 0.0958 | valid Loss: 0.1044
Epoch 2658 | train Loss: 0.0958 | valid Loss: 0.1044
Epoch 2659 | train Loss: 0.0958 | valid Loss: 0.1044
Epoch 2660 | train Loss: 0.0958 | valid Loss: 0.1044
Epoch 2661 | train Loss: 0.0957 | valid Loss: 0.1044
Epoch 2662 | train Loss: 0.0958 | valid Loss: 0.1044
Epoch 2663 | train Loss: 0.0958 | valid Loss: 0.1044
Epoch 2664 | train Loss: 0.0960 | valid Loss: 0.1052
Epoch 2665 | train Loss: 0.0958 | valid Loss: 0.1051
Epoch 2666 | train Loss: 0.0958 | valid Loss: 0.1045
Epoch 2667 | train Loss: 0.0957 | valid Loss: 0.1044
Epoch 2668 | train Loss: 0.0957 | valid Loss: 0.1046
Epoch 2669 | train Loss: 0.0957 | valid Loss: 0.1044
Epoch 2670 | train Loss: 0.0960 | valid Loss: 0.1045
Epoch 2671 | train Loss: 0.0959 | valid Loss: 0.1043
Epoch 2672 | train Loss: 0.0956 | valid Loss: 0.1044
Epoch 2673 | train Loss: 0.0956 | valid Loss: 0.1045
Epoch 2674 | train Loss: 0.0957 | valid Loss: 

Epoch 2811 | train Loss: 0.0954 | valid Loss: 0.1044
Epoch 2812 | train Loss: 0.0955 | valid Loss: 0.1045
Epoch 2813 | train Loss: 0.0955 | valid Loss: 0.1044
Epoch 2814 | train Loss: 0.0956 | valid Loss: 0.1044
Epoch 2815 | train Loss: 0.0955 | valid Loss: 0.1045
Epoch 2816 | train Loss: 0.0955 | valid Loss: 0.1048
Epoch 2817 | train Loss: 0.0953 | valid Loss: 0.1045
Epoch 2818 | train Loss: 0.0955 | valid Loss: 0.1045
Epoch 2819 | train Loss: 0.0955 | valid Loss: 0.1045
Epoch 2820 | train Loss: 0.0954 | valid Loss: 0.1044
Epoch 2821 | train Loss: 0.0956 | valid Loss: 0.1044
Epoch 2822 | train Loss: 0.0954 | valid Loss: 0.1044
Epoch 2823 | train Loss: 0.0952 | valid Loss: 0.1048
Epoch 2824 | train Loss: 0.0954 | valid Loss: 0.1044
Epoch 2825 | train Loss: 0.0954 | valid Loss: 0.1044
Epoch 2826 | train Loss: 0.0955 | valid Loss: 0.1044
Epoch 2827 | train Loss: 0.0958 | valid Loss: 0.1045
Epoch 2828 | train Loss: 0.0959 | valid Loss: 0.1044
Epoch 2829 | train Loss: 0.0959 | valid Loss: 

Epoch 2966 | train Loss: 0.0953 | valid Loss: 0.1046
Epoch 2967 | train Loss: 0.0950 | valid Loss: 0.1046
Epoch 2968 | train Loss: 0.0953 | valid Loss: 0.1045
Epoch 2969 | train Loss: 0.0953 | valid Loss: 0.1045
Epoch 2970 | train Loss: 0.0957 | valid Loss: 0.1045
Epoch 2971 | train Loss: 0.0953 | valid Loss: 0.1045
Epoch 2972 | train Loss: 0.0951 | valid Loss: 0.1047
Epoch 2973 | train Loss: 0.0953 | valid Loss: 0.1045
Epoch 2974 | train Loss: 0.0952 | valid Loss: 0.1045
Epoch 2975 | train Loss: 0.0952 | valid Loss: 0.1046
Epoch 2976 | train Loss: 0.0950 | valid Loss: 0.1046
Epoch 2977 | train Loss: 0.0950 | valid Loss: 0.1046
Epoch 2978 | train Loss: 0.0953 | valid Loss: 0.1049
Epoch 2979 | train Loss: 0.0957 | valid Loss: 0.1051
Epoch 2980 | train Loss: 0.0953 | valid Loss: 0.1045
Epoch 2981 | train Loss: 0.0952 | valid Loss: 0.1045
Epoch 2982 | train Loss: 0.0951 | valid Loss: 0.1045
Epoch 2983 | train Loss: 0.0955 | valid Loss: 0.1045
Epoch 2984 | train Loss: 0.0952 | valid Loss: 

In [55]:
model.state_dict()

OrderedDict([('selfpropel', tensor(0.9331, device='cuda:0')),
             ('interactNN.layer1.weight',
              tensor([[ 9.7235e-02,  3.2959e-02, -8.2636e-01, -6.3180e-02,  2.3767e-01,
                       -7.2867e-01],
                      [-9.6117e-01,  2.8790e-01,  2.0872e-01, -2.7871e-02, -4.5141e-01,
                       -3.6082e-01],
                      [-3.1515e-01,  6.4058e-01,  2.8309e-02,  1.3396e-02, -3.1494e-01,
                        2.0588e-01],
                      [ 3.1310e-01, -6.5077e-02,  1.8764e-01,  4.9433e-01,  1.2934e-01,
                       -4.8865e-01],
                      [ 1.2724e+00,  1.6389e+00, -4.8133e-01, -5.5493e-02, -1.1790e+00,
                       -8.0350e-01],
                      [ 2.3021e-01, -3.5742e-01, -1.6928e-01,  3.5300e-01,  1.1564e-01,
                       -1.3222e-03],
                      [ 3.7582e-01,  4.3525e-01,  3.9143e-02, -1.2452e+00, -9.0765e-01,
                       -9.8433e-01],
                     

In [56]:
# モデルを評価モードに設定
stored_model.eval()

# 推論
test_loss = 0
test_count = 0
with torch.no_grad():
    for batch_x, batch_y, batch_i_dir in test_data:
        for ib in range(batch_x.size(0)):
            model.load_celltypes(celltype_lst[int(batch_i_dir[ib])].to(device))
            test_out = model(batch_x[ib].to(device))
            test_loss = test_loss + myLoss(test_out, batch_y[ib].to(device))
        test_count = test_count + batch_x.size(0)
test_loss = test_loss/test_count
print('test Loss: %.4f' % (test_loss.item()))

test Loss: 0.1021


In [57]:
import pickle
import datetime

now = datetime.datetime.now()
nowstr = now.strftime('%Y%m%d_%H%M%S')

os.makedirs(savedirName + nowstr + '/', exist_ok=True)

SyntaxError: invalid syntax (2675390831.py, line 7)

In [None]:
stored_model.selfpropel.detach()

In [None]:
stored_model = stored_model.to('cpu')

filename1 = savedirName + nowstr + '/' + nowstr + '_Model.pkl'
with open(filename1, "wb") as f:
    pickle.dump(stored_model, f)

filename1_2 = savedirName + nowstr + '/' + nowstr + '_Model.pt'
torch.save(stored_model, filename1_2)

filename2 = savedirName + nowstr + '/' + nowstr
torch.save(stored_model.interactNN.state_dict(), filename2 + '_interactNN.pkl')
torch.save(stored_model.thetaDotNN.state_dict(), filename2 + '_thetaDotNN.pkl')
torch.save(stored_model.selfpropel.detach(), filename2 + '_selfpropel.pkl')

filename3 = savedirName + nowstr + '/' + nowstr + '_Separation.npz'
np.savez(filename3, dr_thresh=dr_thresh, batch_size=batch_size,
         train_inds=train_inds, valid_inds=valid_inds, test_inds=test_inds, 
         val_loss_log=val_loss_log, test_loss=test_loss.cpu())

filename4 = savedirName + nowstr + '/' + nowstr + '_optimizer.pt'
torch.save(optimizer, filename4)