In [1]:
# ─── Cell 1: Patch the mean warning and silence logging ───
import warnings, numpy as np, logging
from gluonts.model import forecast as _fm

# Silence the specific UserWarning
warnings.filterwarnings(
    "ignore",
    message=r"The mean prediction is not stored in the forecast data; the median is being returned instead\. This behaviour may change in the future\."
)
# Disable all WARNING and below from all loggers (including GluonTS internals)
logging.disable(logging.WARNING)

# Override .mean on the original classes (in-place)
def _silent_mean(self):
    fd = getattr(self, "_forecast_dict", {})
    if "mean" in fd:
        return fd["mean"]
    if hasattr(self, "samples"):
        return np.median(self.samples, axis=0)
    return self.quantile("p50")

_fm.SampleForecast.mean   = property(_silent_mean)
_fm.QuantileForecast.mean = property(_silent_mean)
# ────────────────────────────────────────────────────────────

In [2]:
import os
import time
from pathlib import Path

import torch
from gluonts.evaluation import Evaluator
from gluonts.evaluation.backtest import make_evaluation_predictions
from gluonts.model.predictor import Predictor
# ─── Cell 2: Core imports ───
import pandas as pd
from gluonts.dataset.common import ListDataset
from gluonts.dataset.field_names import FieldName
from local.gluonts.torch.model.tft import TemporalFusionTransformerEstimator
import optuna
from optuna.samplers import TPESampler

from pytorch_lightning.utilities.model_summary import ModelSummary
# ─────────────────────────


In [3]:
from sklearn.preprocessing import StandardScaler, LabelEncoder

prediction_length = 24
context_length = 168
window_length = context_length + prediction_length
freq = "1h"

def get_electricity_dataset(csv_path: str, total_samples=500_000):
    df = pd.read_csv(csv_path, index_col=0)
    df["date"] = pd.to_datetime(df["date"])

    # Encode categorical ID
    label_encoder = LabelEncoder()
    df["categorical_id"] = label_encoder.fit_transform(df["categorical_id"].astype(str))

    # Limit data to Jan 1 – Sep 1, 2014 (i.e., days_from_start < 1339)
    full_range_df = df[df["days_from_start"] < 1339]

    # Sample sliding windows over full range
    def sample_windows(subset_df):
        samples = []
        for entity_id, group in subset_df.groupby("id"):
            group = group.sort_values("date")
            if len(group) < window_length:
                continue

            scaler = StandardScaler().fit(group[["power_usage", "hour", "day_of_week", "t"]].values)
            target_scaler = StandardScaler().fit(group[["power_usage"]].values)

            features = scaler.transform(group[["power_usage", "hour", "day_of_week", "t"]].values)
            targets = target_scaler.transform(group[["power_usage"]].values).flatten().astype(np.float32)

            feat_hour = features[:, 1].astype(np.float32)
            feat_dow = features[:, 2].astype(np.float32)
            feat_time = features[:, 3].astype(np.float32)

            static_cat = [group["categorical_id"].iloc[0]]
            dates = group["date"].values

            for i in range(0, len(group) - window_length + 1):
                samples.append({
                    FieldName.START: dates[i],
                    FieldName.TARGET: targets[i:i + window_length],
                    FieldName.FEAT_STATIC_CAT: static_cat,
                    FieldName.FEAT_DYNAMIC_REAL: [
                        feat_hour[i:i + window_length],
                        feat_dow[i:i + window_length],
                        feat_time[i:i + window_length],
                    ],
                })

        return samples

    # Step 1: All possible windows up to Sep 1
    all_samples = sample_windows(full_range_df)

    # Step 2: Shuffle and take 500,000 total
    np.random.shuffle(all_samples)
    all_samples = all_samples[:total_samples]

    # Step 3: Split into 450k train / 50k val
    train_samples = all_samples[:450_000]
    val_samples = all_samples[450_000:]

    # Step 4: Test set = fixed last 7 days (same as official code)
    test_df = df[df["days_from_start"] >= 1332]
    test_samples = sample_windows(test_df)

    train_ds = ListDataset(train_samples, freq=freq)
    val_ds = ListDataset(val_samples, freq=freq)
    test_ds = ListDataset(test_samples, freq=freq)

    return train_ds, val_ds, test_ds, freq, prediction_length


In [4]:
file_path = "../Dataset/Electricity/hourly_electricity.csv"  # Adjust if it's in a subfolder

In [5]:
# ─── Cell 4: Load data & set precision ───
torch.set_float32_matmul_precision("high")
train_ds, val_ds, test_ds, freq, prediction_length = get_electricity_dataset(file_path)

In [27]:
import shutil
# -----------------------------------------------------------------------------
# Configuration
# -----------------------------------------------------------------------------
orig_dir = Path("saved_model_flashNew")       # your full-precision predictor folder
int8_dir = Path("saved_model_flashNew_int8")  # where to write the quantized predictor
os.environ["CUDA_VISIBLE_DEVICES"] = ""       # force CPU-only

# -----------------------------------------------------------------------------
# Helpers
# -----------------------------------------------------------------------------
def safe_load_predictor(folder: Path) -> Predictor:
    """
    Load a GluonTS Predictor from 'folder' on CPU, ignoring shape mismatches.
    """
    import torch as _torch
    # Monkey-patch torch.load to map to CPU and allow full unpickling
    orig_torch_load = _torch.load
    def cpu_load(f, **kwargs):
        return orig_torch_load(f, map_location="cpu", weights_only=False)
    _torch.load = cpu_load

    # Monkey-patch load_state_dict to ignore missing/unexpected keys
    orig_load_state = _torch.nn.Module.load_state_dict
    def loose_load_state(self, state_dict, strict=True):
        return orig_load_state(self, state_dict, strict=False)
    _torch.nn.Module.load_state_dict = loose_load_state

    try:
        pred = Predictor.deserialize(folder)
    finally:
        _torch.load = orig_torch_load
        _torch.nn.Module.load_state_dict = orig_load_state

    return pred



