# TFT model

## Preliminaries

You can run the notebook in two ways:

1. **Google Colab**: place the project folder `heat-forecast` in **MyDrive**. The setup cell below will mount Drive and automatically add `MyDrive/heat-forecast/src` to `sys.path` so `import heat_forecast` works out of the box.

2. **Local machine**:

   * **Installing our package:** from the project root, run `pip install -e .` once (editable install). Then you can open the notebook anywhere and import the package normally.
   * **Alternative:** if you’re running the notebook from `.../heat-forecast/notebooks/` without installing the package, the setup cell will detect `../src` and automatically add it to `sys.path`.

In [1]:
# --- Detect if running on Google Colab & Set base dir ---
# %cd /home/giovanni.lombardi/heat-forecast/notebooks
import subprocess
from pathlib import Path
import sys

def in_colab() -> bool:
    try:
        import google.colab  # type: ignore
        return True
    except Exception:
        return False

# Install required packages only if not already installed
def pip_install(pkg: str):
    subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", pkg])

# Set base directory and handle environment
if in_colab():
    # Make sure IPython is modern (avoids the old %autoreload/imp issue if you ever use it)
    pip_install("ipython>=8.25")
    pip_install("ipykernel>=6.29")
    
    def install(package):
        subprocess.check_call([sys.executable, "-m", "pip", "install", package])

    for pkg in ["statsmodels", "statsforecast", "mlforecast"]:
        pip_install(pkg)

    # Mount Google Drive
    from google.colab import drive  # type: ignore
    drive.mount('/content/drive')

    # Set base directory to your Drive project folder
    BASE_DIR = Path('/content/drive/MyDrive/heat-forecast')

    # Add `src/` to sys.path for custom package imports
    SRC_PATH = BASE_DIR / 'src'
    if str(SRC_PATH) not in sys.path:
        sys.path.append(str(SRC_PATH))

    # Sanity checks (helpful error messages if path is wrong)
    assert SRC_PATH.exists(), f"Expected '{SRC_PATH}' to exist. Fix BASE_DIR."
    pkg_dir = SRC_PATH / "heat_forecast"
    assert pkg_dir.exists(), f"Expected '{pkg_dir}' package directory."
    init_file = pkg_dir / "__init__.py"
    assert init_file.exists(), f"Missing '{init_file}'. Add it so Python treats this as a package."

else:
    # Local: either rely on editable install (pip install -e .) or add src/ when running from repo
    # Assume notebook lives in PROJECT_ROOT/notebooks/
    BASE_DIR = Path.cwd().resolve().parent
    SRC_PATH = BASE_DIR / "src"

    added_src = False
    if (SRC_PATH / "heat_forecast").exists() and str(SRC_PATH) not in sys.path:
        sys.path.append(str(SRC_PATH))
        added_src = True

# --- Logging setup ---
import logging
from zoneinfo import ZoneInfo
from datetime import datetime

LOG_DIR  = (BASE_DIR / "logs")
LOG_DIR.mkdir(parents=True, exist_ok=True)
LOG_FILE = LOG_DIR / "run.log"
PREV_LOG = LOG_DIR / "run.prev.log"

# If there's a previous run.log with content, archive it to run.prev.log
if LOG_FILE.exists() and LOG_FILE.stat().st_size > 0:
    try:
        # Replace old run.prev.log if present
        if PREV_LOG.exists():
            PREV_LOG.unlink()
        LOG_FILE.rename(PREV_LOG)
    except Exception as e:
        # Fall back to truncating if rename fails (e.g., file locked)
        print(f"[warn] Could not archive previous log: {e}. Truncating current run.log.")
        LOG_FILE.write_text("")

# Configure logging: fresh file for this run + echo to notebook/stdout
file_handler   = logging.FileHandler(LOG_FILE, mode="w", encoding="utf-8")
stream_handler = logging.StreamHandler(sys.stdout)

fmt = logging.Formatter("%(asctime)s | %(levelname)s | %(name)s | %(message)s",
                        datefmt="%m-%d %H:%M:%S")
file_handler.setFormatter(fmt)
stream_handler.setFormatter(fmt)

