In [1]:
import sys
sys.path.append("../..")
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import cm
import random
%matplotlib inline

In [2]:
import torch
from torch import nn
from torch import distributions
from torch.nn.parameter import Parameter
from train import train
from train import train_rnis
from models_new import Renorm_Dynamic
from models_new import Rnis_Dynamic
from ei.EI_calculation import count_parameters
from dynamic_models_sis_new import Simple_Spring_Model
from datetime import datetime
t0 = datetime.now()


use_cuda = torch.cuda.is_available()
device = torch.device('cuda:0') if use_cuda else torch.device('cpu')
device

device(type='cuda', index=0)

# Generate data

In [3]:
mul_batch_size = [0,5000,3000,1000]#[0,10,4500,1500]
sigma = 0.03
rou = -0.5
steps = 7
dt = 0.01

seed = 2050
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
spring = Simple_Spring_Model(device=device)
test_data = spring.generate_multistep_sir(size_list=[500,500], steps=10, sigma=sigma, rou=rou,lam=1,miu=0.5,dt=dt) 
train_data = spring.generate_multistep_sir(size_list=mul_batch_size, steps=steps, sigma=sigma,rou=rou,lam=1,miu=0.5,dt=dt)

# Train RNIS

In [None]:
sz = 4
scale = 2
L = 1
mae2_w = 3
T_total = 30001
eis_rnis, term1s_rnis, term2s_rnis, losses_rnis, MAEs_mstep_rnis, net_rnis = train_rnis(train_data=train_data, test_data=test_data, 
                                                                    sigma=sigma, rou=rou, sz=sz, scale=scale, L=L, 
                                                                    mae2_w=mae2_w, dt=dt, T2=T_total)
count_parameters(net_rnis)

check point-------o_0-> 16:59:31.855712; lasting 0:00:05.164688 seconds
Epoch: 0
Train loss: 0.4519
dEI: 0.4148
term1: -0.4372
term2: 1.2669
Test multistep loss: 0.5445
------------------------------------------------------------------------------------------------------------------------
check point-------o_1-> 16:59:39.227315; lasting 0:00:07.371603 seconds
check point-------o_0-> 16:59:48.374060; lasting 0:00:09.146745 seconds
Epoch: 500
Train loss: 0.0627
dEI: 1.3171
term1: 1.3027
term2: 1.3315
Test multistep loss: 0.0714
------------------------------------------------------------------------------------------------------------------------
check point-------o_1-> 17:00:00.006159; lasting 0:00:11.632099 seconds
check point-------o_0-> 17:00:07.156122; lasting 0:00:07.149963 seconds
Epoch: 1000
Train loss: 0.0358
dEI: 1.7008
term1: 2.0568
term2: 1.3448
Test multistep loss: 0.0383
--------------------------------------------------------------------------------------------------------

In [9]:
count_parameters(net_rnis)

29728

# Train NIS

In [None]:
sz = 4
scale = 2
L = 1
mae2_w = 3
T_total = 30001
eis_nis, term1s_nis, term2s_nis, losses_nis, MAEs_mstep_nis, net_nis = train(train_data=train_data, test_data=test_data, 
                                                                    sigma=sigma, rou=rou, sz=sz, scale=scale, L=L, 
                                                                    mae2_w=mae2_w, dt=dt, T2=T_total,framework='nis')

check point-------o_0-> 10:33:16.031063; lasting 13:58:08.182269 seconds
Epoch: 0
Train loss: 0.4681
dEI: 0.7891
term1: 0.1541
term2: 1.4241
Test multistep loss: 0.4775
------------------------------------------------------------------------------------------------------------------------
check point-------o_1-> 10:33:21.617484; lasting 0:00:05.586421 seconds
check point-------o_0-> 10:33:27.615941; lasting 0:00:05.998457 seconds
Epoch: 500
Train loss: 0.0982
dEI: 1.2274
term1: 0.9382
term2: 1.5166
Test multistep loss: 0.0978
------------------------------------------------------------------------------------------------------------------------
check point-------o_1-> 10:33:34.901426; lasting 0:00:07.285485 seconds
check point-------o_0-> 10:33:40.852963; lasting 0:00:05.951537 seconds
Epoch: 1000
Train loss: 0.0364
dEI: 1.9923
term1: 2.4916
term2: 1.4929
Test multistep loss: 0.0367
--------------------------------------------------------------------------------------------------------

# Train NIS+

In [None]:
eis_nisp, term1s_nisp, term2s_nisp, losses_nisp, MAEs_mstep_nisp, net_nisp = train(train_data=train_data, test_data=test_data, 
                                                                         sigma=sigma, rou=rou, sz=sz, scale=scale, L=L, 
                                                                         mae2_w=mae2_w, dt=dt, T2=T_total,framework='nis+')

# Results

In [None]:
print(count_parameters(net_rnis),count_parameters(net_nis),count_parameters(net_nisp))

In [None]:
plt.figure(dpi=300)
timeseries = np.arange(len(eis_nis))*500
plt.plot(timeseries, eis_nis, label='NIS')
plt.plot(timeseries, eis_nisp, label='NIS+')
plt.plot(timeseries, eis_rnis, label='RNIS')
plt.title(r'$\mathcal{J}$')
plt.legend()
plt.show()

In [None]:
plt.figure(dpi=300)
timeseries = np.arange(len(eis_nis))*500
plt.plot(timeseries, losses_nis, label='NIS')
plt.plot(timeseries, losses_nisp, label='NIS+')
plt.plot(timeseries, losses_rnis, label='RNIS')
plt.title(r'loss')
plt.legend()
#plt.yscale('log')
plt.show()

In [None]:
plt.figure(dpi=300)
timeseries = np.arange(len(eis_nis))*500
plt.plot(timeseries, MAEs_mstep_nis, label='NIS')
plt.plot(timeseries, MAEs_mstep_nisp, label='NIS+')
plt.plot(timeseries, MAEs_mstep_rnis, label='RNIS')
plt.title(r'multistep mae')
plt.legend()
plt.show()