In [None]:
import os
from ast import literal_eval
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from gluonts.dataset.pandas import PandasDataset, is_uniform, infer_freq
from gluonts.dataset.split import split, TestData
from uni2ts.model.moirai import MoiraiForecast, MoiraiModule
from uni2ts.eval_util.plot import plot_single, plot_next_multi
from uni2ts.eval_util.evaluation import evaluate_model
from gluonts.ev.metrics import MAE, MAPE
from gluonts.evaluation.metrics import quantile_loss

from fts_explore.common.benchmark_func import get_model_data, get_eval_foreasts

In [None]:
TITLE = "Stage 2 Finetuning: FT vs PT"
XLABEL = "Data available in 2-month patches"
TOTAL_TIME_STEPS = 42
PSZ = 16
models = "weather_temperature_data"

In [None]:
data = pd.read_csv("../data/weather_Smurni_daily_temp_14_24.csv", delimiter=",")
data["DateUTC"] = pd.to_datetime(data["DateUTC"])
data = data.set_index("DateUTC")

In [None]:
def get_metrics_v2(
    model_folder: str,
    data: pd.DataFrame,
    total_time_steps: int,
    prediction_lenght: int,
    context_length: int,
    patch_size: int,
):
    metrics = {
        "mean": [],
        "median": [],
        "lower_0025": [],
        "upper_0975": [],
        "mean_wql": [],
    }
    for i in range(1, total_time_steps + 1):
        try:
            model, test_data = get_model_data(
                model_folder=model_folder,
                prediction_lenght=prediction_lenght,
                time_steps=i,
                time_step_size=60,
                max_context_length=context_length,
                target_var="temperature_2m_mean",
                data=data,
                patch_size=patch_size,
                num_samples=200,
                freq="D",
            )

            forecast_samples, target_values = get_eval_foreasts(model, test_data)

            # absolute error of the average forecast of each time step
            mean_error_ts = (
                np.mean(forecast_samples, axis=1).flatten() - target_values.flatten()
            ) / target_values.flatten()
            mean_error_ts = np.abs(mean_error_ts)

            # absolute error of the median forecast of each time step
            median_error_ts = (
                np.quantile(forecast_samples, 0.5, axis=1).flatten()
                - target_values.flatten()
            ) / target_values.flatten()
            median_error_ts = np.abs(median_error_ts)

            # absolute error of the 97.5th percentile forecast of each time step
            upper_error_ts = (
                np.quantile(forecast_samples, 0.975, axis=1).flatten()
                - target_values.flatten()
            ) / target_values.flatten()
            upper_error_ts = np.abs(upper_error_ts)

            # absolute error of the 2.5th percentile forecast of each time step
            lower_error_ts = (
                np.quantile(forecast_samples, 0.025, axis=1).flatten()
                - target_values.flatten()
            ) / target_values.flatten()
            lower_error_ts = np.abs(lower_error_ts)

            w_q_loss = [
                np.sum(
                    quantile_loss(
                        target_values.flatten(),
                        np.quantile(forecast_samples, q, axis=1).flatten(),
                        q=q,
                    )
                )
                / np.sum(np.abs(target_values.flatten()))
                for q in np.arange(0.1, 1.0, 0.1)
            ]
            mean_w_q_loss = np.mean(w_q_loss)

            metrics["mean"].append(mean_error_ts.tolist())
            metrics["median"].append(median_error_ts.tolist())
            metrics["lower_0025"].append(lower_error_ts.tolist())
            metrics["upper_0975"].append(upper_error_ts.tolist())
            metrics["mean_wql"].append(mean_w_q_loss.tolist())

        except Exception as e:
            print(e)

    return metrics

In [None]:
model_folder_ft = f"../model_checkpoints/{models}/stage_two_ft_temperature/checkpoints/weather_temp_Smurni_data_train_"
model_folder_pt = f"../model_checkpoints/{models}/stage_two_pt_temperature/checkpoints/weather_temp_Smurni_data_train_"

assert model_folder_ft != model_folder_pt, "The 2 folders must be different!"

