# Running Foundation Model CleanTS on gift-eval benchmark

This notebook presents the reproduction code of CleanTS for Gift-eval.

## Install CleanTS
We first install CleanTS via
`pip install git+https://github.com/Taihuachen-cfair/CleanTS.git`.


## Run eval with CleanTS

In [None]:
import logging
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
# skip the no mean value warning from QuantileForecast
logging.getLogger("gluonts").setLevel(logging.ERROR)

import json
from dotenv import load_dotenv

from gluonts.ev.metrics import (
    MAE,
    MAPE,
    MASE,
    MSE,
    MSIS,
    ND,
    NRMSE,
    RMSE,
    SMAPE,
    MeanWeightedSumQuantileLoss,
)

# 配置项
# 数据集信息
short_datasets = "m4_yearly m4_quarterly m4_monthly m4_weekly m4_daily m4_hourly electricity/15T electricity/H electricity/D electricity/W solar/10T solar/H solar/D solar/W hospital covid_deaths us_births/D us_births/M us_births/W saugeenday/D saugeenday/M saugeenday/W temperature_rain_with_missing kdd_cup_2018_with_missing/H kdd_cup_2018_with_missing/D car_parts_with_missing restaurant hierarchical_sales/D hierarchical_sales/W LOOP_SEATTLE/5T LOOP_SEATTLE/H LOOP_SEATTLE/D SZ_TAXI/15T SZ_TAXI/H M_DENSE/H M_DENSE/D ett1/15T ett1/H ett1/D ett1/W ett2/15T ett2/H ett2/D ett2/W jena_weather/10T jena_weather/H jena_weather/D bitbrains_fast_storage/5T bitbrains_fast_storage/H bitbrains_rnd/5T bitbrains_rnd/H bizitobs_application bizitobs_service bizitobs_l2c/5T bizitobs_l2c/H"
med_long_datasets = "electricity/15T electricity/H solar/10T solar/H kdd_cup_2018_with_missing/H LOOP_SEATTLE/5T LOOP_SEATTLE/H SZ_TAXI/15T M_DENSE/H ett1/15T ett1/H ett2/15T ett2/H jena_weather/10T jena_weather/H bitbrains_fast_storage/5T bitbrains_rnd/5T bizitobs_application bizitobs_service bizitobs_l2c/5T bizitobs_l2c/H"
dataset_properties_map_path= "dataset_properties.json"
# json.load(open("dataset_properties.json"))
# 评估配置
max_batch_size=512
context_length=3000

# Load environment variables
load_dotenv()
all_datasets = list(set(short_datasets.split() + med_long_datasets.split()))
dataset_properties_map = json.load(open(dataset_properties_map_path))

# 评估尺度实例化
metrics = [
    MSE(forecast_type="mean"),
    MSE(forecast_type=0.5),
    MAE(),
    MASE(),
    MAPE(),
    SMAPE(),
    MSIS(),
    RMSE(),
    NRMSE(),
    ND(),
    MeanWeightedSumQuantileLoss(
        quantile_levels=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
    ),
]
from uni2ts.model.CleanTS import CleanTSForecast, CleanTSModule
model = CleanTSForecast(
    module=CleanTSModule.from_pretrained("EINK/CleanTS-65M"),
    prediction_length=1,
    context_length=context_length,
    feat_dynamic_real_dim=0,
    past_feat_dynamic_real_dim=0,
)

from gluonts.model import evaluate_model
import csv
import os
from gluonts.time_feature import get_seasonality
from gift_eval.data import Dataset

# Iterate over all available datasets

output_dir = "CleanTS/results"
# Ensure the output directory exists
os.makedirs(output_dir, exist_ok=True)

pretty_names = {
    "saugeenday": "saugeen",
    "temperature_rain_with_missing": "temperature_rain",
    "kdd_cup_2018_with_missing": "kdd_cup_2018",
    "car_parts_with_missing": "car_parts",
}

