# Lag-Llama testing
In this notebook, we produce forecasts with Lag-Llama, both zero-shot and with the fine-tuned version. 

The model is pretty heavy and inference is slow. I recommend you run this notebook on a GPU-enabled environment, such as Colab.

This notebook relies on the demo notebooks open-sourced by the authors of Lag-Llama: [Notebook 1](https://colab.research.google.com/drive/1DRAzLUPxsd-0r8b-o4nlyFXrjw_ZajJJ?usp=sharing) and [Notebook 2](https://colab.research.google.com/drive/1uvTmh-pe1zO5TeaaRVDdoEWJ5dFDI-pA?usp=sharing)

## 0. Download and install Lag-Llama 
(If not installed already)

In [None]:
!git clone https://github.com/time-series-foundation-models/lag-llama/

Cloning into 'lag-llama'...
remote: Enumerating objects: 319, done.[K
remote: Counting objects: 100% (157/157), done.[K
remote: Compressing objects: 100% (71/71), done.[K
remote: Total 319 (delta 111), reused 105 (delta 84), pack-reused 162[K
Receiving objects: 100% (319/319), 232.35 KiB | 19.36 MiB/s, done.
Resolving deltas: 100% (152/152), done.


In [None]:
cd lag-llama

/content/lag-llama


In [None]:
!pip3 install -r requirements.txt --quiet
!huggingface-cli download time-series-foundation-models/Lag-Llama lag-llama.ckpt --local-dir lag-llama

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.5/1.5 MB[0m [31m19.3 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m17.1/17.1 MB[0m [31m47.0 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m6.9/6.9 MB[0m [31m70.7 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m12.3/12.3 MB[0m [31m67.0 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.0/2.0 MB[0m [31m75.3 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m778.1/778.1 kB[0m [31m18.7 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m21.3/21.3 MB[0m [31m45.2 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m207.3/207.3 kB[0m [31m28.5 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━

## 1. Imports and definitions

In [None]:
from itertools import islice
from tqdm.autonotebook import tqdm

import torch

from gluonts.evaluation import make_evaluation_predictions, Evaluator
from gluonts.dataset.common import ListDataset

import pandas as pd
import numpy as np

from sklearn.metrics import mean_absolute_error

from utils.utils import set_seed
from torch import manual_seed

from lag_llama.gluon.estimator import LagLlamaEstimator

import matplotlib.pyplot as plt
import matplotlib.dates as mdates
import plotly.graph_objects as go
import plotly.io as pio
import plotly.offline as poff
import seaborn as sns

pio.templates.default = "seaborn"
plt.style.use('seaborn-v0_8-darkgrid')

set_seed(42)

  from tqdm.autonotebook import tqdm


In [None]:
def create_gluonts_dataset(df, freq, target_column):
    """
    Creates a GluonTS ListDataset from a pandas DataFrame.
    
    Parameters:
    df (pandas.DataFrame): DataFrame containing the time series data.
    freq (str): Frequency string of the time series (e.g., 'D' for daily, 'H' for hourly).
    target_column (str): Name of the column in df that contains the target values.
    
    Returns:
    ListDataset: A GluonTS ListDataset object containing the time series data.
    """
    series = {
        "start": df.index[0],  # start date of the time series
        "target": df[target_column].values,  # target values
    }

    dataset = ListDataset([series], freq=freq)
    return dataset

In [None]:
def get_lag_llama_predictions(dataset,
                                model_ckpt,
                                prediction_length = 24,
                                context_length=None,
                                num_samples=100,
                                device="cuda",
                                batch_size=64,
                                nonnegative_pred_samples=True,
                              ):
    
    """
    Generates predictions using the LagLlama model from a given dataset.
    
    Parameters:
    dataset (ListDataset): The dataset containing the time series data for prediction.
    model_ckpt (str): Path to the model checkpoint.
    prediction_length (int): Number of time steps to predict. Default is 24.
    context_length (int, optional): Number of time steps used for context. If None, it is set from the model checkpoint.
    num_samples (int): Number of prediction samples. Default is 100.
    device (str): Device to run the model on ('cuda' or 'cpu'). Default is 'cuda'.
    batch_size (int): Batch size for predictions. Default is 64.
    nonnegative_pred_samples (bool): Whether to ensure non-negative prediction samples. Default is True.
    
    Returns:
    tuple: A tuple containing:
        - forecasts (list): List of forecasted values.
        - tss (list): List of ground truth values.
    """

    manual_seed(42)
    _device = torch.device(device)
    _ckpt = torch.load(model_ckpt, map_location=_device)
    estimator_args = _ckpt["hyper_parameters"]["model_kwargs"]
    if context_length == None:
      context_length=estimator_args['context_length']

    estimator = LagLlamaEstimator(
        ckpt_path = model_ckpt,
        context_length=context_length,
        prediction_length=prediction_length,
        device = _device,

        # estimator args
        input_size=estimator_args["input_size"],
        n_layer=estimator_args["n_layer"],
        n_embd_per_head=estimator_args["n_embd_per_head"],
        n_head=estimator_args["n_head"],
        scaling=estimator_args["scaling"],
        time_feat=estimator_args["time_feat"],

        nonnegative_pred_samples=nonnegative_pred_samples,

        # linear positional encoding scaling
        rope_scaling={
              "type": "linear",
              "factor": max(1.0, (context_length + prediction_length) / estimator_args["context_length"]),
          },

        batch_size=batch_size,
        num_parallel_samples=num_samples,
    )

    lightning_module = estimator.create_lightning_module().to(device)
    transformation = estimator.create_transformation()
    predictor = estimator.create_predictor(transformation, lightning_module)

    forecast_it, ts_it = make_evaluation_predictions(
        dataset=dataset,
        predictor=predictor,
        num_samples=num_samples
    )
    forecasts = list(tqdm(forecast_it, total=len(dataset[0]['target']), desc="Forecasting batches"))
    tss = list(tqdm(ts_it, total=len(dataset[0]['target']), desc="Ground truth"))

    return forecasts, tss

In [None]:
def recursive_forecast(
    model_ckpt,
    context_df,
    test_df,
    prediction_length=7,
    context_length=32,
    device='cuda',
    num_samples=100):
    
    """
    Generates recursive forecasts using the LagLlama model.
    
    Parameters:
    model_ckpt (str): Path to the model checkpoint.
    context_df (pandas.DataFrame): DataFrame containing the context data for the initial forecast.
    test_df (pandas.DataFrame): DataFrame containing the test data to be used for recursive forecasting.
    prediction_length (int): Number of time steps to predict at each recursive step. Default is 7.
    context_length (int): Number of time steps used for context. Default is 32.
    device (str): Device to run the model on ('cuda' or 'cpu'). Default is 'cuda'.
    num_samples (int): Number of prediction samples. Default is 100.
    
    Returns:
    dict: A dictionary containing:
        - 'prediction' (list): List of point forecasted values.
        - 'p10' (list): List of 10th percentile forecasted values.
        - 'p90' (list): List of 90th percentile forecasted values.
    """
    
    context_gdf = create_gluonts_dataset(context_df, freq='D', target_column='wave_height')
    _device = torch.device(device)
    _ckpt = torch.load(model_ckpt, map_location=_device)
    estimator_args = _ckpt["hyper_parameters"]["model_kwargs"]

    batch_size = prediction_length  # Number of observations to add to context at each step of the forecast.
    total_length = len(test_df)  # Total length of the test DataFrame
    if context_length is None:
        context_length = estimator_args['context_length']

    all_point_forecasts = []
    q10_forecasts = []
    q90_forecasts = []

    # Append batches sequentially to the target DataFrame
    for i in range(0, total_length, batch_size):

        # Prediction based on context dataset
        forecasts, tss = get_lag_llama_predictions(
            model_ckpt=model_ckpt,
            dataset=context_gdf,
            prediction_length=prediction_length,
            num_samples=num_samples,
            context_length=context_length,
            device=_device
        )

        point_forecast = list(forecasts[0]['p50'])
        q10 = list(forecasts[0]['p10'])
        q90 = list(forecasts[0]['p90'])

        # Ensure to append only `batch_size` predictions each iteration
        all_point_forecasts.extend(point_forecast[:batch_size])
        q10_forecasts.extend(q10[:batch_size])
        q90_forecasts.extend(q90[:batch_size])

        # Update context dataset
        batch_start_index = i
        batch_end_index = min(i + batch_size, total_length)  # Ensure not to exceed the length of test_df
        batch = test_df.iloc[batch_start_index:batch_end_index]
        context_df = pd.concat([context_df, batch], ignore_index=False)
        context_gdf = create_gluonts_dataset(context_df, freq='D', target_column='wave_height')

    # Put forecast and quantiles together
    preds_dict = {
        'prediction': all_point_forecasts[:total_length],
        'p10': q10_forecasts[:total_length],
        'p90': q90_forecasts[:total_length]
    }

    return preds_dict


In [None]:
def plot_prob_forecasts(df_forecasts):
    """
    Plots real values and predicted values with confidence intervals.

    Parameters:
    df_forecasts (pd.DataFrame): DataFrame containing the real values, predicted values,
                                 and confidence intervals with columns ['wave_height', 'p10', 'p90'].
    """
    # Plot shaded area
    fig = go.Figure([
        go.Scatter(name='Actual', x=df_forecasts.index, y=df_forecasts['wave_height'], mode='lines'),
        go.Scatter(name='Prediction', x=df_forecasts.index, y=df_forecasts['prediction'], mode='lines'),
        go.Scatter(
            name='Upper Bound', x=df_forecasts.index, y=df_forecasts['p90'],
            mode='lines', marker=dict(color="#444"), line=dict(width=0), showlegend=False
        ),
        go.Scatter(
            name='Lower Bound', x=df_forecasts.index, y=df_forecasts['p10'],
            marker=dict(color="#444"), line=dict(width=0), mode='lines',
            fillcolor='rgba(68, 68, 68, 0.3)', fill='tonexty', showlegend=False
        )
    ])
    fig.update_layout(
        xaxis_title="Date",
        yaxis_title="Wave height (Meters)",
        width=800,
        height=400,
        margin=dict(l=20, r=20, t=35, b=20),
        hovermode="x",
        legend=dict(
            orientation="h",
            yanchor="top",
            y=1.1,
            xanchor="left",
            x=0.001
        )
    )
    fig.show()

In [None]:
def empirical_coverage(y, lower_bound, upper_bound):
    """
    Calculates the empirical coverage of a given interval.
    
    Parameters:
    y (array-like): Array of true values.
    lower_bound (array-like): Array of lower bounds of the interval.
    upper_bound (array-like): Array of upper bounds of the interval.
    
    Returns:
    float: The empirical coverage, which is the proportion of true values that lie within the given interval.
    """
    return np.mean(np.logical_and(y >= lower_bound, y <= upper_bound))

## 2. Read and split data

In [None]:
# Read data
# ==============================================================================
data_dir = 'Data/spain/four years'

df = pd.read_csv(data_dir + '/spain_clean_daily.csv')
df['datetime'] = pd.to_datetime(df['datetime'])
df.set_index(keys = 'datetime', inplace=True)
df = df.asfreq('D')

# Train-test split
# ==============================================================================
one_month = (-1)*30 # One month
two_months = (-1)*60 # Two months

end_val = two_months + two_months + one_month
end_train = end_val + two_months + two_months + one_month

df_train = df.iloc[:end_train].copy()
df_val = df.iloc[end_train:end_val].copy()
df_test = df.iloc[end_val:].copy()

print(f"Train dates      : {df_train.index.min()} --- {df_train.index.max()}  (n={len(df_train)})")
print(f"Validation dates : {df_val.index.min()} --- {df_val.index.max()}  (n={len(df_val)})")
print(f"Test dates       : {df_test.index.min()} --- {df_test.index.max()}  (n={len(df_test)})")

# Create the GluonTS dataset
# ==============================================================================
train = create_gluonts_dataset(df_train, freq='D', target_column='wave_height')
test = create_gluonts_dataset(df_test, freq='D', target_column='wave_height')
val = create_gluonts_dataset(df_val, freq='D', target_column='wave_height')

print('Train dataset:', train)
print('Validation dataset:',val)
print('Test dataset:',test)

Train dates      : 2020-06-19 00:00:00 --- 2023-08-24 00:00:00  (n=1162)
Validation dates : 2023-08-25 00:00:00 --- 2024-01-21 00:00:00  (n=150)
Test dates       : 2024-01-22 00:00:00 --- 2024-06-19 00:00:00  (n=150)
Train dataset: [{'start': Period('2020-06-19', 'D'), 'target': array([0.87, 0.9 , 0.7 , ..., 1.09, 1.08, 0.95], dtype=float32)}]
Validation dataset: [{'start': Period('2023-08-25', 'D'), 'target': array([1.26     , 0.97     , 1.22     , 1.6      , 1.52     , 1.24     ,
       1.15     , 0.97     , 0.91     , 1.23     , 2.49     , 1.95     ,
       1.73     , 1.47     , 0.92     , 0.52     , 0.61     , 0.67     ,
       1.17     , 1.07     , 1.06     , 1.15     , 0.85     , 0.93     ,
       1.09     , 1.61     , 1.98     , 1.99     , 2.82     , 2.76     ,
       1.64     , 1.35     , 1.72     , 1.72     , 1.45     , 1.92     ,
       1.24     , 1.08     , 0.93     , 1.07     , 1.46     , 1.61     ,
       1.14     , 0.91     , 1.07     , 0.72     , 0.5      , 0.37     ,
  

## 3. Zero-shot forecasting

In this section, we generate three sets of forecasts with different context lengths: 32, 64 and 128 tokens.

In [None]:
# Forecast parameters
checkpoint = 'lag-llama/lag-llama.ckpt'
prediction_length = 7
num_samples = 150
device = torch.device('cuda')

### Context length: 32

In [None]:
context_length = 32

forecasts_dict_32 = recursive_forecast(model_ckpt = checkpoint, context_df=df_train[-context_length:], test_df = df_test, prediction_length=7, context_length = context_length)

assert len(forecasts_dict_32['prediction']) == len(df_test) == len(forecasts_dict_32['p10']) == len(forecasts_dict_32['p90'])

Forecasting batches:   0%|          | 0/48 [00:00<?, ?it/s]

Ground truth:   0%|          | 0/48 [00:00<?, ?it/s]

Forecasting batches:   0%|          | 0/55 [00:00<?, ?it/s]

Ground truth:   0%|          | 0/55 [00:00<?, ?it/s]

Forecasting batches:   0%|          | 0/62 [00:00<?, ?it/s]

Ground truth:   0%|          | 0/62 [00:00<?, ?it/s]

Forecasting batches:   0%|          | 0/69 [00:00<?, ?it/s]

Ground truth:   0%|          | 0/69 [00:00<?, ?it/s]

Forecasting batches:   0%|          | 0/76 [00:00<?, ?it/s]

Ground truth:   0%|          | 0/76 [00:00<?, ?it/s]

Forecasting batches:   0%|          | 0/83 [00:00<?, ?it/s]

Ground truth:   0%|          | 0/83 [00:00<?, ?it/s]

Forecasting batches:   0%|          | 0/90 [00:00<?, ?it/s]

Ground truth:   0%|          | 0/90 [00:00<?, ?it/s]

Forecasting batches:   0%|          | 0/97 [00:00<?, ?it/s]

Ground truth:   0%|          | 0/97 [00:00<?, ?it/s]

Forecasting batches:   0%|          | 0/104 [00:00<?, ?it/s]

Ground truth:   0%|          | 0/104 [00:00<?, ?it/s]

Forecasting batches:   0%|          | 0/111 [00:00<?, ?it/s]

Ground truth:   0%|          | 0/111 [00:00<?, ?it/s]

Forecasting batches:   0%|          | 0/118 [00:00<?, ?it/s]

Ground truth:   0%|          | 0/118 [00:00<?, ?it/s]

Forecasting batches:   0%|          | 0/125 [00:00<?, ?it/s]

Ground truth:   0%|          | 0/125 [00:00<?, ?it/s]

Forecasting batches:   0%|          | 0/132 [00:00<?, ?it/s]

Ground truth:   0%|          | 0/132 [00:00<?, ?it/s]

Forecasting batches:   0%|          | 0/139 [00:00<?, ?it/s]

Ground truth:   0%|          | 0/139 [00:00<?, ?it/s]

Forecasting batches:   0%|          | 0/146 [00:00<?, ?it/s]

Ground truth:   0%|          | 0/146 [00:00<?, ?it/s]

Forecasting batches:   0%|          | 0/153 [00:00<?, ?it/s]

Ground truth:   0%|          | 0/153 [00:00<?, ?it/s]

Forecasting batches:   0%|          | 0/160 [00:00<?, ?it/s]

Ground truth:   0%|          | 0/160 [00:00<?, ?it/s]

Forecasting batches:   0%|          | 0/167 [00:00<?, ?it/s]

Ground truth:   0%|          | 0/167 [00:00<?, ?it/s]

Forecasting batches:   0%|          | 0/174 [00:00<?, ?it/s]

Ground truth:   0%|          | 0/174 [00:00<?, ?it/s]

Forecasting batches:   0%|          | 0/181 [00:00<?, ?it/s]

Ground truth:   0%|          | 0/181 [00:00<?, ?it/s]

Forecasting batches:   0%|          | 0/188 [00:00<?, ?it/s]

Ground truth:   0%|          | 0/188 [00:00<?, ?it/s]

Forecasting batches:   0%|          | 0/195 [00:00<?, ?it/s]

Ground truth:   0%|          | 0/195 [00:00<?, ?it/s]

In [None]:
df_forecasts_32 = df_test.copy()
df_forecasts_32['prediction'] = forecasts_dict_32['prediction']
df_forecasts_32['p10'] = forecasts_dict_32['p10']
df_forecasts_32['p90'] = forecasts_dict_32['p90']

In [None]:
plot_prob_forecasts(df_forecasts_32)


In [None]:
# MAE
# ==============================================================================
metric = mean_absolute_error(df_forecasts_32['wave_height'], df_forecasts_32['prediction'])
print(f"Backtest error (MAE): {metric}")

# Predicted interval coverage (on test data)
# ==============================================================================
coverage = empirical_coverage(
    y = df_forecasts_32['wave_height'],
    lower_bound = df_forecasts_32['p10'],
    upper_bound = df_forecasts_32['p90']
)
print(f"Predicted interval coverage: {round(100*coverage, 2)} %")

# Area of the interval
# ==============================================================================
area = (df_forecasts_32['p90'] - df_forecasts_32['p10']).sum()
print(f"Area of the interval: {round(area, 2)}")

Backtest error (MAE): 0.760708513991038
Predicted interval coverage: 70.67 %
Area of the interval: 292.1700134277344


### Context length: 64

In [None]:
context_length = 64

forecasts_dict_64 = recursive_forecast(model_ckpt = checkpoint, context_df=df_train[-context_length:], test_df = df_test, prediction_length=7, context_length = context_length)

assert len(forecasts_dict_64['prediction']) == len(df_test) == len(forecasts_dict_64['p10']) == len(forecasts_dict_64['p90'])

Forecasting batches:   0%|          | 0/48 [00:00<?, ?it/s]

Ground truth:   0%|          | 0/48 [00:00<?, ?it/s]

Forecasting batches:   0%|          | 0/55 [00:00<?, ?it/s]

Ground truth:   0%|          | 0/55 [00:00<?, ?it/s]

Forecasting batches:   0%|          | 0/62 [00:00<?, ?it/s]

Ground truth:   0%|          | 0/62 [00:00<?, ?it/s]

Forecasting batches:   0%|          | 0/69 [00:00<?, ?it/s]

Ground truth:   0%|          | 0/69 [00:00<?, ?it/s]

Forecasting batches:   0%|          | 0/76 [00:00<?, ?it/s]

Ground truth:   0%|          | 0/76 [00:00<?, ?it/s]

Forecasting batches:   0%|          | 0/83 [00:00<?, ?it/s]

Ground truth:   0%|          | 0/83 [00:00<?, ?it/s]

Forecasting batches:   0%|          | 0/90 [00:00<?, ?it/s]

Ground truth:   0%|          | 0/90 [00:00<?, ?it/s]

Forecasting batches:   0%|          | 0/97 [00:00<?, ?it/s]

Ground truth:   0%|          | 0/97 [00:00<?, ?it/s]

Forecasting batches:   0%|          | 0/104 [00:00<?, ?it/s]

Ground truth:   0%|          | 0/104 [00:00<?, ?it/s]

Forecasting batches:   0%|          | 0/111 [00:00<?, ?it/s]

Ground truth:   0%|          | 0/111 [00:00<?, ?it/s]

Forecasting batches:   0%|          | 0/118 [00:00<?, ?it/s]

Ground truth:   0%|          | 0/118 [00:00<?, ?it/s]

Forecasting batches:   0%|          | 0/125 [00:00<?, ?it/s]

Ground truth:   0%|          | 0/125 [00:00<?, ?it/s]

Forecasting batches:   0%|          | 0/132 [00:00<?, ?it/s]

Ground truth:   0%|          | 0/132 [00:00<?, ?it/s]

Forecasting batches:   0%|          | 0/139 [00:00<?, ?it/s]

Ground truth:   0%|          | 0/139 [00:00<?, ?it/s]

Forecasting batches:   0%|          | 0/146 [00:00<?, ?it/s]

Ground truth:   0%|          | 0/146 [00:00<?, ?it/s]

Forecasting batches:   0%|          | 0/153 [00:00<?, ?it/s]

Ground truth:   0%|          | 0/153 [00:00<?, ?it/s]

Forecasting batches:   0%|          | 0/160 [00:00<?, ?it/s]

Ground truth:   0%|          | 0/160 [00:00<?, ?it/s]

Forecasting batches:   0%|          | 0/167 [00:00<?, ?it/s]

Ground truth:   0%|          | 0/167 [00:00<?, ?it/s]

Forecasting batches:   0%|          | 0/174 [00:00<?, ?it/s]

Ground truth:   0%|          | 0/174 [00:00<?, ?it/s]

Forecasting batches:   0%|          | 0/181 [00:00<?, ?it/s]

Ground truth:   0%|          | 0/181 [00:00<?, ?it/s]

Forecasting batches:   0%|          | 0/188 [00:00<?, ?it/s]

Ground truth:   0%|          | 0/188 [00:00<?, ?it/s]

Forecasting batches:   0%|          | 0/195 [00:00<?, ?it/s]

Ground truth:   0%|          | 0/195 [00:00<?, ?it/s]

In [None]:
df_forecasts_64 = df_test.copy()
df_forecasts_64['prediction'] = forecasts_dict_64['prediction']
df_forecasts_64['p10'] = forecasts_dict_64['p10']
df_forecasts_64['p90'] = forecasts_dict_64['p90']

In [None]:
plot_prob_forecasts(df_forecasts_64)


In [None]:
# MAE
# ==============================================================================
metric = mean_absolute_error(df_forecasts_64['wave_height'], df_forecasts_64['prediction'])
print(f"Backtest error (MAE): {metric}")

# Predicted interval coverage (on test data)
# ==============================================================================
coverage = empirical_coverage(
    y = df_forecasts_64['wave_height'],
    lower_bound = df_forecasts_64['p10'],
    upper_bound = df_forecasts_64['p90']
)
print(f"Predicted interval coverage: {round(100*coverage, 2)} %")

# Area of the interval
# ==============================================================================
area = (df_forecasts_64['p90'] - df_forecasts_64['p10']).sum()
print(f"Area of the interval: {round(area, 2)}")

Backtest error (MAE): 0.7675400104840597
Predicted interval coverage: 80.67 %
Area of the interval: 361.760009765625


### Context length: 128

In [None]:
context_length = 128

forecasts_dict_128 = recursive_forecast(model_ckpt = checkpoint, context_df=df_train[-context_length:], test_df = df_test, prediction_length=7, context_length = context_length)

assert len(forecasts_dict_128['prediction']) == len(df_test) == len(forecasts_dict_128['p10']) == len(forecasts_dict_128['p90'])

Forecasting batches:   0%|          | 0/48 [00:00<?, ?it/s]

Ground truth:   0%|          | 0/48 [00:00<?, ?it/s]

Forecasting batches:   0%|          | 0/55 [00:00<?, ?it/s]

Ground truth:   0%|          | 0/55 [00:00<?, ?it/s]

Forecasting batches:   0%|          | 0/62 [00:00<?, ?it/s]

Ground truth:   0%|          | 0/62 [00:00<?, ?it/s]

Forecasting batches:   0%|          | 0/69 [00:00<?, ?it/s]

Ground truth:   0%|          | 0/69 [00:00<?, ?it/s]

Forecasting batches:   0%|          | 0/76 [00:00<?, ?it/s]

Ground truth:   0%|          | 0/76 [00:00<?, ?it/s]

Forecasting batches:   0%|          | 0/83 [00:00<?, ?it/s]

Ground truth:   0%|          | 0/83 [00:00<?, ?it/s]

Forecasting batches:   0%|          | 0/90 [00:00<?, ?it/s]

Ground truth:   0%|          | 0/90 [00:00<?, ?it/s]

Forecasting batches:   0%|          | 0/97 [00:00<?, ?it/s]

Ground truth:   0%|          | 0/97 [00:00<?, ?it/s]

Forecasting batches:   0%|          | 0/104 [00:00<?, ?it/s]

Ground truth:   0%|          | 0/104 [00:00<?, ?it/s]

Forecasting batches:   0%|          | 0/111 [00:00<?, ?it/s]

Ground truth:   0%|          | 0/111 [00:00<?, ?it/s]

Forecasting batches:   0%|          | 0/118 [00:00<?, ?it/s]

Ground truth:   0%|          | 0/118 [00:00<?, ?it/s]

Forecasting batches:   0%|          | 0/125 [00:00<?, ?it/s]

Ground truth:   0%|          | 0/125 [00:00<?, ?it/s]

Forecasting batches:   0%|          | 0/132 [00:00<?, ?it/s]

Ground truth:   0%|          | 0/132 [00:00<?, ?it/s]

Forecasting batches:   0%|          | 0/139 [00:00<?, ?it/s]

Ground truth:   0%|          | 0/139 [00:00<?, ?it/s]

Forecasting batches:   0%|          | 0/146 [00:00<?, ?it/s]

Ground truth:   0%|          | 0/146 [00:00<?, ?it/s]

Forecasting batches:   0%|          | 0/153 [00:00<?, ?it/s]

Ground truth:   0%|          | 0/153 [00:00<?, ?it/s]

Forecasting batches:   0%|          | 0/160 [00:00<?, ?it/s]

Ground truth:   0%|          | 0/160 [00:00<?, ?it/s]

Forecasting batches:   0%|          | 0/167 [00:00<?, ?it/s]

Ground truth:   0%|          | 0/167 [00:00<?, ?it/s]

Forecasting batches:   0%|          | 0/174 [00:00<?, ?it/s]

Ground truth:   0%|          | 0/174 [00:00<?, ?it/s]

Forecasting batches:   0%|          | 0/181 [00:00<?, ?it/s]

Ground truth:   0%|          | 0/181 [00:00<?, ?it/s]

Forecasting batches:   0%|          | 0/188 [00:00<?, ?it/s]

Ground truth:   0%|          | 0/188 [00:00<?, ?it/s]

Forecasting batches:   0%|          | 0/195 [00:00<?, ?it/s]

Ground truth:   0%|          | 0/195 [00:00<?, ?it/s]

In [None]:
df_forecasts_128 = df_test.copy()
df_forecasts_128['prediction'] = forecasts_dict_128['prediction']
df_forecasts_128['p10'] = forecasts_dict_128['p10']
df_forecasts_128['p90'] = forecasts_dict_128['p90']

In [None]:
plot_prob_forecasts(df_forecasts_128)


In [None]:
# MAE
# ==============================================================================
metric = mean_absolute_error(df_forecasts_128['wave_height'], df_forecasts_128['prediction'])
print(f"Backtest error (MAE): {metric}")

# Predicted interval coverage (on test data)
# ==============================================================================
coverage = empirical_coverage(
    y = df_forecasts_128['wave_height'],
    lower_bound = df_forecasts_128['p10'],
    upper_bound = df_forecasts_128['p90']
)
print(f"Predicted interval coverage: {round(100*coverage, 2)} %")

# Area of the interval
# ==============================================================================
area = (df_forecasts_128['p90'] - df_forecasts_128['p10']).sum()
print(f"Area of the interval: {round(area, 2)}")

Backtest error (MAE): 0.8524319249471028
Predicted interval coverage: 88.67 %
Area of the interval: 477.04998779296875


## 4. Forecasting with fine-tuned model
In this section we produce forecasts with the version of the model that we fine-tuned with our data. 

In [None]:
ckpt_dir = '...'

model_6 = ckpt_dir + '/cl128_lr10e3/checkpoints/epoch=0-step=50.ckpt'

In [None]:
# Produce forecasts
context_length = 128

forecasts = recursive_forecast(model_ckpt = model_6, context_df=df_train[-context_length:], test_df = df_test, context_length=128)
df_forecasts = df_test.copy()
df_forecasts['prediction'] = forecasts['prediction']
df_forecasts['p10'] = forecasts['p10']
df_forecasts['p90'] = forecasts['p90']

Forecasting batches:   0%|          | 0/48 [00:00<?, ?it/s]

Ground truth:   0%|          | 0/48 [00:00<?, ?it/s]

Forecasting batches:   0%|          | 0/55 [00:00<?, ?it/s]

Ground truth:   0%|          | 0/55 [00:00<?, ?it/s]

Forecasting batches:   0%|          | 0/62 [00:00<?, ?it/s]

Ground truth:   0%|          | 0/62 [00:00<?, ?it/s]

Forecasting batches:   0%|          | 0/69 [00:00<?, ?it/s]

Ground truth:   0%|          | 0/69 [00:00<?, ?it/s]

Forecasting batches:   0%|          | 0/76 [00:00<?, ?it/s]

Ground truth:   0%|          | 0/76 [00:00<?, ?it/s]

Forecasting batches:   0%|          | 0/83 [00:00<?, ?it/s]

Ground truth:   0%|          | 0/83 [00:00<?, ?it/s]

Forecasting batches:   0%|          | 0/90 [00:00<?, ?it/s]

Ground truth:   0%|          | 0/90 [00:00<?, ?it/s]

Forecasting batches:   0%|          | 0/97 [00:00<?, ?it/s]

Ground truth:   0%|          | 0/97 [00:00<?, ?it/s]

Forecasting batches:   0%|          | 0/104 [00:00<?, ?it/s]

Ground truth:   0%|          | 0/104 [00:00<?, ?it/s]

Forecasting batches:   0%|          | 0/111 [00:00<?, ?it/s]

Ground truth:   0%|          | 0/111 [00:00<?, ?it/s]

Forecasting batches:   0%|          | 0/118 [00:00<?, ?it/s]

Ground truth:   0%|          | 0/118 [00:00<?, ?it/s]

Forecasting batches:   0%|          | 0/125 [00:00<?, ?it/s]

Ground truth:   0%|          | 0/125 [00:00<?, ?it/s]

Forecasting batches:   0%|          | 0/132 [00:00<?, ?it/s]

Ground truth:   0%|          | 0/132 [00:00<?, ?it/s]

Forecasting batches:   0%|          | 0/139 [00:00<?, ?it/s]

Ground truth:   0%|          | 0/139 [00:00<?, ?it/s]

Forecasting batches:   0%|          | 0/146 [00:00<?, ?it/s]

Ground truth:   0%|          | 0/146 [00:00<?, ?it/s]

Forecasting batches:   0%|          | 0/153 [00:00<?, ?it/s]

Ground truth:   0%|          | 0/153 [00:00<?, ?it/s]

Forecasting batches:   0%|          | 0/160 [00:00<?, ?it/s]

Ground truth:   0%|          | 0/160 [00:00<?, ?it/s]

Forecasting batches:   0%|          | 0/167 [00:00<?, ?it/s]

Ground truth:   0%|          | 0/167 [00:00<?, ?it/s]

Forecasting batches:   0%|          | 0/174 [00:00<?, ?it/s]

Ground truth:   0%|          | 0/174 [00:00<?, ?it/s]

Forecasting batches:   0%|          | 0/181 [00:00<?, ?it/s]

Ground truth:   0%|          | 0/181 [00:00<?, ?it/s]

Forecasting batches:   0%|          | 0/188 [00:00<?, ?it/s]

Ground truth:   0%|          | 0/188 [00:00<?, ?it/s]

Forecasting batches:   0%|          | 0/195 [00:00<?, ?it/s]

Ground truth:   0%|          | 0/195 [00:00<?, ?it/s]

In [None]:
plot_prob_forecasts(df_forecasts)

In [None]:
# Point forecast metric - MAE
# ==============================================================================
metric = mean_absolute_error(df_forecasts['wave_height'], df_forecasts['prediction'])
print(f"Backtest error (MAE): {metric}")

# Predicted interval coverage
# ==============================================================================
coverage = empirical_coverage(
    y = df_forecasts['wave_height'],
    lower_bound = df_forecasts['p10'],
    upper_bound = df_forecasts['p90']
)
print(f"Predicted interval coverage: {round(100*coverage, 2)} %")

# Area of the interval
# ==============================================================================
area = (df_forecasts['p90'] - df_forecasts['p10']).sum()
print(f"Area of the interval: {round(area, 2)}")

Backtest error (MAE): 0.6342880547046661
Predicted interval coverage: 69.33 %
Area of the interval: 204.85000610351562