for CTX in [30, 60]:
    for PDT in [7, 14, 21, 30, 60]:
        if CTX == 30 and PDT == 60:
            continue
        else:
            print(f" CTX: {CTX} | PDT: {PDT}")

            stage_2_ft = get_metrics_v2(
                model_folder_ft,
                data,
                total_time_steps=TOTAL_TIME_STEPS,
                prediction_lenght=PDT,
                context_length=CTX,
                patch_size=PSZ,
            )

            stage_2_pt = get_metrics_v2(
                model_folder_pt,
                data,
                total_time_steps=TOTAL_TIME_STEPS,
                prediction_lenght=PDT,
                context_length=CTX,
                patch_size=PSZ,
            )

            pretrained = get_metrics_v2(
                "pretrained",
                data,
                total_time_steps=TOTAL_TIME_STEPS,
                prediction_lenght=PDT,
                context_length=CTX,
                patch_size=PSZ,
            )

            stage_2_ft = pd.DataFrame.from_dict(stage_2_ft)
            stage_2_pt = pd.DataFrame.from_dict(stage_2_pt)
            pretrained = pd.DataFrame.from_dict(pretrained)

            stage_2_ft.to_csv(
                f"../experiment_results/stage_2_finetune/{models}/ft_{PDT}_{CTX}_performance.csv",
                index=False,
            )
            stage_2_pt.to_csv(
                f"../experiment_results/stage_2_finetune/{models}/pt_{PDT}_{CTX}_performance.csv",
                index=False,
            )
            pretrained.to_csv(
                f"../experiment_results/stage_2_finetune/{models}/pretrained_{PDT}_{CTX}_performance.csv",
                index=False,
            )

In [None]:
PDT = 7
CTX = 30

stage_2_ft = pd.read_csv(
    f"../experiment_results/stage_2_finetune/{models}/ft_{PDT}_{CTX}_performance.csv",
)
stage_2_pt = pd.read_csv(
    f"../experiment_results/stage_2_finetune/{models}/pt_{PDT}_{CTX}_performance.csv",
)
pretrained = pd.read_csv(
    f"../experiment_results/stage_2_finetune/{models}/pretrained_{PDT}_{CTX}_performance.csv",
)

In [None]:
# MAPE of the average forecast
plt.title(TITLE)
plt.ylabel("MAPE of the mean forecast")
plt.xlabel(XLABEL)
plt.errorbar(
    x=range(1, TOTAL_TIME_STEPS + 1),
    y=[np.mean(literal_eval(ts)) for ts in stage_2_ft["mean"][0:]],
    ecolor="red",
    barsabove=True,
    linestyle="dotted",
    marker=".",
    errorevery=2,
    label="from FT",
)
plt.errorbar(
    x=range(1, TOTAL_TIME_STEPS + 1),
    y=[np.mean(literal_eval(ts)) for ts in stage_2_pt["mean"][0:]],
    ecolor="red",
    barsabove=True,
    linestyle="dotted",
    marker=".",
    errorevery=2,
    label="from PT",
)
plt.errorbar(
    x=range(1, TOTAL_TIME_STEPS + 1),
    y=[np.mean(literal_eval(ts)) for ts in pretrained["mean"][0:]],
    ecolor="red",
    barsabove=True,
    linestyle="dotted",
    marker=".",
    errorevery=2,
    label="Pretrained",
)
plt.legend()
plt.savefig(
    f"../experiment_results/stage_2_finetune/{models}/stage_2_main_comp_mean.jpeg"
);

In [None]:
# MAPE of the 97.5th percentile forecast
plt.title(TITLE)
plt.ylabel("MAPE of the 97.5th percentile forecast")
plt.xlabel(XLABEL)
plt.errorbar(
    x=range(2, TOTAL_TIME_STEPS + 1),
    y=[np.mean(literal_eval(ts)) for ts in stage_2_ft["upper_0975"][1:]],
    ecolor="red",
    barsabove=True,
    linestyle="dotted",
    marker=".",
    errorevery=2,
    label="from FT",
)
plt.errorbar(
    x=range(2, TOTAL_TIME_STEPS + 1),
    y=[np.mean(literal_eval(ts)) for ts in stage_2_pt["upper_0975"][1:]],
    ecolor="red",
    barsabove=True,
    linestyle="dotted",
    marker=".",
    errorevery=2,
    label="from PT",
)
plt.errorbar(
    x=range(2, TOTAL_TIME_STEPS + 1),
    y=[np.mean(literal_eval(ts)) for ts in pretrained["upper_0975"][1:]],
    ecolor="red",
    barsabove=True,
    linestyle="dotted",
    marker=".",
    errorevery=2,
    label="Pretrained",
)
plt.legend()
plt.savefig(
    f"../experiment_results/stage_2_finetune/{models}/stage_2_main_comp_975.jpeg"
);

