In [88]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from model import Iwata_simple
from data import Iwata_Dataset_DB_Stock
import numpy as np
import time

from utils import StandardScaler
import pandas as pd
import psycopg2 as pg

import os
import sktime
import random

In [89]:
bidirectional = True
seq_len=13
enc_in = 5
hidden_size = 64
c_out = 1
s_n_layers = 2
batch_size = 32 # = support size
direcs = 2 if bidirectional else 1
model = Iwata_simple(enc_in, hidden_size, c_out, s_n_layers)
support_set = torch.rand(batch_size, seq_len, enc_in)
query_set = torch.rand(1, seq_len, enc_in)
output = model(support_set, query_set)
device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')

In [90]:
print(model)

Iwata_simple(
  (support_encoder): LSTM(5, 32, num_layers=2, bidirectional=True)
  (query_encoder): LSTM(5, 64)
  (attention): MultiheadAttention(
    (out_proj): NonDynamicallyQuantizableLinear(in_features=64, out_features=64, bias=True)
  )
  (g): Sequential(
    (0): Linear(in_features=128, out_features=64, bias=True)
    (1): ReLU()
    (2): Linear(in_features=64, out_features=1, bias=True)
  )
)


In [91]:
# train model
S = torch.rand(batch_size, seq_len, enc_in)
Q = torch.rand(1, seq_len, enc_in)
y = torch.rand(c_out)
epochs = 10

In [92]:
model.train()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
loss_f = nn.MSELoss()

In [93]:
S_N, Q_N, freq, = 32, 1, '1T'
conn = pg.connect("dbname=stonksdb user=postgres password=admin")
iwata_stck_ds = Iwata_Dataset_DB_Stock(conn, 32, 1, size=[seq_len, seq_len, 1], flag='train', features='MS', scale=True)
data_loader = DataLoader(
            iwata_stck_ds,
            batch_size=1, # only works with one as they are sampled already from Q_N, S_N
            shuffle=False,
            drop_last=True)



In [103]:
train_steps = len(data_loader)
time_now = time.time()

for epoch in range(epochs):
    train_loss = []
    epoch_time = time.time()
    iter_count = 0
    for s_seq_x, q_seq_x, q_seq_y in data_loader:
        s_seq_x = s_seq_x.float().squeeze(0).to(device) # support set
        q_seq_x = q_seq_x.float().to(device) # query set 
        q_seq_y = q_seq_y.float() # query set label 

        optimizer.zero_grad()
        output = model(s_seq_x, q_seq_x)
        loss = loss_f(output, q_seq_y)
        train_loss.append(loss.item())
        loss.backward()
        optimizer.step()
        iter_count += 1
        if (iter_count) % 100 == 0: 
            print("\titers: {0}, epoch: {1} | loss: {2:.7f}".format(iter_count, epoch + 1, loss.item()))
            speed = (time.time()-time_now)/iter_count
            left_time = speed*((epochs - epoch)*train_steps - iter_count)
            print('\tspeed: {:.4f}s/iter; left time: {:.4f}s'.format(speed, left_time))
            iter_count = 0
            time_now = time.time()

train_loss = np.average(train_loss)
print('Epoch {}/{} \t Time: {:.2f}s \t Loss: {:.4f}'.format(epoch+1, epochs, time.time() - epoch_time, train_loss))


	iters: 100, epoch: 1 | loss: 0.0448884
	speed: 0.0213s/iter; left time: 1859.7937s
	iters: 100, epoch: 1 | loss: 0.0008172
	speed: 0.0204s/iter; left time: 1781.8999s
	iters: 100, epoch: 1 | loss: 0.0016035
	speed: 0.0194s/iter; left time: 1695.3009s
	iters: 100, epoch: 1 | loss: 0.0054420
	speed: 0.0204s/iter; left time: 1777.0671s
	iters: 100, epoch: 1 | loss: 0.0461139
	speed: 0.0194s/iter; left time: 1694.6296s
	iters: 100, epoch: 1 | loss: 0.0085819
	speed: 0.0199s/iter; left time: 1736.8304s
	iters: 100, epoch: 1 | loss: 0.0280992
	speed: 0.0202s/iter; left time: 1758.3066s
	iters: 100, epoch: 1 | loss: 0.0019248
	speed: 0.0212s/iter; left time: 1853.8316s
	iters: 100, epoch: 1 | loss: 0.0029385
	speed: 0.0253s/iter; left time: 2210.1729s
	iters: 100, epoch: 1 | loss: 0.0000380
	speed: 0.0207s/iter; left time: 1808.3484s
	iters: 100, epoch: 1 | loss: 0.0705639
	speed: 0.0207s/iter; left time: 1803.5714s
	iters: 100, epoch: 1 | loss: 0.0104736
	speed: 0.0210s/iter; left time: 183

KeyboardInterrupt: 

In [102]:
s_seq_x.shape, q_seq_x.shape, q_seq_y.shape

(torch.Size([32, 12, 5]), torch.Size([1, 13, 5]), torch.Size([1]))