In [51]:
#!pip install dgl-cu113 dglgo -f https://data.dgl.ai/wheels/repo.html

In [52]:
import os
os.environ['DGLBACKEND'] = 'pytorch'

In [53]:
import torch
from torch import nn
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import Dataset

import dgl
import dgl.function as fn

import networkx as nx
from matplotlib import pyplot as plt
import numpy as np

import os
import gc

# デバイス設定
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')


In [54]:
def printNPZ(npz):
    for kw in npz.files:
        print(kw, npz[kw])

In [55]:
dirName = '/home/uwamichi/jupyter/HiraiwaModel_chem20230227_160803/'
savedirName = dirName + 'ActiveNet_vp_rotsym_multiStep_transfer_batchNorm/'
os.makedirs(savedirName, exist_ok=True)

params = np.load(dirName+'params.npz')
#traj = np.load(dirName+'result.npz')

printNPZ(params)
#printNPZ(traj)

v0 1.0
r 1.0
D 0.0
A 0.0
L 20
rho 1.0
beta 1.0
A_CFs [0.9 0.5]
A_CIL 0.0
cellType_ratio [0.7 0.3]
quiv_colors ['k' 'r']
kappa 0.5
A_Macdonalds [0.5 0.5]
batch_size 400
state_size 3
brownian_size 1
periodic True
t_max 1000
methodSDE heun
isIto False
stepSDE 0.01


In [56]:
if params['periodic']:
    L = torch.tensor(params['L'])
    def calc_dr(r1, r2):
        dr = torch.remainder((r1 - r2), L)
        dr[dr > L/2] = dr[dr > L/2] - L
        return dr
else:
    def calc_dr(r1, r2):
        return r1 - r2
    
def makeGraph(x_data, r_thresh):
        Ndata = x_data.size(0)
        dx = calc_dr(torch.unsqueeze(x_data, 0), torch.unsqueeze(x_data, 1))
        dx = torch.sum(dx**2, dim=2)
        edges = torch.argwhere(dx < r_thresh/2)
        return dgl.graph((edges[:,0], edges[:,1]), num_nodes=Ndata)

In [57]:
subdir_list = [f.path for f in os.scandir(dirName) if f.is_dir()]

print(subdir_list)

datadir_list = [f for f in subdir_list if 'result.npz' in [ff.name for ff in os.scandir(f)]]

print(datadir_list)

