# TFT model

## Preliminaries

You can run the notebook in two ways:

1. **Google Colab**: place the project folder `heat-forecast` in **MyDrive**. The setup cell below will mount Drive and automatically add `MyDrive/heat-forecast/src` to `sys.path` so `import heat_forecast` works out of the box.

2. **Local machine**:

   * **Installing our package:** from the project root, run `pip install -e .` once (editable install). Then you can open the notebook anywhere and import the package normally.
   * **Alternative:** if you’re running the notebook from `.../heat-forecast/notebooks/` without installing the package, the setup cell will detect `../src` and automatically add it to `sys.path`.

In [None]:
# --- Detect if running on Google Colab & Set base dir ---
# %cd /home/giovanni.lombardi/heat-forecast/notebooks
import subprocess
from pathlib import Path
import sys

def in_colab() -> bool:
    try:
        import google.colab  # type: ignore
        return True
    except Exception:
        return False

# Install required packages only if not already installed
def pip_install(pkg: str):
    subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", pkg])

# Set base directory and handle environment
if in_colab():
    # Make sure IPython is modern (avoids the old %autoreload/imp issue if you ever use it)
    pip_install("ipython>=8.25")
    pip_install("ipykernel>=6.29")
    
    def install(package):
        subprocess.check_call([sys.executable, "-m", "pip", "install", package])

    for pkg in ["statsmodels", "statsforecast", "mlforecast"]:
        pip_install(pkg)

    # Mount Google Drive
    from google.colab import drive  # type: ignore
    drive.mount('/content/drive')

    # Set base directory to your Drive project folder
    BASE_DIR = Path('/content/drive/MyDrive/heat-forecast')

    # Add `src/` to sys.path for custom package imports
    SRC_PATH = BASE_DIR / 'src'
    if str(SRC_PATH) not in sys.path:
        sys.path.append(str(SRC_PATH))

    # Sanity checks (helpful error messages if path is wrong)
    assert SRC_PATH.exists(), f"Expected '{SRC_PATH}' to exist. Fix BASE_DIR."
    pkg_dir = SRC_PATH / "heat_forecast"
    assert pkg_dir.exists(), f"Expected '{pkg_dir}' package directory."
    init_file = pkg_dir / "__init__.py"
    assert init_file.exists(), f"Missing '{init_file}'. Add it so Python treats this as a package."

else:
    # Local: either rely on editable install (pip install -e .) or add src/ when running from repo
    # Assume notebook lives in PROJECT_ROOT/notebooks/
    BASE_DIR = Path.cwd().resolve().parent
    SRC_PATH = BASE_DIR / "src"

    added_src = False
    if (SRC_PATH / "heat_forecast").exists() and str(SRC_PATH) not in sys.path:
        sys.path.append(str(SRC_PATH))
        added_src = True

# --- Logging setup ---
import logging
from zoneinfo import ZoneInfo
from datetime import datetime

LOG_DIR  = (BASE_DIR / "logs")
LOG_DIR.mkdir(parents=True, exist_ok=True)
LOG_FILE = LOG_DIR / "run.log"
PREV_LOG = LOG_DIR / "run.prev.log"

# If there's a previous run.log with content, archive it to run.prev.log
if LOG_FILE.exists() and LOG_FILE.stat().st_size > 0:
    try:
        # Replace old run.prev.log if present
        if PREV_LOG.exists():
            PREV_LOG.unlink()
        LOG_FILE.rename(PREV_LOG)
    except Exception as e:
        # Fall back to truncating if rename fails (e.g., file locked)
        print(f"[warn] Could not archive previous log: {e}. Truncating current run.log.")
        LOG_FILE.write_text("")

# Configure logging: fresh file for this run + echo to notebook/stdout
file_handler   = logging.FileHandler(LOG_FILE, mode="w", encoding="utf-8")
stream_handler = logging.StreamHandler(sys.stdout)

