# Baseline forecast

- first scenario is simple: multivariate forecasting of the system based on past behavior (both observable and control variables)

## Dataloader

In [None]:
cd ..

In [None]:
from typing import Tuple

import numpy as np
import pandas as pd
import pytorch_lightning as pl
from torch.utils.data import Dataset, Subset, DataLoader
from scipy.integrate import odeint

from data.data_module import ThreeTankDataModule

In [None]:
dm = ThreeTankDataModule(batch_size=256)
# dm.setup()

## Vanilla LSTM

### Model

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl
from torch.optim import Adam
from torch.optim.lr_scheduler import ReduceLROnPlateau

In [None]:
import torch
import torch.nn as nn
import pytorch_lightning as pl
from torchmetrics import MeanSquaredError, MeanAbsoluteError
from torch.optim import Adam
from torch.optim.lr_scheduler import ReduceLROnPlateau

import train_arguments as args



class PLCore(pl.LightningModule):
    """pytorch lightning core module
    This module is the base class for all models.
    It implements the training, validation and test steps.
    Args:
        d_seq_in (int): input sequence length
        d_features (int): number of features
        d_seq_out (int): output sequence length
        train_scenario (str): the scenario to train on
        lr (float): learning rate
        beta1 (float): beta1 for Adam optimizer
        beta2 (float): beta2 for Adam optimizer
        eps (float): epsilon for Adam optimizer
    """
    def __init__(self, d_seq_in=250, d_features=3, d_seq_out=50, 
                 train_scenario="standard",
                 lr=1e-3, beta1=0.9, beta2=0.999, eps=1e-8):
        super().__init__()

        self.d_seq_in = d_seq_in
        self.d_features = d_features
        self.d_seq_out = d_seq_out
        
        self.example_input_array = torch.rand(32, self.d_seq_in, self.d_features)  # 32 as example batch size

        self.save_hyperparameters()  # stores hyperparameters in self.hparams and allows logging
        self.visualization_device = "cpu"  # for visualizations, can be changed in model if necessary

        scenario_dict = {
            "standard": 0,
            "fault": 1,
            "noise": 2,
            "duration": 3,
            "scale": 4,
            "switch": 5,
            "q1+v3": 6,
            "q1+v3+rest": 7,
            "v12+v23": 8,
            "standard+": 9,
            "standard++": 10,
            "frequency": 11,
            "time_warp": 12
        }
        self.train_scenario_idx = scenario_dict[train_scenario]
        
        # metrics to keep track of
        metrics = {
            "MAE": MeanAbsoluteError(),
            "MSE": MeanSquaredError()
        }
        # the loss function to use for backpropagation
        self.loss_fct_key = args.LOSS_FCT  
        assert self.loss_fct_key in metrics.keys(), "loss function key should be in metrics"

        self.train_metrics = nn.ModuleDict({name: metric.clone() for name, metric in metrics.items()})
        self.val_metrics = nn.ModuleDict({name: metric.clone() for name, metric in metrics.items()})
        self.test_metrics = nn.ModuleDict({name: metric.clone() for name, metric in metrics.items()})
        self.validation_step_outputs = []
        self.min_epoch_val_loss = float("inf")

    def _shared_step(self, x, y):
        """Shared step used in training, validation and test step.
        Should return the prediction and the target (y_pred, y).
        """
        raise NotImplementedError("This should be implemented in the model that inherits from PLCore.")

    @torch.no_grad()
    def forward(self, x):
        """Forward pass for pytorch lightning.
        Should return the prediction (y_pred)."""
        return self._shared_step(x, None)[0]
    
    def training_step(self, batch, batch_id):
        """Training step for pytorch lightning.
        Only receives dataloader 0, which is the training dataloader.
        """
        x1, x2 = batch
        pred, target = self._shared_step(x1, x2)
        for name, metric in self.train_metrics.items():
            metric_loss = metric(pred, target)
            self.log("train_" + name, metric_loss, logger=True)
            if name == self.loss_fct_key:
                # use this loss function for backpropagation
                loss = metric_loss
        return loss

    def validation_step(self, batch, batch_id, dataloader_idx):
        """Validation step for pytorch lightning."""
        x1, x2 = batch
        pred, target = self._shared_step(x1, x2)
        for name, metric in self.val_metrics.items():
            metric_loss = metric(pred, target)
            self.log("val_" + name, metric_loss, logger=True)
            if name == self.loss_fct_key and dataloader_idx == self.train_scenario_idx:
                # save epoch losses on standard dataset for logging
                self.validation_step_outputs.append(metric_loss)

    def on_validation_epoch_end(self):
        """Validation epoch end for pytorch lightning.
        Only receives val losses from dataloader 0.
        """
        epoch_losses = torch.stack(self.validation_step_outputs)
        mean_loss = torch.mean(epoch_losses)
        self.log("ep_val_loss", mean_loss, prog_bar=True, logger=True)
        self.validation_step_outputs.clear()  # free memory
        # save if this is the best model so far
        if mean_loss < self.min_epoch_val_loss:
            self.min_epoch_val_loss = mean_loss

    def on_train_end(self):
        """Train end for pytorch lightning."""
        # log best val_loss
        if self.logger:
            self.logger.log_hyperparams(self.hparams, {"hp/min_epoch_val_loss": self.min_epoch_val_loss})

    def test_step(self, batch, batch_id, dataloader_idx):
        """Test step for pytorch lightning."""
        x1, x2 = batch
        pred, target = self._shared_step(x1, x2)
        for name, metric in self.test_metrics.items():
            self.log("test_" + name, metric(pred, target), logger=True)

    def configure_optimizers(self):
        """Configure optimizers for pytorch lightning."""
        optimizer = Adam(
            self.parameters(), 
            lr=self.hparams.lr,  
            betas=(self.hparams.beta1, self.hparams.beta2),
            eps=self.hparams.eps
            )
        scheduler = ReduceLROnPlateau(optimizer, factor=0.5, patience=25, min_lr=1e-5)
        return [optimizer], [{"scheduler": scheduler, "interval": "epoch", "monitor": f"val_{self.loss_fct_key}/dataloader_idx_0"}]


