In [1]:
#mport warnings
#arnings.filterwarnings(action='ignore')
import os

import numpy as np

import torch
import torch.optim as optim

from tqdm.notebook import tqdm

from utils import *

In [2]:
data_dir = "data/"
model_dir = "model/"


data_name = "strongly_bimodal" # write down your data in the second string
save_file_name = data_name # 

K = 4 ## latent space dim
h = 16 ## num. nodes in hidden layers
M=80 ## number of kernel
s_min = 0.5 ## minimum size of scale params
s_max = 2 ## maximum size of scale params

n_col_pts = 100
learning_rate = 1e-4
EPOCH = 2000000 # maximum numbe of iteration
early_stopping = True
if early_stopping == True:
    n_hist = 30  ## number of 
    epsilon = 1e-5

In [3]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("device: {}".format(device))

device: cuda:0


In [4]:
obs_time, time_trace, scaled_trace, mean_scaled_trace = data_preprocessing(data_name, data_dir, device)

In [5]:
net_MDN = MDN(d = len(obs_time), final_time = obs_time[-1].item(), K = K, h = h, M = M, s_min = s_min, s_max = s_max, device = device)
net_MDN = net_MDN.to(device)
net_MDN.data = scaled_trace
net_MDN.train()
print("model created")

model created


In [6]:
opt=torch.optim.Adam([{'params': net_MDN.parameters(), 'lr':learning_rate}], amsgrad = True)

r1, r2 = obs_time[0].item(), obs_time[-1].item()
col_time = torch.linspace(r1,r2,n_col_pts).to(device).reshape(-1,1)
t_v = Variable(col_time, requires_grad = True)

beta = derivative_penalty(time_trace, M)

In [7]:
loss_profile = []
w_list_mean = []
w_list_std = []

In [None]:
for n_iter in tqdm(range(1, EPOCH+1)) :    
    loss1, loss2, loss3 = train(model=net_MDN, optimizer=opt, loss_f=nn.L1Loss(), loss_kl=kl_divergence, 
                       col_time = col_time, obs_time = obs_time,
                       time_trace = time_trace, beta = beta, device = device, s_min=0.5, s_max=2)
    # Print Log
    if n_iter%1000 == 0 :
        loss_profile.append(loss1+loss2+loss3)
        print("{}/{}: P_loss:{}, D_loss:{}, R_loss:{}".format(n_iter, EPOCH, np.round(loss1,3), np.round(loss2,3), np.round(loss3,3)))
        w_mean, w_std = mean_w(net_MDN, mean_scaled_trace, time_trace)
        w_list_mean.append(w_mean)
        w_list_std.append(w_std)
 
        if early_stopping == True:

            if n_iter>n_hist*1000:
                w_mean = np.array(w_list_mean).T
                w_std = np.array(w_list_std).T
                w_mean = w_mean/w_mean.sum(axis=0)

                w_cv = np.zeros_like(w_mean)
                for k in range(n_hist,w_mean.shape[1]):
                    w_cv[:,k] = w_mean[:,k-n_hist:k].std(axis=1)/(w_mean[:,k-n_hist:k].mean(axis=1)+epsilon)

                w_cv_traj = (w_cv<0.01).sum(axis=0)/M
                w_cv_traj[:n_hist] = 0


                if (n_iter>n_hist*1000):
                    if (w_cv_traj[-1] > .99):
                        print("training stopped (w> 99%)")
                        break

     # Save Model
        if loss_profile[-1] <= np.min(loss_profile):
                torch.save([net_MDN, loss_profile], model_dir+'/'+save_file_name)
                


  0%|          | 0/2000000 [00:00<?, ?it/s]

