In [1]:
import pandas as pd
import numpy as np
import torch
import lightning as L
import optuna

from lightning.pytorch.callbacks import Callback
from pytorch_forecasting.metrics.quantile import QuantileLoss
from optuna.integration.pytorch_lightning import PyTorchLightningPruningCallback

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
output_dir = "./OutputData/"

In [3]:
df = pd.read_csv(output_dir + "train_data.csv")
df["time"] = pd.to_datetime(df["time"])

In [4]:
df

Unnamed: 0,time,consumption_MWh,consumption_lag2,trend,hour_sin,hour_cos,day_sin,day_cos,month_sin,month_cos
0,2018-01-01 02:00:00,24635.32,27412.81,2,7.071068e-01,7.071068e-01,7.818315e-01,0.62349,5.000000e-01,0.866025
1,2018-01-01 03:00:00,23872.12,26324.39,3,8.660254e-01,5.000000e-01,7.818315e-01,0.62349,5.000000e-01,0.866025
2,2018-01-01 04:00:00,23194.89,24635.32,4,9.659258e-01,2.588190e-01,7.818315e-01,0.62349,5.000000e-01,0.866025
3,2018-01-01 05:00:00,23071.96,23872.12,5,1.000000e+00,6.123234e-17,7.818315e-01,0.62349,5.000000e-01,0.866025
4,2018-01-01 06:00:00,23267.90,23194.89,6,9.659258e-01,-2.588190e-01,7.818315e-01,0.62349,5.000000e-01,0.866025
...,...,...,...,...,...,...,...,...,...,...
52577,2023-12-31 19:00:00,35090.93,34549.42,52579,-8.660254e-01,5.000000e-01,-2.449294e-16,1.00000,-2.449294e-16,1.000000
52578,2023-12-31 20:00:00,33310.94,36193.59,52580,-7.071068e-01,7.071068e-01,-2.449294e-16,1.00000,-2.449294e-16,1.000000
52579,2023-12-31 21:00:00,32083.96,35090.93,52581,-5.000000e-01,8.660254e-01,-2.449294e-16,1.00000,-2.449294e-16,1.000000
52580,2023-12-31 22:00:00,30469.49,33310.94,52582,-2.588190e-01,9.659258e-01,-2.449294e-16,1.00000,-2.449294e-16,1.000000


## Data prep: Getting input & output sequences

In [5]:
past_target = df.consumption_lag2.values

In [6]:
historic_covars = df.drop(["time", "consumption_MWh", "consumption_lag2"], axis = 1).values

In [7]:
future_covars = df.drop(["time", "consumption_MWh", "consumption_lag2"], axis = 1).shift(-1).values

In [8]:
future_target = df.consumption_MWh.shift(-1).values

In [9]:
# This is the consumption_lag2 value at T
past_target

array([27412.81, 26324.39, 24635.32, ..., 35090.93, 33310.94, 32083.96])

In [10]:
# This is consumption_MWh at T+1, the target value at T
future_target

array([23872.12, 23194.89, 23071.96, ..., 30469.49, 30029.91,      nan])

In [11]:
# These are the trend & seasonality features at T (historic future covariates)
historic_covars

array([[ 2.00000000e+00,  7.07106781e-01,  7.07106781e-01, ...,
         6.23489802e-01,  5.00000000e-01,  8.66025404e-01],
       [ 3.00000000e+00,  8.66025404e-01,  5.00000000e-01, ...,
         6.23489802e-01,  5.00000000e-01,  8.66025404e-01],
       [ 4.00000000e+00,  9.65925826e-01,  2.58819045e-01, ...,
         6.23489802e-01,  5.00000000e-01,  8.66025404e-01],
       ...,
       [ 5.25810000e+04, -5.00000000e-01,  8.66025404e-01, ...,
         1.00000000e+00, -2.44929360e-16,  1.00000000e+00],
       [ 5.25820000e+04, -2.58819045e-01,  9.65925826e-01, ...,
         1.00000000e+00, -2.44929360e-16,  1.00000000e+00],
       [ 5.25830000e+04, -2.44929360e-16,  1.00000000e+00, ...,
         1.00000000e+00, -2.44929360e-16,  1.00000000e+00]])

In [12]:
# These are the trend & seasonality features at T+1 (future covariates)
future_covars

