In [None]:
!pip install dotenv
!pip install gluonts
!pip install --upgrade datasets
!pip install utilsforecast

In [None]:
!git clone https://github.com/GiuliaGhisolfi/TSFM-ZeroShotEval
%cd TSFM-ZeroShotEval/src

In [None]:
import json
from dotenv import load_dotenv

# Load environment variables
load_dotenv()

short_datasets = "solar/10T solar/H solar/D solar/W jena_weather/10T jena_weather/H jena_weather/D " \
"bitbrains_fast_storage/5T bitbrains_fast_storage/H bitbrains_rnd/5T bitbrains_rnd/H bizitobs_application " \
"bizitobs_service bizitobs_l2c/5T bizitobs_l2c/H"

med_long_datasets = "solar/10T solar/H jena_weather/10T jena_weather/H " \
"bitbrains_fast_storage/5T bitbrains_rnd/5T bizitobs_application bizitobs_service bizitobs_l2c/5T bizitobs_l2c/H"

# Get union of short and med_long datasets
all_datasets = list(set(short_datasets.split() + med_long_datasets.split()))

dataset_properties_map = json.load(open("data/dataset_properties.json"))

In [None]:
from utils.load_data import load_gift_data

load_gift_data()

In [None]:
from gluonts.ev.metrics import (
    MSE,
    MAE,
    MASE,
    MAPE,
    SMAPE,
    MSIS,
    RMSE,
    NRMSE,
    ND,
    MeanWeightedSumQuantileLoss,
)

# Instantiate the metrics
metrics = [
    MSE(forecast_type="mean"),
    MSE(forecast_type=0.5),
    MAE(),
    MASE(),
    MAPE(),
    SMAPE(),
    MSIS(),
    RMSE(),
    NRMSE(),
    ND(),
    MeanWeightedSumQuantileLoss(
        quantile_levels=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
    ),
]

In [None]:
import timesfm as timesfm

tfm = timesfm.TimesFm(
    hparams=timesfm.TimesFmHparams(
        backend="gpu",
        per_core_batch_size=32,
        num_layers=50,
        horizon_len=128,
        context_len=2048,
        use_positional_embedding=False,
        output_patch_len=128,
    ),
    checkpoint=timesfm.TimesFmCheckpoint(
        huggingface_repo_id="google/timesfm-2.0-500m-pytorch"), #"google/timesfm-1.0-200m-pytorch"
)

model_name = "timesfm2" # "timesfm1"

In [None]:
from typing import List
import numpy as np
from tqdm.auto import tqdm
from gluonts.itertools import batcher
from gluonts.model import Forecast
from gluonts.model.forecast import QuantileForecast

class TimesFmPredictor:
    def __init__(
        self,
        tfm,
        prediction_length: int,
        ds_freq: str,
        *args,
        **kwargs,
    ):
      self.tfm = tfm
      self.prediction_length = prediction_length
      if self.prediction_length > self.tfm.horizon_len:
        self.tfm.horizon_len = (
            (self.prediction_length + self.tfm.output_patch_len - 1) //
            self.tfm.output_patch_len) * self.tfm.output_patch_len
        print('Jitting for new prediction length.')
      self.freq = timesfm.freq_map(ds_freq)

    def predict(self, test_data_input, batch_size: int = 1024) -> List[Forecast]:
      forecast_outputs = []
      for batch in tqdm(batcher(test_data_input, batch_size=batch_size)):
        context = []
        for entry in batch:
          arr = np.array(entry["target"])
          context.append(arr)
        freqs = [self.freq] * len(context)
        _, full_preds = self.tfm.forecast(context, freqs, normalize=True)
        full_preds = full_preds[:, 0:self.prediction_length, 1:]
        forecast_outputs.append(full_preds.transpose((0, 2, 1)))
      forecast_outputs = np.concatenate(forecast_outputs)

      # Convert forecast samples into gluonts Forecast objects
      forecasts = []
      for item, ts in zip(forecast_outputs, test_data_input):
        forecast_start_date = ts["start"] + len(ts["target"])
        forecasts.append(
            QuantileForecast(
                forecast_arrays=item,
                forecast_keys=list(map(str, self.tfm.quantiles)),
                start_date=forecast_start_date,
            ))

      return forecasts

In [None]:
import logging

class WarningFilter(logging.Filter):
    def __init__(self, text_to_filter):
        super().__init__()
        self.text_to_filter = text_to_filter

    def filter(self, record):
        return self.text_to_filter not in record.getMessage()