In [28]:
def quantize_predictor(predictor: Predictor) -> Predictor:
    """
    Apply dynamic post-training quantization to all Linear and LSTM layers
    in the predictor's PyTorch model.
    """
    model_fp = predictor.prediction_net.model.cpu()
    quantized_model = tq.quantize_dynamic(
        model_fp,
        {torch.nn.Linear, torch.nn.LSTM},
        dtype=torch.qint8
    )
    predictor.prediction_net.model = quantized_model
    return predictor



In [29]:
def evaluate_predictor(predictor: Predictor, test_ds, num_samples: int = 100):
    """
    Run backtest + evaluation and return key metrics and timings.
    """
    # Inference timing
    t0 = time.time()
    forecast_it, ts_it = make_evaluation_predictions(
        dataset=test_ds, predictor=predictor, num_samples=num_samples
    )
    forecasts = list(forecast_it)
    tss       = list(ts_it)
    inf_time  = time.time() - t0

    # Metrics timing
    evaluator = Evaluator(quantiles=[0.5], num_workers=0)
    t1 = time.time()
    agg_metrics, _ = evaluator(iter(tss), iter(forecasts))
    eval_time = time.time() - t1

    return {
        "RMSE":     agg_metrics["RMSE"],
        "MASE":     agg_metrics["MASE"],
        "sMAPE":    agg_metrics["sMAPE"],
        "inf_time": inf_time,
        "eval_time": eval_time,
    }



In [30]:
def file_size_mb(path: Path) -> float:
    return path.stat().st_size / (1024 ** 2)

# -----------------------------------------------------------------------------
# Prepare output folder
# -----------------------------------------------------------------------------
if int8_dir.exists():
    shutil.rmtree(int8_dir)
int8_dir.mkdir()

# -----------------------------------------------------------------------------
# Load full-precision predictor
# -----------------------------------------------------------------------------
fp_pred = safe_load_predictor(orig_dir)

# -----------------------------------------------------------------------------
# Quantize a fresh copy of it
# -----------------------------------------------------------------------------
int8_pred = safe_load_predictor(orig_dir)
int8_pred = quantize_predictor(int8_pred)

# -----------------------------------------------------------------------------
# Serialize the quantized predictor
# -----------------------------------------------------------------------------
int8_pred.serialize(int8_dir)
torch.save(
    int8_pred.prediction_net.model.state_dict(),
    int8_dir / "tft_flash_weights_int8.pt"
)
# -----------------------------------------------------------------------------
# Compare file sizes
# -----------------------------------------------------------------------------
orig_size = file_size_mb(orig_dir  / "tft_flash_weights.pt")
int8_size = file_size_mb(int8_dir  / "tft_flash_weights_int8.pt")
print(f"Full-precision size: {orig_size:.2f} MB")
print(f"Quantized size:      {int8_size:.2f} MB")
print(f"Size reduction:      {(orig_size - int8_size) / orig_size * 100:.1f}%\n")



Full-precision size: 3.40 MB
Quantized size:      1.07 MB
Size reduction:      68.5%



In [31]:
# -----------------------------------------------------------------------------
# Evaluate both predictors on your test_ds (must be in scope)
# -----------------------------------------------------------------------------
metrics_fp   = evaluate_predictor(fp_pred,   test_ds)
metrics_int8 = evaluate_predictor(int8_pred, test_ds)


Running evaluation: 53505it [02:46, 322.28it/s]
Running evaluation: 53505it [02:46, 321.22it/s]


In [44]:
# -----------------------------------------------------------------------------
# Print results
# -----------------------------------------------------------------------------
print("Full-precision metrics:", metrics_fp)
print("Quantized metrics:    ", metrics_int8)
print(f"Inference time ↓: {(metrics_fp['inf_time']   - metrics_int8['inf_time'])   / metrics_fp['inf_time']   * 100:.1f}%")
print(f"Eval time ↓:      {(metrics_fp['eval_time'] - metrics_int8['eval_time']) / metrics_fp['eval_time'] * 100:.1f}%")
print(f"Inference time ↓: {metrics_fp['inf_time']:.4f}s")
# print(f"Eval time ↓:      {(metrics_fp['eval_time'] - metrics_int8['eval_time']) / metrics_fp['eval_time'] * 100:.1f}%")
print(f"Inference time ↓: {metrics_int8['inf_time']:.4f}s")
# print(f"Eval time ↓:      {(metrics_fp['eval_time'] - metrics_int8['eval_time']) / metrics_fp['eval_time'] * 100:.1f}%")


Full-precision metrics: {'RMSE': np.float64(0.42567054745741684), 'MASE': np.float64(0.9383791327488374), 'sMAPE': np.float64(0.44552324287192097), 'inf_time': 99.15749335289001, 'eval_time': 166.3502960205078}
Quantized metrics:     {'RMSE': np.float64(0.44229397355988903), 'MASE': np.float64(1.0373053124277989), 'sMAPE': np.float64(0.47795135155647406), 'inf_time': 122.00386500358582, 'eval_time': 166.88988614082336}
Inference time ↓: -23.0%
Eval time ↓:      -0.3%
Inference time ↓: 99.1575s
Inference time ↓: 122.0039s
