In [None]:
import torch
import torch.nn as nn
import numpy as np
from random import random
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
import pandas as pd

In [None]:
class Encoder(nn.Module):
    def __init__(self, input_size: int, hidden_size: int,
                 num_layers: int=1, batch_first: bool=True, dropout: float=0):
        super().__init__()
        self.lstm = nn.LSTM(input_size, hidden_size,
                            num_layers=num_layers, batch_first=batch_first, dropout=dropout)
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.batch_first = batch_first
        self.dropout = dropout

    def forward(self, x, h0, c0):
        hidden_states, (h, c) = self.lstm(x, (h0, c0))
        return hidden_states, h, c

class Decoder(nn.Module):
    def __init__(self, input_size: int, hidden_size: int, output_size: int,
                 num_layers: int=1, batch_first: bool=True, lstm_dropout: float=0, dropout_layer_dropout: float=0):
        super().__init__()
        self.lstm = nn.LSTM(input_size, hidden_size,
                            num_layers=num_layers, batch_first=batch_first, dropout=lstm_dropout)
        self.dropout_layer = nn.Dropout(dropout_layer_dropout)
        self.layernorm = nn.LayerNorm(hidden_size)
        self.linear=nn.Linear(hidden_size, output_size)
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.num_layers = num_layers
        self.batch_first = batch_first

    def forward(self, x, h0, c0):
        if len(x.shape) == 2:
            x = x.unsqueeze(1) if self.batch_first else x.unsqueeze(0)

        outputs, (h, c) = self.lstm(x, (h0, c0)) # outputs.shape = (batch_size, 1, hidden_size) if batch_first = True
                                                 #               = (1, batch_size, hidden_size) if batch_first = False
        outputs = self.dropout_layer(outputs)
        outputs = self.layernorm(outputs)
        y = self.linear(outputs)
        return y, h, c


class Attention(nn.Module):
    def __init__(self, decoder_output_size, hidden_size, W0=None):
        super().__init__()
        if W0 is not None and W0.shape == (decoder_output_size, hidden_size):
            self.W = nn.Parameter(W0)
        else:
            self.W = nn.Parameter(torch.empty(decoder_output_size, hidden_size))
            nn.init.xavier_uniform_(self.W)

    def forward(self, h, encoder_hidden_states, batch_first):
        # h is the current hidden state of the decoder; h.shape = (decoder_num_layers, batch_size, decoder_output_size)
        # encoder_hidden_states.shape = (batch_size, L, hidden_size) if batch_first else (L, batch_size, hidden_size), L is source (input) sequence length
        # self.W should be of shape (decoder_output_size, hidden_size)
        h_last = h[-1]

        if not batch_first:
            encoder_hidden_states = encoder_hidden_states.transpose(0, 1)

        scores = torch.einsum("bd, dh, blh -> bl", h_last, self.W, encoder_hidden_states)  # scores.shape = (batch_size, L)
        attn_weights = F.softmax(scores, dim=1)  # attn_weights.shape = (batch_size, L)
        context_vec = torch.einsum("bl, blh -> bh", attn_weights, encoder_hidden_states) # context_vec.shape = (batch_size, hidden_size)

        return context_vec


