In [1]:
import warnings

import lightning.pytorch as pl
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchmetrics import (
    ExplainedVariance,
    MeanAbsoluteError,
    MeanSquaredError,
)

warnings.filterwarnings("ignore")

# config
%matplotlib inline
%config InlineBackend.figure_format='retina'

In [2]:
class LSTMTransformerModel(pl.LightningModule):
    def __init__(
        self,
        input_dim=1,
        hidden_dim=64,
        num_layers=2,
        output_dim=1,
        droupout=0.2,
        lr=0.01,
        d_model=64,
        nhead=4,
        num_transformer_layers=2,
    ):
        super(LSTMTransformerModel, self).__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers
        self.output_dim = output_dim
        self.droupout = droupout
        self.lr = lr
        self.d_model = d_model
        self.nhead = nhead
        self.num_transformer_layers = num_transformer_layers

        self.lstm = nn.LSTM(
            input_size=input_dim,
            hidden_size=hidden_dim,
            num_layers=num_layers,
            batch_first=True,
        )
        self.dropout = nn.Dropout(droupout)
        self.fc1 = nn.Linear(in_features=hidden_dim, out_features=d_model)
        self.transformer_encoder = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model, nhead),
            num_layers=num_transformer_layers,
        )
        self.fc2 = nn.Linear(in_features=d_model, out_features=output_dim)

        self.mse = MeanSquaredError()
        self.mae = MeanAbsoluteError()
        self.evs = ExplainedVariance()

    def forward(self, x):
        out, _ = self.lstm(x)
        out = out[:, -1, :]
        out = self.dropout(out)
        out = self.fc1(out)
        out = self.transformer_encoder(out)
        out = self.fc2(out)
        return out

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_pred = self(x["encoder_cont"])
        loss = F.mse_loss(y_pred, y[0])
        self.log("train_loss", loss)
        self.log("train_mse", self.mse(y_pred, y[0]))
        self.log("train_mae", self.mae(y_pred, y[0]))
        self.log("train_evs", self.evs(y_pred, y[0]))
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_pred = self(x["encoder_cont"])
        loss = F.mse_loss(y_pred, y[0])
        self.log("validation_loss", loss)
        self.log("validation_mse", self.mse(y_pred, y[0]))
        self.log("validation_mae", self.mae(y_pred, y[0]))
        self.log("validation_evs", self.evs(y_pred, y[0]))
        return loss

    def test_step(self, batch, batch_idx):
        x, y = batch
        y_pred = self(x["encoder_cont"])
        loss = F.mse_loss(y_pred, y[0])
        self.log("test_loss", loss)
        self.log("test_mse", self.mse(y_pred, y[0]))
        self.log("test_mae", self.mae(y_pred, y[0]))
        self.log("test_evs", self.evs(y_pred, y[0]))

    def predict_step(self, batch, batch_idx):
        x, y = batch
        y_pred = self(x["encoder_cont"])
        y_pred = y_pred.detach().numpy()
        y_pred = self.trainer.datamodule.target_scaler.inverse_transform(y_pred)
        return y_pred

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.lr)