In [1]:
import sys
print(sys.executable)


C:\Users\dhruv\anaconda3\envs\lag-llama-env\python.exe


In [7]:
import pandas as pd 
import matplotlib.pyplot as plt 
import matplotlib.dates as mdates 
import torch

from itertools import islice

from gluonts.evaluation import make_evaluation_predictions, Evaluator 
from gluonts.dataset.repository.datasets import get_dataset 
from lag_llama.gluon.estimator import LagLlamaEstimator

In [5]:
torch.cuda.is_available()

False

In [4]:
dataset = get_dataset("australian_electricity_demand")

backtest_dataset = dataset.test
prediction_length = dataset.metadata.prediction_length
context_length = 3 * prediction_length

Download australian_electricity_demand_dataset.zip:: 5.51MB [00:04, 1.31MB/s]
creating json files: 100%|███████████████████████████████████████████████████████████████████████| 5/5 [00:00<?, ?it/s]


In [9]:
ckpt = torch.load("lag-llama.ckpt", map_location=torch.device('cpu'))
estimator_args = ckpt["hyper_parameters"]["model_kwargs"]

In [11]:
estimator = LagLlamaEstimator(
    ckpt_path="lag-llama.ckpt",
    prediction_length=prediction_length,
    context_length=context_length,

    # estimator args
    input_size=estimator_args["input_size"],
    n_layer=estimator_args["n_layer"],
    n_embd_per_head=estimator_args["n_embd_per_head"],
    n_head=estimator_args["n_head"],
    scaling=estimator_args["scaling"],
    time_feat=estimator_args["time_feat"],

    device=torch.device('cpu')
)

lightning_module = estimator.create_lightning_module()
transformation = estimator.create_transformation()
predictor = estimator.create_predictor(transformation, lightning_module)

In [12]:
forecast_it, ts_it = make_evaluation_predictions(
    dataset=backtest_dataset,
    predictor=predictor,
)

In [15]:
forecasts = list(forecast_it)
tss = list(ts_it)

KeyboardInterrupt: 

In [None]:
evaluator = Evaluator()


In [None]:
agg_metrics, ts_metrics = evaluator(iter(tss), iter(forecasts))


In [None]:
agg_metrics

In [None]:
plt.figure(figsize=(20, 15))
date_formater = mdates.DateFormatter('%b, %d')
plt.rcParams.update({'font.size': 15})

for idx, (forecast, ts) in islice(enumerate(zip(forecasts, tss)), 4):
    ax = plt.subplot(2, 2, idx+1)

    plt.plot(ts[-4 * dataset.metadata.prediction_length:].to_timestamp(), label="target", )
    forecast.plot( color='g')
    plt.xticks(rotation=60)
    ax.xaxis.set_major_formatter(date_formater)
    ax.set_title(forecast.item_id)

plt.gcf().tight_layout()
plt.legend()
plt.show()
     


In [None]:
from gluonts.torch import TemporalFusionTransformerEstimator, DeepAREstimator

tft_estimator = TemporalFusionTransformerEstimator(
    prediction_length=prediction_length,
    context_length=context_length,
    freq="30min",
    trainer_kwargs={"max_epochs": 5}
)

deepar_estimator = DeepAREstimator(
    prediction_length=prediction_length,
    context_length=context_length,
    freq="30min",
    trainer_kwargs={"max_epochs": 5}
)


In [None]:
tft_predictor = tft_estimator.train(dataset.train)
deepar_predictor = deepar_estimator.train(dataset.train)

In [None]:

tft_forecast_it, tft_ts_it = make_evaluation_predictions(
    dataset=backtest_dataset,
    predictor=tft_predictor,
)

deepar_forecast_it, deepar_ts_it = make_evaluation_predictions(
    dataset=backtest_dataset,
    predictor=deepar_predictor,
)

In [None]:
tft_forecasts = list(tft_forecast_it)
tft_tss = list(tft_ts_it)

deepar_forecasts = list(deepar_forecast_it)
deepar_tss = list(deepar_ts_it)

In [None]:
tft_agg_metrics, tft_ts_metrics = evaluator(iter(tft_tss), iter(tft_forecasts))
deepar_agg_metrics, deepar_ts_metrics = evaluator(iter(deepar_tss), iter(deepar_forecasts))

In [None]:
print(f'''

Lag-LlamaL: {agg_metrics['RMSE']}
TFT: {tft_agg_metrics['RMSE']}
DeepAR: {deepar_agg_metrics['RMSE']}

''')