fmt = logging.Formatter("%(asctime)s | %(levelname)s | %(name)s | %(message)s",
                        datefmt="%m-%d %H:%M:%S")
file_handler.setFormatter(fmt)
stream_handler.setFormatter(fmt)

root = logging.getLogger()
root.handlers[:] = [file_handler, stream_handler]  # replace handlers (important in notebooks)
root.setLevel(logging.INFO)

# Use Rome time
logging.Formatter.converter = lambda *args: datetime.now(ZoneInfo("Europe/Rome")).timetuple()

logging.captureWarnings(True)
logging.info("=== Logging started (fresh current run) ===")
logging.info("Previous run (if any): %s", PREV_LOG if PREV_LOG.exists() else "none")

if added_src:
    logging.info("heat_forecast not installed; added src/ to sys.path")
else:
    logging.info("heat_forecast imported without modifying sys.path (likely installed)")

OPTUNA_DIR = BASE_DIR / "results" / "finetuning" / "tft"
OPTUNA_DIR.mkdir(parents=True, exist_ok=True)
logging.info("BASE_DIR (make sure it's '*/heat-forecast/', else cd and re-run): %s", BASE_DIR)
logging.info("LOG_DIR: %s", LOG_DIR)
logging.info("OPTUNA_DIR: %s", OPTUNA_DIR)