['/home/uwamichi/jupyter/HiraiwaModel_chem20230227_160803/20230228_024901', '/home/uwamichi/jupyter/HiraiwaModel_chem20230227_160803/20230228_064020', '/home/uwamichi/jupyter/HiraiwaModel_chem20230227_160803/20230228_103154', '/home/uwamichi/jupyter/HiraiwaModel_chem20230227_160803/20230228_142155', '/home/uwamichi/jupyter/HiraiwaModel_chem20230227_160803/20230228_181400', '/home/uwamichi/jupyter/HiraiwaModel_chem20230227_160803/20230228_220458', '/home/uwamichi/jupyter/HiraiwaModel_chem20230227_160803/20230301_015428', '/home/uwamichi/jupyter/HiraiwaModel_chem20230227_160803/20230301_054439', '/home/uwamichi/jupyter/HiraiwaModel_chem20230227_160803/20230301_093612', '/home/uwamichi/jupyter/HiraiwaModel_chem20230227_160803/20230301_132715', '/home/uwamichi/jupyter/HiraiwaModel_chem20230227_160803/ActiveNet_vp_rotsym_multiStep_batchNorm', '/home/uwamichi/jupyter/HiraiwaModel_chem20230227_160803/ActiveNet_vp_rotsym_batchNorm', '/home/uwamichi/jupyter/HiraiwaModel_chem20230227_160803/Acti

In [58]:
modeldirName = dirName + 'ActiveNet_vp_rotsym_batchNorm/'

#datename = '20220921_160055'
i_model = -1

model_files_list = [[os.path.join(c, ff) for ff in f if ff.endswith('_Model.pt')]
               for c, s, f in os.walk(modeldirName)]
print(model_files_list)

model_files = []
for i in range(len(model_files_list)):
    for j in range(len(model_files_list[i])):
        model_files.append(model_files_list[i][j])
print(model_files)

model_dir = os.path.dirname(model_files[i_model])

model_name = model_files[i_model].replace('_Model.pt', '').replace(modeldirName, '')
print(model_name)

[[], ['/home/uwamichi/jupyter/HiraiwaModel_chem20230227_160803/ActiveNet_vp_rotsym_batchNorm/20230301_231603/20230301_231603_Model.pt']]
['/home/uwamichi/jupyter/HiraiwaModel_chem20230227_160803/ActiveNet_vp_rotsym_batchNorm/20230301_231603/20230301_231603_Model.pt']
20230301_231603/20230301_231603


In [59]:
class myDataset(Dataset):
    def __init__(self, data_x, celltype_List, t_yseq=1):
        super().__init__()
        
        self.data_x = data_x # List of tensors
        #self.data_y = data_y
        self.celltype_List = celltype_List
        
        self.data_len = np.array([xx.size(0) for xx in self.data_x])
        self.t_yseq = t_yseq
        
        self.data_len_cumsum = np.cumsum(self.data_len - (self.t_yseq - 1))
        
    def __len__(self):
        return (self.data_len - (self.t_yseq - 1)).sum()
    
    def __getitem__(self, index):
        id_List = np.argwhere(index<self.data_len_cumsum)[0,0]
        
        if id_List:
            id_tensor = index - self.data_len_cumsum[id_List-1]
        else:
            id_tensor = index
        
        return self.data_x[id_List][id_tensor], self.data_x[id_List][id_tensor:(id_tensor+self.t_yseq)], self.celltype_List[id_List]

In [60]:
#dr_thresh = 7
dt = 1
batch_size = 8

T_pred = 10

N_data = len(datadir_list)

#TR_VA_rate = np.array([0.6, 0.2])

TR_last = 3
VA_last = 4

shuffle_inds = np.arange(N_data, dtype=int)
np.random.shuffle(shuffle_inds)

train_inds = shuffle_inds[:TR_last]
valid_inds = shuffle_inds[TR_last:VA_last]
test_inds = shuffle_inds[VA_last:]

celltype_lst = []

train_x = []
valid_x = []
test_x = []

train_y = []
valid_y = []
test_y = []

train_ct = []
valid_ct = []
test_ct = []

for i_dir, subdirName in enumerate(datadir_list):
    
    traj = np.load(subdirName+'/result.npz')
    
    xy_t = torch.tensor(traj['xy'])#[:-1,:,:])
    #v_t = calc_dr(torch.tensor(traj['xy'][1:,:,:]), torch.tensor(traj['xy'][:-1,:,:])) / dt
    p_t = torch.unsqueeze(torch.tensor(traj['theta']), dim=2)#[:-1,:]), dim=2)
    #w_t = torch.unsqueeze(torch.tensor((traj['theta'][1:,:]-traj['theta'][:-1,:])%(2*np.pi)/dt), dim=2)
    
    if i_dir in train_inds:
        train_x.append(torch.concat((xy_t, p_t), -1))
        #train_y.append(torch.concat((v_t, w_t), -1))
        train_ct.append(torch.tensor(traj['celltype_label']).view(-1,1))

    if i_dir in valid_inds:
        valid_x.append(torch.concat((xy_t, p_t), -1))
        #valid_y.append(torch.concat((v_t, w_t), -1))
        valid_ct.append(torch.tensor(traj['celltype_label']).view(-1,1))
        
    if i_dir in test_inds:
        test_x.append(torch.concat((xy_t, p_t), -1))
        #test_y.append(torch.concat((v_t, w_t), -1))
        test_ct.append(torch.tensor(traj['celltype_label']).view(-1,1))
    
train_dataset = myDataset(train_x, train_ct, t_yseq=T_pred)

valid_dataset = myDataset(valid_x, valid_ct, t_yseq=T_pred)

test_dataset = myDataset(test_x, test_ct, t_yseq=T_pred)


train_data = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, pin_memory=True)
valid_data = torch.utils.data.DataLoader(valid_dataset, batch_size=batch_size, pin_memory=True)
test_data = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, pin_memory=True)

