In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from model import Iwata_simple
from data import Iwata_Dataset_DB_Stock
import numpy as np
import time

from utils import StandardScaler
import pandas as pd
import psycopg2 as pg

In [23]:
bidirectional = True
seq_len=13
enc_in = 5
hidden_size = 64
c_out = 1
s_n_layers = 2
batch_size = 32 # = support size
direcs = 2 if bidirectional else 1
model = Iwata_simple(enc_in, hidden_size, c_out, s_n_layers)
support_set = torch.rand(batch_size, seq_len, enc_in)
query_set = torch.rand(1, seq_len, enc_in)
output = model(support_set, query_set)
device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')

In [4]:
print(model)

Iwata_simple(
  (support_encoder): LSTM(5, 32, num_layers=2, bidirectional=True)
  (query_encoder): LSTM(5, 64)
  (attention): MultiheadAttention(
    (out_proj): NonDynamicallyQuantizableLinear(in_features=64, out_features=64, bias=True)
  )
  (g): Sequential(
    (0): Linear(in_features=128, out_features=64, bias=True)
    (1): ReLU()
    (2): Linear(in_features=64, out_features=1, bias=True)
  )
)


In [5]:
# train model
S = torch.rand(batch_size, seq_len, enc_in)
Q = torch.rand(1, seq_len, enc_in)
y = torch.rand(c_out)
epochs = 10

In [6]:
model.train()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
loss_f = nn.MSELoss()

In [None]:
S_N, Q_N, freq, = 32, 1, '1T'
conn = pg.connect("dbname=stonksdb user=postgres password=admin")
iwata_stck_ds = Iwata_Dataset_DB_Stock(conn, 32, 1, size=[seq_len, seq_len, 1], flag='train', features='MS', scale=True)
data_loader = DataLoader(
            iwata_stck_ds,
            batch_size=1, # only works with one as they are sampled already from Q_N, S_N
            shuffle=False,
            drop_last=True)

In [37]:
batch_s, batch_Q, y = S, Q, y
# batch_s = batch_s.to(device)
# batch_Q = batch_Q.to(device)
# y = y.to(device)

train_steps = len(data_loader)
time_now = time.time()

for epoch in range(epochs):
    train_loss = []
    epoch_time = time.time()
    iter_count = 0
    for s_seq_x, q_seq_x, q_seq_y in data_loader:
        s_seq_x = s_seq_x.float().squeeze(0).to(device) # support set
        q_seq_x = q_seq_x.float().to(device) # query set 
        q_seq_y = q_seq_y.float() # query set label 

        optimizer.zero_grad()
        output = model(s_seq_x, q_seq_x)
        loss = loss_f(output, y)
        train_loss.append(loss.item())
        loss.backward()
        optimizer.step()
        iter_count += 1
        if (iter_count) % 100 == 0: 
            print("\titers: {0}, epoch: {1} | loss: {2:.7f}".format(iter_count, epoch + 1, loss.item()))
            speed = (time.time()-time_now)/iter_count
            left_time = speed*((epochs - epoch)*train_steps - iter_count)
            print('\tspeed: {:.4f}s/iter; left time: {:.4f}s'.format(speed, left_time))
            iter_count = 0
            time_now = time.time()

train_loss = np.average(train_loss)
print('Epoch {}/{} \t Time: {:.2f}s \t Loss: {:.4f}'.format(epoch+1, epochs, time.time() - epoch_time, train_loss))


	iters: 100, epoch: 1 | loss: 0.5696072
	speed: 0.0218s/iter; left time: 25032.0064s
	iters: 100, epoch: 1 | loss: 0.7892002
	speed: 0.0213s/iter; left time: 24433.9707s
	iters: 100, epoch: 1 | loss: 0.6570753
	speed: 0.0277s/iter; left time: 31778.2336s
	iters: 100, epoch: 1 | loss: 0.9482577
	speed: 0.0544s/iter; left time: 62431.8814s
	iters: 100, epoch: 1 | loss: 0.9702225
	speed: 0.0301s/iter; left time: 34507.3732s
	iters: 100, epoch: 1 | loss: 0.8559260
	speed: 0.0238s/iter; left time: 27329.8664s
	iters: 100, epoch: 1 | loss: 0.6561556
	speed: 0.0227s/iter; left time: 26055.7155s
	iters: 100, epoch: 1 | loss: 0.9326979
	speed: 0.0227s/iter; left time: 26027.1939s
	iters: 100, epoch: 1 | loss: 0.3719246
	speed: 0.0227s/iter; left time: 26091.7414s
	iters: 100, epoch: 1 | loss: 0.5082205
	speed: 0.0230s/iter; left time: 26379.7215s
	iters: 100, epoch: 1 | loss: 0.9417534
	speed: 0.0224s/iter; left time: 25692.0818s
	iters: 100, epoch: 1 | loss: 0.6433809
	speed: 0.0221s/iter; lef

KeyboardInterrupt: 