class Seq2Seq(nn.Module):
    def __init__(self, encoder, decoder, attention):
        super().__init__()
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.encoder = encoder.to(self.device)
        self.decoder = decoder.to(self.device)
        self.attention = attention.to(self.device)

    # source = input to the encoder; x0 = initial input to initialize the decoder
    # Y = ground truths; Y.shape = (batch_size, L, decoder.output_size) if batch_first=True
    #                            = (L, batch_size, decoder.output_size) if batch_first = False
    def forward(self, source, x0, h0, c0,
                transform: callable=lambda x: x, teacher_forcing_ratio: float=0.0, Y=None, L=None):  # L is target sequence length
        source, x0, h0, c0 = source.to(self.device), x0.to(self.device), h0.to(self.device), c0.to(self.device)
        batch_first = self.encoder.batch_first
        if Y is None:
            if L is None:
                print("Target sequence length not provided. Defaulting to input sequence length.")
                L = source.shape[1] if batch_first else source.shape[0]
            if teacher_forcing_ratio != 0.0:
                print("For teacher_forcing_ratio=0, Y must be None")
                return None
        else:
            L = Y.shape[1] if batch_first else Y.shape[0]
            Y = Y.to(self.device)

        batch_size = source.shape[0] if batch_first else source.shape[1]
        decoder_output_size = self.decoder.output_size

        if Y is not None and batch_first:
            Y = Y.transpose(0,1) # for indexing convenience

        encoder_hidden_states, h, c = self.encoder(source, (h0, c0))  # h and c of shape (num_layers, batch_size, hidden_size); hidden_size = encoder_hidden_size = decoder_hidden_size
        context_vec = self.attention(h, encoder_hidden_states, batch_first)
        context_vec = context_vec.unsqueeze(1) if batch_first else context_vec.unsqueeze(0) # context_vec.shape becomes (batch_size, 1, hidden_size) if batch_first=True
                                                                                            #                           (1, batch_size, hidden_size) if batch_first=False

        outputs = torch.zeros(L, batch_size, decoder_output_size, device=self.device) # have to reshape later if batch_first=True
        decoder_input = torch.cat((x0, context_vec), dim=2)
        x, h, c = self.decoder(decoder_input, h, c) # x.shape = (batch_size, 1, decoder.output_size) if batch_first=True
                                                    #         = (1, batch_size, decoder.output_size) if batch_first=False
        i = 1 if batch_first else 0
        outputs[0] = x.squeeze(i)
        for t in range(1, L):
            context_vec = self.attention(h, encoder_hidden_states, batch_first)
            context_vec = context_vec.unsqueeze(1) if batch_first else context_vec.unsqueeze(0) # context_vec.shape becomes (batch_size, 1, hidden_size) if batch_first=True
                                                                                                #                           (1, batch_size, hidden_size) if batch_first=False
            if random() > teacher_forcing_ratio:
                x_ = transform(x)
            else:
                x_ = Y[t-1].unsqueeze(1) if batch_first else Y[t-1].unsqueeze(0)
            decoder_input = torch.cat((x_, context_vec), dim=2)
            x, h, c = self.decoder(decoder_input, h, c)
            outputs[t] = x.squeeze(i)

        if batch_first:
            outputs = outputs.transpose(0,1)
        return outputs

In [None]:
# batching from scratch, not used
def batching(X, Y, batch_size, batch_first):
    """batches the data at each epoch; returns a list of batches for X and Y"""
    # X.shape = (num_samples, L, encoder_input_size); Y.shape = (num_samples, L, decoder_output_size)
    num_samples = X.shape[0]
    num_batches = num_samples // batch_size + 1
    indices = torch.randperm(num_samples)
    indices_groups = [indices[i*batch_size: (i+1)*batch_size] for i in range(num_batches-1)]
    indices_groups.append(indices[(num_batches-1)*batch_size:])
    X_batched, Y_batched = [X[group] for group in indices_groups], [Y[group] for group in indices_groups]
    if not batch_first:
        X_batched, Y_batched = [x.permute(1,0,2) for x in X_batched], [y.permute(1,0,2) for y in Y_batched]
    return X_batched, Y_batched


def initializer(model, batch_size):
    # h0 and c0 of shape (num_layers, batch_size, hidden_size); hidden_size = encoder_hidden_size = decoder_hidden_size
    """returns appropriate h0, c0 to initialize the encoder"""
    h0 = torch.zeros(model.decoder.num_layers, batch_size, model.decoder.hidden_size).normal_(mean=0.0, std=0.01)
    c0 = torch.zeros(model.decoder.num_layers, batch_size, model.decoder.hidden_size).normal_(mean=0.0, std=0.01)
    return h0, c0


class MyDataset(Dataset):
    def __init__(self, X, Y): # X.shape[0] = Y.shape[0] = num_samples
        self.X = X
        self.Y = Y
    def __len__(self):
        return len(self.X)
    def __getitem__(self, idx):
        return self.X[idx], self.Y[idx]


### needs revising for multivariate if input_size != output_size
def train_seq2seq_model(model, criterion, optimizer, train_dataset: MyDataset,
                        num_epochs: int, batch_size: int, transform: callable, teacher_forcing_ratio: float):
    batch_first = model.encoder.batch_first
    dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    for epoch_idx in range(num_epochs):
        print(f"Epoch {epoch_idx+1}...")
        for _, (X_batch, Y_batch) in enumerate(dataloader): # X_batch.shape = (batch_size, L, encoder_input_size); Y_batch.shape = (batch_size, L, decoder_output_size)
            # X_batch, Y_batch = X_batch.to(model.device), Y_batch.to(model.device)
            ### this part for x0 initialization needs modifying
            x0 = X_batch[:, -1, :]
            if not batch_first:
                X_batch, Y_batch = X_batch.transpose(0,1), Y_batch.transpose(0,1)
            optimizer.zero_grad()

            # forward pass
            h0, c0 = initializer(model, batch_size)
            outputs = model(X_batch, x0, h0, c0,
                            transform=transform,
                            teacher_forcing_ratio=teacher_forcing_ratio,
                            Y=Y_batch)
            loss = criterion(outputs, Y_batch)
            loss.backward() # back-propagation
            nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) # avoids exploding gradients
            optimizer.step()
    print("Finished training.")