# Define the path for the CSV file
csv_file_path = os.path.join(output_dir, "all_results.csv")
header = [
    "dataset",
    "model_small",
    "eval_metrics/MSE[mean]",
    "eval_metrics/MSE[0.5]",
    "eval_metrics/MAE[0.5]",
    "eval_metrics/MASE[0.5]",
    "eval_metrics/MAPE[0.5]",
    "eval_metrics/sMAPE[0.5]",
    "eval_metrics/MSIS",
    "eval_metrics/RMSE[mean]",
    "eval_metrics/NRMSE[mean]",
    "eval_metrics/ND[0.5]",
    "eval_metrics/mean_weighted_sum_quantile_loss",
    "domain",
    "num_variates",
]
results=[]
for ds_name in all_datasets:
    ds_key = ds_name.split("/")[0]
    print(f"Processing dataset: {ds_name}")
    terms = ["short", "medium", "long"]
    for term in terms:
        if (
            term == "medium" or term == "long"
        ) and ds_name not in med_long_datasets.split():
            continue

        if "/" in ds_name:
            ds_key = ds_name.split("/")[0]
            ds_freq = ds_name.split("/")[1]
            ds_key = ds_key.lower()
            ds_key = pretty_names.get(ds_key, ds_key)
        else:
            ds_key = ds_name.lower()
            ds_key = pretty_names.get(ds_key, ds_key)
            ds_freq = dataset_properties_map[ds_key]["frequency"]

        ds_config = f"{ds_key}/{ds_freq}/{term}"

        to_univariate = False if Dataset(name=ds_name, term=term,to_univariate=False).target_dim == 1 else True

        dataset = Dataset(name=ds_name, term=term, to_univariate=to_univariate)

        # set the hyperparameter according to each dataset, then create the predictor
        model.hparams.prediction_length = dataset.prediction_length
        model.hparams.target_dim = dataset.target_dim
        model.hparams.past_feat_dynamic_real_dim = dataset.past_feat_dynamic_real_dim

        batch_size = max_batch_size
        while batch_size >= 1:
            try:
                predictor = model.create_predictor(batch_size=batch_size)
                season_length = get_seasonality(dataset.freq)
                res = evaluate_model(
                    predictor,
                    test_data=dataset.test_data,
                    metrics=metrics,
                    axis=None,
                    mask_invalid_label=True,
                    allow_nan_forecast=False,
                    seasonality=season_length,
                )
                break  # 成功则跳出循环
            except RuntimeError as e:
                if "CUDA out of memory" in str(e) or "out of memory" in str(e).lower():
                    print(f"CUDA out of memory with batch_size={batch_size}, halving...")
                    batch_size //= 2
                    if batch_size < 1:
                        raise RuntimeError("Batch size reduced below 1; cannot proceed.") from e
                    import torch

                    torch.cuda.empty_cache()

                else:
                    raise
        else:
            raise RuntimeError("Evaluation failed even with batch_size=1")

        result_row = [
            ds_config,
            "CleanTS-65M",
            res["MSE[mean]"][0],
            res["MSE[0.5]"][0],
            res["MAE[0.5]"][0],
            res["MASE[0.5]"][0],
            res["MAPE[0.5]"][0],
            res["sMAPE[0.5]"][0],
            res["MSIS"][0],
            res["RMSE[mean]"][0],
            res["NRMSE[mean]"][0],
            res["ND[0.5]"][0],
            res["mean_weighted_sum_quantile_loss"][0],
            dataset_properties_map[ds_key]["domain"],
            dataset_properties_map[ds_key]["num_variates"],
        ]

        results.append((ds_config, result_row))


with open(csv_file_path, "w", newline="") as csvfile:
    writer = csv.writer(csvfile)
    writer.writerow(header)
    for _, row in results:
        writer.writerow(row)

print(f"Results written to {csv_file_path}")
