In [12]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import random
from matplotlib import pyplot as plt
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from models import SciNet
import pandas as pd
from utils import target_loss 
from loader import build_dataloader
import torch.optim.lr_scheduler as lr_scheduler
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [13]:
t_size = 80
size = 150
t_max = 5
t = np.linspace(0, t_max, t_size)
min_fr, max_fr = 0.01, 100
fr = np.random.uniform(min_fr, max_fr, size)
start_st, end_st = 0.01, 100
st = np.logspace(np.log10(start_st), np.log10(end_st), size, endpoint = True)

In [14]:
# the function that we generate the data with
'''def f(t, st, fr):
    return st**2 * fr * (1 - t/st - np.exp(-t/st))'''
def f(t, st, fr):
    return  st**2 + fr **2 + t
data = []
for st_ in st:
    for fr_ in fr:
        example = list(f(t, st_, fr_))
        t_pred = np.random.uniform(0, t_max)
        pred = f(t_pred,st_,fr_)
        example.append(fr_)
        example.append(st_)
        example.append(t_pred)
        example.append(pred)
        data.append(example)
data = np.array(data)
colummns = [str(i) for i in range(t_size)]
colummns.append("fr")
colummns.append("st")
colummns.append("t_pred")
colummns.append("pred") 
df = pd.DataFrame(data,columns=colummns)
df.shape

(22500, 84)

In [15]:
# Setup scinet model with 3 latent neurons
scinet = SciNet(t_size,1,3,100)

# Load and prepare training data
dataloader = build_dataloader(batch_size =100, size=t_size)

In [16]:
# Training setup
SAVE_PATH = "trained_models/scinet1.dat"
N_EPOCHS = 130
optimizer = optim.Adam(scinet.parameters(), lr=0.001)
hist_error = []
hist_loss = []
scheduler = lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.009, total_iters=N_EPOCHS)
beta = 0.5

# Training loop
for epoch in range(N_EPOCHS):  
    epoch_error = []
    epoch_loss = []
    for i_batch, minibatch in enumerate(dataloader):
        time_series, _, _, question, answer = minibatch['time_series'] / 5, minibatch['fr'] / 5, minibatch['st'] / 5, minibatch['question'] / 5, minibatch['answer'] / 5
        
        # concat the time series with the features
        inputs = torch.cat((time_series, question.reshape(-1, 1)), 1)
        # build the output
        outputs = answer

        optimizer.zero_grad()
        # print(inputs[0])
        pred = scinet.forward(inputs)
        # print(pred)
        # break
        loss = target_loss(pred, outputs) + beta * scinet.kl_loss
        loss.backward()
        optimizer.step()
        error = torch.mean(torch.sqrt((pred[:,0]-outputs)**2)).detach().numpy()
        epoch_error.append(error)
        epoch_loss.append(loss.data.detach().numpy())
    # break
    hist_error.append(np.mean(epoch_error))
    hist_loss.append(np.mean(epoch_loss))

    before_lr = optimizer.param_groups[0]["lr"]
    scheduler.step()
    after_lr = optimizer.param_groups[0]["lr"]
    print("Epoch %d: SGD lr %.6f -> %.6f" % (epoch+1, before_lr, after_lr))
    
    print("Epoch %d -- loss %f, RMS error %f " % (epoch+1, hist_loss[-1], hist_error[-1]))
torch.save(scinet.state_dict(), SAVE_PATH)
print("Model saved to %s" % SAVE_PATH)

Epoch 1: SGD lr 0.001000 -> 0.000992
Epoch 1 -- loss 185.734283, RMS error 0.787579 
Epoch 2: SGD lr 0.000992 -> 0.000985
Epoch 2 -- loss 94.510986, RMS error 0.517777 
Epoch 3: SGD lr 0.000985 -> 0.000977
Epoch 3 -- loss 59.778454, RMS error 0.368741 
Epoch 4: SGD lr 0.000977 -> 0.000970
Epoch 4 -- loss 48.419079, RMS error 0.332723 
Epoch 5: SGD lr 0.000970 -> 0.000962
Epoch 5 -- loss 46.031620, RMS error 0.320696 
Epoch 6: SGD lr 0.000962 -> 0.000954
Epoch 6 -- loss 44.111805, RMS error 0.304555 
Epoch 7: SGD lr 0.000954 -> 0.000947
Epoch 7 -- loss 45.322041, RMS error 0.294003 
Epoch 8: SGD lr 0.000947 -> 0.000939
Epoch 8 -- loss 41.507080, RMS error 0.273516 
Epoch 9: SGD lr 0.000939 -> 0.000931
Epoch 9 -- loss 42.506229, RMS error 0.270787 
Epoch 10: SGD lr 0.000931 -> 0.000924
Epoch 10 -- loss 35.841282, RMS error 0.214470 
Epoch 11: SGD lr 0.000924 -> 0.000916
Epoch 11 -- loss 24.845491, RMS error 0.118114 
Epoch 12: SGD lr 0.000916 -> 0.000909
Epoch 12 -- loss 22.834450, RMS e

In [None]:


print("Original Data:", data)
print("Window Means:", window_means)


In [15]:
# Plot some training history data
%matplotlib inline 
f, (ax1, ax2) = plt.subplots(2, 1, sharex=True)
ax1.plot(hist_error)
ax1.set_ylabel("Amplitude RMSE")
ax2.plot(hist_loss)
ax2.set_ylabel("Loss")
ax2.set_xlabel("Epoch")
plt.show()

: 

: 