In [7]:
import time
import numpy as np
import pandas as pd
from tqdm import tqdm
from sklearn.metrics import mean_absolute_error, mean_squared_error, mean_absolute_percentage_error
from statsforecast import StatsForecast
from statsforecast.models import AutoCES

In [8]:
# Load wide-format data
df_wide = pd.read_csv("outbreaks_disease_location.csv")
value_cols = [str(i) for i in range(60)]
start_dates = pd.to_datetime(df_wide["start_date"])
series_values = df_wide[value_cols].astype(float).fillna(0)

# Convert to long format
records = []
for i, (start_date, row) in enumerate(zip(start_dates, series_values.values), start=1):
    adjusted_start = start_date - pd.Timedelta(weeks=4)
    dates = pd.date_range(start=adjusted_start, periods=60, freq="W-SAT")
    for t, value in enumerate(row):
        records.append({"unique_id": f"Y_{i}", "ds": dates[t], "y": value})

df_long = pd.DataFrame(records)

In [9]:
df_long

Unnamed: 0,unique_id,ds,y
0,Y_1,2024-05-25,0.0
1,Y_1,2024-06-01,0.0
2,Y_1,2024-06-08,0.0
3,Y_1,2024-06-15,0.0
4,Y_1,2024-06-22,0.0
...,...,...,...
647935,Y_10799,2025-06-21,0.0
647936,Y_10799,2025-06-28,0.0
647937,Y_10799,2025-07-05,0.0
647938,Y_10799,2025-07-12,0.0


In [10]:
class FixedCESProcessor:
    def __init__(self):
        self.forecasts = []
        self.eval_pairs = []
        self.dates = []
        self.unique_ids = []

        self.maes = []
        self.mses = []
        self.mapes = []
        self.nmses = []

        self.metrics_df = pd.DataFrame(columns=["Reference Date", "MAE", "MSE", "MAPE", "NMSE"])
        self.display_df = pd.DataFrame(columns=["Unique_id", "Reference Date", "Target End Date", "GT", "Quantile", "Prediction"])

    def create_fixed_model(self, df_long, h, freq="W-SAT", level=[80, 95], season_length=1):
        df_fit = df_long.groupby("unique_id").apply(lambda g: g.iloc[:-h]).reset_index(drop=True)
        df_truth = df_long.groupby("unique_id").apply(lambda g: g.iloc[-h:]).reset_index(drop=True)

        start = time.time()
        self.sf = StatsForecast(models=[AutoCES(season_length=season_length)], freq=freq, n_jobs=-1)
        self.sf.fit(df_fit)
        forecast = self.sf.predict(h=h, level=level)
        print(f"CES fit time: {time.time() - start:.2f} sec")

        forecast.set_index(["unique_id", "ds"], inplace=True)
        df_truth.set_index(["unique_id", "ds"], inplace=True)

        print("Processing forecasts per series...")
        for uid in tqdm(df_fit["unique_id"].unique(), desc="Fitting per series"):
            f = forecast.loc[uid].copy()
            f["unique_id"] = uid
            t = df_truth.loc[uid]
            self.forecasts.append(f)
            self.eval_pairs.append((f, t))
            self.dates.append(df_fit[df_fit["unique_id"] == uid]["ds"].max().strftime("%Y-%m-%d"))
            self.unique_ids.append(uid)

    def calculate_metrics(self):
        for forecast_df, truth_df in self.eval_pairs:
            y_true = truth_df.iloc[:, 0]
            y_pred = forecast_df.iloc[:, 0]
            self.maes.append(mean_absolute_error(y_true, y_pred))
            self.mses.append(mean_squared_error(y_true, y_pred))
            self.mapes.append(mean_absolute_percentage_error(y_true, y_pred))
            self.nmses.append(self.mses[-1] / np.var(y_true))

    def create_metrics_df(self):
        self.metrics_df = pd.DataFrame({
            "Reference Date": self.dates,
            "MAE": self.maes,
            "MSE": self.mses,
            "MAPE": self.mapes,
            "NMSE": self.nmses,
        })

    def create_display_df(self):
        records = []
        print("Generating display DataFrame...")
        for i in tqdm(range(len(self.forecasts)), desc="Building display_df"):
            forecast_df = self.forecasts[i]
            reference_date = self.dates[i]
            unique_id = self.unique_ids[i]
            truth_series = self.eval_pairs[i][1].iloc[:, 0]

            for col in forecast_df.columns:
                if col == "unique_id":
                    continue
                if "lo" in col or "hi" in col:
                    number = int(col.split("-")[-1])
                    alpha = 1 - (number / 100)
                    quantile = 1 - (alpha / 2) if "hi" in col else alpha / 2
                elif col == "CES":
                    quantile = 0.5
                else:
                    continue

                preds = forecast_df[col]
                for idx, pred in preds.items():
                    records.append({
                        "Unique_id": unique_id,
                        "Reference Date": reference_date,
                        "Target End Date": idx,
                        "GT": truth_series.get(idx, np.nan),
                        "Quantile": quantile,
                        "Prediction": pred
                    })

        self.display_df = pd.DataFrame(records).sort_values(
            by=["Unique_id", "Reference Date", "Target End Date", "GT", "Quantile"]
        ).reset_index(drop=True)

    def compute_wis(self):
        df = self.display_df.sort_values(by=["Unique_id", "Reference Date", "Target End Date", "Quantile"])
        records = []
        grouped = df.groupby(["Unique_id", "Reference Date", "Target End Date"])

        print("Computing WIS for each forecasted point...")
        for (uid, ref_date, tgt_date), group in tqdm(grouped, desc="Computing WIS"):
            gt = group["GT"].iloc[0]
            preds = group.set_index("Quantile")["Prediction"]

            if 0.5 not in preds.index:
                continue

            ae = abs(preds[0.5] - gt)
            quantiles = sorted(q for q in preds.index if q != 0.5)
            n = len(quantiles) // 2
            interval_scores = []

            for i in range(n):
                lo_q = quantiles[i]
                hi_q = quantiles[-(i + 1)]
                lo = preds[lo_q]
                hi = preds[hi_q]
                alpha = hi_q - lo_q

                interval_score = (
                    (hi - lo)
                    + (2 / alpha) * max(lo - gt, 0)
                    + (2 / alpha) * max(gt - hi, 0)
                )
                interval_scores.append(interval_score)

            wis = (ae + np.sum(interval_scores)) / (1 + len(interval_scores))
            records.append({
                "Unique_id": uid,
                "Reference Date": ref_date,
                "Target End Date": tgt_date,
                "GT": gt,
                "WIS": wis
            })

        return pd.DataFrame(records)