1000/2000000: P_loss:22.085, D_loss:9.111, R_loss:6.152
2000/2000000: P_loss:19.042, D_loss:8.392, R_loss:8.525
3000/2000000: P_loss:15.72, D_loss:8.797, R_loss:7.025
4000/2000000: P_loss:14.689, D_loss:9.057, R_loss:5.607
5000/2000000: P_loss:13.311, D_loss:9.277, R_loss:4.959
6000/2000000: P_loss:12.505, D_loss:9.404, R_loss:4.36
7000/2000000: P_loss:11.576, D_loss:9.515, R_loss:3.963
8000/2000000: P_loss:11.013, D_loss:9.692, R_loss:3.599
9000/2000000: P_loss:10.541, D_loss:9.824, R_loss:3.266
10000/2000000: P_loss:10.587, D_loss:9.84, R_loss:3.037
11000/2000000: P_loss:9.956, D_loss:9.964, R_loss:2.847
12000/2000000: P_loss:9.728, D_loss:9.966, R_loss:2.763
13000/2000000: P_loss:9.867, D_loss:9.947, R_loss:2.672
14000/2000000: P_loss:9.32, D_loss:10.025, R_loss:2.55
15000/2000000: P_loss:9.42, D_loss:9.954, R_loss:2.497
16000/2000000: P_loss:9.218, D_loss:10.058, R_loss:2.449
17000/2000000: P_loss:8.728, D_loss:10.073, R_loss:2.469
18000/2000000: P_loss:8.779, D_loss:10.073, R_loss

147000/2000000: P_loss:0.141, D_loss:8.086, R_loss:0.903
148000/2000000: P_loss:0.093, D_loss:8.137, R_loss:0.875
149000/2000000: P_loss:0.148, D_loss:7.98, R_loss:0.889
150000/2000000: P_loss:0.15, D_loss:8.262, R_loss:0.885
151000/2000000: P_loss:0.094, D_loss:8.12, R_loss:0.874
152000/2000000: P_loss:0.131, D_loss:8.159, R_loss:0.899
153000/2000000: P_loss:0.111, D_loss:8.08, R_loss:0.852
154000/2000000: P_loss:0.091, D_loss:8.204, R_loss:0.883
155000/2000000: P_loss:0.091, D_loss:8.072, R_loss:0.857
156000/2000000: P_loss:0.177, D_loss:8.15, R_loss:0.878
157000/2000000: P_loss:0.22, D_loss:7.972, R_loss:0.883
158000/2000000: P_loss:0.111, D_loss:8.168, R_loss:0.843
159000/2000000: P_loss:0.114, D_loss:7.925, R_loss:0.859
160000/2000000: P_loss:0.12, D_loss:8.154, R_loss:0.84
161000/2000000: P_loss:0.096, D_loss:8.22, R_loss:0.86
162000/2000000: P_loss:0.093, D_loss:8.33, R_loss:0.861
163000/2000000: P_loss:0.083, D_loss:8.274, R_loss:0.856
164000/2000000: P_loss:0.101, D_loss:8.155

292000/2000000: P_loss:0.044, D_loss:8.242, R_loss:0.608
293000/2000000: P_loss:0.045, D_loss:8.195, R_loss:0.596
294000/2000000: P_loss:0.041, D_loss:8.048, R_loss:0.594
295000/2000000: P_loss:0.043, D_loss:8.113, R_loss:0.587
296000/2000000: P_loss:0.062, D_loss:8.153, R_loss:0.593
297000/2000000: P_loss:0.061, D_loss:8.131, R_loss:0.597
298000/2000000: P_loss:0.039, D_loss:8.104, R_loss:0.585
299000/2000000: P_loss:0.046, D_loss:8.139, R_loss:0.598
300000/2000000: P_loss:0.04, D_loss:8.17, R_loss:0.592
301000/2000000: P_loss:0.045, D_loss:8.222, R_loss:0.596
302000/2000000: P_loss:0.046, D_loss:8.044, R_loss:0.596
303000/2000000: P_loss:0.045, D_loss:8.204, R_loss:0.591
304000/2000000: P_loss:0.041, D_loss:8.175, R_loss:0.588
305000/2000000: P_loss:0.043, D_loss:8.148, R_loss:0.586
306000/2000000: P_loss:0.054, D_loss:7.981, R_loss:0.588
307000/2000000: P_loss:0.039, D_loss:8.062, R_loss:0.576
308000/2000000: P_loss:0.044, D_loss:8.211, R_loss:0.59
309000/2000000: P_loss:0.116, D_lo