array([[ 3.00000000e+00,  8.66025404e-01,  5.00000000e-01, ...,
         6.23489802e-01,  5.00000000e-01,  8.66025404e-01],
       [ 4.00000000e+00,  9.65925826e-01,  2.58819045e-01, ...,
         6.23489802e-01,  5.00000000e-01,  8.66025404e-01],
       [ 5.00000000e+00,  1.00000000e+00,  6.12323400e-17, ...,
         6.23489802e-01,  5.00000000e-01,  8.66025404e-01],
       ...,
       [ 5.25820000e+04, -2.58819045e-01,  9.65925826e-01, ...,
         1.00000000e+00, -2.44929360e-16,  1.00000000e+00],
       [ 5.25830000e+04, -2.44929360e-16,  1.00000000e+00, ...,
         1.00000000e+00, -2.44929360e-16,  1.00000000e+00],
       [            nan,             nan,             nan, ...,
                    nan,             nan,             nan]])

In [13]:
# Get rid of last rows due to unknown future target
past_target = past_target[:-1]
future_target = future_target[:-1]
historic_covars = historic_covars[:-1, :]
future_covars = future_covars[:-1, :]

In [14]:
# Check shapes
print("Past target shape: " f"{past_target.shape}")
print("Historic future covariates shape: " f"{historic_covars.shape}")
print("Future target shape: " f"{future_target.shape}")
print("Future covariates shape: " f"{future_covars.shape}")

Past target shape: (52581,)
Historic future covariates shape: (52581, 7)
Future target shape: (52581,)
Future covariates shape: (52581, 7)


In [15]:
# Get shifted datasets
df_past = pd.DataFrame(
    np.concatenate((past_target.reshape(-1, 1), historic_covars), axis = 1),
    columns = df.columns.values[2:]
)
df_future = pd.DataFrame(
    np.concatenate((future_target.reshape(-1, 1), future_covars), axis = 1),
    columns = df.columns.values[2:]
).rename({"consumption_lag2": "consumption_MWh"}, axis = 1)

In [16]:
df_past

Unnamed: 0,consumption_lag2,trend,hour_sin,hour_cos,day_sin,day_cos,month_sin,month_cos
0,27412.81,2.0,0.707107,7.071068e-01,7.818315e-01,0.62349,5.000000e-01,0.866025
1,26324.39,3.0,0.866025,5.000000e-01,7.818315e-01,0.62349,5.000000e-01,0.866025
2,24635.32,4.0,0.965926,2.588190e-01,7.818315e-01,0.62349,5.000000e-01,0.866025
3,23872.12,5.0,1.000000,6.123234e-17,7.818315e-01,0.62349,5.000000e-01,0.866025
4,23194.89,6.0,0.965926,-2.588190e-01,7.818315e-01,0.62349,5.000000e-01,0.866025
...,...,...,...,...,...,...,...,...
52576,32670.06,52578.0,-0.965926,2.588190e-01,-2.449294e-16,1.00000,-2.449294e-16,1.000000
52577,34549.42,52579.0,-0.866025,5.000000e-01,-2.449294e-16,1.00000,-2.449294e-16,1.000000
52578,36193.59,52580.0,-0.707107,7.071068e-01,-2.449294e-16,1.00000,-2.449294e-16,1.000000
52579,35090.93,52581.0,-0.500000,8.660254e-01,-2.449294e-16,1.00000,-2.449294e-16,1.000000


In [17]:
df_future

Unnamed: 0,consumption_MWh,trend,hour_sin,hour_cos,day_sin,day_cos,month_sin,month_cos
0,23872.12,3.0,8.660254e-01,5.000000e-01,7.818315e-01,0.62349,5.000000e-01,0.866025
1,23194.89,4.0,9.659258e-01,2.588190e-01,7.818315e-01,0.62349,5.000000e-01,0.866025
2,23071.96,5.0,1.000000e+00,6.123234e-17,7.818315e-01,0.62349,5.000000e-01,0.866025
3,23267.90,6.0,9.659258e-01,-2.588190e-01,7.818315e-01,0.62349,5.000000e-01,0.866025
4,23875.44,7.0,8.660254e-01,-5.000000e-01,7.818315e-01,0.62349,5.000000e-01,0.866025
...,...,...,...,...,...,...,...,...
52576,35090.93,52579.0,-8.660254e-01,5.000000e-01,-2.449294e-16,1.00000,-2.449294e-16,1.000000
52577,33310.94,52580.0,-7.071068e-01,7.071068e-01,-2.449294e-16,1.00000,-2.449294e-16,1.000000
52578,32083.96,52581.0,-5.000000e-01,8.660254e-01,-2.449294e-16,1.00000,-2.449294e-16,1.000000
52579,30469.49,52582.0,-2.588190e-01,9.659258e-01,-2.449294e-16,1.00000,-2.449294e-16,1.000000


