In [1]:
#!pip install dgl-cu113 dglgo -f https://data.dgl.ai/wheels/repo.html

In [2]:
import os
os.environ['DGLBACKEND'] = 'pytorch'

In [3]:
import torch
from torch import nn
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import Dataset

import dgl
import dgl.function as fn

import networkx as nx
from matplotlib import pyplot as plt
import numpy as np

import os
import gc

# デバイス設定
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')


In [4]:
def printNPZ(npz):
    for kw in npz.files:
        print(kw, npz[kw])

In [5]:
dirName = './HiraiwaModel_chem20221102_020319/'
savedirName = dirName + 'ActiveNet_vp_rotsym_noSelfLoop_multiStep_fineTuning_batchNorm/'
os.makedirs(savedirName, exist_ok=True)

params = np.load(dirName+'params.npz')
#traj = np.load(dirName+'result.npz')

printNPZ(params)
#printNPZ(traj)

v0 1.0
r 1.0
D 0.1
A 0.0
L 20
rho 1.0
beta 1.0
A_CFs [0.9 0.1]
A_CIL 0.0
cellType_ratio [0.7 0.3]
quiv_colors ['k' 'r']
kappa 1.0
A_Macdonalds [2. 2.]
batch_size 400
state_size 3
brownian_size 1
periodic True
t_max 200
methodSDE heun
isIto False
stepSDE 0.01


In [6]:
if params['periodic']:
    L = torch.tensor(params['L'])
    def calc_dr(r1, r2):
        dr = torch.remainder((r1 - r2), L)
        dr[dr > L/2] = dr[dr > L/2] - L
        return dr
else:
    def calc_dr(r1, r2):
        return r1 - r2
    
def makeGraph(x_data, r_thresh):
        Ndata = x_data.size(0)
        dx = calc_dr(torch.unsqueeze(x_data, 0), torch.unsqueeze(x_data, 1))
        dx = torch.sum(dx**2, dim=2)
        edges = torch.argwhere(torch.logical_and(dx > 0, dx < r_thresh/2))
        return dgl.graph((edges[:,0], edges[:,1]), num_nodes=Ndata)

In [7]:
subdir_list = [f.path for f in os.scandir(dirName) if f.is_dir()]

print(subdir_list)

datadir_list = [f for f in subdir_list if 'result.npz' in [ff.name for ff in os.scandir(f)]]

print(datadir_list)