del train_x, train_ct, train_dataset
del valid_x, valid_ct, valid_dataset
del test_x, test_ct, test_dataset
gc.collect()

#print(data)
#print(data.num_graphs)
#print(data.x)
#print(data.y)
#print(data.edge_index)

308

In [61]:
print(train_inds)

[5 4 3]


In [62]:
dir(train_data)

['_DataLoader__initialized',
 '_DataLoader__multiprocessing_context',
 '_IterableDataset_len_called',
 '__annotations__',
 '__class__',
 '__class_getitem__',
 '__delattr__',
 '__dict__',
 '__dir__',
 '__doc__',
 '__eq__',
 '__format__',
 '__ge__',
 '__getattribute__',
 '__gt__',
 '__hash__',
 '__init__',
 '__init_subclass__',
 '__iter__',
 '__le__',
 '__len__',
 '__lt__',
 '__module__',
 '__ne__',
 '__new__',
 '__orig_bases__',
 '__parameters__',
 '__reduce__',
 '__reduce_ex__',
 '__repr__',
 '__setattr__',
 '__sizeof__',
 '__slots__',
 '__str__',
 '__subclasshook__',
 '__weakref__',
 '_auto_collation',
 '_dataset_kind',
 '_get_iterator',
 '_get_shared_seed',
 '_index_sampler',
 '_is_protocol',
 '_iterator',
 'batch_sampler',
 'batch_size',
 'check_worker_number_rationality',
 'collate_fn',
 'dataset',
 'drop_last',
 'generator',
 'multiprocessing_context',
 'num_workers',
 'persistent_workers',
 'pin_memory',
 'pin_memory_device',
 'prefetch_factor',
 'sampler',
 'timeout',
 'worker_i

In [63]:
len(train_data)

372

In [64]:
print(train_data.dataset.data_len_cumsum)

[ 992 1984 2976]


In [65]:
def plotGraph(data):

    # networkxのグラフに変換
    nxg = dgl.to_networkx(data)

    # 可視化のためのページランク計算
    pr = nx.pagerank(nxg)
    pr_max = np.array(list(pr.values())).max()

    # 可視化する際のノード位置
    draw_pos = nx.spring_layout(nxg, seed=0) 

    # ノードの色設定
    cmap = plt.get_cmap('tab10')
    labels = data.y.numpy()
    colors = [cmap(l) for l in labels]

    # 図のサイズ
    fig0 = plt.figure(figsize=(10, 10))

    # 描画
    nx.draw_networkx_nodes(nxg, 
                          draw_pos,
                          node_size=[v / pr_max * 1000 for v in pr.values()])#,
                          #node_color=colors, alpha=0.5)
    nx.draw_networkx_edges(nxg, draw_pos, arrowstyle='-', alpha=0.2)
    nx.draw_networkx_labels(nxg, draw_pos, font_size=10)

    #plt.title('KarateClub')
    plt.show()

    return fig0

In [66]:
class NeuralNet(nn.Module):
    def __init__(self, in_channels, out_channels, Nchannels, dropout=0, batchN=False, flgBias=False):
        super(NeuralNet, self).__init__()

        if dropout:
            self.dropout = nn.Dropout(p=dropout)
        else:
            self.dropout = 0
            
        if batchN:
            self.bNorm1 = nn.BatchNorm1d(Nchannels)
            self.bNorm2 = nn.BatchNorm1d(Nchannels)
            self.bNorm3 = nn.BatchNorm1d(Nchannels)
            
        self.batchN=batchN
        
        self.layer1 = nn.Linear(in_channels, Nchannels, bias=flgBias)
        self.layer2 = nn.Linear(Nchannels, Nchannels, bias=flgBias)
        self.layer3 = nn.Linear(Nchannels, Nchannels, bias=flgBias)
        self.layer4 = nn.Linear(Nchannels, out_channels, bias=flgBias)

        self.activation = nn.ReLU()

    def reset_parameters(self):
        self.layer1.reset_parameters()
        self.layer2.reset_parameters()
        self.layer3.reset_parameters()
        self.layer4.reset_parameters()
        #nn.init.zeros_(self.layer1.weight)
        #nn.init.zeros_(self.layer2.weight)
        #nn.init.zeros_(self.layer3.weight)
        #nn.init.zeros_(self.layer4.weight)
        
    def forward(self, x):
        out = self.activation(self.layer1(x))
        if self.batchN:
            out = self.bNorm1(out)
        if self.dropout:
            out = self.dropout(out)
        
        out = self.activation(self.layer2(out))
        if self.batchN:
            out = self.bNorm2(out)
        if self.dropout:
            out = self.dropout(out)
        
        out = self.activation(self.layer3(out))
        if self.batchN:
            out = self.bNorm3(out)
        if self.dropout:
            out = self.dropout(out)
        
        out = self.layer4(out)

        return out

class ActiveNet(nn.Module):
    def __init__(self, xy_dim, r, dropout=0, batchN=False, bias=False, Nchannels=128):
        super().__init__()

        self.interactNN = NeuralNet(xy_dim*2 + 2, xy_dim, Nchannels, dropout, batchN, bias)

        self.thetaDotNN = NeuralNet(xy_dim*2 + 2, 1, Nchannels, dropout, batchN, bias)
        
        self.selfpropel = nn.Parameter(torch.tensor(0.0, requires_grad=True, device=device))

        #self.Normalizer = nn.Softmax(dim=1)

        self.xy_dim = xy_dim
        
        self.r = r

        self.reset_parameters()

    def reset_parameters(self):
        self.interactNN.reset_parameters()

        self.thetaDotNN.reset_parameters()
        
        nn.init.uniform_(self.selfpropel)

        #self.bias.data.zero_()
        
    def load_celltypes(self, celltype):
        self.celltype = celltype

    def calc_message(self, edges):
        dx = calc_dr(edges.dst['x'], edges.src['x'])

        costheta = torch.cos(edges.dst['theta'])
        sintheta = torch.sin(edges.dst['theta'])

        dx_para = costheta * dx[..., :1] + sintheta * dx[..., 1:]
        dx_perp = costheta * dx[..., 1:] - sintheta * dx[..., :1]

        p_para_src = torch.cos(edges.src['theta'] - edges.dst['theta'])
        p_perp_src = torch.sin(edges.src['theta'] - edges.dst['theta'])

        rot_m_v = self.interactNN(torch.concat((dx_para, dx_perp, 
                                                p_para_src, p_perp_src,
                                                edges.dst['type'], edges.src['type']), -1))

        m_v = torch.concat((costheta * rot_m_v[..., :1] - sintheta * rot_m_v[..., 1:],
                            costheta * rot_m_v[..., 1:] + sintheta * rot_m_v[..., :1]), -1)

        m_theta = self.thetaDotNN(torch.concat((dx_para, dx_perp, 
                                                p_para_src, p_perp_src, 
                                                edges.dst['type'], edges.src['type']), -1))
        
        return {'m': torch.concat((m_v, m_theta), -1)}
        
    def forward(self, xv):
        r_g = makeGraph(xv[..., :self.xy_dim], self.r/2)
        r_g.ndata['x'] = xv[..., :self.xy_dim]
        r_g.ndata['theta'] = xv[..., self.xy_dim:(self.xy_dim+1)]
        r_g.ndata['type'] = self.celltype
        r_g.update_all(self.calc_message, fn.sum('m', 'a'))
        r_g.ndata['a'][..., :self.xy_dim] = r_g.ndata['a'][..., :self.xy_dim] + self.selfpropel * torch.concat((torch.cos(r_g.ndata['theta']), torch.sin(r_g.ndata['theta'])), -1)
        
        return r_g.ndata['a']



In [67]:
def myLoss(out, target):
    #dv = torch.sum(torch.square(out[..., :xy_dim] - target[..., :xy_dim]), dim=-1)
    dv = torch.sum(torch.square(calc_dr(out[..., :xy_dim], target[..., :xy_dim])), dim=-1)
    dcos = torch.cos(out[..., xy_dim] - target[..., xy_dim])
    
    wei_shape = np.ones([dv.dim()], dtype=int)
    wei_shape[0] = T_pred
    wei = torch.tensor(np.reshape(1/np.arange(1, T_pred+1), wei_shape)).to(dv.device)
    wei = wei/wei.mean()
    
    return torch.mean(dv*wei), torch.mean((1-dcos)*wei)

In [68]:
filename1 = modeldirName + model_name + '_Model.pt'
model = torch.load(filename1, map_location=torch.device(device)) #pickle.load(open(filename1, 'rb'))
model.train()

filename3 = modeldirName + model_name + '_Separation.npz'
learn_params = np.load(filename3, allow_pickle=True)

printNPZ(learn_params)

dr_thresh = learn_params['dr_thresh'].item()

dr_thresh 4
batch_size 8
train_inds [6 1 4 5 7]
valid_inds [2 9]
test_inds [0 8 3]
val_loss_log [[0.06957185 0.01138402]
 [0.06715702 0.01138812]
 [0.05852827 0.01076387]
 [0.05464074 0.01047792]
 [0.04124076 0.01028508]
 [0.04581858 0.01010125]
 [0.03636987 0.00993934]
 [0.03513648 0.00985879]
 [0.03662048 0.00865206]
 [0.0346748  0.00870927]
 [0.0328907  0.00827903]
 [0.03297586 0.00823795]
 [0.02731015 0.00815105]
 [0.03475118 0.00824428]
 [0.03300626 0.00805876]
 [0.03231701 0.00799449]
 [0.03065317 0.00789906]
 [0.03004111 0.00781185]
 [0.02944164 0.00775791]
 [0.02152422 0.00777418]
 [0.02077205 0.00775165]
 [0.02282816 0.00787244]
 [0.02027514 0.00775065]
 [0.02011793 0.007732  ]
 [0.02029534 0.00774424]
 [0.01848694 0.00769201]
 [0.02006832 0.00769452]
 [0.01811177 0.0076291 ]
 [0.0193208  0.00749339]
 [0.01931852 0.00757757]
 [0.01905341 0.00756452]
 [0.01934462 0.00758915]
 [0.01876732 0.00760899]
 [0.01857739 0.00751577]
 [0.01504676 0.00637165]
 [0.01473385 0.00635003]
 [0.

In [69]:
# モデルのインスタンス生成
xy_dim = 2

#model = ActiveNet(xy_dim, dr_thresh, dropout=0, batchN=True, bias=True, Nchannels=128).to(device)
# input data
#data = dataset[0]

def calc_multiSteps(x0):
    outs = []
    x_i = x0
    for i_step in range(T_pred):
        x_i = x_i + model(x_i) * dt
        outs.append(x_i.clone())
    return torch.stack(outs, dim=0)

# optimizer
#optimizer = torch.optim.SGD(model.parameters(), lr=1e-4, momentum=0.9)
optimizer = torch.optim.Adam(model.parameters(), lr=3e-6)#, weight_decay=5e-4)
#optimizer = torch.optim.Adadelta(model.parameters())#, rho=0.95)#, lr=1e-1, momentum=0.9)

scheduler = ReduceLROnPlateau(optimizer, 'min', factor=0.5, patience=5, verbose=True)

val_loss_log = []

val_loss_min = np.Inf

# learnig loop
for epoch in range(50):
    model.train()
    for batch_x, batch_y, batch_ct in train_data:
        optimizer.zero_grad()
        lossv = 0
        losstheta = 0
        for ib in range(batch_x.size(0)):
            model.load_celltypes(batch_ct[ib].to(device))
            out = calc_multiSteps(batch_x[ib].to(device))
            lv, ltheta = myLoss(out, batch_y[ib].to(device))
            lossv = lossv + lv
            losstheta = losstheta + ltheta
        lossv = lossv / batch_x.size(0)
        losstheta = losstheta / batch_x.size(0)
        (lossv+losstheta).backward()
        optimizer.step()
    model.eval()
    val_lossv = 0
    val_losstheta = 0
    val_count = 0
    with torch.no_grad():
        for batch_x, batch_y, batch_ct in valid_data:
            for ib in range(batch_x.size(0)):
                model.load_celltypes(batch_ct[ib].to(device))
                val_out = calc_multiSteps(batch_x[ib].to(device))
                lv, ltheta = myLoss(val_out, batch_y[ib].to(device))
                val_lossv = val_lossv + lv
                val_losstheta = val_losstheta + ltheta
            val_count = val_count + batch_x.size(0)
    val_lossv = val_lossv/val_count
    val_losstheta = val_losstheta/val_count
    val_loss = val_lossv + val_losstheta
    scheduler.step(val_loss)
    print('Epoch %d | train Loss: [%.4f, %.4f] | valid Loss: [%.4f, %.4f]' % (epoch,
                                                                              lossv.item(), 
                                                                              losstheta.item(),
                                                                              val_lossv.item(), 
                                                                              val_losstheta.item()))
    val_loss_log.append([val_lossv.cpu().item(), val_losstheta.cpu().item()])
    if val_loss.item() < val_loss_min:
        stored_model = model
        val_loss_min = val_loss.item()

Epoch 0 | train Loss: [3.6737, 0.2377] | valid Loss: [3.4095, 0.2219]
Epoch 1 | train Loss: [3.4591, 0.2413] | valid Loss: [3.3853, 0.2215]
Epoch 2 | train Loss: [3.5780, 0.2353] | valid Loss: [3.3881, 0.2219]
Epoch 3 | train Loss: [3.5287, 0.2365] | valid Loss: [3.3911, 0.2213]
Epoch 4 | train Loss: [3.5383, 0.2395] | valid Loss: [3.3930, 0.2217]
Epoch 5 | train Loss: [3.5548, 0.2363] | valid Loss: [3.3893, 0.2217]
Epoch 6 | train Loss: [3.5902, 0.2358] | valid Loss: [3.3933, 0.2219]
Epoch 00008: reducing learning rate of group 0 to 1.5000e-06.
Epoch 7 | train Loss: [3.5181, 0.2336] | valid Loss: [3.3881, 0.2215]
Epoch 8 | train Loss: [3.6187, 0.2362] | valid Loss: [3.3897, 0.2220]
Epoch 9 | train Loss: [3.6344, 0.2367] | valid Loss: [3.3870, 0.2221]
Epoch 10 | train Loss: [3.5825, 0.2347] | valid Loss: [3.3933, 0.2219]
Epoch 11 | train Loss: [3.5599, 0.2349] | valid Loss: [3.3925, 0.2219]
Epoch 12 | train Loss: [3.6539, 0.2398] | valid Loss: [3.3890, 0.2220]
Epoch 13 | train Loss: [3

In [70]:
print(batch_x.shape)
print(batch_y.shape)
print(batch_ct.shape)

torch.Size([8, 400, 3])
torch.Size([8, 10, 400, 3])
torch.Size([8, 400, 1])


In [71]:
model.state_dict()

OrderedDict([('selfpropel', tensor(1.0137, device='cuda:0')),
             ('interactNN.layer1.weight',
              tensor([[-6.2236e-01,  4.4161e-01, -3.2166e-01,  2.3285e-01, -3.5735e-02,
                        1.4747e-02],
                      [ 1.2277e-03, -2.4975e-01, -1.7714e-02,  2.3002e-01,  1.0570e-01,
                        1.0600e-01],
                      [ 8.0876e-02, -6.8451e-02,  6.2910e-02,  2.8809e-02, -2.7933e-01,
                       -3.7462e-01],
                      [ 2.7385e-01, -3.7914e-01, -3.7419e-01, -2.9074e-02,  1.2012e-02,
                        5.6747e-03],
                      [-2.0044e-01,  4.9273e-02, -2.8922e-01,  1.4737e-01,  2.3178e-01,
                        1.0604e-01],
                      [-4.1265e-01, -2.9591e-01, -1.8542e-01, -1.3536e-01, -2.8173e-02,
                        1.6512e-02],
                      [-3.4694e-02,  1.0687e-01,  2.0695e-02, -1.5187e-01, -4.3258e-03,
                        1.0713e-02],
                     

In [72]:
# モデルを評価モードに設定
stored_model.eval()

# 推論
test_lossv = 0
test_losstheta = 0
test_count = 0
with torch.no_grad():
    for batch_x, batch_y, batch_ct in test_data:
        for ib in range(batch_x.size(0)):
            model.load_celltypes(batch_ct[ib].to(device))
            test_out = calc_multiSteps(batch_x[ib].to(device))
            lv, ltheta = myLoss(test_out, batch_y[ib].to(device))
            test_lossv = test_lossv + lv
            test_losstheta = test_losstheta + ltheta
        test_count = test_count + batch_x.size(0)
test_lossv = test_lossv/test_count
test_losstheta = test_losstheta/test_count
print('test Loss: [%.4f, %.4f]' % (test_lossv.item(), test_losstheta.item()))
test_loss = [test_lossv.item(), test_losstheta.item()]

test Loss: [3.4976, 0.2283]


In [73]:
import pickle
import datetime

now = datetime.datetime.now()
nowstr = now.strftime('%Y%m%d_%H%M%S')

os.makedirs(savedirName + nowstr + '/', exist_ok=True)

In [74]:
stored_model.selfpropel.detach()

tensor(1.0137, device='cuda:0')

In [75]:
stored_model = stored_model.to('cpu')

filename1 = savedirName + nowstr + '/' + nowstr + '_Model.pkl'
with open(filename1, "wb") as f:
    pickle.dump(stored_model, f)

filename1_2 = savedirName + nowstr + '/' + nowstr + '_Model.pt'
torch.save(stored_model, filename1_2)

filename2 = savedirName + nowstr + '/' + nowstr
torch.save(stored_model.interactNN.state_dict(), filename2 + '_interactNN.pkl')
torch.save(stored_model.thetaDotNN.state_dict(), filename2 + '_thetaDotNN.pkl')
torch.save(stored_model.selfpropel.detach(), filename2 + '_selfpropel.pkl')

filename3 = savedirName + nowstr + '/' + nowstr + '_Separation.npz'
np.savez(filename3, dr_thresh=dr_thresh, T_pred=T_pred, batch_size=batch_size,
         train_inds=train_inds, valid_inds=valid_inds, test_inds=test_inds, 
         val_loss_log=val_loss_log, test_loss=test_loss,
         transfer_origin=model_dir)

filename4 = savedirName + nowstr + '/' + nowstr + '_optimizer.pt'
torch.save(optimizer, filename4)