In [18]:
n_steps = len(df_future)
input_length = 72 # T-N to T hours as input
input_dims = 8 # Consumption lag 2, trend, 6 cyclical columns
output_length = 32 # We are only interested in T+8 to 32, but we have to predict from T+1 because we need hidden states at each time step.

In [19]:
# Find the index of the first 16:00 row in the data, where the index is bigger than input_length - 1. This will be the first T.
first_t = df.loc[(df.time.dt.hour == 16) & (df.index >= input_length - 1)].index.values[0]

In [20]:
first_t 

86

In [21]:
# Find the index of the last 16:00 row in the data, with 32 time steps after it. This will be the last T.
last_t = df.loc[(df.time.dt.hour == 16) & (df.index + output_length - 1 <= df.index.values[-1])].index.values[-2]

In [22]:
last_t 

52526

In [23]:
# One input sequence: past target [T - input_length, T] & future covariates [T - input_length - 1, T+1]
input_seq = np.concatenate((
        df_past.iloc[(first_t - input_length):first_t, 0].values.reshape(-1, 1), # Past target
        df_future.iloc[(first_t - input_length):first_t, 1:].values # Future covariates
    ), axis = 1)
input_seq.shape

(72, 8)

In [24]:
# One output sequence: future target [T + 1, T + output_length] & future covariates [T+2, T + output_length + 1].
output_seq = df_future.iloc[first_t:(first_t + output_length), :].values # Target & future covariates for following steps
output_seq.shape

(32, 8)

We pair every target value at T+1 with the future covariates of the target value at T+2.
\
This is because the target at T+1 and the future covariates at T+2 will be past target & future covariates in the next step. LSTMs and RNNs can only forecast 1 step at a time by their nature. 
\
For validation & prediciton steps, we will replace the future targets after T+1 with predicitons from the previous step, as these will be unknown values at real prediciton time.
\
During training, we still use the real target values for all prediciton steps "in hindsight", as training with predictions as the target may mislead the model. In a real life scenario, we'd have the "hindsight" values available in the historic data, just like we do here.

In [25]:
n_sequences = (last_t - first_t) // 24 + 1 # Number of 16:00 rows followed by a sufficient input / output sequence
print("Number of possible sequences: " + f"{n_sequences}")

Number of possible sequences: 2186


In [26]:
# Get all sequences
for t in range(first_t, last_t + 1, 24):

    # Get input sequence
    new_input = np.concatenate((
        df_past.iloc[(t - input_length):t, 0].values.reshape(-1, 1),
        df_future.iloc[(t - input_length):t, 1:].values
    ), axis = 1)

    # Get output sequence
    new_output = df_future.iloc[t:(t + output_length), :].values

    if t == first_t:

        # Initialize arrays of sequences
        input_sequences = np.array([new_input])
        output_sequences = np.array([new_output])
        
    else:
        # Concatenate to arrays of sequences
        input_sequences = np.concatenate((input_sequences, [new_input]), axis = 0)
        output_sequences = np.concatenate((output_sequences, [new_output]), axis = 0)


In [27]:
input_sequences.shape

(2186, 72, 8)

In [28]:
output_sequences.shape

(2186, 32, 8)

In [29]:
# Should be the last past target of the first input sequence. Row (first_t) in df_past.
# Also the future target at row (first_t - 3) in df_future.
input_sequences[0, -1, 0] 

39635.29

In [30]:
df_past.iloc[82:89, 0]

82    40593.83
83    40955.07
84    39505.55
85    39635.29
86    39952.75
87    39649.45
88    40063.17
Name: consumption_lag2, dtype: float64

In [31]:
# Should be the first future target of the first input sequence. Row (first_t + 1) in df_future
output_sequences[0, 0, 0] 

40487.65

In [32]:
df_future.iloc[82:89, 0]

82    39635.29
83    39952.75
84    39649.45
85    40063.17
86    40487.65
87    39936.25
88    38772.68
Name: consumption_MWh, dtype: float64

Sequencing seems successful.

## Preprocessing, Torch datasets & dataloaders