In [None]:
# MAPE of the 2.5th percentile forecast
plt.title(TITLE)
plt.ylabel("MAPE of the 2.5th forecast")
plt.xlabel(XLABEL)
plt.errorbar(
    x=range(1, TOTAL_TIME_STEPS + 1),
    y=[np.mean(literal_eval(ts)) for ts in stage_2_ft["lower_0025"]],
    ecolor="red",
    barsabove=True,
    linestyle="dotted",
    marker=".",
    errorevery=2,
    label="from FT",
)
plt.errorbar(
    x=range(1, TOTAL_TIME_STEPS + 1),
    y=[np.mean(literal_eval(ts)) for ts in stage_2_pt["lower_0025"]],
    ecolor="red",
    barsabove=True,
    linestyle="dotted",
    marker=".",
    errorevery=2,
    label="from PT",
)
plt.errorbar(
    x=range(1, TOTAL_TIME_STEPS + 1),
    y=[np.mean(literal_eval(ts)) for ts in pretrained["lower_0025"]],
    ecolor="red",
    barsabove=True,
    linestyle="dotted",
    marker=".",
    errorevery=2,
    label="Pretrained",
)
plt.legend()
plt.savefig(
    f"../experiment_results/stage_2_finetune/{models}/stage_2_main_comp_025.jpeg"
);

In [None]:
# MAPE of the median forecast
plt.title(TITLE)
plt.ylabel("MAPE of the median forecast")
plt.xlabel(XLABEL)
plt.errorbar(
    x=range(1, TOTAL_TIME_STEPS + 1),
    y=[np.mean(literal_eval(ts)) for ts in stage_2_ft["median"]],
    ecolor="red",
    barsabove=True,
    linestyle="dotted",
    marker=".",
    errorevery=2,
    label="from FT",
)
plt.errorbar(
    x=range(1, TOTAL_TIME_STEPS + 1),
    y=[np.mean(literal_eval(ts)) for ts in stage_2_pt["median"]],
    ecolor="red",
    barsabove=True,
    linestyle="dotted",
    marker=".",
    errorevery=2,
    label="from PT",
)
plt.errorbar(
    x=range(1, TOTAL_TIME_STEPS + 1),
    y=[np.mean(literal_eval(ts)) for ts in pretrained["median"]],
    ecolor="red",
    barsabove=True,
    linestyle="dotted",
    marker=".",
    errorevery=2,
    label="Pretrained",
)
plt.legend()
plt.savefig(
    f"../experiment_results/stage_2_finetune/{models}/stage_2_main_comp_median.jpeg"
);

In [None]:
# MAPE of the average forecast
plt.title(TITLE)
plt.ylabel("Mean Weighted Quantile Loss")
plt.xlabel(XLABEL)
plt.errorbar(
    x=range(1, TOTAL_TIME_STEPS + 1),
    y=stage_2_ft["mean_wql"].tolist(),
    ecolor="red",
    barsabove=True,
    linestyle="dotted",
    marker=".",
    errorevery=2,
    label="from FT",
)
plt.errorbar(
    x=range(1, TOTAL_TIME_STEPS + 1),
    y=stage_2_pt["mean_wql"].tolist(),
    ecolor="red",
    barsabove=True,
    linestyle="dotted",
    marker=".",
    errorevery=2,
    label="from PT",
)
plt.errorbar(
    x=range(1, TOTAL_TIME_STEPS + 1),
    y=pretrained["mean_wql"].tolist(),
    ecolor="red",
    barsabove=True,
    linestyle="dotted",
    marker=".",
    errorevery=2,
    label="Pretrained",
)
plt.legend()
plt.savefig(
    f"../experiment_results/stage_2_finetune/{models}/stage_2_main_comp_mean_wql.jpeg"
);