In [11]:
processor = FixedCESProcessor()

In [12]:
processor.create_fixed_model(df_long=df_long, h=4, freq="W-SAT",season_length = 60, level=[10,20,30,40,50,60,70,80,85,90,95])

CES fit time: 75.72 sec
Processing forecasts per series...


Fitting per series: 100%|██████████| 10799/10799 [07:17<00:00, 24.66it/s]


In [13]:
processor.create_display_df()

Generating display DataFrame...


Building display_df: 100%|██████████| 10799/10799 [00:19<00:00, 563.39it/s]


In [14]:
wis_df = processor.compute_wis()

Computing WIS for each forecasted point...


Computing WIS: 100%|██████████| 43196/43196 [00:16<00:00, 2644.36it/s]


In [15]:
wis_df

Unnamed: 0,Unique_id,Reference Date,Target End Date,GT,WIS
0,Y_1,2025-06-14,2025-06-21,0.0,9.245807
1,Y_1,2025-06-14,2025-06-28,0.0,9.233771
2,Y_1,2025-06-14,2025-07-05,0.0,8.952459
3,Y_1,2025-06-14,2025-07-12,0.0,8.883299
4,Y_10,2016-10-01,2016-10-08,0.0,0.235136
...,...,...,...,...,...
43191,Y_9998,1954-12-25,1955-01-22,0.0,0.701033
43192,Y_9999,1955-03-19,1955-03-26,0.0,8.625921
43193,Y_9999,1955-03-19,1955-04-02,0.0,8.634544
43194,Y_9999,1955-03-19,1955-04-09,0.0,9.159494


In [16]:
np.mean(wis_df['WIS'].values)

1166.2658776360595

In [17]:
wis_dfs = [wis_df.iloc[i::4].reset_index(drop=True) for i in range(4)]

In [21]:
np.mean(wis_dfs[3]['WIS'].values)

1151.076970465885

In [22]:
wis_df.to_csv('CES_DF_WIS.csv')