In [22]:
# ─── Cell 1: Patch the mean warning and silence logging ───
import warnings, numpy as np, logging
from gluonts.model import forecast as _fm

# Silence the specific UserWarning
warnings.filterwarnings(
    "ignore",
    message=r"The mean prediction is not stored in the forecast data; the median is being returned instead\. This behaviour may change in the future\."
)
# Disable all WARNING and below from all loggers (including GluonTS internals)
logging.disable(logging.WARNING)

# Override .mean on the original classes (in-place)
def _silent_mean(self):
    fd = getattr(self, "_forecast_dict", {})
    if "mean" in fd:
        return fd["mean"]
    if hasattr(self, "samples"):
        return np.median(self.samples, axis=0)
    return self.quantile("p50")

_fm.SampleForecast.mean   = property(_silent_mean)
_fm.QuantileForecast.mean = property(_silent_mean)
# ────────────────────────────────────────────────────────────

In [23]:
# ─── Cell 2: Core imports ───
import torch
import time
import pandas as pd
from gluonts.dataset.common import ListDataset
from gluonts.dataset.field_names import FieldName
from gluonts.evaluation import Evaluator
from gluonts.evaluation.backtest import make_evaluation_predictions
from local.gluonts.torch.model.tft import TemporalFusionTransformerEstimator
import optuna
from optuna.samplers import TPESampler

from pytorch_lightning.utilities.model_summary import ModelSummary
# ─────────────────────────

In [24]:
from sklearn.preprocessing import StandardScaler, LabelEncoder

prediction_length = 24
context_length = 168
window_length = context_length + prediction_length
freq = "1h"

def get_electricity_dataset(csv_path: str, total_samples=500_000):
    df = pd.read_csv(csv_path, index_col=0)
    df["date"] = pd.to_datetime(df["date"])

    # Encode categorical ID
    label_encoder = LabelEncoder()
    df["categorical_id"] = label_encoder.fit_transform(df["categorical_id"].astype(str))

    # Limit data to Jan 1 – Sep 1, 2014 (i.e., days_from_start < 1339)
    full_range_df = df[df["days_from_start"] < 1339]

    # Sample sliding windows over full range
    def sample_windows(subset_df):
        samples = []
        for entity_id, group in subset_df.groupby("id"):
            group = group.sort_values("date")
            if len(group) < window_length:
                continue

            scaler = StandardScaler().fit(group[["power_usage", "hour", "day_of_week", "t"]].values)
            target_scaler = StandardScaler().fit(group[["power_usage"]].values)

            features = scaler.transform(group[["power_usage", "hour", "day_of_week", "t"]].values)
            targets = target_scaler.transform(group[["power_usage"]].values).flatten().astype(np.float32)

            feat_hour = features[:, 1].astype(np.float32)
            feat_dow = features[:, 2].astype(np.float32)
            feat_time = features[:, 3].astype(np.float32)

            static_cat = [group["categorical_id"].iloc[0]]
            dates = group["date"].values

            for i in range(0, len(group) - window_length + 1):
                samples.append({
                    FieldName.START: dates[i],
                    FieldName.TARGET: targets[i:i + window_length],
                    FieldName.FEAT_STATIC_CAT: static_cat,
                    FieldName.FEAT_DYNAMIC_REAL: [
                        feat_hour[i:i + window_length],
                        feat_dow[i:i + window_length],
                        feat_time[i:i + window_length],
                    ],
                })

        return samples

    # Step 1: All possible windows up to Sep 1
    all_samples = sample_windows(full_range_df)

    # Step 2: Shuffle and take 500,000 total
    np.random.shuffle(all_samples)
    all_samples = all_samples[:total_samples]

    # Step 3: Split into 450k train / 50k val
    train_samples = all_samples[:450_000]
    val_samples = all_samples[450_000:]

    # Step 4: Test set = fixed last 7 days (same as official code)
    test_df = df[df["days_from_start"] >= 1332]
    test_samples = sample_windows(test_df)

    train_ds = ListDataset(train_samples, freq=freq)
    val_ds = ListDataset(val_samples, freq=freq)
    test_ds = ListDataset(test_samples, freq=freq)

    return train_ds, val_ds, test_ds, freq, prediction_length


