# Baseline forecast

## Vanilla MLP

In [None]:
import torch
import pytorch_lightning as pl
from torch.optim import Adam
from torch.optim.lr_scheduler import ReduceLROnPlateau

In [5]:
class PLCore(pl.LightningModule):
    """pytorch lightning core module"""
    def __init__(self, lr=1e-3):
        super().__init__()
        self.save_hyperparameters()  # stores hyperparameters in self.hparams and allows logging

    def _shared_step(self, x, y):
        """Shared step used in training, validation and test step."""
        raise NotImplementedError

    @torch.no_grad()
    def forward(self, x):
        raise NotImplementedError

    def training_step(self, batch, batch_id):
        (x1, x2), (s1, _) = batch
        loss = self._shared_step((x1, s1), x2)
        self.log("train_loss", loss)
        return loss

    def validation_step(self, batch, batch_id):
        (x1, x2), (s1, _) = batch
        loss = self._shared_step((x1, s1), x2)
        self.log("val_loss", loss)
        return loss

    def validation_epoch_end(self, outputs):
        # log hparams with val_loss as reference
        if self.logger:
            val_loss = torch.min(torch.stack(outputs))
            self.logger.log_hyperparams(self.hparams, {"hp/epoch_val_loss": val_loss})

    def test_step(self, batch, batch_id):
        (x1, x2), (s1, _) = batch
        loss = self._shared_step((x1, s1), x2)
        self.log("test_loss", loss)

    def configure_optimizers(self):
        optimizer = Adam(self.parameters(), lr=self.hparams.lr)
        scheduler = ReduceLROnPlateau(optimizer, factor=0.5, patience=10, min_lr=1e-5)
        return [optimizer], [{"scheduler": scheduler, "interval": "epoch", "monitor": "train_loss"}]

In [3]:
class MLP(PLCore):
    def __init__(self, d_in=400, d_hidden=512, d_out=150, lr=1e-3):
        super().__init__()

        self.ffn = nn.Sequential(nn.Linear(d_in, d_hidden),
                                 nn.ReLU(),
                                 nn.Linear(d_hidden, d_out))
    def _shared_step(self, x, y):
        (x1, s1) = x
        x2 = y
        
        x_ = torch.cat(torch.unsqueeze())
        
        h1 = x[:, 0, 0]
        h2 = x[:, 0, 1]
        h2_all = x[:, :, 1]  # keep as time series
        h3 = x[:, 0, 2]
        
        y = y[:, 0, :-1]  # q1, q3, kv1, kv2, kv3, no duration (always 50)
        
        y = torch.stack([h1, h2, h3, *torch.unbind(y, dim=-1)], dim=-1)
        
        h2_pred = self.ffn(y)
        
        loss = F.mse_loss(h2_all, h2_pred)
        
        return loss
    
    @torch.no_grad()
    def forward(self, y):
        h2_pred = self.ffn(y)
        return h2_pred

350