437000/2000000: P_loss:0.032, D_loss:8.192, R_loss:0.548
438000/2000000: P_loss:0.036, D_loss:8.075, R_loss:0.541
439000/2000000: P_loss:0.04, D_loss:8.065, R_loss:0.546
440000/2000000: P_loss:0.032, D_loss:8.153, R_loss:0.538
441000/2000000: P_loss:0.035, D_loss:8.033, R_loss:0.539
442000/2000000: P_loss:0.032, D_loss:8.095, R_loss:0.55
443000/2000000: P_loss:0.036, D_loss:8.091, R_loss:0.541
444000/2000000: P_loss:0.034, D_loss:8.158, R_loss:0.551
445000/2000000: P_loss:0.044, D_loss:8.028, R_loss:0.549
446000/2000000: P_loss:0.036, D_loss:8.111, R_loss:0.55
447000/2000000: P_loss:0.039, D_loss:8.118, R_loss:0.547
448000/2000000: P_loss:0.035, D_loss:8.155, R_loss:0.547
449000/2000000: P_loss:0.033, D_loss:8.081, R_loss:0.546
450000/2000000: P_loss:0.04, D_loss:8.211, R_loss:0.544
451000/2000000: P_loss:0.05, D_loss:8.095, R_loss:0.547
452000/2000000: P_loss:0.035, D_loss:8.22, R_loss:0.542
453000/2000000: P_loss:0.033, D_loss:8.21, R_loss:0.55
454000/2000000: P_loss:0.032, D_loss:8.

582000/2000000: P_loss:0.028, D_loss:8.209, R_loss:0.534
583000/2000000: P_loss:0.035, D_loss:8.128, R_loss:0.525
584000/2000000: P_loss:0.027, D_loss:8.125, R_loss:0.53
585000/2000000: P_loss:0.028, D_loss:8.14, R_loss:0.531
586000/2000000: P_loss:0.033, D_loss:8.118, R_loss:0.526
587000/2000000: P_loss:0.036, D_loss:8.057, R_loss:0.528
588000/2000000: P_loss:0.033, D_loss:8.059, R_loss:0.528
589000/2000000: P_loss:0.03, D_loss:8.074, R_loss:0.536
590000/2000000: P_loss:0.027, D_loss:8.191, R_loss:0.527
591000/2000000: P_loss:0.032, D_loss:8.111, R_loss:0.523
592000/2000000: P_loss:0.032, D_loss:8.038, R_loss:0.531
593000/2000000: P_loss:0.029, D_loss:8.021, R_loss:0.525
594000/2000000: P_loss:0.027, D_loss:8.164, R_loss:0.535
595000/2000000: P_loss:0.028, D_loss:8.191, R_loss:0.535
596000/2000000: P_loss:0.032, D_loss:8.185, R_loss:0.534
597000/2000000: P_loss:0.028, D_loss:8.212, R_loss:0.531
598000/2000000: P_loss:0.03, D_loss:8.128, R_loss:0.532
599000/2000000: P_loss:0.026, D_los

727000/2000000: P_loss:0.027, D_loss:8.13, R_loss:0.519
728000/2000000: P_loss:0.028, D_loss:8.099, R_loss:0.521
729000/2000000: P_loss:0.025, D_loss:8.131, R_loss:0.523
730000/2000000: P_loss:0.024, D_loss:8.174, R_loss:0.529
731000/2000000: P_loss:0.034, D_loss:8.189, R_loss:0.533
732000/2000000: P_loss:0.024, D_loss:8.208, R_loss:0.531
733000/2000000: P_loss:0.024, D_loss:8.151, R_loss:0.529
734000/2000000: P_loss:0.027, D_loss:8.136, R_loss:0.534
735000/2000000: P_loss:0.027, D_loss:8.093, R_loss:0.524
736000/2000000: P_loss:0.039, D_loss:8.084, R_loss:0.527
737000/2000000: P_loss:0.027, D_loss:8.127, R_loss:0.531
738000/2000000: P_loss:0.027, D_loss:8.138, R_loss:0.535
739000/2000000: P_loss:0.027, D_loss:8.13, R_loss:0.538
740000/2000000: P_loss:0.026, D_loss:8.125, R_loss:0.528
741000/2000000: P_loss:0.029, D_loss:8.024, R_loss:0.532
742000/2000000: P_loss:0.025, D_loss:8.196, R_loss:0.531
743000/2000000: P_loss:0.027, D_loss:8.173, R_loss:0.524
744000/2000000: P_loss:0.025, D_l