Ensure [compatibility with Numba](https://numba.readthedocs.io/en/stable/user/installing.html#numba-support-info).

In [None]:
import sys, numpy, numba
logging.info("=== Current Environment ===")
logging.info("Python : %s", sys.version.split()[0])
logging.info("NumPy  : %s", numpy.__version__)
logging.info("Numba  : %s", numba.__version__)

Imports:

In [None]:
# --- Magic Commands ---
%load_ext autoreload
%autoreload 2

# --- Standard Library ---
import os
os.environ["OPTUNA_LOGGING_DISABLE_DEFAULT_HANDLER"] = "1" # prevent Optuna from attaching its handler
import logging
from datetime import datetime
from itertools import product
import torch
import optuna
import copy
from typing import Dict, Tuple

# --- Third-Party Libraries ---
import numpy as np
import pandas as pd
pd.set_option('display.float_format', '{:.3f}'.format)

import yaml
import matplotlib as mpl
import matplotlib.pyplot as plt
from matplotlib.ticker import MaxNLocator
from tqdm.notebook import tqdm
from IPython.display import display, HTML

# --- Plotting Configuration ---
plt.style.use("seaborn-v0_8")
plt.rcParams['font.size'] = 14
plt.rcParams['axes.titlesize'] = 16
plt.rcParams['axes.labelsize'] = 14
plt.rcParams['xtick.labelsize'] = 10
plt.rcParams['ytick.labelsize'] = 10
plt.rcParams['legend.fontsize'] = 10
plt.rcParams['figure.titlesize'] = 18
mpl.rcParams['axes.grid'] = True
mpl.rcParams['axes.grid.which'] = 'both'

# --- Helper to detect error cause by keyboard interrupt ---
# Pytorch Lightning sometimes wraps KeyboardInterrupt in another exception
def is_keyboard_interrupt_like(exc: BaseException) -> bool:
    """
    Return True if exc is a KeyboardInterrupt or was caused by one
    (e.g. NameError raised while handling KeyboardInterrupt).
    """
    cur = exc
    visited = set()
    while cur is not None and cur not in visited:
        visited.add(cur)
        if isinstance(cur, KeyboardInterrupt):
            return True
        # check both context and cause
        cur = cur.__context__ or cur.__cause__
    return False

# --- YAML Customization ---
from heat_forecast.utils.yaml import safe_dump_yaml

# --- Safe File Deletion Helper ---
from heat_forecast.utils.fileshandling import remove_tree

# --- Project-Specific Imports ---
from heat_forecast.utils.cv_utils import get_cv_params_for_test
from heat_forecast.pipeline.tft import (
    TFTModelConfig, DataConfig, FeatureConfig, NormalizeConfig, TFTRunConfig, TFTPipeline,
    TrainConfig
)

from heat_forecast.utils.optuna import (
    OptunaStudyConfig, run_study, continue_study, describe_suggester, rename_study, clone_filtered_study
)

from heat_forecast.utils.plotting import plotly_cutoffs_with_exog

logging.info("All imports successful.")

Import pre-elaborated data.

In [None]:
heat_path = BASE_DIR / 'data' / 'timeseries_preprocessed' / 'heat.csv'
aux_path = BASE_DIR / 'data' / 'timeseries_preprocessed' / 'auxiliary.csv'
heat_df = pd.read_csv(heat_path, parse_dates=['ds'])
aux_df = pd.read_csv(aux_path, parse_dates=['ds'])
logging.info("Loaded heat data: %s", heat_path.relative_to(BASE_DIR))
logging.info("Loaded auxiliary data: %s", aux_path.relative_to(BASE_DIR))

Set device:

In [None]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logging.info(f"Using device: {DEVICE}")

Set a seed (set to None to skip):

In [None]:
SEED = 42
logging.info(f"Using seed: {SEED}")

## Pipeline

In [None]:
id = 'F1'
config = TFTRunConfig(
    model=TFTModelConfig(
        hidden_size=64, 
        dropout=0.3, 
        input_chunk_length=168, output_chunk_length=168,
        torch_device_str="auto"
    ),
    data=DataConfig(stride=1), # high stride to make training faster for this example
    train=TrainConfig(
        n_epochs=25,
        es_rel_min_delta=0.1, es_warmup_epochs=5,
    ),
    features=FeatureConfig(),
    norm=NormalizeConfig(),
    seed=SEED
)

tft_pipe = TFTPipeline(
    config=config, 
    target_df=heat_df[heat_df['unique_id'] == id], 
    aux_df=aux_df[aux_df['unique_id'] == id]
)

In [None]:
tft_pipe.generate_vars()
_ = tft_pipe.describe_dataset()

In [None]:
_ = tft_pipe.describe_model()

In [None]:
start_train = None
end_train = pd.Timestamp("2023-10-31 23:00")  
end_val = pd.Timestamp("2024-04-01 23:00")
out = tft_pipe.fit(end_train=end_train, end_val=end_val, start_train=start_train)

In [None]:
preds_df = tft_pipe.predict_many(n=168, start=end_train+pd.Timedelta(hours=1), end=end_val, stride_hours=168)
plotly_cutoffs_with_exog(
    target_df=heat_df[heat_df['unique_id'] == id],
    cv_df=preds_df,
    aux_df=aux_df[aux_df['unique_id'] == id],
    start_offset=24*3,
    end_offset=24*3,
)

## Optuna tuning

We tune the model using **Optuna**, a Python framework for hyperparameter optimization that supports both **grid search** and more efficient, adaptive methods. With Optuna we can easily run evaluations over a fixed hyperparameter grid, or use its **TPE (Tree-structured Parzen Estimator)** sampler to explore continuous, large and / or conditional search spaces in a data-driven way. The TPE sampler fits separate probabilistic models to "good" and "bad" parameter sets and selects new trials that maximize the ratio of the two (i.e. promising configurations).

For more details on TPE and Optuna, see:

* [Optuna main site](https://optuna.org/)
* [TPESampler documentation](https://optuna.readthedocs.io/en/stable/reference/samplers/generated/optuna.samplers.TPESampler.html) 
* [This article, for a deeper understanding of the TPESampler](https://arxiv.org/abs/2304.11127)

Overview of the following subsections:
* **"Start a new study"**: how to create and run a new optimization study from scratch.
* **"Review past studies"**: how to load a completed Optuna study in order to delete, copy, continue, or inspect its results in detail.

A detailed review and commentary of the tuning results is provided in the next section.

### Start a new study

As a first step to start a study, select the time series (`unique_id`) you want to optimize. Next, choose a **suggester function** (or define a new one in `heat_forecast/suggesters/lstm.py`) that specifies the hyperparameter search space. Finally, configure the Optuna study by specifying the sampler, pruner, study name, optimization objective, and other parameters that control the search process.

In [None]:
# --- Step 1: Select series ---
unique_id = 'F1'

# --- Step 2: Define search space ---
suggester_name = "tft_v4_F1"

# === Do not edit below ===
# Set path based on unique_id
db_path = OPTUNA_DIR / f"optuna_{unique_id}.db"  # single DB file for each id
storage_url = f"sqlite:///{db_path.as_posix()}"

# Describe suggester
desc = describe_suggester(suggester_name)
logging.info(f"Suggester used:\n{desc}")

In [None]:
# --- Step 3: Configure study ---
optuna_cfg = OptunaStudyConfig(
    # General
    study_name='study_v4_F1',
    objective="avg_near_best", # "best" or "last" (based on val metric)
    n_trials=None,               # number of trials, use None for grid search
    timeout=None,              # timeout for the study (max time per trial in seconds)
    seed=SEED,                 # seed for reproducibility of the optuna sampler
    storage=storage_url,       # storage URL for the study
    pruner="nop",              # type of pruner: "percentile", "median", "nop"
    sampler="grid",             # type of sampler: "tpe", "grid"
)

Run the study:

In [None]:
# --- Step 4: Run the study ---
do_run = True

# === Do not edit below ===
if do_run:
    # Set path based on unique_id
    db_path = OPTUNA_DIR / f"optuna_{unique_id}.db"  # single DB file for each id
    storage_url = f"sqlite:///{db_path.as_posix()}"

    # Set base configuration of the pipeline
    base_cfg = TFTRunConfig(
        model = TFTModelConfig(),
        data = DataConfig(),
        features = FeatureConfig(),
        train = TrainConfig(),
        norm = NormalizeConfig(),
        seed = SEED,
    )

    start_train = pd.Timestamp("2019-09-30 23:00")  
    start_val = None # -> pd.Timestamp("2023-11-01 00:00") 
    end_train = pd.Timestamp("2023-10-31 23:00")  
    end_val = pd.Timestamp("2024-04-01 23:00")

    optuna_cfg.storage = storage_url

    study = run_study(
        unique_id,
        heat_df, 
        aux_df, 
        base_cfg,
        start_train=start_train, end_train=end_train, 
        start_val=start_val, end_val=end_val,
        optuna_cfg=optuna_cfg,
        suggest_config_name=suggester_name,
    )
    print("Best value (val loss):", study.best_value)
    print("Best params:", study.best_trial.params)

### Review past studies

In this section, we load an existing study and choose whether to continue, delete, rename, copy, or inspect it.

#### Load and/or modify a study

See available studies for a fixed ID:

In [None]:
# Choose unique_id
unique_id = 'F1'

# === Do not edit below ===
# Set path based on unique_id
db_path = OPTUNA_DIR / f"optuna_{unique_id}.db"  # single DB file for each id
storage_url = f"sqlite:///{db_path.as_posix()}"

# Get all study summaries
study_summaries = optuna.study.get_all_study_summaries(storage=storage_url)

# Print study summaries
study_summaries = sorted(study_summaries, key=lambda s: s.study_name)
lines = [f"Study name: {s.study_name}, trials: {s.n_trials}" for s in study_summaries]
logging.info("Available studies:\n\n" + "\n".join(lines))

Choose a study to continue/delete/review by selecting its name below. Then view a description of the search space used for that study.

In [None]:
study_name = "study_v4_F1"

# === Do not edit below ===
# View detailed description of the search space for the study (the suggester documentation)
study = optuna.load_study(study_name=study_name, storage=storage_url)
desc = describe_suggester(study.user_attrs.get("suggest_config_name", ""))
parent_study = study.user_attrs.get("_parent_study", "")
txt = f"parent study ({parent_study}):" if parent_study else 'this study:'
logging.info(f"Suggester used in {txt} \n{desc}\n")
if parent_study:
    filter = yaml.dump(study.user_attrs.get('_substudy_filter', ''), indent=4, sort_keys=False)
    logging.info(f"This substudy was obtained though the filter: \n{filter}")

Optionally continue the study:

In [None]:
do_continue = True

# Choose how many trials to add
n_new_trials = 15

# Choose combos to enqueue (or None to skip)
# e.g. [{"model.hidden_size": 50, "model.num_layers": 2}, {"model.hidden_size": 80, "model.num_layers": 3}]
combos_to_enqueue = None
trials_per_combo = 0  # how many times to repeat each combo (default is 1)

# === Do not edit below ===
if do_continue:
    # Continue the loaded study with additional trials
    study = continue_study(
        study_name,
        storage_url,
        n_new_trials=n_new_trials,
        target_df=heat_df,
        aux_df=aux_df,
        combos_to_enqueue=combos_to_enqueue,
        trials_per_combo=trials_per_combo,
    )

Optionally create a sub-study selecting trials based on a condition on a parameter:

In [None]:
do_create = False

if do_create:
    study = create_substudy_by_param(
        storage_url=storage_url,
        src_study_name=study_name,
        dst_study_name="substudy_lags_24_168_preliminary_v3_F1",
        param_name="features.lags_key",
        equals="7days_1day"
    )

Optionally delete the study:

In [None]:
do_delete = False

if do_delete:
    optuna.delete_study(
        study_name="study_v4_F1", # change to `study_name` if you are really sure you want to proceed 
        storage=storage_url,  
    )

Optionally rename the study:

In [None]:
do_rename = False

if do_rename:
    study = rename_study(
        old_name="",
        new_name="",
        storage_url=storage_url,
        keep_old=False,
        dry_run=True
    )

#### View study results

In [None]:
from heat_forecast.utils.optuna import (
    trials_df, trials_df_for_display, summarize_params_coverage, 
    plot_intermediate_values, plot_optimization_history
)

View best trials in the study:

In [None]:
# Create DataFrame of all trials
df, val_name = trials_df(study)

# Filter if needed
#df = df[df['params_model.num_layers'] == 2]
#df = df[df['params_model.hidden_size'] == 64]

# === Do not edit below ===
logging.info(f"Trials total={len(df)}, complete={sum(df['state']=='COMPLETE')}, pruned={sum(df['state']=='PRUNED')}, " \
             f"fail={sum(df['state']=='FAIL')}, running={sum(df['state']=='RUNNING')}, waiting={sum(df['state']=='WAITING')}")
with pd.option_context("display.max_columns", None, "display.max_rows", None):
    display(trials_df_for_display(df, val_name).head(100))

View optimization history:

In [None]:
plot_optimization_history(study)

View the intermediate validation losses for each trial. The function also supports filtering curves based on specific trial parameters or attributes.

In [None]:
plot_intermediate_values(
    study, 
    # --- Apply custom filtering here if needed ---
    #include_params={'features.lags_key': 'none'}, 
    #predicate=lambda t: t.params.get("train.learning_rate") > 5e-4 and t.params.get("model.dropout") < 0.1,
    dim_excluded=True,
    dim_factor=0.1,
    semilogy=False
)

Analysis of coverage of the parameters space:

In [None]:
# === Do not edit below ===
num_sty, cat_sty = summarize_params_coverage(study, df, val_name)

with pd.option_context("display.float_format", lambda v: f"{v:,.4f}"):
    if num_sty:
        logging.info("Coverage summary of numeric parameters:")
        display(num_sty)
    else:
        logging.info("No numeric parameters found for the study.")
    if cat_sty:
        logging.info("Coverage summary of categorical parameters:")
        display(cat_sty)
    else:
        logging.info("No categorical parameters found for the study.")

Counts of best epochs:

In [None]:
from matplotlib.ticker import MaxNLocator

# === Do not edit below ===
best_epochs = [
    t.user_attrs["best_epoch"]
    for t in study.trials
    if t.state == optuna.trial.TrialState.COMPLETE and "best_epoch" in t.user_attrs
]

s = pd.Series(pd.to_numeric(best_epochs, errors="coerce")).dropna().astype(int)
if not s.any():
    logging.info("No best_epoch found.")
else:
    # count each integer and include missing integers with count=0
    lo, hi = int(s.min()), int(s.max())
    counts = s.value_counts().sort_index()
    counts = counts.reindex(range(lo, hi + 1), fill_value=0)

    # plot
    fig, ax = plt.subplots(figsize=(9, 3.5), constrained_layout=True)
    ax.bar(counts.index, counts.values, width=0.8)
    ax.set_xlabel("Best epoch")
    ax.set_ylabel("Count of trials")
    ax.set_title(f"Best epoch counts")
    ax.set_xticks(counts.index)
    ax.yaxis.set_major_locator(MaxNLocator(integer=True))  # integer ticks only

    # log best
    logging.info(f"Best epoch by median: {s.median() :.0f}")

    plt.show()


#### View results marginalized on single hyperparameters

In [None]:
from heat_forecast.utils.optuna import (
    plot_marginals_1d, plot_param_importances, display_marginals_1d
)
import optuna.importance as imp

View fANOVA importances:

In [None]:
# === Do not edit below ===
# get importances
imps = imp.get_param_importances(study)
imps = pd.Series(imps, dtype=float).sort_values(ascending=False)

# plot
if imps.any():
    fig, ax = plot_param_importances(imps)

Compute and plot 1D marginal distributions. Many of the following functions also accept arguments such as `non_params_to_allow`, which lets you include selected trial user attributes in the marginal computations (treating them as parameters), and `objective`, which allows you to replace the objective value with any other numeric user attribute.

In [None]:
plot_marginals_1d(
    df, val_name,
    bins_numeric=7,
    non_params_to_allow=["user_attrs_n_params"],
    #objective="user_attrs_avg_near_best"
)

Plot only a subset of marginals (e.g. most important):

In [None]:
most_imp_params = imps[:1].index.tolist()

plot_marginals_1d(
    df, val_name,
    params=most_imp_params,
    bins_numeric=12,
    #non_params_to_allow=["user_attrs_n_params"]
)

Detailed summaries for each parameter:

In [None]:
top_k = 10          # Will show the fraction of trials for each parameter choice that belongs to the top_k trials
top_frac = 0.20     # Will show the fraction of trials for each parameter choice that belongs to the top_frac trials
params = None

tbls_sty = display_marginals_1d(
    df, val_name,
    #params=[],
    non_params_to_allow=["user_attrs_n_params"],
    #objective="user_attrs_avg_near_best",
    top_k=top_k,
    top_frac=top_frac,
    bins_numeric=7,
)

#### Study interactions between hyperparameters

In [None]:
from optuna.visualization import plot_parallel_coordinate
from heat_forecast.utils.optuna import marginal_2d

Show parallel coordinate plot (mostly userful when using continuous params):

In [None]:
top_frac = 0.2   # color only top_frac% of trials
params = None    # choose params to plot
# params = imps[:6].index.tolist()  # alternative: pick only the most important

# === Do not edit below ===
fig = plot_parallel_coordinate(study, params=params)

vals = df[val_name].dropna().to_list()
th = np.quantile(vals, top_frac)
fig = optuna.visualization.plot_parallel_coordinate(study, params=params)
fig.data[0].dimensions[0].constraintrange = [min(vals), th]
fig.update_coloraxes(cmin=min(vals), cmax=max(vals))  

html = fig.to_html(include_plotlyjs="inline", full_html=False)
display(HTML(html))

Below we can visualize pairwise relationships between parameters or user-defined attributes by creating 2D marginal plots.

In [None]:
from heat_forecast.utils.optuna import plot_marginals_2d

fig, pivots = plot_marginals_2d(
    df, val_name,
    #objective="user_attrs_avg_near_best",
    #params=["model.hidden_size", "model.num_layers"], #imps.index[:3].tolist(),
    #as_first="model.num_layers", #imps.index[0],
    statistic="median",
    binning="quantile",
    show_text=True,
    bins_a=7,
    bins_b=7,
    non_params_to_allow=["user_attrs_n_params"]
)
fig.show()


Display the tables plotted above, or choose a different statistic.

In [None]:
df.groupby(['params_model.hidden_size', 'params_model.num_layers'])['user_attrs_n_params'].unique()

In [None]:
statistic = "std"
n_max = 20 # max number of tables to display

# === Do not edit below ===
for key, pivs in list(pivots.items())[:n_max]:
    logging.info(f"2D marginal for {key}, statistic = '{statistic}':")
    piv = pivs.get(statistic)
    display(piv if piv is not None else f"(Not found)")

## Test

Final configurations

In [None]:
# --- Define final configurations for each series and horizon ---
base_final_cfgs = TFTRunConfig(
    model = TFTModelConfig(
        input_chunk_length=168+72,
        output_chunk_length=24,
        dropout=0.15,
        hidden_size=40,
        num_attention_heads=1,
        lstm_layers=1,
    ),
    data = DataConfig(
        stride=1,
        batch_size=64
    ),
    features = FeatureConfig(),
    train = TrainConfig(
        lr = 1e-3,
        n_epochs = 7,
        gradient_clip_val=10.0,
        loss_fn_str="L1"
    ),
    norm = NormalizeConfig(),
    seed = SEED,
)

def final_cfgs(unique_id: str, horizon_type: str) -> TFTRunConfig:
    # Define final configurations for different series
    if unique_id not in ('F1', 'F2', 'F3', 'F4', 'F5') or horizon_type not in ('day', 'week'):
        raise ValueError("Invalid unique_id or horizon_type.")
    cfg = copy.deepcopy(base_final_cfgs)
    if unique_id == 'F1':
        cfg.model.output_chunk_length = 168 if horizon_type == 'week' else 24
        cfg.model.dropout =            0.15 if horizon_type == 'week' else 0.05
        cfg.train.lr =                 3e-3 if horizon_type == 'week' else 2.5e-3
        cfg.train.n_epochs =              4 if horizon_type == 'week' else 6
        return cfg
    else:
        raise NotImplementedError("Final configuration not defined for this unique_id yet.")

Code for final testing with the tuned models:

In [None]:
do_test = True
grid = list(product(['F1'], ['day']))

# Set a run_id here for any grid element if you want to resume that run from the last existing checkpoint
run_ids_dict: Dict[Tuple[str, str], str] = {
    #('F1', 'day'): 
}

# === Do not edit below ===
if do_test:
    for id, horizon_type in tqdm(grid, desc="Test", leave=True):

        metadata = {}

        # --- Create directory for test results ---
        timestamp = datetime.now().strftime("%Y%m%dT%H%M%S")
        run_id_was_provided = bool(run_ids_dict.get((id, horizon_type), None))
        run_id = run_ids_dict.get((id, horizon_type), None) or f"{id}_{horizon_type}_test_tft_{timestamp}"
        path = BASE_DIR / "results" / "test" / "tft" / run_id
        checkpoint_path = path / "checkpoint"
        metadata['run_id'] = run_id

        try:
            if run_id_was_provided:
                if not path.exists():
                    raise FileNotFoundError(
                        f"Expected existing directory sfor provided run_id '{run_id}', but none was found at: {path}"
                    )
                logging.info(f"Retrieving directory for test results: {path.relative_to(BASE_DIR)}")
            else:
                path.mkdir(parents=True, exist_ok=False)
                logging.info(f"Created directory for test results: {path.relative_to(BASE_DIR)}")

            # --- Set params for cv ---
            out = get_cv_params_for_test(horizon_type)
            metadata['for_cv'] = {
                'step_size': out['step_size'],
                'test_hours': out['test_hours'],
                'end_test_cv': str(out['end_test_actual']),
                'n_windows': out['n_windows'],
                'refit': out['refit'],
                'n_fits': out['n_fits'],
            }

            # ------------- Run cross-validation with the optimal parameters -------------
            # create pipeline and generate futures
            heat_id_df = heat_df[heat_df['unique_id'] == id]
            aux_id_df = aux_df[aux_df['unique_id'] == id]
            config = final_cfgs(id, horizon_type)
            pipe = TFTPipeline(
                target_df=heat_id_df, 
                config=config, 
                aux_df=aux_id_df, 
            )
            metadata['model_config'] = config.to_dict()
            metadata['device'] = DEVICE.type

            # Run cv
            t0 = pd.Timestamp.now()
            cv_df = pipe.cross_validation(
                test_size=out['test_hours'],  # Test size in hours
                end_test=out['end_test_actual'],  # End of the test period
                step_size=out['step_size'],   # Step size in hours
                refit=out['refit'],  # Do not refit the model on each window
                verbose=True,
                checkpoint_path=checkpoint_path
            )
            t1 = pd.Timestamp.now()

            avg_elapsed = (t1 - t0).total_seconds() / out['n_fits']
            metadata['avg_el_per_fit'] = avg_elapsed

            cv_df.to_parquet(path / "cv_df.parquet", compression="snappy")

            metadata_path = path / 'metadata.yaml'
            with open(metadata_path, 'w') as f:
                safe_dump_yaml(
                    metadata,
                    f,
                    indent=4, 
                )

            logging.info(f"✓ Artifacts saved successfully for id={id}, horizon={horizon_type}.")

        except BaseException as e:
            if is_keyboard_interrupt_like(e):
                logging.warning("✗ Detected KeyboardInterrupt. Not cleaning to allow later resumption.")
            else:
                logging.exception("✗ Error during test for id=%s, horizon=%s. Not cleaning to allow later resumption.", id, horizon_type)
            raise

    logging.info(f"✓ Test completed.")

In [None]:
run_id = "F1_day_cpu_times_tft_20251119T152713"
n_warmup = 2
n = 8

# === Do not edit below ===
results_dir = BASE_DIR / "results" / "times" / "tft" / run_id
times_path = results_dir / "times.pkl"
import pickle
import numpy as np
with open(times_path, "rb") as f:
    data = pickle.load(f)
training_times = np.array(data.get("train", [])[:])
inference_times = np.array(data.get("inference", [])) * 1000  # Convert to milliseconds
if len(training_times) < n+n_warmup or len(inference_times) < n+n_warmup:
    raise ValueError(f"Not enough runs recorded for stats: "
                     f"training_times={len(training_times)}, inference_times={len(inference_times)}, required={n+n_warmup}")
logging.info(f"Using n_warmup={n_warmup} to skip initial runs for stats.")
logging.info(f"Training times (s): \n"
             f"n={n}, \tmean={training_times[n_warmup:n_warmup+n].mean():.3f}, \tstd={training_times[n_warmup:n_warmup+n].std():.3f}")
logging.info(f"Inference times (ms): \n"
             f"n={n}, \tmean={inference_times[n_warmup:n_warmup+n].mean():.3f}, \tstd={inference_times[n_warmup:n_warmup+n].std():.3f}")