In [None]:
import torch
import torch.nn as nn

from models.core import PLCore



class LSTM(PLCore):
	"""LSTM model.
	Args:
		d_hidden (int): hidden dimension
		n_layers (int): number of layers
		bidirectional (bool): whether to use bidirectional LSTM
		dropout (float): dropout rate
		autoregressive (bool): whether to predict one output at a time
	"""
	def __init__(self, d_hidden, n_layers=1, bidirectional=False, dropout=0.5, autoregressive=False, **kwargs):
		super().__init__(**kwargs)
		
		# self.d_features and self.d_seq_out from parent class

		self.lstm = nn.LSTM(
			input_size=self.d_features,
			hidden_size=d_hidden,
			num_layers=n_layers,
			bidirectional=bidirectional,
			batch_first=True,
			dropout=dropout if n_layers > 1 else 0
		)
		
		if not autoregressive:
			# Predict all outputs at once
			self.fc = nn.Linear(d_hidden * (bidirectional + 1), self.d_features * self.d_seq_out)
		else:
			# Predict one output at a time
			self.fc = nn.Linear(d_hidden * (bidirectional + 1), self.d_features)
		self.autoregressive = autoregressive

	def _shared_step(self, x, y):
		# x: (batch_size, d_seq_in, d_features)
		# y: (batch_size, d_seq_out, d_features)
		b_size = x.size(0)

		if not self.autoregressive:
			_, (h, _) = self.lstm(x)
			h = h[-1, :, :]
			y_pred = self.fc(h).view(b_size, self.d_seq_out, self.d_features)
		else:
			y_pred = []
			_, (h, c) = self.lstm(x)
			output = self.fc(h[-1, :, :]).unsqueeze(1)
			y_pred.append(output)

			# Autoregressive forecasting
			for _ in range(self.d_seq_out - 1):
				_, (h, c) = self.lstm(output, (h, c))
				output = self.fc(h[-1, :, :]).unsqueeze(1)
				y_pred.append(output)

			y_pred = torch.cat(y_pred, dim=1)

		return y_pred, y

### Train

In [None]:
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import LearningRateMonitor, EarlyStopping
from torchmetrics import MeanSquaredError, MeanAbsoluteError

In [None]:
callbacks = [EarlyStopping(monitor="val_loss/dataloader_idx_0", patience=50)]

use_logger = False
if not use_logger:
    logger = False
else:
    name = "LSTM"
    callbacks.append(LearningRateMonitor())
    logger = TensorBoardLogger(
        "logs/2-baseline",  # change to notebook number
        name=name, 
        default_hp_metric=False
    )

trainer_hparams = dict(
    accelerator='auto', 
    devices=1,
    max_epochs=500,
    log_every_n_steps=10,
    logger=logger,
    callbacks=callbacks,
    enable_checkpointing=False
)
trainer = pl.Trainer(**trainer_hparams)