872000/2000000: P_loss:0.035, D_loss:8.086, R_loss:0.527
873000/2000000: P_loss:0.027, D_loss:8.052, R_loss:0.529
874000/2000000: P_loss:0.025, D_loss:8.225, R_loss:0.526
875000/2000000: P_loss:0.025, D_loss:8.069, R_loss:0.524
876000/2000000: P_loss:0.032, D_loss:8.028, R_loss:0.526
877000/2000000: P_loss:0.03, D_loss:8.11, R_loss:0.53
878000/2000000: P_loss:0.023, D_loss:8.18, R_loss:0.535
879000/2000000: P_loss:0.027, D_loss:8.215, R_loss:0.528
880000/2000000: P_loss:0.029, D_loss:8.093, R_loss:0.525
881000/2000000: P_loss:0.026, D_loss:8.122, R_loss:0.535
882000/2000000: P_loss:0.025, D_loss:8.107, R_loss:0.525
883000/2000000: P_loss:0.025, D_loss:8.267, R_loss:0.526
884000/2000000: P_loss:0.029, D_loss:8.039, R_loss:0.532
885000/2000000: P_loss:0.022, D_loss:8.118, R_loss:0.527
886000/2000000: P_loss:0.024, D_loss:8.005, R_loss:0.536
887000/2000000: P_loss:0.025, D_loss:7.917, R_loss:0.528
888000/2000000: P_loss:0.031, D_loss:8.154, R_loss:0.532
889000/2000000: P_loss:0.026, D_los

1017000/2000000: P_loss:0.026, D_loss:8.08, R_loss:0.527
1018000/2000000: P_loss:0.027, D_loss:8.141, R_loss:0.534
1019000/2000000: P_loss:0.025, D_loss:8.131, R_loss:0.522
1020000/2000000: P_loss:0.023, D_loss:8.114, R_loss:0.527
1021000/2000000: P_loss:0.025, D_loss:8.035, R_loss:0.531
1022000/2000000: P_loss:0.035, D_loss:8.042, R_loss:0.534
1023000/2000000: P_loss:0.03, D_loss:7.996, R_loss:0.522
1024000/2000000: P_loss:0.026, D_loss:8.095, R_loss:0.531
1025000/2000000: P_loss:0.031, D_loss:7.905, R_loss:0.521
1026000/2000000: P_loss:0.024, D_loss:8.044, R_loss:0.532
1027000/2000000: P_loss:0.026, D_loss:7.995, R_loss:0.525
1028000/2000000: P_loss:0.033, D_loss:8.038, R_loss:0.539
1029000/2000000: P_loss:0.025, D_loss:8.22, R_loss:0.538
1030000/2000000: P_loss:0.026, D_loss:8.035, R_loss:0.527
1031000/2000000: P_loss:0.023, D_loss:8.059, R_loss:0.519
1032000/2000000: P_loss:0.029, D_loss:8.055, R_loss:0.529
1033000/2000000: P_loss:0.031, D_loss:8.189, R_loss:0.534
1034000/2000000: 

1160000/2000000: P_loss:0.037, D_loss:7.691, R_loss:0.587
1161000/2000000: P_loss:0.037, D_loss:8.023, R_loss:0.587
1162000/2000000: P_loss:0.033, D_loss:7.842, R_loss:0.594
1163000/2000000: P_loss:0.036, D_loss:7.966, R_loss:0.597
1164000/2000000: P_loss:0.034, D_loss:7.944, R_loss:0.594
1165000/2000000: P_loss:0.032, D_loss:8.022, R_loss:0.602
1166000/2000000: P_loss:0.032, D_loss:7.787, R_loss:0.603
1167000/2000000: P_loss:0.039, D_loss:7.795, R_loss:0.591
1168000/2000000: P_loss:0.03, D_loss:8.173, R_loss:0.602
1169000/2000000: P_loss:0.032, D_loss:8.006, R_loss:0.591
1170000/2000000: P_loss:0.034, D_loss:7.996, R_loss:0.596
1171000/2000000: P_loss:0.031, D_loss:7.649, R_loss:0.594
1172000/2000000: P_loss:0.03, D_loss:7.926, R_loss:0.591
1173000/2000000: P_loss:0.033, D_loss:8.047, R_loss:0.602
1174000/2000000: P_loss:0.03, D_loss:7.828, R_loss:0.593
1175000/2000000: P_loss:0.032, D_loss:7.91, R_loss:0.598
1176000/2000000: P_loss:0.029, D_loss:7.781, R_loss:0.598
1177000/2000000: P

In [None]:
exit()