gts_logger = logging.getLogger("gluonts.model.forecast")
gts_logger.addFilter(
    WarningFilter("The mean prediction is not stored in the forecast data")
)

In [None]:
import csv
import os

from gluonts.model import evaluate_model
from gluonts.time_feature import get_seasonality

from gift_eval.data import Dataset

all_ds_tuples = []

pretty_names = {
    "saugeenday": "saugeen",
    "temperature_rain_with_missing": "temperature_rain",
    "kdd_cup_2018_with_missing": "kdd_cup_2018",
    "car_parts_with_missing": "car_parts",
}

for ds_num, ds_name in enumerate(all_datasets):
  ds_key = ds_name.split("/")[0]
  print(f"Processing dataset: {ds_name} ({ds_num + 1} of {len(all_datasets)})")
  terms = ["short", "medium", "long"]
  for term in terms:
    if (term == "medium" or
        term == "long") and ds_name not in med_long_datasets.split():
      continue

    if "/" in ds_name:
      ds_key = ds_name.split("/")[0]
      ds_freq = ds_name.split("/")[1]
      ds_key = ds_key.lower()
      ds_key = pretty_names.get(ds_key, ds_key)
    else:
      ds_key = ds_name.lower()
      ds_key = pretty_names.get(ds_key, ds_key)
      ds_freq = dataset_properties_map[ds_key]["frequency"]
    ds_config = f"{ds_key}/{ds_freq}/{term}"
    # Initialize the dataset
    to_univariate = (False if Dataset(
        name=ds_name, term=term, to_univariate=False).target_dim == 1 else True)
    dataset = Dataset(name=ds_name, term=term, to_univariate=to_univariate)
    all_ds_tuples.append(
        (dataset.prediction_length, ds_config, ds_name, to_univariate))

all_ds_tuples = sorted(all_ds_tuples)
all_ds_tuples[0:10]

In [None]:
output_dir = f"results/timesfm"
# Ensure the output directory exists
os.makedirs(output_dir, exist_ok=True)

# Define the path for the CSV file
csv_file_path = os.path.join(output_dir, f"{model_name}_results.csv")

with open(csv_file_path, "w", newline="") as csvfile:
    writer = csv.writer(csvfile)

    # Write the header
    writer.writerow([
        "dataset",
        "model",
        "eval_metrics/MSE[mean]",
        "eval_metrics/MSE[0.5]",
        "eval_metrics/MAE[0.5]",
        "eval_metrics/MASE[0.5]",
        "eval_metrics/MAPE[0.5]",
        "eval_metrics/sMAPE[0.5]",
        "eval_metrics/MSIS",
        "eval_metrics/RMSE[mean]",
        "eval_metrics/NRMSE[mean]",
        "eval_metrics/ND[0.5]",
        "eval_metrics/mean_weighted_sum_quantile_loss",
        "domain",
        "num_variates",
    ])

for entry in all_ds_tuples:
    prediction_length = entry[0]
    ds_name = entry[2]
    to_univariate = entry[3]
    ds_config = entry[1]
    ds_key, ds_freq, term = ds_config.split("/")
    dataset = Dataset(name=ds_name, term=term, to_univariate=to_univariate)
    season_length = get_seasonality(dataset.freq)
    print(f"Processing entry: {entry}")
    print(f"Dataset size: {len(dataset.test_data)}")
    predictor = TimesFmPredictor(
        tfm=tfm,
        prediction_length=dataset.prediction_length,
        ds_freq=ds_freq,
    )
    # Measure the time taken for evaluation
    res = evaluate_model(
        predictor,
        test_data=dataset.test_data,
        metrics=metrics,
        batch_size=1024,
        axis=None,
        mask_invalid_label=True,
        allow_nan_forecast=False,
        seasonality=season_length,
    )

    # Append the results to the CSV file
    with open(csv_file_path, "a", newline="") as csvfile:
        writer = csv.writer(csvfile)
        writer.writerow([
            ds_config,
            model_name,
            res["MSE[mean]"][0],
            res["MSE[0.5]"][0],
            res["MAE[0.5]"][0],
            res["MASE[0.5]"][0],
            res["MAPE[0.5]"][0],
            res["sMAPE[0.5]"][0],
            res["MSIS"][0],
            res["RMSE[mean]"][0],
            res["NRMSE[mean]"][0],
            res["ND[0.5]"][0],
            res["mean_weighted_sum_quantile_loss"][0],
            dataset_properties_map[ds_key]["domain"],
            dataset_properties_map[ds_key]["num_variates"],
        ])

    print(f"Results for {ds_name} have been written to {csv_file_path}")