In [33]:
# Get indices for train - val - test plit
sixty_percent = int(input_sequences.shape[0] * 0.6)
twenty_percent = int(input_sequences.shape[0] * 0.2)
train_end = sixty_percent
val_end = sixty_percent + twenty_percent

In [34]:
# Perform train - val - test split
tr_input, tr_output = input_sequences[0:train_end], output_sequences[0:train_end] # Training data at validation step
train_input, train_output = input_sequences[0:val_end], output_sequences[0:val_end] # Training data at testing step

val_input, val_output = input_sequences[train_end:val_end], output_sequences[train_end:val_end]
test_input, test_output = input_sequences[val_end:], output_sequences[val_end:]

We have to scale the past consumption & trend values in the input sequences, and the future consumption & trend values in the output sequences, because they'll be the past values as the forecast horizon expands.
\
We also need the ability to backtransform the network's final predictions accordingly. We need a class instead of a function.

In [59]:
# Define scaling class for sequence data
class sequence_scaler:

    def __init__(self, feature_range = (-1, 1)):
        self.lower = feature_range[0]
        self.upper = feature_range[1]

    def fit(self, input_seq, output_seq):

        # Get number of features
        self.num_features = input_seq.shape[2]
        
        # Extract & save minimum, maximum for each feature
        feature_mini = []
        feature_maxi = []
        for feature in range(0, self.num_features):
            min = np.min([
                np.min(input_seq[:, :, feature]),
                np.min(output_seq[:, :, feature])
            ])
            feature_mini.append(min)

            max = np.max([
                np.max(input_seq[:, :, feature]),
                np.max(output_seq[:, :, feature])
            ])
            feature_maxi.append(max)

        self.feature_mini = feature_mini
        self.feature_maxi = feature_maxi

    def transform(self, sequences):

        # Initialize list of scaled features
        scaled_features = []

        # Scale each feature & append to list
        for feature in range(0, self.num_features):
            values = sequences[:, :, feature]
            min = self.feature_mini[feature]
            max = self.feature_maxi[feature]
            std = (values - min) / (max - min)
            scaled = std * (self.upper - self.lower) + self.lower
            scaled_features.append(scaled)

        # Stack over 3rd dimension & return
        return np.stack(scaled_features, axis = 2)

    def backtransform(self, )

In [60]:
# Scale validation data
scaler_val = sequence_scaler()
_ = scaler_val.fit(tr_input, tr_output)

In [66]:
scaler_val.transform(tr_input)

array([-3.08952997e-01, -1.00000000e+00, -1.00000000e+00, -2.22044605e-16,
        8.01937736e-01,  6.03875472e-01,  5.00000000e-01,  8.66025404e-01])

In [None]:
# Define Torch dataset class
class SequenceDataset(torch.utils.data.Dataset):

    # Store preprocessed input & output sequences
    def __init__(self, input_seq, output_seq): 
        self.input_seq = torch.tensor(input_seq, dtype = torch.float32) # Store inputs sequences
        self.output_seq = torch.tensor(output_seq, dtype = torch.float32) # Store output sequences
  
    # Return data length  
    def __len__(self):
        return len(self.input_seq) 
  
    # Return a pair of input & output sequences
    def __getitem__(self, idx):
        return self.input_seq[idx], self.output_seq[idx]

## Model testing