['./HiraiwaModel_chem20221102_020319/20221102_194956', './HiraiwaModel_chem20221102_020319/20221102_203311', './HiraiwaModel_chem20221102_020319/20221102_211612', './HiraiwaModel_chem20221102_020319/20221102_215953', './HiraiwaModel_chem20221102_020319/20221102_224239', './HiraiwaModel_chem20221102_020319/20221102_232519', './HiraiwaModel_chem20221102_020319/20221103_000843', './HiraiwaModel_chem20221102_020319/20221103_005036', './HiraiwaModel_chem20221102_020319/20221103_013303', './HiraiwaModel_chem20221102_020319/20221103_021518', './HiraiwaModel_chem20221102_020319/20221103_025758', './HiraiwaModel_chem20221102_020319/20221103_033955', './HiraiwaModel_chem20221102_020319/20221103_042225', './HiraiwaModel_chem20221102_020319/20221103_050424', './HiraiwaModel_chem20221102_020319/20221103_054634', './HiraiwaModel_chem20221102_020319/20221103_062842', './HiraiwaModel_chem20221102_020319/20221103_071105', './HiraiwaModel_chem20221102_020319/20221103_075340', './HiraiwaModel_chem2022110

In [8]:
modeldirName = dirName + 'ActiveNet_vp_rotsym_noSelfLoop_batchNorm/'

#datename = '20220921_160055'
i_model = -1

model_files_list = [[os.path.join(c, ff) for ff in f if ff.endswith('_Model.pt')]
               for c, s, f in os.walk(modeldirName)]
print(model_files_list)

model_files = []
for i in range(len(model_files_list)):
    for j in range(len(model_files_list[i])):
        model_files.append(model_files_list[i][j])
print(model_files)

model_dir = os.path.dirname(model_files[i_model])

model_name = model_files[i_model].replace('_Model.pt', '').replace(modeldirName, '')
print(model_name)

[[], ['./HiraiwaModel_chem20221102_020319/ActiveNet_vp_rotsym_noSelfLoop_batchNorm/20221130_024129/20221130_024129_Model.pt']]
['./HiraiwaModel_chem20221102_020319/ActiveNet_vp_rotsym_noSelfLoop_batchNorm/20221130_024129/20221130_024129_Model.pt']
20221130_024129/20221130_024129


In [9]:
class myDataset(Dataset):
    def __init__(self, data_x, celltype_List, t_yseq=1):
        super().__init__()
        
        self.data_x = data_x # List of tensors
        #self.data_y = data_y
        self.celltype_List = celltype_List
        
        self.data_len = np.array([xx.size(0) for xx in self.data_x])
        self.t_yseq = t_yseq
        
        self.data_len_cumsum = np.cumsum(self.data_len - (self.t_yseq - 1))
        
    def __len__(self):
        return (self.data_len - (self.t_yseq - 1)).sum()
    
    def __getitem__(self, index):
        id_List = np.argwhere(index<self.data_len_cumsum)[0,0]
        
        if id_List:
            id_tensor = index - self.data_len_cumsum[id_List-1]
        else:
            id_tensor = index
        
        return self.data_x[id_List][id_tensor], self.data_x[id_List][id_tensor:(id_tensor+self.t_yseq)], self.celltype_List[id_List]

In [10]:
#dr_thresh = 7
dt = 1
batch_size = 8

T_pred = 5

N_data = len(datadir_list)

#TR_VA_rate = np.array([0.6, 0.2])

TR_last = 20
VA_last = 23

shuffle_inds = np.arange(N_data, dtype=int)
np.random.shuffle(shuffle_inds)

train_inds = shuffle_inds[:TR_last]
valid_inds = shuffle_inds[TR_last:VA_last]
test_inds = shuffle_inds[VA_last:]

celltype_lst = []

train_x = []
valid_x = []
test_x = []

train_y = []
valid_y = []
test_y = []

train_ct = []
valid_ct = []
test_ct = []

for i_dir, subdirName in enumerate(datadir_list):
    
    traj = np.load(subdirName+'/result.npz')
    
    xy_t = torch.tensor(traj['xy'])#[:-1,:,:])
    #v_t = calc_dr(torch.tensor(traj['xy'][1:,:,:]), torch.tensor(traj['xy'][:-1,:,:])) / dt
    p_t = torch.unsqueeze(torch.tensor(traj['theta']), dim=2)#[:-1,:]), dim=2)
    #w_t = torch.unsqueeze(torch.tensor((traj['theta'][1:,:]-traj['theta'][:-1,:])%(2*np.pi)/dt), dim=2)
    
    if i_dir in train_inds:
        train_x.append(torch.concat((xy_t, p_t), -1))
        #train_y.append(torch.concat((v_t, w_t), -1))
        train_ct.append(torch.tensor(traj['celltype_label']).view(-1,1))

    if i_dir in valid_inds:
        valid_x.append(torch.concat((xy_t, p_t), -1))
        #valid_y.append(torch.concat((v_t, w_t), -1))
        valid_ct.append(torch.tensor(traj['celltype_label']).view(-1,1))
        
    if i_dir in test_inds:
        test_x.append(torch.concat((xy_t, p_t), -1))
        #test_y.append(torch.concat((v_t, w_t), -1))
        test_ct.append(torch.tensor(traj['celltype_label']).view(-1,1))
    
train_dataset = myDataset(train_x, train_ct, t_yseq=T_pred)

valid_dataset = myDataset(valid_x, valid_ct, t_yseq=T_pred)

test_dataset = myDataset(test_x, test_ct, t_yseq=T_pred)


train_data = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, pin_memory=True)
valid_data = torch.utils.data.DataLoader(valid_dataset, batch_size=batch_size, pin_memory=True)
test_data = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, pin_memory=True)

del train_x, train_ct, train_dataset
del valid_x, valid_ct, valid_dataset
del test_x, test_ct, test_dataset
gc.collect()

#print(data)
#print(data.num_graphs)
#print(data.x)
#print(data.y)
#print(data.edge_index)

374

In [11]:
print(train_inds)