root = logging.getLogger()
root.handlers[:] = [file_handler, stream_handler]  # replace handlers (important in notebooks)
root.setLevel(logging.INFO)

# Use Rome time
logging.Formatter.converter = lambda *args: datetime.now(ZoneInfo("Europe/Rome")).timetuple()

logging.captureWarnings(True)
logging.info("=== Logging started (fresh current run) ===")
logging.info("Previous run (if any): %s", PREV_LOG if PREV_LOG.exists() else "none")

if added_src:
    logging.info("heat_forecast not installed; added src/ to sys.path")
else:
    logging.info("heat_forecast imported without modifying sys.path (likely installed)")

OPTUNA_DIR = BASE_DIR / "results" / "finetuning" / "lstm"
OPTUNA_DIR.mkdir(parents=True, exist_ok=True)
logging.info("BASE_DIR (make sure it's '*/heat-forecast/', else cd and re-run): %s", BASE_DIR)
logging.info("LOG_DIR: %s", LOG_DIR)
logging.info("OPTUNA_DIR: %s", OPTUNA_DIR)

10-21 16:26:10 | INFO | root | === Logging started (fresh current run) ===
10-21 16:26:10 | INFO | root | Previous run (if any): C:\Users\giolo\OneDrive\DocumentiOneDrive\Phython scripts\Tirocinio Optit\heat-forecast\logs\run.prev.log
10-21 16:26:10 | INFO | root | heat_forecast imported without modifying sys.path (likely installed)
10-21 16:26:10 | INFO | root | BASE_DIR (make sure it's '*/heat-forecast/', else cd and re-run): C:\Users\giolo\OneDrive\DocumentiOneDrive\Phython scripts\Tirocinio Optit\heat-forecast
10-21 16:26:10 | INFO | root | LOG_DIR: C:\Users\giolo\OneDrive\DocumentiOneDrive\Phython scripts\Tirocinio Optit\heat-forecast\logs
10-21 16:26:10 | INFO | root | OPTUNA_DIR: C:\Users\giolo\OneDrive\DocumentiOneDrive\Phython scripts\Tirocinio Optit\heat-forecast\results\finetuning\lstm


Ensure [compatibility with Numba](https://numba.readthedocs.io/en/stable/user/installing.html#numba-support-info).

In [2]:
import sys, numpy, numba
logging.info("=== Current Environment ===")
logging.info("Python : %s", sys.version.split()[0])
logging.info("NumPy  : %s", numpy.__version__)
logging.info("Numba  : %s", numba.__version__)

10-21 16:26:11 | INFO | root | === Current Environment ===
10-21 16:26:11 | INFO | root | Python : 3.12.9
10-21 16:26:11 | INFO | root | NumPy  : 2.2.6
10-21 16:26:11 | INFO | root | Numba  : 0.61.2


Imports:

In [3]:
# --- Magic Commands ---
%load_ext autoreload
%autoreload 2

# --- Standard Library ---
import os
os.environ["OPTUNA_LOGGING_DISABLE_DEFAULT_HANDLER"] = "1" # prevent Optuna from attaching its handler
import stat
import logging
from datetime import datetime
from itertools import product
import torch
import optuna
import copy

# --- Third-Party Libraries ---
import numpy as np
import pandas as pd
pd.set_option('display.float_format', '{:.3f}'.format)

import yaml
import matplotlib as mpl
import matplotlib.pyplot as plt
from matplotlib.ticker import MaxNLocator
from tqdm.notebook import tqdm
from IPython.display import display, HTML

# --- Plotting Configuration ---
plt.style.use("seaborn-v0_8")
plt.rcParams['font.size'] = 14
plt.rcParams['axes.titlesize'] = 16
plt.rcParams['axes.labelsize'] = 14
plt.rcParams['xtick.labelsize'] = 10
plt.rcParams['ytick.labelsize'] = 10
plt.rcParams['legend.fontsize'] = 10
plt.rcParams['figure.titlesize'] = 18
mpl.rcParams['axes.grid'] = True
mpl.rcParams['axes.grid.which'] = 'both'

# --- YAML Customization ---
from heat_forecast.utils.yaml import safe_dump_yaml

# --- Safe File Deletion Helper ---
from heat_forecast.utils.fileshandling import remove_tree

# --- Project-Specific Imports ---
from heat_forecast.utils.cv_utils import get_cv_params_for_test
from heat_forecast.pipeline.tft import (
    TFTModelConfig, DataConfig, FeatureConfig, NormalizeConfig, TFTRunConfig, TFTPipeline
)

from heat_forecast.utils.optuna import (
    OptunaStudyConfig, run_study, continue_study, describe_suggester, rename_study, clone_filtered_study
)

logging.info("All imports successful.")

10-21 16:26:16 | INFO | numexpr.utils | NumExpr defaulting to 12 threads.
10-21 16:26:23 | INFO | root | All imports successful.


Import pre-elaborated data.

In [4]:
heat_path = BASE_DIR / 'data' / 'timeseries_preprocessed' / 'heat.csv'
aux_path = BASE_DIR / 'data' / 'timeseries_preprocessed' / 'auxiliary.csv'
heat_df = pd.read_csv(heat_path, parse_dates=['ds'])
aux_df = pd.read_csv(aux_path, parse_dates=['ds'])
logging.info("Loaded heat data: %s", heat_path.relative_to(BASE_DIR))
logging.info("Loaded auxiliary data: %s", aux_path.relative_to(BASE_DIR))

10-21 16:26:24 | INFO | root | Loaded heat data: data\timeseries_preprocessed\heat.csv
10-21 16:26:24 | INFO | root | Loaded auxiliary data: data\timeseries_preprocessed\auxiliary.csv


Set device (the LSTM pipeline automatically uses "cuda" if available):

In [5]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logging.info(f"Using device: {DEVICE}")

10-21 16:26:24 | INFO | root | Using device: cpu


Set a seed (set to None to skip):

In [6]:
SEED = 42
logging.info(f"Using seed: {SEED}")

10-17 17:46:12 | INFO | root | Using seed: 42


## Pipeline

In [7]:
id = 'F1'
config = TFTRunConfig(
    model=TFTModelConfig(),
    data=DataConfig(),
    features=FeatureConfig(),
    norm=NormalizeConfig(),
    seed=SEED
)
tft_pipe = TFTPipeline(
    config=config, 
    target_df=heat_df[heat_df['unique_id'] == id], 
    aux_df=aux_df[aux_df['unique_id'] == id]
)

10-17 17:46:16 | INFO | heat_forecast.pipeline.tft | [pipe init] setting random seed: 42.


In [8]:
tft_pipe.generate_vars()

10-17 17:46:18 | INFO | heat_forecast.pipeline.tft | [gvars] features ready: past_covs=0 (endog=0, climate=0) | future_covs=9 (endog=0, climate=1, time=8)


In [12]:
tft_pipe.fit(end_train=pd.Timestamp('2023-11-01'))

TFTModel(hidden_size=64, input_chunk_length=168, output_chunk_length=24, batch_size=64, n_epochs=20, pl_trainer_kwargs={'accelerator': 'auto'})
10-17 17:47:22 | INFO | darts.models.forecasting.torch_forecasting_model | Train dataset contains 37826 samples.
10-17 17:47:22 | INFO | darts.models.forecasting.torch_forecasting_model | Time series values are 64-bits; casting model to float64.
10-17 17:47:22 | INFO | pytorch_lightning.utilities.rank_zero | GPU available: False, used: False
10-17 17:47:22 | INFO | pytorch_lightning.utilities.rank_zero | TPU available: False, using: 0 TPU cores
10-17 17:47:22 | INFO | pytorch_lightning.utilities.rank_zero | HPU available: False, using: 0 HPUs
10-17 17:47:22 | INFO | pytorch_lightning.callbacks.model_summary | 
   | Name                              | Type                             | Params | Mode 
------------------------------------------------------------------------------------------------
0  | train_metrics                     | MetricCol

Training: |          | 0/? [00:00<?, ?it/s]

10-17 18:01:15 | INFO | pytorch_lightning.utilities.rank_zero | 
Detected KeyboardInterrupt, attempting graceful shutdown ...


NameError: name 'exit' is not defined