In [None]:
# Define model class
class StatefulQuantileLSTM(L.LightningModule):

    # Initialize model
    def __init__(self, hyperparams_dict):
        
         # Delegate function to parent class
        super().__init__() 
        
        # Save external hyperparameters so they are available when loading saved models
        self.save_hyperparameters(logger = False) 

        # Define hyperparameters
        self.input_length = hyperparams_dict["input_length"] # Length of input sequence. Necessary?
        self.output_length = hyperparams_dict["output_length"] # Length of output sequence
        self.input_size = hyperparams_dict["input_size"] # Number of features (network inputs)
        self.horizon = hyperparams_dict["horizon"] # Start of the forecast horizon relevant for loss computing
        self.quantiles = hyperparams_dict["quantiles"] # Provide as list of floats: [0.025, 0.5, 0.975]
        self.learning_rate = hyperparams_dict["learning_rate"]
        self.lr_decay = hyperparams_dict["lr_decay"]
        self.num_layers = hyperparams_dict["num_layers"] # Number of layers in the LSTM block
        self.hidden_size = hyperparams_dict["hidden_size"] # Number of units in each LSTM block = LSTM block output size
        self.dropout_rate = hyperparams_dict["dropout_rate"]

        # Define architecture
        
        # LSTM input: input, (prev_hidden_state, prev_cell_state)
        # Shapes: (N, input_length, input_size), ((num_layers, N, hidden_size), (num_layers, N, hidden_size))
        self.lstm = torch.nn.LSTM(
            input_size = self.input_size,
            hidden_size = self.hidden_size,
            num_layers = self.num_layers,
            batch_first = True
        )
        # LSTM output: output, (last_hidden_states, last_cell_states)
        # Shapes: (N, input_length, hidden_size), ((num_layers, N, hidden_size), (num_layers, N, hidden_size))
        # The tuple of hidden states & cell states have the last hidden & cell states for each LSTM layer.

        # Output layer input: LSTM output, shape (N, input_length, hidden_size)
        self.output_layer = torch.nn.Linear(
            in_features = self.hidden_size,
            out_features = 1
        )
        # Output layer output: Scalar prediction, shape (N, 1)

        # Loss function: Quantile loss
        self.loss = QuantileLoss(quantiles = self.quantiles)

    # Define forward propagation
    def forward(self, input_chunk, prev_states = None): # Pass prev_states as (prev_hidden_states, prev_cell_states)

        # Pass inputs through LSTMs
        # If prev_states is not passed, they are automatically initialized as zeroes
        if prev_states == None:
            lstm_output, (last_hidden_states, last_cell_states) = self.lstm(input_chunk)
        else: 
            lstm_output, (last_hidden_states, last_cell_states) = self.lstm(input_chunk, prev_states)

        # Pass LSTM output through output layer
        preds = self.output_layer(lstm_output)

        return last_hidden_states, last_cell_states, preds

    # Define training step
    def training_step(self, batch, batch_idx):

        # Initialize variables to record horizon, hidden & cell states, predictions
        h = 0
        prev_hiddens = []
        prev_cells = []
        batch_preds = []

        # Get inputs & outputs for first forecast step
        input_sequences, output_sequences = batch
        input_seq = input_sequences # Inputs of the forecast step 0. (N, input_length, input_size) 
        output_seq = output_sequences[:, 0, :] # Target & future covars of forecast step 0. Needed for later forecast steps. (N, 1, input_size)

        # Perform training & recording for first forecast step
        # If a hidden & cell state is retained from the previous batch, use it. This will be the case for all batches except the first in an epoch.
        if self._last_hiddens_train == None:
            last_hidden_states, last_cell_states, preds = self.forward(input_seq)
        else:
            last_hidden_states, last_cell_states, preds = self.forward(
                input_seq, 
                prev_states = (self._last_hiddens_train, self._last_cells_train)
            )

        prev_hiddens.append(last_hidden_states) # 1-dimensional list. Each element has shape (num_layers, N, hidden_size)
        prev_cells.append(last_cell_states) # 1-dimensional list. Each element has shape (num_layers, N, hidden_size)
        batch_preds.append(preds) # 1-dimensional list. Each element has shape (N, 1)
        h += 1

        # Perform training & recording for remaining forecast steps
        while h < (self.output_length - 1):

            # Get inputs & outputs for forecast step h: 
            input_seq = torch.cat((
                input_seq[:, 1:, :], # Inputs of the previous forecast step, with the first row dropped. (N, input_length - 1, input_size)
                output_seq, # Target & future covars of previous forecast step, the last row of the new input. (N, 1, input_size)
            ), dims = 1)
            
            output_seq = output_sequences[:, h, :] # Target & covars. of forecast step h. Needed for later forecast steps. (N, 1, input_size)

            # Perform training & recording for forecast step h:
            last_hidden_states, last_cell_states, preds = self.forward(
                input_seq, 
                prev_states = (prev_hiddens[h-1], prev_cells[h-1])
            )
            prev_hiddens.append(last_hidden_states)
            prev_cells.append(last_cell_states)
            batch_preds.append(preds)
            h += 1

        # Calculate loss for forecast steps starting from horizon
        preds_horizon = batch_preds[self.horizon:] # List length (output_length - horizon). Each elememt has shape (N, 1).
        preds_horizon = torch.cat(preds_horizon, dim = 1) # Shape (N, output_length - horizon) # RESHAPE???

        loss = self.loss.loss(
            preds_horizon, 
            output_sequences[:, self.horizon: , 0] # Target values from horizon to end of sequence. Shape(N, output_length - horizon, 1)
        )

        # Log the training loss
        self.log("train_loss", loss, on_step = True, on_epoch = True, prog_bar = True, logger = True)

        # Update last hidden & cell states from training (for within-epoch use)
        self._last_hiddens_train = prev_hiddens[-1]
        self._last_cells_train = prev_cells[-1]

        # Update final hidden & cell states from training (for inference)
        self._final_hiddens_train = prev_hiddens[-1]
        self._final_cells_train = prev_cells[-1]
        
        return loss

    # When a training epoch ends, flush the last hidden & cell states.
    # Final hidden & cell states remain for inference.
    def on_train_epoch_end(self):
        self._last_hiddens_train = None
        self._last_cells_train = None

    # Method to flush the final hidden & cell states left from training, if desired
    def reset_context(self):
        self._final_hiddens_train = None
        self._final_cells_train = None

    # Define validation_step
    def validation_step(self, batch, batch_idx):

        # Initialize variables to record horizon, hidden & cell states, predictions
        h = 0
        prev_hiddens = []
        prev_cells = []
        batch_preds = []

        # Get inputs & outputs for first forecast step
        input_sequences, output_sequences = batch
        input_seq = input_sequences # Inputs of the forecast step 0. (N, input_length, input_size) 
        output_seq = output_sequences[:, 0, :] # Target & future covars of forecast step 0. Needed for later forecast steps. (N, 1, input_size)

        # Perform validation & recording for first forecast step
        # If a hidden & cell state is retained from training, use it.
        if self._final_hiddens_train == None:
            last_hidden_states, last_cell_states, preds = self.forward(input_seq)
        else:
            last_hidden_states, last_cell_states, preds = self.forward(
                input_seq, 
                prev_states = (self._final_hiddens_train, self._final_cells_train)
            )

        prev_hiddens.append(last_hidden_states) # 1-dimensional list. Each element has shape (num_layers, N, hidden_size)
        prev_cells.append(last_cell_states) # 1-dimensional list. Each element has shape (num_layers, N, hidden_size)
        batch_preds.append(preds) # 1-dimensional list. Each element has shape (N, 1)
        h += 1

        # Perform training & recording for remaining forecast steps
        while h < (self.output_length - 1):

            # Get inputs & outputs for forecast step h: 
            input_seq = torch.cat((
                input_seq[:, 1:, :], # Inputs of the previous forecast step, with the first row dropped. (N, input_length - 1, input_size)
                output_seq, # Target & future covars of previous forecast step, the last row of the new input. (N, 1, input_size)
            ), dims = 1)
            
            output_seq = output_sequences[:, h, :] # Target & covars. of forecast step h. Needed for later forecast steps. (N, 1, input_size)

            # Perform training & recording for forecast step h:
            last_hidden_states, last_cell_states, preds = self.forward(
                input_seq, 
                prev_states = (prev_hiddens[h-1], prev_cells[h-1])
            )
            prev_hiddens.append(last_hidden_states)
            prev_cells.append(last_cell_states)
            batch_preds.append(preds)
            h += 1

        # Calculate loss for forecast steps starting from horizon
        preds_horizon = batch_preds[self.horizon:] # List length (output_length - horizon). Each elememt has shape (N, 1).
        preds_horizon = torch.cat(preds_horizon, dim = 1) # Shape (N, output_length - horizon) # RESHAPE???

        loss = self.loss.loss(
            preds_horizon, 
            output_sequences[:, self.horizon: , 0] # Target values from horizon to end of sequence. Shape(N, output_length - horizon, 1)
        )

        # Log the val. loss
        self.log("val_loss", loss, on_step = True, on_epoch = True, prog_bar = True, logger = True)

        return loss

    # Define prediction_step
    def predict_step(self, batch, batch_idx):

        # Stack the prediction at h = 0 (N, output_length, 1) with the future covars (N, output_length, 1)

       
        

    # Define optimizer & learning rate scheduler
    def configure_optimizers(self):

        # Adam optimizer
        optimizer = torch.optim.Adam(self.parameters(), lr = self.learning_rate)
        
        # Exponential LR scheduler
        lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(
          optimizer, gamma = self.lr_decay) 
        
        return {
        "optimizer": optimizer,
        "lr_scheduler": {
          "scheduler": lr_scheduler
          }
        }
        