In [1]:
from itertools import islice

from matplotlib import pyplot as plt
import matplotlib.dates as mdates
from tqdm import tqdm
import pandas as pd

import torch
from gluonts.evaluation import make_evaluation_predictions, Evaluator
from gluonts.dataset.repository.datasets import get_dataset

from lag_llama.gluon.estimator import LagLlamaEstimator

In [2]:
torch.set_float32_matmul_precision('medium')

In [3]:
ckpt_path = "lag-llama.ckpt"

# missing: Beijing PM2.5, ETT M2
test_dataset_names = ["weather", "pedestrian_counts", "exchange_rate"]

In [4]:
def plot_results(forecasts, tss, prediction_length):
    plt.figure(figsize=(20, 15))
    date_formater = mdates.DateFormatter('%b, %d')
    plt.rcParams.update({'font.size': 15})

    for idx, (forecast, ts) in islice(enumerate(zip(forecasts, tss)), 9):
        ax = plt.subplot(3, 3, idx+1)

        plt.plot(ts[-4 * prediction_length:].to_timestamp(), label="target", )
        forecast.plot( color='g')
        plt.xticks(rotation=60)
        ax.xaxis.set_major_formatter(date_formater)
        ax.set_title(forecast.item_id)

    plt.gcf().tight_layout()
    plt.legend()
    plt.show()

# Create Model

In [5]:
nonnegative_pred_samples = True
batch_size = 64 # 256 in paper
num_samples = 100 # from paper

In [6]:
def create_model(ckpt_path, prediction_length, context_length):
    ckpt = torch.load(ckpt_path, map_location=torch.device('cuda:0'))
    estimator_args = ckpt["hyper_parameters"]["model_kwargs"]

    estimator = LagLlamaEstimator(
        ckpt_path=ckpt_path,
        prediction_length=prediction_length,
        context_length=context_length,

        # estimator args
        input_size=estimator_args["input_size"],
        n_layer=estimator_args["n_layer"],
        n_embd_per_head=estimator_args["n_embd_per_head"],
        n_head=estimator_args["n_head"],
        scaling=estimator_args["scaling"],
        time_feat=estimator_args["time_feat"],

        nonnegative_pred_samples=nonnegative_pred_samples,
        aug_prob=0,
        lr=5e-4,

        # linear positional encoding scaling
        rope_scaling={
            "type": "linear",
            "factor": max(1.0, (context_length + prediction_length) / estimator_args["context_length"]),
        },

        batch_size=batch_size,
        num_parallel_samples=num_samples,
        trainer_kwargs = {"max_epochs": 50,}, # <- lightning trainer arguments
    )

    lightning_module = estimator.create_lightning_module()
    transformation = estimator.create_transformation()
    predictor = estimator.create_predictor(transformation, lightning_module)

    return estimator, predictor

# Zero-shot

In [9]:
import warnings
warnings.filterwarnings(action='ignore', category=FutureWarning, message=r".*Use a DatetimeIndex.*")

In [7]:
results = {ds_name:[] for ds_name in test_dataset_names}
for ds_name in test_dataset_names:
    dataset = get_dataset(ds_name)

    prediction_length = dataset.metadata.prediction_length
    context_length = prediction_length * 3

    _, predictor = create_model(ckpt_path, prediction_length, context_length)

    for _ in range(10):
        forecast_it, ts_it = make_evaluation_predictions(
            dataset=dataset.test,
            predictor=predictor,
            num_samples=num_samples
        )
        forecasts = list(tqdm(forecast_it, total=len(dataset), desc="Forecasting batches"))
        tss = list(tqdm(ts_it, total=len(dataset), desc="Ground truth"))

        evaluator = Evaluator()
        agg_metrics, ts_metrics = evaluator(iter(tss), iter(forecasts))
        print(ds_name, "CRPS:", agg_metrics['mean_wQuantileLoss'])
        results[ds_name].append(agg_metrics['mean_wQuantileLoss'])
results = pd.DataFrame(results)
results.mean()

Download weather_dataset.zip:: 37.0MB [00:08, 4.37MB/s]
creating json files: 100%|██████████| 3010/3010 [00:00<00:00, 1003166.87it/s]
  y = F.scaled_dot_product_attention(
Forecasting batches: 3010it [05:23,  9.31it/s]                    
Ground truth: 3010it [00:03, 900.28it/s]           
Running evaluation: 3010it [00:08, 337.99it/s]
  return arr.astype(dtype, copy=True)


weather CRPS: 0.1601470350461294


Forecasting batches: 66it [01:33,  1.42s/it]                      
Ground truth: 66it [00:00, 289.60it/s]             
Running evaluation: 66it [00:00, 327.80it/s]


pedestrian_counts CRPS: 0.2729531559039424


  return pd.Period(val, freq)
  sliced_entry[FieldName.START] += offset
  index = pd.period_range(start, periods=length, freq=start.freq)
  index = pd.period_range(start, periods=length, freq=start.freq)
  entry[self.start_field] + idx + self.lead_time
  sliced_entry[FieldName.START] += offset
  index = pd.period_range(start, periods=length, freq=start.freq)
  index = pd.period_range(start, periods=length, freq=start.freq)
  entry[self.start_field] + idx + self.lead_time
  sliced_entry[FieldName.START] += offset
  index = pd.period_range(start, periods=length, freq=start.freq)
  index = pd.period_range(start, periods=length, freq=start.freq)
  entry[self.start_field] + idx + self.lead_time
  sliced_entry[FieldName.START] += offset
  index = pd.period_range(start, periods=length, freq=start.freq)
  index = pd.period_range(start, periods=length, freq=start.freq)
  entry[self.start_field] + idx + self.lead_time
  sliced_entry[FieldName.START] += offset
  index = pd.period_range(start, per

exchange_rate CRPS: 0.011299894894485999



  metrics_per_ts = pd.DataFrame.from_records(rows)


# Fine-Tuning

In [None]:
results = {ds_name:[] for ds_name in test_dataset_names}
for ds_name in test_dataset_names:
    dataset = get_dataset(ds_name)

    prediction_length = dataset.metadata.prediction_length
    context_length = prediction_length * 3

    estimator, _ = create_model(ckpt_path, prediction_length, context_length)

    for _ in range(10):
        predictor = estimator.train(dataset.train, cache_data=True, shuffle_buffer_length=1000)

        forecast_it, ts_it = make_evaluation_predictions(
            dataset=dataset.test,
            predictor=predictor,
            num_samples=num_samples
        )
        forecasts = list(tqdm(forecast_it, total=len(dataset), desc="Forecasting batches"))
        tss = list(tqdm(ts_it, total=len(dataset), desc="Ground truth"))

        evaluator = Evaluator()
        agg_metrics, ts_metrics = evaluator(iter(tss), iter(forecasts))
        print(ds_name, "CRPS:", agg_metrics['mean_wQuantileLoss'])
        results[ds_name].append(agg_metrics['mean_wQuantileLoss'])
results = pd.DataFrame(results)
results.mean()