In [25]:
file_path = "../Dataset/Electricity/hourly_electricity.csv"  # Adjust if it's in a subfolder

In [26]:
# ─── Cell 4: Load data & set precision ───
torch.set_float32_matmul_precision("high")
train_ds, val_ds, test_ds, freq, prediction_length = get_electricity_dataset(file_path)

In [27]:
def objective(trial):
    # ── 1) Sample hyperparameters ──
    batch_size = trial.suggest_categorical("batch_size", [64, 128, 256])
    num_heads  = trial.suggest_categorical("num_heads",  [1,   4  ])
    # hidden_dim = trial.suggest_categorical("hidden_dim",[80, 160, 240])

    # ── 2) Build estimator with sampled batch_size ──
    print("Batch Size: ", batch_size)
    print("Num heads: ",num_heads)
    # print("Hidden dim: ",hidden_dim)
    
    estimator = TemporalFusionTransformerEstimator(
        freq=freq,
        prediction_length=prediction_length,
        context_length=168,
        static_cardinalities=[370],
        dynamic_dims=[3],
        quantiles=[0.1, 0.5, 0.9],
        hidden_dim=160,
        num_heads=num_heads,
        batch_size=batch_size,
        num_batches_per_epoch= int(len(train_ds) // batch_size),
        trainer_kwargs={
            "accelerator": "gpu",
            "devices": [0],
            "max_epochs": 3,
            "precision": "bf16-mixed",
        },
    )

    # Optional: print model summary
    module  = estimator.create_lightning_module()
    summary = ModelSummary(module, max_depth=1)
    print(summary)

    # ── 3) Time & memory for training ──
    if torch.cuda.is_available():
        torch.cuda.reset_peak_memory_stats()
        torch.cuda.synchronize()
    t0 = time.time()
    predictor = estimator.train(
        training_data=train_ds,
        validation_data=val_ds
    )
    if torch.cuda.is_available():
        torch.cuda.synchronize()
        train_peak_mem = torch.cuda.max_memory_allocated() / (1024**2)
    train_time = time.time() - t0
    print(f"[bs={batch_size}] Training: {train_time:.3f}s, Peak GPU mem: {train_peak_mem:.1f} MB")

    # ── 4) Time & memory for inference ──
    if torch.cuda.is_available():
        torch.cuda.reset_peak_memory_stats()
        torch.cuda.synchronize()
    t1 = time.time()
    # (keeps your existing precision setting)
    with torch.amp.autocast("cuda"):
        f_it, ts_it = make_evaluation_predictions(
            dataset=test_ds,
            predictor=predictor,
            num_samples=100,
        )
        forecasts = list(f_it)
        tss       = list(ts_it)
    if torch.cuda.is_available():
        torch.cuda.synchronize()
        inf_peak_mem = torch.cuda.max_memory_allocated() / (1024**2)
    inf_time = time.time() - t1
    print(f"[bs={batch_size}] Inference: {inf_time:.3f}s, Peak GPU mem: {inf_peak_mem:.1f} MB")

    # ── 5) Time evaluation and return metric ──
    t2 = time.time()
    evaluator    = Evaluator(quantiles=[0.5], num_workers=0)
    agg_metrics, _ = evaluator(tss, forecasts)
    eval_time = time.time() - t2
    print(f"[bs={batch_size}] Evaluation time: {eval_time:.3f}s")

    return agg_metrics["MASE"]


In [29]:
# ─── Cell 6: Run Optuna optimization ───
from optuna.samplers import GridSampler
# ─── Cell 6: Run Optuna optimization ───
search_space = {
    "batch_size": [64, 128, 256],
    "num_heads":  [1,   4],
}

# 2. Create the GridSampler
sampler = GridSampler(search_space)

study = optuna.create_study(direction="minimize", sampler=sampler)
start = time.time()
study.optimize(objective, n_trials=6, timeout=36000)
end = time.time() - start

print("total time taken:", end)
print("Best trial:")
print(study.best_trial)
# ─────────────────────────────────────────────

Batch Size:  256
Num heads:  4
  | Name  | Type                           | Params | Mode  | In sizes                                                                                 | Out sizes                     
---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
0 | model | TemporalFusionTransformerModel | 858 K  | train | [[1, 168], [1, 168], [1, 1], [1, 1], [1, 192, 7], [1, 192, 0], [1, 168, 0], [1, 168, 0]] | [[[1, 24, 3]], [1, 1], [1, 1]]
---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
858 K     Trainable params
0         Non-trainable params
858 K     Total params
3.435     Total estimated model params size (MB)
250       Modules in train mode
0         Modules in eval mode


/home/akm9999/.local/lib/python3.9/site-packages/lightning/fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /share/apps/pyenv/py3.9/lib/python3.9/site-packages/ ...
/home/akm9999/.local/lib/python3.9/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:654: Checkpoint directory /scratch/akm9999/Project/FlashAttention/lightning_logs/version_60858/checkpoints exists and is not empty.


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

[bs=256] Training: 1125.634s, Peak GPU mem: 871.8 MB
[bs=256] Inference: 41.955s, Peak GPU mem: 365.5 MB




Running evaluation: 0it [00:00, ?it/s][A[A

Running evaluation: 3179it [00:10, 317.89it/s][A[A

Running evaluation: 6358it [00:20, 317.73it/s][A[A

Running evaluation: 9542it [00:30, 318.04it/s][A[A

Running evaluation: 12737it [00:40, 318.61it/s][A[A

Running evaluation: 15932it [00:50, 318.76it/s][A[A

Running evaluation: 19133it [01:00, 319.20it/s][A[A

Running evaluation: 22334it [01:10, 319.24it/s][A[A

Running evaluation: 25128it [01:20, 305.67it/s][A[A

Running evaluation: 25128it [01:20, 305.67it/s][A[A

Running evaluation: 28336it [01:30, 310.36it/s][A[A

Running evaluation: 31558it [01:40, 314.00it/s][A[A

Running evaluation: 34780it [01:50, 316.25it/s][A[A

Running evaluation: 38003it [02:00, 318.08it/s][A[A

Running evaluation: 41226it [02:10, 319.17it/s][A[A

Running evaluation: 44445it [02:20, 319.98it/s][A[A

Running evaluation: 47664it [02:30, 320.01it/s][A[A

Running evaluation: 53505it [02:48, 317.66it/s][A[A


[bs=256] Evaluation time: 168.750s
Batch Size:  128
Num heads:  1
  | Name  | Type                           | Params | Mode  | In sizes                                                                                 | Out sizes                     
---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
0 | model | TemporalFusionTransformerModel | 858 K  | train | [[1, 168], [1, 168], [1, 1], [1, 1], [1, 192, 7], [1, 192, 0], [1, 168, 0], [1, 168, 0]] | [[[1, 24, 3]], [1, 1], [1, 1]]
---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
858 K     Trainable params
0         Non-trainable params
858 K     Total params
3.435     Total estimated model params size (MB)
250       Modules in train mode
0         Modules in eva

/home/akm9999/.local/lib/python3.9/site-packages/lightning/fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /share/apps/pyenv/py3.9/lib/python3.9/site-packages/ ...
/home/akm9999/.local/lib/python3.9/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:654: Checkpoint directory /scratch/akm9999/Project/FlashAttention/lightning_logs/version_60858/checkpoints exists and is not empty.


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

[bs=128] Training: 1400.961s, Peak GPU mem: 450.2 MB
[bs=128] Inference: 42.358s, Peak GPU mem: 183.6 MB




Running evaluation: 0it [00:00, ?it/s][A[A

Running evaluation: 3167it [00:10, 316.67it/s][A[A

Running evaluation: 6357it [00:20, 317.99it/s][A[A

Running evaluation: 9547it [00:30, 318.33it/s][A[A

Running evaluation: 12736it [00:40, 318.56it/s][A[A

Running evaluation: 15925it [00:50, 317.70it/s][A[A

Running evaluation: 19088it [01:01, 302.39it/s][A[A

Running evaluation: 22287it [01:11, 307.89it/s][A[A

Running evaluation: 25486it [01:21, 311.28it/s][A[A

Running evaluation: 28684it [01:31, 313.88it/s][A[A

Running evaluation: 31882it [01:41, 315.47it/s][A[A

Running evaluation: 35074it [01:51, 316.59it/s][A[A

Running evaluation: 38284it [02:01, 317.92it/s][A[A

Running evaluation: 41494it [02:11, 318.55it/s][A[A

Running evaluation: 44695it [02:21, 318.99it/s][A[A

Running evaluation: 47903it [02:31, 319.52it/s][A[A

Running evaluation: 53505it [02:49, 316.27it/s][A[A


[bs=128] Evaluation time: 169.494s
Batch Size:  64
Num heads:  4
  | Name  | Type                           | Params | Mode  | In sizes                                                                                 | Out sizes                     
---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
0 | model | TemporalFusionTransformerModel | 858 K  | train | [[1, 168], [1, 168], [1, 1], [1, 1], [1, 192, 7], [1, 192, 0], [1, 168, 0], [1, 168, 0]] | [[[1, 24, 3]], [1, 1], [1, 1]]
---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
858 K     Trainable params
0         Non-trainable params
858 K     Total params
3.435     Total estimated model params size (MB)
250       Modules in train mode
0         Modules in eval

/home/akm9999/.local/lib/python3.9/site-packages/lightning/fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /share/apps/pyenv/py3.9/lib/python3.9/site-packages/ ...
/home/akm9999/.local/lib/python3.9/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:654: Checkpoint directory /scratch/akm9999/Project/FlashAttention/lightning_logs/version_60858/checkpoints exists and is not empty.


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

[bs=64] Training: 2003.377s, Peak GPU mem: 240.5 MB
[bs=64] Inference: 46.063s, Peak GPU mem: 112.1 MB




Running evaluation: 0it [00:00, ?it/s][A[A

Running evaluation: 3202it [00:10, 320.14it/s][A[A

Running evaluation: 6404it [00:20, 319.68it/s][A[A

Running evaluation: 9537it [00:31, 303.22it/s][A[A

Running evaluation: 9537it [00:31, 303.22it/s][A[A

Running evaluation: 12666it [00:41, 306.91it/s][A[A

Running evaluation: 15812it [00:51, 309.62it/s][A[A

Running evaluation: 18958it [01:01, 311.02it/s][A[A

Running evaluation: 22149it [01:11, 313.61it/s][A[A

Running evaluation: 25336it [01:21, 314.50it/s][A[A

Running evaluation: 25336it [01:21, 314.50it/s][A[A

Running evaluation: 28500it [01:31, 315.04it/s][A[A

Running evaluation: 31673it [01:41, 315.73it/s][A[A

Running evaluation: 34846it [01:51, 316.15it/s][A[A

Running evaluation: 38019it [02:01, 316.48it/s][A[A

Running evaluation: 41192it [02:11, 316.43it/s][A[A

Running evaluation: 44362it [02:21, 316.59it/s][A[A

Running evaluation: 47140it [02:31, 302.28it/s][A[A

Running evaluation: 4

[bs=64] Evaluation time: 171.884s
Batch Size:  128
Num heads:  4
  | Name  | Type                           | Params | Mode  | In sizes                                                                                 | Out sizes                     
---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
0 | model | TemporalFusionTransformerModel | 858 K  | train | [[1, 168], [1, 168], [1, 1], [1, 1], [1, 192, 7], [1, 192, 0], [1, 168, 0], [1, 168, 0]] | [[[1, 24, 3]], [1, 1], [1, 1]]
---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
858 K     Trainable params
0         Non-trainable params
858 K     Total params
3.435     Total estimated model params size (MB)
250       Modules in train mode
0         Modules in eval

/home/akm9999/.local/lib/python3.9/site-packages/lightning/fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /share/apps/pyenv/py3.9/lib/python3.9/site-packages/ ...
/home/akm9999/.local/lib/python3.9/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:654: Checkpoint directory /scratch/akm9999/Project/FlashAttention/lightning_logs/version_60858/checkpoints exists and is not empty.


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

[bs=128] Training: 1402.072s, Peak GPU mem: 450.6 MB
[bs=128] Inference: 40.753s, Peak GPU mem: 193.6 MB




Running evaluation: 0it [00:00, ?it/s][A[A

Running evaluation: 3190it [00:10, 318.92it/s][A[A

Running evaluation: 6380it [00:20, 318.69it/s][A[A

Running evaluation: 9570it [00:30, 318.80it/s][A[A

Running evaluation: 12772it [00:40, 319.32it/s][A[A

Running evaluation: 15974it [00:50, 319.24it/s][A[A

Running evaluation: 19168it [01:00, 319.29it/s][A[A

Running evaluation: 22374it [01:10, 319.71it/s][A[A

Running evaluation: 25580it [01:20, 319.39it/s][A[A

Running evaluation: 28767it [01:31, 305.95it/s][A[A

Running evaluation: 31958it [01:41, 309.86it/s][A[A

Running evaluation: 35149it [01:51, 312.62it/s][A[A

Running evaluation: 38340it [02:01, 314.46it/s][A[A

Running evaluation: 41528it [02:11, 315.45it/s][A[A

Running evaluation: 44706it [02:21, 316.03it/s][A[A

Running evaluation: 47881it [02:31, 316.33it/s][A[A

Running evaluation: 53505it [02:49, 316.15it/s][A[A


[bs=128] Evaluation time: 169.565s
Batch Size:  64
Num heads:  1
  | Name  | Type                           | Params | Mode  | In sizes                                                                                 | Out sizes                     
---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
0 | model | TemporalFusionTransformerModel | 858 K  | train | [[1, 168], [1, 168], [1, 1], [1, 1], [1, 192, 7], [1, 192, 0], [1, 168, 0], [1, 168, 0]] | [[[1, 24, 3]], [1, 1], [1, 1]]
---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
858 K     Trainable params
0         Non-trainable params
858 K     Total params
3.435     Total estimated model params size (MB)
250       Modules in train mode
0         Modules in eval

/home/akm9999/.local/lib/python3.9/site-packages/lightning/fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /share/apps/pyenv/py3.9/lib/python3.9/site-packages/ ...
/home/akm9999/.local/lib/python3.9/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:654: Checkpoint directory /scratch/akm9999/Project/FlashAttention/lightning_logs/version_60858/checkpoints exists and is not empty.


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

[bs=64] Training: 1987.662s, Peak GPU mem: 240.4 MB
[bs=64] Inference: 48.903s, Peak GPU mem: 112.1 MB




Running evaluation: 0it [00:00, ?it/s][A[A

Running evaluation: 3166it [00:10, 316.60it/s][A[A

Running evaluation: 6346it [00:20, 317.41it/s][A[A

Running evaluation: 9527it [00:30, 317.72it/s][A[A

Running evaluation: 12710it [00:40, 317.94it/s][A[A

Running evaluation: 15895it [00:50, 318.13it/s][A[A

Running evaluation: 19080it [01:00, 318.11it/s][A[A

Running evaluation: 22261it [01:10, 318.10it/s][A[A

Running evaluation: 25442it [01:20, 317.94it/s][A[A

Running evaluation: 28619it [01:31, 304.64it/s][A[A

Running evaluation: 31799it [01:41, 308.62it/s][A[A

Running evaluation: 34981it [01:51, 311.49it/s][A[A

Running evaluation: 38163it [02:01, 313.50it/s][A[A

Running evaluation: 41347it [02:11, 314.94it/s][A[A

Running evaluation: 44531it [02:21, 315.54it/s][A[A

Running evaluation: 47719it [02:31, 316.52it/s][A[A

Running evaluation: 53505it [02:49, 315.35it/s][A[A


[bs=64] Evaluation time: 169.990s
Batch Size:  256
Num heads:  1
  | Name  | Type                           | Params | Mode  | In sizes                                                                                 | Out sizes                     
---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
0 | model | TemporalFusionTransformerModel | 858 K  | train | [[1, 168], [1, 168], [1, 1], [1, 1], [1, 192, 7], [1, 192, 0], [1, 168, 0], [1, 168, 0]] | [[[1, 24, 3]], [1, 1], [1, 1]]
---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
858 K     Trainable params
0         Non-trainable params
858 K     Total params
3.435     Total estimated model params size (MB)
250       Modules in train mode
0         Modules in eval

/home/akm9999/.local/lib/python3.9/site-packages/lightning/fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /share/apps/pyenv/py3.9/lib/python3.9/site-packages/ ...
/home/akm9999/.local/lib/python3.9/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:654: Checkpoint directory /scratch/akm9999/Project/FlashAttention/lightning_logs/version_60858/checkpoints exists and is not empty.


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

[bs=256] Training: 1126.620s, Peak GPU mem: 860.8 MB
[bs=256] Inference: 39.522s, Peak GPU mem: 355.4 MB




Running evaluation: 0it [00:00, ?it/s][A[A

Running evaluation: 3197it [00:10, 319.66it/s][A[A

Running evaluation: 6407it [00:20, 320.42it/s][A[A

Running evaluation: 9617it [00:30, 320.57it/s][A[A

Running evaluation: 12834it [00:40, 320.99it/s][A[A

Running evaluation: 16051it [00:50, 321.02it/s][A[A

Running evaluation: 19262it [01:00, 320.96it/s][A[A

Running evaluation: 22486it [01:10, 321.42it/s][A[A

Running evaluation: 25710it [01:20, 321.36it/s][A[A

Running evaluation: 28923it [01:31, 308.49it/s][A[A

Running evaluation: 32135it [01:41, 312.30it/s][A[A

Running evaluation: 35353it [01:51, 315.14it/s][A[A

Running evaluation: 38571it [02:01, 316.67it/s][A[A

Running evaluation: 41774it [02:11, 317.70it/s][A[A

Running evaluation: 44999it [02:21, 319.12it/s][A[A

Running evaluation: 48224it [02:31, 319.98it/s][A[A

Running evaluation: 53505it [02:47, 318.69it/s][A[A


[bs=256] Evaluation time: 168.204s
total time taken: 10327.435197591782
Best trial:
FrozenTrial(number=3, state=TrialState.COMPLETE, values=[0.8949014631189218], datetime_start=datetime.datetime(2025, 5, 9, 4, 2, 20, 407256), datetime_complete=datetime.datetime(2025, 5, 9, 4, 29, 13, 422558), params={'batch_size': 128, 'num_heads': 4}, user_attrs={}, system_attrs={'search_space': {'batch_size': [64, 128, 256], 'num_heads': [1, 4]}, 'grid_id': 3}, intermediate_values={}, distributions={'batch_size': CategoricalDistribution(choices=(64, 128, 256)), 'num_heads': CategoricalDistribution(choices=(1, 4))}, trial_id=3, value=None)


In [None]:
# 1) Grab the best hyperparameters
best = study.best_trial.params
num_heads    = best["num_heads"]
# hidden_dim   = best["hidden_dim"]
batch_size   = best["batch_size"]


# collect into a dict
best_params = {
    "num_heads":    num_heads,
    # "hidden_dim":   hidden_dim,
    "batch_size":   batch_size,
}

# pretty‑print each name and value
for name, value in best_params.items():
    print(f"{name:12s}: {value}")

# # 2) Build a new estimator with full‑training epochs
# best_estimator = TemporalFusionTransformerEstimator(
#     freq="1H",
#     prediction_length=24,
#     context_length=168,
#     hidden_dim=hidden_dim,         # 6 * 20 = 120
#     num_heads=num_heads,           # 6
#     dropout_rate=dropout_rate,     # ~0.26
#     lr=lr,
#     weight_decay=weight_decay,
#     batch_size=64,
#     num_batches_per_epoch=50,
#     trainer_kwargs={
#         "accelerator": "gpu",
#         "max_epochs": 1,           # full training run
#     },
# )

# # 3) Train on train_ds (and val_ds if you want early stopping)
# best_predictor = best_estimator.train(
#     training_data=train_ds,
#     validation_data=val_ds
# )

# forecast_it, ts_it = make_evaluation_predictions(
#     dataset=test_ds,
#     predictor=best_predictor,
#     num_samples=100,
# )
# agg_metrics, item_metrics = Evaluator()(ts_it, forecast_it)

# print("Test set metrics:", agg_metrics)


In [None]:
from pathlib import Path

# choose an output directory
save_path = Path("model_dir")

# this will create model_dir/ with all the predictor files inside
best_predictor.serialize(save_path)