[ 9  4 16  1 20 11  2 17  3 13  5  7 19 23 24  6  0 21 18  8]


In [12]:
dir(train_data)

['_DataLoader__initialized',
 '_DataLoader__multiprocessing_context',
 '_IterableDataset_len_called',
 '__annotations__',
 '__class__',
 '__class_getitem__',
 '__delattr__',
 '__dict__',
 '__dir__',
 '__doc__',
 '__eq__',
 '__format__',
 '__ge__',
 '__getattribute__',
 '__gt__',
 '__hash__',
 '__init__',
 '__init_subclass__',
 '__iter__',
 '__le__',
 '__len__',
 '__lt__',
 '__module__',
 '__ne__',
 '__new__',
 '__orig_bases__',
 '__parameters__',
 '__reduce__',
 '__reduce_ex__',
 '__repr__',
 '__setattr__',
 '__sizeof__',
 '__slots__',
 '__str__',
 '__subclasshook__',
 '__weakref__',
 '_auto_collation',
 '_dataset_kind',
 '_get_iterator',
 '_get_shared_seed',
 '_index_sampler',
 '_is_protocol',
 '_iterator',
 'batch_sampler',
 'batch_size',
 'check_worker_number_rationality',
 'collate_fn',
 'dataset',
 'drop_last',
 'generator',
 'multiprocessing_context',
 'num_workers',
 'persistent_workers',
 'pin_memory',
 'pin_memory_device',
 'prefetch_factor',
 'sampler',
 'timeout',
 'worker_i

In [13]:
len(train_data)

493

In [14]:
print(train_data.dataset.data_len_cumsum)

[ 197  394  591  788  985 1182 1379 1576 1773 1970 2167 2364 2561 2758
 2955 3152 3349 3546 3743 3940]


In [15]:
def plotGraph(data):

    # networkxのグラフに変換
    nxg = dgl.to_networkx(data)

    # 可視化のためのページランク計算
    pr = nx.pagerank(nxg)
    pr_max = np.array(list(pr.values())).max()

    # 可視化する際のノード位置
    draw_pos = nx.spring_layout(nxg, seed=0) 

    # ノードの色設定
    cmap = plt.get_cmap('tab10')
    labels = data.y.numpy()
    colors = [cmap(l) for l in labels]

    # 図のサイズ
    fig0 = plt.figure(figsize=(10, 10))

    # 描画
    nx.draw_networkx_nodes(nxg, 
                          draw_pos,
                          node_size=[v / pr_max * 1000 for v in pr.values()])#,
                          #node_color=colors, alpha=0.5)
    nx.draw_networkx_edges(nxg, draw_pos, arrowstyle='-', alpha=0.2)
    nx.draw_networkx_labels(nxg, draw_pos, font_size=10)

    #plt.title('KarateClub')
    plt.show()

    return fig0

In [16]:
class NeuralNet(nn.Module):
    def __init__(self, in_channels, out_channels, Nchannels, dropout=0, batchN=False, flgBias=False):
        super(NeuralNet, self).__init__()

        if dropout:
            self.dropout = nn.Dropout(p=dropout)
        else:
            self.dropout = 0
            
        if batchN:
            self.bNorm1 = nn.BatchNorm1d(Nchannels)
            self.bNorm2 = nn.BatchNorm1d(Nchannels)
            self.bNorm3 = nn.BatchNorm1d(Nchannels)
            
        self.batchN=batchN
        
        self.layer1 = nn.Linear(in_channels, Nchannels, bias=flgBias)
        self.layer2 = nn.Linear(Nchannels, Nchannels, bias=flgBias)
        self.layer3 = nn.Linear(Nchannels, Nchannels, bias=flgBias)
        self.layer4 = nn.Linear(Nchannels, out_channels, bias=flgBias)

        self.activation = nn.ReLU()

    def reset_parameters(self):
        self.layer1.reset_parameters()
        self.layer2.reset_parameters()
        self.layer3.reset_parameters()
        self.layer4.reset_parameters()
        #nn.init.zeros_(self.layer1.weight)
        #nn.init.zeros_(self.layer2.weight)
        #nn.init.zeros_(self.layer3.weight)
        #nn.init.zeros_(self.layer4.weight)
        
    def forward(self, x):
        out = self.activation(self.layer1(x))
        if self.batchN:
            out = self.bNorm1(out)
        if self.dropout:
            out = self.dropout(out)
        
        out = self.activation(self.layer2(out))
        if self.batchN:
            out = self.bNorm2(out)
        if self.dropout:
            out = self.dropout(out)
        
        out = self.activation(self.layer3(out))
        if self.batchN:
            out = self.bNorm3(out)
        if self.dropout:
            out = self.dropout(out)
        
        out = self.layer4(out)

        return out

class ActiveNet(nn.Module):
    def __init__(self, xy_dim, r, dropout=0, batchN=False, bias=False, Nchannels=128):
        super().__init__()

        self.interactNN = NeuralNet(xy_dim*2 + 2, xy_dim, Nchannels, dropout, batchN, bias)

        self.thetaDotNN = NeuralNet(xy_dim*2 + 2, 1, Nchannels, dropout, batchN, bias)
        
        self.selfpropel = nn.Parameter(torch.tensor(0.0, requires_grad=True, device=device))

        #self.Normalizer = nn.Softmax(dim=1)

        self.xy_dim = xy_dim
        
        self.r = r

        self.reset_parameters()

    def reset_parameters(self):
        self.interactNN.reset_parameters()

        self.thetaDotNN.reset_parameters()
        
        nn.init.uniform_(self.selfpropel)

        #self.bias.data.zero_()
        
    def load_celltypes(self, celltype):
        self.celltype = celltype

    def calc_message(self, edges):
        dx = calc_dr(edges.dst['x'], edges.src['x'])

        costheta = torch.cos(edges.dst['theta'])
        sintheta = torch.sin(edges.dst['theta'])

        dx_para = costheta * dx[..., :1] + sintheta * dx[..., 1:]
        dx_perp = costheta * dx[..., 1:] - sintheta * dx[..., :1]

        p_para_src = torch.cos(edges.src['theta'] - edges.dst['theta'])
        p_perp_src = torch.sin(edges.src['theta'] - edges.dst['theta'])

        rot_m_v = self.interactNN(torch.concat((dx_para, dx_perp, 
                                                p_para_src, p_perp_src,
                                                edges.dst['type'], edges.src['type']), -1))

        m_v = torch.concat((costheta * rot_m_v[..., :1] - sintheta * rot_m_v[..., 1:],
                            costheta * rot_m_v[..., 1:] + sintheta * rot_m_v[..., :1]), -1)

        m_theta = self.thetaDotNN(torch.concat((dx_para, dx_perp, 
                                                p_para_src, p_perp_src, 
                                                edges.dst['type'], edges.src['type']), -1))
        
        return {'m': torch.concat((m_v, m_theta), -1)}
        
    def forward(self, xv):
        r_g = makeGraph(xv[..., :self.xy_dim], self.r/2)
        r_g.ndata['x'] = xv[..., :self.xy_dim]
        r_g.ndata['theta'] = xv[..., self.xy_dim:(self.xy_dim+1)]
        r_g.ndata['type'] = self.celltype
        r_g.update_all(self.calc_message, fn.sum('m', 'a'))
        r_g.ndata['a'][..., :self.xy_dim] = r_g.ndata['a'][..., :self.xy_dim] + self.selfpropel * torch.concat((torch.cos(r_g.ndata['theta']), torch.sin(r_g.ndata['theta'])), -1)
        
        return r_g.ndata['a']



In [17]:
def myLoss(out, target):
    #dv = torch.sum(torch.square(out[..., :xy_dim] - target[..., :xy_dim]), dim=-1)
    dv = torch.sum(torch.square(calc_dr(out[..., :xy_dim], target[..., :xy_dim])), dim=-1)
    dcos = torch.cos(out[..., xy_dim] - target[..., xy_dim])
    
    wei_shape = np.ones([dv.dim()], dtype=int)
    wei_shape[0] = T_pred
    wei = torch.tensor(np.reshape(1/np.arange(1, T_pred+1), wei_shape)).to(dv.device)
    wei = wei/wei.mean()
    
    return torch.mean(dv*wei), torch.mean((1-dcos)*wei)

In [18]:
filename1 = modeldirName + model_name + '_Model.pt'
model = torch.load(filename1, map_location=torch.device(device)) #pickle.load(open(filename1, 'rb'))
model.train()

filename3 = modeldirName + model_name + '_Separation.npz'
learn_params = np.load(filename3, allow_pickle=True)

printNPZ(learn_params)

dr_thresh = learn_params['dr_thresh'].item()

dr_thresh 4
batch_size 8
train_inds [11  1  4 10 15 13 19 22 21  9  8 20 18  2 16  7 24  6 14 23]
valid_inds [ 0  3 12]
test_inds [ 5 17]
val_loss_log [[0.10067808 0.0744601 ]
 [0.08914171 0.07282238]
 [0.09470647 0.0721049 ]
 [0.08078571 0.07164061]
 [0.0797892  0.07127005]
 [0.08634418 0.07097195]
 [0.07360431 0.07071819]
 [0.07060766 0.07049644]
 [0.06887938 0.07027841]
 [0.06684761 0.07006979]
 [0.06723124 0.06988583]
 [0.06572321 0.06973875]
 [0.06415123 0.06959546]
 [0.06298081 0.06939325]
 [0.06295352 0.06916527]
 [0.06279114 0.06893577]
 [0.06475446 0.06869691]
 [0.06269334 0.06849652]
 [0.06630635 0.06829   ]
 [0.07724824 0.06808818]
 [0.07457949 0.06788816]
 [0.06410416 0.06772624]
 [0.06332196 0.06756904]
 [0.06122893 0.06741692]
 [0.05964878 0.06730241]
 [0.05886233 0.06721356]
 [0.05889479 0.06713305]
 [0.0584549  0.06705505]
 [0.0583207  0.06697793]
 [0.05772057 0.06692173]
 [0.05756785 0.06685448]
 [0.05784079 0.06678727]
 [0.05773148 0.06672858]
 [0.05812796 0.06666937]

In [19]:
model.state_dict()

OrderedDict([('selfpropel', tensor(0.9046, device='cuda:0')),
             ('interactNN.layer1.weight',
              tensor([[ 6.8693e-02,  2.2930e-02,  2.8067e-01, -3.1570e-01, -2.2064e-01,
                       -4.7282e-01],
                      [ 3.6042e-02,  1.7907e-01, -4.5696e-01, -3.4711e-01,  3.2970e-01,
                       -7.4623e-02],
                      [ 3.3213e-01, -1.0251e-01, -1.4388e-01,  3.2349e-02, -1.4225e-02,
                        1.8594e-01],
                      [ 8.1909e-02, -5.9736e-01, -6.5860e-02, -6.6977e-02, -9.3137e-02,
                       -3.7936e-01],
                      [ 1.7986e-01, -2.9950e-01, -1.8117e-01,  2.3075e-01, -9.3827e-02,
                        7.1192e-04],
                      [ 1.7699e-01,  1.2099e-01, -4.0220e-01, -1.5253e-01, -3.9307e-01,
                        7.2420e-02],
                      [ 2.8259e-01,  1.4268e-01, -2.3052e-01, -7.4569e-02,  2.1027e-01,
                        1.1247e-01],
                     

In [20]:
list(model.state_dict().keys())

['selfpropel',
 'interactNN.layer1.weight',
 'interactNN.layer1.bias',
 'interactNN.layer2.weight',
 'interactNN.layer2.bias',
 'interactNN.layer3.weight',
 'interactNN.layer3.bias',
 'interactNN.layer4.weight',
 'interactNN.layer4.bias',
 'thetaDotNN.layer1.weight',
 'thetaDotNN.layer1.bias',
 'thetaDotNN.layer2.weight',
 'thetaDotNN.layer2.bias',
 'thetaDotNN.layer3.weight',
 'thetaDotNN.layer3.bias',
 'thetaDotNN.layer4.weight',
 'thetaDotNN.layer4.bias']

In [21]:
list(model.state_dict().values())

[tensor(0.9046, device='cuda:0'),
 tensor([[ 6.8693e-02,  2.2930e-02,  2.8067e-01, -3.1570e-01, -2.2064e-01,
          -4.7282e-01],
         [ 3.6042e-02,  1.7907e-01, -4.5696e-01, -3.4711e-01,  3.2970e-01,
          -7.4623e-02],
         [ 3.3213e-01, -1.0251e-01, -1.4388e-01,  3.2349e-02, -1.4225e-02,
           1.8594e-01],
         [ 8.1909e-02, -5.9736e-01, -6.5860e-02, -6.6977e-02, -9.3137e-02,
          -3.7936e-01],
         [ 1.7986e-01, -2.9950e-01, -1.8117e-01,  2.3075e-01, -9.3827e-02,
           7.1192e-04],
         [ 1.7699e-01,  1.2099e-01, -4.0220e-01, -1.5253e-01, -3.9307e-01,
           7.2420e-02],
         [ 2.8259e-01,  1.4268e-01, -2.3052e-01, -7.4569e-02,  2.1027e-01,
           1.1247e-01],
         [ 5.5769e-01,  4.0678e-01, -5.3094e-02,  8.4206e-02,  6.0408e-02,
           2.0130e-02],
         [-8.2915e-02,  2.6401e-01,  1.1098e-01,  2.4484e-02,  1.4117e-01,
          -2.9952e-01],
         [ 4.4458e-01, -6.7364e-02, -2.6988e-01, -3.1733e-01, -1.4858e-01,


In [22]:
for i in model.state_dict():
    print(i)
    print(model.state_dict()[i])

selfpropel
tensor(0.9046, device='cuda:0')
interactNN.layer1.weight
tensor([[ 6.8693e-02,  2.2930e-02,  2.8067e-01, -3.1570e-01, -2.2064e-01,
         -4.7282e-01],
        [ 3.6042e-02,  1.7907e-01, -4.5696e-01, -3.4711e-01,  3.2970e-01,
         -7.4623e-02],
        [ 3.3213e-01, -1.0251e-01, -1.4388e-01,  3.2349e-02, -1.4225e-02,
          1.8594e-01],
        [ 8.1909e-02, -5.9736e-01, -6.5860e-02, -6.6977e-02, -9.3137e-02,
         -3.7936e-01],
        [ 1.7986e-01, -2.9950e-01, -1.8117e-01,  2.3075e-01, -9.3827e-02,
          7.1192e-04],
        [ 1.7699e-01,  1.2099e-01, -4.0220e-01, -1.5253e-01, -3.9307e-01,
          7.2420e-02],
        [ 2.8259e-01,  1.4268e-01, -2.3052e-01, -7.4569e-02,  2.1027e-01,
          1.1247e-01],
        [ 5.5769e-01,  4.0678e-01, -5.3094e-02,  8.4206e-02,  6.0408e-02,
          2.0130e-02],
        [-8.2915e-02,  2.6401e-01,  1.1098e-01,  2.4484e-02,  1.4117e-01,
         -2.9952e-01],
        [ 4.4458e-01, -6.7364e-02, -2.6988e-01, -3.1733e-01

In [23]:
keyval = zip(list(model.state_dict().keys()), list(model.state_dict().values()))

keysGroups = [['layer1', 'bNorm1', 'selfpropel'],
              ['layer2', 'bNorm2'],
              ['layer3', 'bNorm3'],
              ['layer4']]

keysRemove = [['running_mean', 'running_var', 'num_batches_tracked']]*4

paramsGroups = []
for kG, kR in zip(keysGroups, keysRemove):
    paramsTmp = []
    for paramName, paramVal in model.named_parameters():
        if any([kN in paramName for kN in kG]) and (not any([kRN in paramName for kRN in kR])):
            #paramVal = model.state_dict()[paramName]
            paramVal.requires_grad = True
            paramsTmp.append(paramVal)        
    paramsGroups.append(paramsTmp)

print(paramsGroups)

#paramsGroups[0][0].register_hook(lambda grad: print('grad', grad))

lrs = np.array([1, 1e1, 1e2, 1e3]) * 1e-9


[[Parameter containing:
tensor(0.9046, device='cuda:0', requires_grad=True), Parameter containing:
tensor([[ 6.8693e-02,  2.2930e-02,  2.8067e-01, -3.1570e-01, -2.2064e-01,
         -4.7282e-01],
        [ 3.6042e-02,  1.7907e-01, -4.5696e-01, -3.4711e-01,  3.2970e-01,
         -7.4623e-02],
        [ 3.3213e-01, -1.0251e-01, -1.4388e-01,  3.2349e-02, -1.4225e-02,
          1.8594e-01],
        [ 8.1909e-02, -5.9736e-01, -6.5860e-02, -6.6977e-02, -9.3137e-02,
         -3.7936e-01],
        [ 1.7986e-01, -2.9950e-01, -1.8117e-01,  2.3075e-01, -9.3827e-02,
          7.1192e-04],
        [ 1.7699e-01,  1.2099e-01, -4.0220e-01, -1.5253e-01, -3.9307e-01,
          7.2420e-02],
        [ 2.8259e-01,  1.4268e-01, -2.3052e-01, -7.4569e-02,  2.1027e-01,
          1.1247e-01],
        [ 5.5769e-01,  4.0678e-01, -5.3094e-02,  8.4206e-02,  6.0408e-02,
          2.0130e-02],
        [-8.2915e-02,  2.6401e-01,  1.1098e-01,  2.4484e-02,  1.4117e-01,
         -2.9952e-01],
        [ 4.4458e-01, -6.736

In [24]:
# モデルのインスタンス生成
xy_dim = 2

#model = ActiveNet(xy_dim, dr_thresh, dropout=0, batchN=True, bias=True, Nchannels=128).to(device)
# input data
#data = dataset[0]

def calc_multiSteps(x0):
    outs = []
    x_i = x0
    for i_step in range(T_pred):
        x_i = x_i + model(x_i) * dt
        outs.append(x_i.clone())
    return torch.stack(outs, dim=0)

# optimizer
#optimizer = torch.optim.SGD(model.parameters(), lr=1e-4, momentum=0.9)
optimizer = torch.optim.Adam([
    {'params': pG, 'lr': lr} for pG, lr in zip(paramsGroups, lrs)
])
#optimizer = torch.optim.Adadelta(model.parameters())#, rho=0.95)#, lr=1e-1, momentum=0.9)

scheduler = ReduceLROnPlateau(optimizer, 'min', factor=0.5, patience=5, verbose=True)

val_loss_log = []

val_loss_min = np.Inf

# learnig loop
for epoch in range(50):
    model.train()
    for batch_x, batch_y, batch_ct in train_data:
        optimizer.zero_grad()
        lossv = 0
        losstheta = 0
        for ib in range(batch_x.size(0)):
            model.load_celltypes(batch_ct[ib].to(device))
            out = calc_multiSteps(batch_x[ib].to(device))
            lv, ltheta = myLoss(out, batch_y[ib].to(device))
            lossv = lossv + lv
            losstheta = losstheta + ltheta
        lossv = lossv / batch_x.size(0)
        losstheta = losstheta / batch_x.size(0)
        (lossv+losstheta).backward()
        optimizer.step()
    model.eval()
    val_lossv = 0
    val_losstheta = 0
    val_count = 0
    with torch.no_grad():
        for batch_x, batch_y, batch_ct in valid_data:
            for ib in range(batch_x.size(0)):
                model.load_celltypes(batch_ct[ib].to(device))
                val_out = calc_multiSteps(batch_x[ib].to(device))
                lv, ltheta = myLoss(val_out, batch_y[ib].to(device))
                val_lossv = val_lossv + lv
                val_losstheta = val_losstheta + ltheta
            val_count = val_count + batch_x.size(0)
    val_lossv = val_lossv/val_count
    val_losstheta = val_losstheta/val_count
    val_loss = val_lossv + val_losstheta
    scheduler.step(val_loss)
    print('Epoch %d | train Loss: [%.4f, %.4f] | valid Loss: [%.4f, %.4f]' % (epoch,
                                                                              lossv.item(), 
                                                                              losstheta.item(),
                                                                              val_lossv.item(), 
                                                                              val_losstheta.item()))
    val_loss_log.append([val_lossv.cpu().item(), val_losstheta.cpu().item()])
    if val_loss.item() < val_loss_min:
        stored_model = model
        val_loss_min = val_loss.item()

Epoch 0 | train Loss: [1.4314, 0.2154] | valid Loss: [1.4807, 0.2231]
Epoch 1 | train Loss: [1.4384, 0.2161] | valid Loss: [1.4747, 0.2229]
Epoch 2 | train Loss: [1.4327, 0.2176] | valid Loss: [1.4687, 0.2226]
Epoch 3 | train Loss: [1.4293, 0.2169] | valid Loss: [1.4638, 0.2225]
Epoch 4 | train Loss: [1.4232, 0.2178] | valid Loss: [1.4624, 0.2225]
Epoch 5 | train Loss: [1.4348, 0.2172] | valid Loss: [1.4583, 0.2223]
Epoch 6 | train Loss: [1.4143, 0.2174] | valid Loss: [1.4528, 0.2222]
Epoch 7 | train Loss: [1.4084, 0.2177] | valid Loss: [1.4494, 0.2221]
Epoch 8 | train Loss: [1.4103, 0.2156] | valid Loss: [1.4463, 0.2219]
Epoch 9 | train Loss: [1.4131, 0.2154] | valid Loss: [1.4433, 0.2218]
Epoch 10 | train Loss: [1.4145, 0.2153] | valid Loss: [1.4410, 0.2216]
Epoch 11 | train Loss: [1.4287, 0.2156] | valid Loss: [1.4393, 0.2214]
Epoch 12 | train Loss: [1.4303, 0.2158] | valid Loss: [1.4368, 0.2214]
Epoch 13 | train Loss: [1.4297, 0.2158] | valid Loss: [1.4337, 0.2213]
Epoch 14 | train

In [25]:
print(batch_x.shape)
print(batch_y.shape)
print(batch_ct.shape)

torch.Size([7, 400, 3])
torch.Size([7, 5, 400, 3])
torch.Size([7, 400, 1])


In [26]:
# モデルを評価モードに設定
stored_model.eval()

# 推論
test_lossv = 0
test_losstheta = 0
test_count = 0
with torch.no_grad():
    for batch_x, batch_y, batch_ct in test_data:
        for ib in range(batch_x.size(0)):
            model.load_celltypes(batch_ct[ib].to(device))
            test_out = calc_multiSteps(batch_x[ib].to(device))
            lv, ltheta = myLoss(test_out, batch_y[ib].to(device))
            test_lossv = test_lossv + lv
            test_losstheta = test_losstheta + ltheta
        test_count = test_count + batch_x.size(0)
test_lossv = test_lossv/test_count
test_losstheta = test_losstheta/test_count
print('test Loss: [%.4f, %.4f]' % (test_lossv.item(), test_losstheta.item()))
test_loss = [test_lossv.item(), test_losstheta.item()]

test Loss: [1.3656, 0.2169]


In [27]:
import pickle
import datetime

now = datetime.datetime.now()
nowstr = now.strftime('%Y%m%d_%H%M%S')

os.makedirs(savedirName + nowstr + '/', exist_ok=True)

In [28]:
stored_model.selfpropel.detach()

tensor(0.9046, device='cuda:0')

In [29]:
stored_model = stored_model.to('cpu')

filename1 = savedirName + nowstr + '/' + nowstr + '_Model.pkl'
with open(filename1, "wb") as f:
    pickle.dump(stored_model, f)

filename1_2 = savedirName + nowstr + '/' + nowstr + '_Model.pt'
torch.save(stored_model, filename1_2)

filename2 = savedirName + nowstr + '/' + nowstr
torch.save(stored_model.interactNN.state_dict(), filename2 + '_interactNN.pkl')
torch.save(stored_model.thetaDotNN.state_dict(), filename2 + '_thetaDotNN.pkl')
torch.save(stored_model.selfpropel.detach(), filename2 + '_selfpropel.pkl')

filename3 = savedirName + nowstr + '/' + nowstr + '_Separation.npz'
np.savez(filename3, dr_thresh=dr_thresh, T_pred=T_pred, batch_size=batch_size,
         train_inds=train_inds, valid_inds=valid_inds, test_inds=test_inds, 
         val_loss_log=val_loss_log, test_loss=test_loss,
         transfer_origin=model_dir, 
         lrs=lrs, keysGroups=keysGroups, keysRemove=keysRemove)

filename4 = savedirName + nowstr + '/' + nowstr + '_optimizer.pt'
torch.save(optimizer, filename4)

  val = np.asanyarray(val)