model = LSTM(d_hidden=256)

In [None]:
# train
trainer.fit(model=model, datamodule=dm)

# Visualization

In [None]:
import plotly.graph_objects as go
from plotly.subplots import make_subplots

In [None]:
def plot_sample_forecast(sample, fcast, title=None, display=True):
    """Plots forecast of tank levels and settings for one sample.
    Args:
        sample: torch.Tensor, shape (seq_len, 3), sample from dataloader
        fcast: torch.Tensor, shape (pred_len, 3), forecast of sample
        title: str, title of plot
        display: bool, if True, plot is displayed, else returned
    """
    x1, x2 = sample
    pred_x2 = np.squeeze(fcast)
    x = np.concatenate((x1, x2))
    colors = [
        '#1f77b4',  # muted blue
        '#d62728',  # brick red
        '#2ca02c',  # cooked asparagus green
        '#17becf',  # blue-teal
        '#ff7f0e',  # safety orange
        '#bcbd22',  # curry yellow-green
        '#9467bd',  # muted purple
        '#8c564b',  # chestnut brown
        '#e377c2',  # raspberry yogurt pink
        '#7f7f7f',  # middle gray
    ]

    fig = go.Figure()
    for sig, name, c in zip([x[:, 0], x[:, 1], x[:, 2]],
                            ['h1', 'h2', 'h3'],
                            colors[:3]):
        fig.add_trace(go.Scatter(x=np.array(range(x.shape[0])), y=sig, name=name,
                      mode="lines", opacity=1, line=dict(color=c)))
    for sig, name, c in zip([pred_x2[:, 0], pred_x2[:, 1], pred_x2[:, 2]],
                            ['pred_h1', 'pred_h2', 'pred_h3'],
                            colors[3:7]):
        fig.add_trace(go.Scatter(x=np.array(range(x1.shape[0], x1.shape[0] + x2.shape[0])), y=sig, name=name,
                      mode="lines", opacity=1, line=dict(color=c, dash="dot")))

    fig.add_vline(x=len(x1), line_dash="dash")
    fig.update_xaxes(tick0=0, dtick=200)
    fig.update_xaxes(title_text=r'time')
    fig.update_layout(width=800, height=500,
                      font_family="Serif", font_size=14,
                      margin_l=5, margin_t=50, margin_b=5, margin_r=5)
    if title is not None:
        fig.update_layout(title=title)
    if display:
        fig.show()
    else:
        return fig


def fcast_overview(datamodule, model, idx=0, title=None, save_path=None):
    """Plots forecast of tank levels and settings.
    All scenarios are plotted in two figures. Combines plot_sample_forecast() and fcast_overview_separate().
    Args:
        datamodule: DataModule
        model: Model
        idx: int, index of sample to plot
        title: str, title of plot
        save_path: str, path to save plot
    """
    model = model.to(model.visualization_device)
    model.eval()
    datasets = datamodule.ds_dict

    # plot water levels
    n_rows = 3
    n_cols = 3
    fig = subplots.make_subplots(rows=n_rows, cols=n_cols, shared_xaxes=True, vertical_spacing=0.02)
    for i, (scenario, ds) in enumerate(datasets.items()):
        if i >= n_rows * n_cols:
            break
        sample = ds[idx]
        x = torch.tensor(sample[0]).unsqueeze(0)
        fcast = model(x).cpu().detach().numpy()
        fcast_plot = plot_sample_forecast(sample, fcast, title=scenario, display=False)
        for j in range(6):
            fig.add_trace(fcast_plot.data[j], row=(i//n_cols)+1, col=(i%n_cols)+1)
        fig.update_xaxes(tick0=0, dtick=50)
        fig.update_yaxes(title_text=scenario + f" [DataLoader {i}]", row=(i//n_cols)+1, col=(i%n_cols)+1)

    fig.add_vline(x=x.size(1), line_dash="dash")
    for col in range(n_cols):
        fig.update_xaxes(title_text=r'time', row=n_rows, col=col+1)
    fig.update_layout(showlegend=False)
    fig.update_layout(title=f"Water Level Predictions by {title}")
    # export figure
    if save_path is not None:
        # if the path does not exist, create it
        if not os.path.exists(save_path):
            os.makedirs(save_path)
        fig.write_image(save_path + f"water_levels_{title}.png", width=1200, height=800)
    else:
        fig.show()
                

In [None]:
fcast_overview(dm, model, 0)

In [None]:
trainer.test(model=model, datamodule=dm)