In [1]:
import numpy as np
import torch

from loader.DataLoader import read_sequence, read_dataframe
from loader.DataTransformer import lag_list, moving_average
from model.Attention_LstmModel import Attention_LstmModel

In [2]:
# script parameter
# MODE: MA (moving average), D1(lag 1 degree), DMA(decaying moving average) or default no change
# MODE = 'MA'
LAG = 16

In [3]:
# prepare data
sequence = read_dataframe('all').to_numpy()
y_var = np.var(sequence[:,-1])
shifted_sequence = lag_list(sequence, LAG)  # shift into delayed sequences

x_train = shifted_sequence[:, :-1, 1:]  # for each delayed sequence, take all elements except last element
y_train = shifted_sequence[:, -1, -1]  # for each delayed sequence, only take the last element
y_train = y_train.reshape(-1, 1)

x_train = torch.from_numpy(x_train.astype('float64')).type(torch.Tensor)  # convert to tensor
y_train = torch.from_numpy(y_train.astype('int32')).type(torch.Tensor)  # convert to tensor

# build model
input_dim = x_train.shape[-1]
hidden_dim = 64
num_layers = 2
output_dim = 1

In [10]:
model = Attention_LstmModel(input_dim, hidden_dim, num_layers, output_dim)

# train
num_epochs = 3_000 # 3_000
loss_fn = torch.nn.MSELoss()
optimiser = torch.optim.Adam(model.parameters(), lr=0.01)
model.train()
for epoch in range(1, num_epochs + 1):
    y_pred = model(x_train)[0]
    loss = loss_fn(y_pred, y_train)
    if epoch % 100 == 0:
        print("Epoch: %d | MSE: %.2E | RRSE: %.2E" % (epoch, loss.item(), np.sqrt(loss.item() / y_var)))
    optimiser.zero_grad()
    loss.backward()
    optimiser.step()

# test
model.eval()
y_pred = model(x_train[:5])[0]
y_pred = y_pred.detach().numpy()  # revert from tensor
y_pred = y_pred.reshape(-1)  # reshape back to normal list
print("sample prediction:  ", y_pred)

y_train_sample = y_train[:5].detach().numpy().reshape(-1)
print("sample true result: ", y_train_sample)

# verify
y_pred_round = [round(p) for p in y_pred]
y_train_round = [round(p) for p in y_train_sample]

# assert (y_pred_round == y_train_round)

Epoch: 100 | MSE: 2.05E+07 | RRSE: 1.03E+00
Epoch: 200 | MSE: 2.04E+07 | RRSE: 1.03E+00
Epoch: 300 | MSE: 2.03E+07 | RRSE: 1.03E+00
Epoch: 400 | MSE: 2.02E+07 | RRSE: 1.02E+00
Epoch: 500 | MSE: 2.01E+07 | RRSE: 1.02E+00
Epoch: 600 | MSE: 2.00E+07 | RRSE: 1.02E+00
Epoch: 700 | MSE: 1.99E+07 | RRSE: 1.02E+00
Epoch: 800 | MSE: 1.98E+07 | RRSE: 1.01E+00
Epoch: 900 | MSE: 1.97E+07 | RRSE: 1.01E+00
Epoch: 1000 | MSE: 1.96E+07 | RRSE: 1.01E+00
Epoch: 1100 | MSE: 1.95E+07 | RRSE: 1.01E+00
Epoch: 1200 | MSE: 1.95E+07 | RRSE: 1.01E+00
Epoch: 1300 | MSE: 1.94E+07 | RRSE: 1.01E+00
Epoch: 1400 | MSE: 1.94E+07 | RRSE: 1.00E+00
Epoch: 1500 | MSE: 1.93E+07 | RRSE: 1.00E+00
Epoch: 1600 | MSE: 1.92E+07 | RRSE: 1.00E+00
Epoch: 1700 | MSE: 1.92E+07 | RRSE: 1.00E+00
Epoch: 1800 | MSE: 1.92E+07 | RRSE: 9.98E-01
Epoch: 1900 | MSE: 1.91E+07 | RRSE: 9.97E-01
Epoch: 2000 | MSE: 1.91E+07 | RRSE: 9.96E-01
Epoch: 2100 | MSE: 1.90E+07 | RRSE: 9.94E-01
Epoch: 2200 | MSE: 1.91E+07 | RRSE: 9.97E-01
Epoch: 2300 | MSE: 