# Foundation Models for Time Series Forecasting: Lag-Llama

This SageMaker Studio notebook has been tested with the following configuration:

* Kernel: Python 3
* Image: Data Science 3.0
* Instance Type: ml.g5.2xlarge (8 vCPU + 1 GPU + 32 GiB)
* Start-up Script: No script

This notebook consists of the following sections:

1. Introduction
2. Prepare Model and Libraries
3. Zero-Shot with Lag-Llama Foundation Model
4. Comparison to GluonTS SimpleFeedForwardEstimator
5. Fine Tune Lag-Llama on Sample Data
6. Evaluate Fine-Tuned Lag-Llama
7. Summary/Conclusions

## 1. Introduction

This is a sample notebook for educational purposes.

Time series analysis is important in many real-world industries and applications.  Historically, numerous statistical and machine learning methods have been developed for time series analysis, and libraries such as [GluonTS](https://ts.gluon.ai/stable/) or [Darts](https://unit8co.github.io/darts/) have been used for tasks such as time series forecasting, classification, inputation, anomaly detection, and event prediction.

Time series data comes in a variety of forms and from a variety of industries: healthcare, finance, retail, etc.  This variety increases the complexity of domain-specific model training; and, real-world time series data often exhibit non-stationary properties, meaning that the characteristics of the data changes over time.  This can lead to what's known as the concept drift problem, where there are changes in the data patterns and relationships that the ML model has learned, as the statistical properties of the target changes over time.

There is an increasing trend in using Foundation Models (FMs) and Large Language Models (LLMs) in various time series applications, with several papers published in late 2023 and early 2024.  One benefit of foundation models is that they provide a framework for handling diverse tasks, which contrasts conventional wisdom where each task requires a specially designed algorithm.  In this notebook, we look at [Lag-Llama (October, 2023 paper)](https://arxiv.org/abs/2310.08278), a decoder-only transformer model designed for zero-shot probabilistic time series forecasting.

In this notebook, we test a zero-shot Lag-Llama, compare it against a SimpleFeedForwardEstimator, and finally evaluate a fine-tuned Lag-Llama model.

## 2. Prepare Model and Libraries

In [None]:
%cd ~

!git clone https://github.com/time-series-foundation-models/lag-llama/

!pip install -r lag-llama/requirements.txt --quiet
!huggingface-cli download time-series-foundation-models/Lag-Llama lag-llama.ckpt --local-dir lag-llama

%cd lag-llama

In [None]:
import torch

from lag_llama.gluon.estimator import LagLlamaEstimator

from gluonts.evaluation import make_evaluation_predictions, Evaluator
from gluonts.mx import SimpleFeedForwardEstimator, Trainer

from gluonts.dataset.repository.datasets import get_dataset

from tqdm.autonotebook import tqdm

from matplotlib import pyplot as plt
from matplotlib import dates as mpld

from itertools import islice

GluonTS comes with a number of publicly available datasets, one of which is the *electricity* dataset.

In [None]:
dataset = get_dataset("electricity")

## 3. Zero-Shot with Lag-Llama Foundation Model

In [None]:
ckpt = torch.load("lag-llama.ckpt", map_location=torch.device('cuda'))
estimator_args = ckpt["hyper_parameters"]["model_kwargs"]

In [None]:
estimator = LagLlamaEstimator(
    ckpt_path="lag-llama.ckpt",
    prediction_length=dataset.metadata.prediction_length,
    
    #
    # This is specific to Lag-Llama due to training;
    # do not change this value.
    #
    context_length=32,

    #
    # estimator args
    #
    input_size=estimator_args["input_size"],
    n_layer=estimator_args["n_layer"],
    n_embd_per_head=estimator_args["n_embd_per_head"],
    n_head=estimator_args["n_head"],
    scaling=estimator_args["scaling"],
    time_feat=estimator_args["time_feat"],
)

predictor = estimator.create_predictor(
    estimator.create_transformation(),
    estimator.create_lightning_module()
)

In [None]:
#
# Setup for zero-shot inference.
#
forecast_it, ts_it = make_evaluation_predictions(
    dataset=dataset.test,
    predictor=predictor,
)

In [None]:
forecasts = list(forecast_it)
tss = list(ts_it)

In [None]:
#
# The Continuous Ranked Probability Score (CRPS) generalizes the MAE to the case of probabilistic forecasts.
#   where CRPS = 0 means the forecast is wholly accurate; and, CRPS = 1 means the forecast is wholly inaccurate.
#
evaluator = Evaluator()

agg_metrics, ts_metrics = evaluator(
    iter(tss), iter(forecasts)
)

print("CRPS:", agg_metrics['mean_wQuantileLoss'])

In [None]:
plt.figure(figsize=(20, 15))
date_formater = mpld.DateFormatter('%b, %d')
plt.rcParams.update({'font.size': 15})

for idx, (forecast, ts) in islice(enumerate(zip(forecasts, tss)), 9):
    ax = plt.subplot(3, 3, idx+1)

    plt.plot(ts[-4 * dataset.metadata.prediction_length:].to_timestamp(), label="target", )
    forecast.plot( color='g')
    plt.xticks(rotation=60)
    ax.xaxis.set_major_formatter(date_formater)
    ax.set_title(forecast.item_id)

plt.gcf().tight_layout()
plt.legend()
plt.show()

## 4. Comparison to GluonTS SimpleFeedForwardEstimator

In [None]:
estimator = SimpleFeedForwardEstimator(
    num_hidden_dimensions=[10],
    prediction_length=dataset.metadata.prediction_length,
    context_length=100,
    trainer=Trainer(ctx="cpu", epochs=5, learning_rate=1e-3, num_batches_per_epoch=100),
)

predictor = estimator.train(dataset.train)

forecast_it, ts_it = make_evaluation_predictions(
    dataset=dataset.test,
    predictor=predictor,
    num_samples=100,
)

forecasts = list(forecast_it)
tss = list(ts_it)

In [None]:
#
# first entry of the forecast list
#
forecast_entry = forecasts[0]

print(f"Number of sample paths: {forecast_entry.num_samples}")
print(f"Dimension of samples: {forecast_entry.samples.shape}")
print(f"Start date of the forecast window: {forecast_entry.start_date}")
print(f"Frequency of the time series: {forecast_entry.freq}")

print(f"Mean of the future window:\n {forecast_entry.mean}")
print(f"0.5-quantile (median) of the future window:\n {forecast_entry.quantile(0.5)}")

In [None]:
evaluator = Evaluator()

agg_metrics, ts_metrics = evaluator(
    iter(tss), iter(forecasts)
)

print("CRPS:", agg_metrics['mean_wQuantileLoss'])

## 5. Fine Tune Lag-Llama on Sample Data

In [None]:
ckpt = torch.load("lag-llama.ckpt", map_location=torch.device('cuda'))
estimator_args = ckpt["hyper_parameters"]["model_kwargs"]

estimator = LagLlamaEstimator(
    ckpt_path="lag-llama.ckpt",
    prediction_length=dataset.metadata.prediction_length,
    
    #
    # This is specific to Lag-Llama due to training;
    # do not change this value.
    #
    context_length=32,

    #
    # estimator args
    #
    input_size=estimator_args["input_size"],
    n_layer=estimator_args["n_layer"],
    n_embd_per_head=estimator_args["n_embd_per_head"],
    n_head=estimator_args["n_head"],
    scaling=estimator_args["scaling"],
    time_feat=estimator_args["time_feat"],
    
    nonnegative_pred_samples=True,
    aug_prob=0,
    lr=5e-4,

    batch_size=64,
    num_parallel_samples=20,
    trainer_kwargs = {"max_epochs": 50,},
)

In [None]:
predictor = estimator.train(
    dataset.train,
    cache_data=True,
    shuffle_buffer_length=1000
)

## 6. Evaluate Fine-Tuned Lag-Llama

In [None]:
forecast_it, ts_it = make_evaluation_predictions(
    dataset=dataset.test,
    predictor=predictor,
    num_samples=20
)

In [None]:
forecasts = list(tqdm(forecast_it, total=len(dataset), desc="Forecasting Batches"))

In [None]:
tss = list(tqdm(ts_it, total=len(dataset), desc="Ground Truth"))

In [None]:
plt.figure(figsize=(20, 15))
date_formater = mpld.DateFormatter('%b, %d')
plt.rcParams.update({'font.size': 15})

# Iterate through the first 9 series, and plot the predicted samples
for idx, (forecast, ts) in islice(enumerate(zip(forecasts, tss)), 9):
    ax = plt.subplot(3, 3, idx+1)

    plt.plot(ts[-4 * dataset.metadata.prediction_length:].to_timestamp(), label="target", )
    forecast.plot( color='g')
    plt.xticks(rotation=60)
    ax.xaxis.set_major_formatter(date_formater)
    ax.set_title(forecast.item_id)

plt.gcf().tight_layout()
plt.legend()
plt.show()

In [None]:
evaluator = Evaluator()

agg_metrics, ts_metrics = evaluator(
    iter(tss), iter(forecasts)
)

print("CRPS:", agg_metrics['mean_wQuantileLoss'])

## 7. Summary/Conclusions

In a run of this notebook, we determined Continuous Ranked Probability Score (CRPS) to be:

* Zero-shot Lag-Llama: CRPS = 0.0489
* SimpleFeedForwardEstimator: CPRS = 0.0747
* Fine-Tuned Lag-Llama: CRPS = 0.0434

where a lower CRPS number indicates better performance.