# Import backage

In [1]:
import sys
sys.path.append("../..")
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import cm
import random
%matplotlib inline

In [2]:
import torch
from torch import nn
from torch import distributions
from torch.nn.parameter import Parameter
from train import train
from models.models_new import Renorm_Dynamic
from dynamic_models_sis_new import Simple_Spring_Model
from sklearn.neighbors import NearestNeighbors
from sklearn.neighbors import KernelDensity
from sklearn.decomposition import PCA
from sklearn.model_selection import GridSearchCV
from datetime import datetime
t0 = datetime.now()


use_cuda = torch.cuda.is_available()
device = torch.device('cuda:0') if use_cuda else torch.device('cpu')
device

device(type='cuda', index=0)

# Generate data

In [3]:
mul_batch_size = [0,5000,3000,1000]#[0,10,4500,1500]
sigma = 0.03
rou = -0.5
steps = 7

seed = 2050
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
spring = Simple_Spring_Model(device=device)
test_data = spring.generate_multistep_sir(size_list=[500,500], steps=10, sigma=sigma, rou=rou,lam=1,miu=0.5,dt=0.01) 
train_data = spring.generate_multistep_sir(size_list=mul_batch_size, steps=steps, sigma=sigma,rou=rou,lam=1,miu=0.5,dt=0.01)

# Train NIS

In [4]:
sz = 4
scale = 2
L = 1
mae2_w = 3
T_total = 30001
eis_nis, term1s_nis, term2s_nis, losses_nis, MAEs_mstep_nis, net = train(train_data=train_data, test_data=test_data, 
                                                                    sigma=sigma, rou=rou, sz=sz, scale=scale, L=L, 
                                                                    mae2_w=mae2_w, T2=T_total,framework='nis')
torch.save(net.state_dict(), 'mdl_data/test nis.mdl')

check point-------o_0-> 10:50:13.140550; lasting 0:00:02.579769 seconds
Epoch: 0
Train loss: 0.5320
dEI: 0.5823
term1: -0.2108
term2: 1.3755
Test multistep loss: 0.5802
------------------------------------------------------------------------------------------------------------------------
check point-------o_1-> 10:50:18.224970; lasting 0:00:05.084420 seconds
check point-------o_0-> 10:50:24.060743; lasting 0:00:05.835773 seconds
Epoch: 500
Train loss: 0.0772
dEI: 1.6582
term1: 1.6337
term2: 1.6826
Test multistep loss: 0.0758
------------------------------------------------------------------------------------------------------------------------
check point-------o_1-> 10:50:31.516196; lasting 0:00:07.455453 seconds
check point-------o_0-> 10:50:37.367951; lasting 0:00:05.851755 seconds
Epoch: 1000
Train loss: 0.0379
dEI: 2.0680
term1: 2.4721
term2: 1.6638
Test multistep loss: 0.0384
--------------------------------------------------------------------------------------------------------

# Train NIS+

In [5]:
eis_nisp, term1s_nisp, term2s_nisp, losses_nisp, MAEs_mstep_nisp = train(train_data=train_data, test_data=test_data, 
                                                                         sigma=sigma, rou=rou, sz=sz, scale=scale, L=L, 
                                                                         mae2_w=mae2_w, T2=T_total,framework='nis+')

check point-------o_0-> 11:07:14.811925; lasting 0:00:00.056965 seconds
Epoch: 0
Train loss: 0.5016
dEI: 0.6823
term1: -0.0044
term2: 1.3690
Test multistep loss: 0.4765
------------------------------------------------------------------------------------------------------------------------
check point-------o_1-> 11:07:19.714841; lasting 0:00:04.902916 seconds
check point-------o_0-> 11:07:31.103184; lasting 0:00:11.388343 seconds
Epoch: 500


KeyboardInterrupt: 

# Results

In [None]:
plt.figure(dpi=300)
timeseries = np.arange(len(eis_nis))*500
plt.plot(timeseries, eis_nis, label='NIS')
plt.plot(timeseries, eis_nisp, label='NIS+')
plt.title(r'$\mathcal{J}$')
plt.legend()
plt.show()

In [None]:
plt.figure(dpi=300)
timeseries = np.arange(len(eis_nis))*500
plt.plot(timeseries, losses_nis, label='NIS')
plt.plot(timeseries, losses_nisp, label='NIS+')
plt.title(r'loss')
plt.legend()
#plt.yscale('log')
plt.show()

In [None]:
plt.figure(dpi=300)
timeseries = np.arange(len(eis_nis))*500
plt.plot(timeseries, MAEs_mstep_nis, label='NIS')
plt.plot(timeseries, MAEs_mstep_nisp, label='NIS+')
plt.title(r'multistep mae')
plt.legend()
plt.show()