<a href="https://colab.research.google.com/github/CornerSiow/stacked-lstm/blob/main/Sequence_to_Sequence_Attention.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!git clone https://github.com/CornerSiow/stacked-lstm.git

Cloning into 'stacked-lstm'...
remote: Enumerating objects: 55, done.[K
remote: Counting objects: 100% (55/55), done.[K
remote: Compressing objects: 100% (53/53), done.[K
remote: Total 55 (delta 25), reused 0 (delta 0), pack-reused 0[K
Receiving objects: 100% (55/55), 2.31 MiB | 2.33 MiB/s, done.
Resolving deltas: 100% (25/25), done.


In [None]:
!cp /content/stacked-lstm/utility.py /content/utility.py

In [None]:
import numpy as np
import torch
import torch.nn as nn
from torch import optim
import torch.nn.functional as F
import pickle
from tqdm import tqdm
import numpy as np
from torch.utils.data import Dataset, DataLoader
import random
import math
from utility import showScore
from utility import CustomDataset

In [None]:
train_dataset = CustomDataset("/content/stacked-lstm/data/data_train.pickle")
train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=0)
test_dataset = CustomDataset("/content/stacked-lstm/data/data_test.pickle")
test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=0)

In [None]:
class EncoderRNN(nn.Module):
    def __init__(self, input_size, hidden_size, dropout_p=0.1):
        super(EncoderRNN, self).__init__()
        self.hidden_size = hidden_size
        self.embedding = nn.Linear(input_size, hidden_size)
        self.gru = nn.GRU(hidden_size, hidden_size, batch_first=True)
        self.dropout = nn.Dropout(dropout_p)

    def forward(self, input):
        embedded = self.dropout(self.embedding(input))
        output, hidden = self.gru(embedded)
        return output, hidden

class BahdanauAttention(nn.Module):
    def __init__(self, hidden_size):
        super(BahdanauAttention, self).__init__()
        self.Wa = nn.Linear(hidden_size, hidden_size)
        self.Ua = nn.Linear(hidden_size, hidden_size)
        self.Va = nn.Linear(hidden_size, 1)

    def forward(self, query, keys):
        scores = self.Va(torch.tanh(self.Wa(query) + self.Ua(keys)))
        scores = scores.squeeze(2).unsqueeze(1)

        weights = F.softmax(scores, dim=-1)
        context = torch.bmm(weights, keys)

        return context, weights

class AttnDecoderRNN(nn.Module):
    def __init__(self, hidden_size, output_size, dropout_p=0.1):
        super(AttnDecoderRNN, self).__init__()
        self.embedding = nn.Linear(1, 64)
        self.attention = BahdanauAttention(hidden_size)
        self.gru = nn.GRU(154, hidden_size, batch_first=True)
        self.dropout = nn.Dropout(dropout_p)
        self.mlp = nn.Sequential(
            nn.Linear(hidden_size , hidden_size  ),
            nn.ReLU(),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, output_size)
        )

    def forward(self, encoder_outputs, encoder_hidden, target_tensor=None):
        batch_size = encoder_outputs.size(0)
        decoder_input = torch.empty(batch_size, 1, 90, device=device).fill_(0)

        decoder_hidden = encoder_hidden
        decoder_outputs = []
        attentions = []

        for i in range(40):
            decoder_output, decoder_hidden, attn_weights = self.forward_step(
                decoder_input, decoder_hidden, encoder_outputs
            )

            decoder_outputs.append(decoder_output)
            attentions.append(attn_weights)

            if target_tensor is not None:
                decoder_input = target_tensor[:, i].unsqueeze(1)
            else:
                decoder_input = decoder_output

        decoder_outputs = torch.cat(decoder_outputs, dim=1)
        attentions = torch.cat(attentions, dim=1)

        return decoder_outputs, decoder_hidden, attentions


    def forward_step(self, input, hidden, encoder_outputs):
        query = hidden.permute(1, 0, 2)
        context, attn_weights = self.attention(query, encoder_outputs)

        input_gru = torch.cat((input, context), dim=2)

        output, hidden = self.gru(input_gru, hidden)
        output = self.mlp(output)

        return output, hidden, attn_weights


In [None]:
random.seed(0)
np.random.seed(0)
torch.manual_seed(0)
torch.cuda.manual_seed_all(0)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
encoder = EncoderRNN(66, 64).to(device)
decoder = AttnDecoderRNN(64, 90).to(device)

In [None]:
encoder_optimizer = optim.Adam(encoder.parameters(), lr=0.0001)
decoder_optimizer = optim.Adam(decoder.parameters(), lr=0.0001)
criterion = nn.MSELoss()

In [None]:
bar = tqdm(range(10000))
prevLoss = math.inf
stopCondition = 300
for epoch in bar:
    encoder.train()
    decoder.train()
    totalLoss = 0
    for x, y in train_dataloader:
        input_tensor = x.float().to(device)
        target_tensor = y.float().to(device)

        encoder_optimizer.zero_grad()
        decoder_optimizer.zero_grad()
        encoder_outputs, encoder_hidden = encoder(input_tensor)
        decoder_outputs, _, _ = decoder(encoder_outputs, encoder_hidden, target_tensor)

        loss = criterion(decoder_outputs, target_tensor)
        loss.backward()
        encoder_optimizer.step()
        decoder_optimizer.step()
        totalLoss += loss.item()
    if totalLoss > prevLoss:
        stopCondition -= 1
    else:
        prevLoss = totalLoss
        stopCondition = 300
    bar.set_description("loss:{}".format(totalLoss))

loss:2.0846536585850117e-07: 100%|██████████| 10000/10000 [2:10:38<00:00,  1.28it/s]


In [11]:
encoder.eval()
decoder.eval()
y_pred = []
y_true = []
for x, y in test_dataloader:
    with torch.no_grad():
        input_tensor = x.float().to(device)
        target_tensor = y.float()

        encoder_outputs, encoder_hidden = encoder(input_tensor)
        decoder_outputs, _, _ = decoder(encoder_outputs, encoder_hidden)

        y_true += target_tensor.detach().cpu().data[0,:,:]
        y_pred += decoder_outputs.detach().cpu().data[0,:,:]

y_true = torch.stack(y_true)
y_pred = torch.stack(y_pred)

In [12]:
showScore(y_true, y_pred)

mean absolute error: 0.00051
r2 score: -2.25415
explained variance score: -2.23265
maximum absolute error: 23.96376
