In [1]:
import os


"""Configuration settings for TFT model."""
WORKING_DIR = "/kaggle/input/demandforecasting/demandForecasting/"
# WORKING_DIR = ""
TFT_CHECKPOINTS_DIR = os.path.join("TFT", "checkpoints")
RAW = f"{WORKING_DIR}/data_raw/"
TFT_DATA_DIR = f"{WORKING_DIR}TFT/data"

# encoder features
ENC_VARS = [
    "sales",
    "transactions",
    "dcoilwtico",
    "onpromotion",
    "dow",
    "month",
    "weekofyear",
    "is_holiday",
    "is_workday"
]
# known future features
DEC_VARS = [
    "onpromotion",
    "dow",
    "month",
    "weekofyear",
    "is_holiday",
    "is_workday"
]
# static features
STATIC_COLS = [
    "store_nbr",
    "family",
    "state",
    "cluster"
    ]

REALS_TO_SCALE = [
    "transactions",
    "dcoilwtico"
    ]


In [2]:
!nvidia-smi

Fri Jan  2 16:29:51 2026       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 570.172.08             Driver Version: 570.172.08     CUDA Version: 12.8     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  Tesla P100-PCIE-16GB           Off |   00000000:00:04.0 Off |                    0 |
| N/A   34C    P0             25W /  250W |       0MiB /  16384MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                     

In [3]:
import torch
print(f"PyTorch CUDA available: {torch.cuda.is_available()}")
import tensorflow as tf
print(f"TensorFlow GPU list: {tf.config.list_physical_devices('GPU')}")


PyTorch CUDA available: True


2026-01-02 16:29:57.579088: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1767371397.769660      25 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1767371397.824926      25 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1767371398.279547      25 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1767371398.279608      25 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1767371398.279611      25 computation_placer.cc:177] computation placer alr

TensorFlow GPU list: [PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]


In [4]:
import sys
sys.path.insert(1, '/kaggle/input/demandforecasting/demandForecasting')

In [5]:
torch.device("cuda" if torch.cuda.is_available() else "cpu")

device(type='cuda')

In [6]:
import argparse
import os
import numpy as np
import pandas as pd
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
from sklearn.preprocessing import StandardScaler
from TFT.architecture.tft import TemporalFusionTransformer, QuantileLoss
from TFT.tft_dataset import TFTWindowDataset, tft_collate
from utils.utils import set_seed, build_onehot_maps
from utils.utils import compute_metrics
from utils.utils import get_date_splits


def save_results_csv(rows):
    if rows:
        test_forecasts_df = (
            pd.DataFrame(rows)
            .sort_values(["family", "store_nbr", "date"])
        )
        out_csv = os.path.join("tft_test_forecasts.csv")
        test_forecasts_df.to_csv(out_csv, index=False)
        print(f"Saved test forecasts CSV -> {out_csv}")


def get_data_split(dec_len, enc_len, batch_size, stride):
    # Load data
    panel_path = os.path.join(TFT_DATA_DIR, "panel.csv")
    assert os.path.exists(panel_path), (
        "Run data preprocessing first: "
        "python src/data/preprocess_favorita.py"
    )
    df = pd.read_csv(panel_path, parse_dates=["date"])

    # Scale continuous features (fit on train period only)
    train_end, val_end, test_end = get_date_splits(df, dec_len)

    scaler = StandardScaler()
    train_mask = df["date"] <= train_end
    df.loc[train_mask, REALS_TO_SCALE] = scaler.fit_transform(
        df.loc[train_mask, REALS_TO_SCALE]
    )
    df.loc[~train_mask, REALS_TO_SCALE] = scaler.transform(
        df.loc[~train_mask, REALS_TO_SCALE]
    )

    # One-hot maps for static features
    static_maps = build_onehot_maps(df, STATIC_COLS)
    static_dims = [len(static_maps[c]) for c in STATIC_COLS]

    # Dataset and loaders
    split_bounds = (train_end, val_end, test_end)
    train_ds = TFTWindowDataset(
        df, enc_len, dec_len, ENC_VARS, DEC_VARS, STATIC_COLS,
        split_bounds, split="train", stride=stride,
        static_onehot_maps=static_maps,
    )
    val_ds = TFTWindowDataset(
        df, enc_len, dec_len, ENC_VARS, DEC_VARS, STATIC_COLS,
        split_bounds, split="val", stride=stride,
        static_onehot_maps=static_maps,
    )
    test_ds = TFTWindowDataset(
        df, enc_len, dec_len, ENC_VARS, DEC_VARS, STATIC_COLS,
        split_bounds, split="test", stride=stride,
        static_onehot_maps=static_maps,
    )

    train_loader = DataLoader(
        train_ds, batch_size=batch_size, shuffle=True,
        num_workers=4, pin_memory=True, collate_fn=tft_collate,
    )
    val_loader = DataLoader(
        val_ds, batch_size=batch_size, shuffle=False,
        num_workers=4, pin_memory=True, collate_fn=tft_collate,
    )
    test_loader = DataLoader(
        test_ds, batch_size=batch_size, shuffle=False,
        num_workers=4, pin_memory=True, collate_fn=tft_collate,
    )
    print(
        f"Train samples: {len(train_ds)} | "
        f"Val: {len(val_ds)} | Test: {len(test_ds)}"
    )
    return (train_loader, val_loader, test_loader,
            static_dims,
            len(train_ds), len(val_ds), len(test_ds)
            )


def train_model(model, quantiles, args, train_loader, val_loader,
                train_len, val_len
                ):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    criterion = QuantileLoss(quantiles=quantiles)
    model = model.to(device)
    optimizer = torch.optim.Adam(
        model.parameters(), lr=args.lr, weight_decay=1e-5
    )
    best_val = float("inf")
    best_path = os.path.join("tft_best.pt")
    median_idx = int(np.argmin([abs(q - 0.5) for q in quantiles]))
    # Early stopping state
    patience = int(getattr(args, "early_stopping_patience", 7))
    min_delta = float(getattr(args, "early_stopping_min_delta", 0.0))
    no_improve_epochs = 0

    # Training loop
    for epoch in range(1, args.epochs + 1):
        model.train()
        train_loss = 0.0
        train_ys, train_preds = [], []
        pbar = tqdm(train_loader, desc=f"Epoch {epoch}/{args.epochs} [train]")
        for batch in pbar:
            optimizer.zero_grad()

            past = batch["past_inputs"].to(device)     # [B, L_enc, E]
            future = batch["future_inputs"].to(device)  # [B, L_dec, D]
            static = batch["static_inputs"].to(device)  # [B, S]
            y = batch["target"].to(device)             # [B, L_dec]

            out = model(past, future, static)
            loss = criterion(out["prediction"].to(device), y)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            train_loss += loss.item() * past.size(0)
            pbar.set_postfix({"loss": f"{loss.item():.4f}"})
            yhat = out["prediction"][..., median_idx]
            train_ys.append(y.detach().cpu().numpy())
            train_preds.append(yhat.detach().cpu().numpy())

        train_loss /= max(train_len, 1)

        metrics_train = compute_metrics(train_ys, train_preds)
        
        print(f"Epoch {epoch} Train Loss: {train_loss:.6f}, \
              Train Metrics: {metrics_train}"
              )

        # Validation
        model.eval()
        val_loss = 0.0
        valid_ys, valid_preds = [], []
        with torch.no_grad():
            for batch in val_loader:
                past = batch["past_inputs"].to(device)
                future = batch["future_inputs"].to(device)
                static = batch["static_inputs"].to(device)
                y = batch["target"].to(device)
                out = model(past, future, static)
                loss = criterion(out["prediction"].to(device), y)
                val_loss += loss.item() * past.size(0)

                yhat = out["prediction"][..., median_idx]
                valid_ys.append(y.detach().cpu().numpy())
                valid_preds.append(yhat.detach().cpu().numpy())

        val_loss /= max(val_len, 1)
        
        metrics_val = compute_metrics(valid_ys, valid_preds)
        print(f"Epoch {epoch} Validation Loss: {val_loss:.6f}, \
              Validation Metrics: {metrics_val}"
              )
        

        improved = (best_val - val_loss) > min_delta
        if improved:
            best_val = val_loss
            no_improve_epochs = 0
            torch.save(
                {
                    "model_state": model.state_dict(),
                    "cfg": vars(args),
                    "quantiles": quantiles,
                },
                best_path,
            )
            print(
                f"Saved TFT model to {best_path} at {epoch} \
                    epochs (val_loss={val_loss:.6f})"
            )
        else:
            no_improve_epochs += 1
            if no_improve_epochs >= patience:
                print(
                    f"Early stopping at epoch {epoch} (patience={patience},\
                    min_delta={min_delta})"
                )
                break
    
    return model


# Evaluation on test set
def eval_loader(model, data_loader, quantiles, test_len):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    median_idx = int(np.argmin([abs(q - 0.5) for q in quantiles]))
    # Load saved TFT model and evaluate on test set
    best_path = os.path.join("tft_best.pt")
    if os.path.exists(best_path):
        ckpt = torch.load(best_path, map_location=device)
        model.load_state_dict(ckpt["model_state"])
        print(f"Loaded stored TFT model for evaluation {best_path}")
    model.eval()

    criterion = QuantileLoss(quantiles=quantiles)
    rows = []
    total_loss = 0.0
    test_ys, test_preds = [], []
    with torch.no_grad():
        for batch in data_loader:
            past = batch["past_inputs"].to(device)
            future = batch["future_inputs"].to(device)
            static = batch["static_inputs"].to(device)
            y = batch["target"].to(device)

            out = model(past, future, static)
            preds_med = out["prediction"][..., median_idx]  # [B, L_dec]
            preds = preds_med.cpu().numpy()
            loss = criterion(out["prediction"].to(device), y)
            total_loss += loss.item() * past.size(0)
            yhat = out["prediction"][..., median_idx]
            test_ys.append(y.detach().cpu().numpy())
            test_preds.append(yhat.detach().cpu().numpy())

            metas = batch.get("meta", [])
            for i, meta in enumerate(metas):
                store_nbr = meta["store_nbr"]
                family = meta["family"]
                fut_dates = meta["future_dates"]
                targets = batch["target"].cpu().numpy()   # [B, L_dec]
                for d_idx, date in enumerate(fut_dates):
                    rows.append({
                        "date": pd.to_datetime(date),
                        "store_nbr": store_nbr,
                        "family": family,
                        "y_true": float(targets[i, d_idx]),
                        "y_pred": float(preds[i, d_idx]),
                    })
                # Append encoder history (past sales) before
                #   forecast horizon
                # Use the 'sales' feature from encoder inputs
                sales_idx = ENC_VARS.index("sales")
                past_dates = meta["past_dates"]
                for d_idx, date in enumerate(past_dates):
                    rows.append(
                        {
                            "date": pd.to_datetime(date),
                            "store_nbr": store_nbr,
                            "family": family,
                            "y_past": float(
                                past[i, d_idx, sales_idx].cpu()
                                ),
                        }
                    )
    save_results_csv(rows)
    total_loss /= max(test_len, 1)
    test_metrics = compute_metrics(test_ys, test_preds)
    print(f"Test Loss: {total_loss:.6f}. Test Metrics: {test_metrics}")
    return test_metrics


def main_train():
    parser = argparse.ArgumentParser()
    parser.add_argument("--enc-len", type=int, default=56)
    parser.add_argument("--dec-len", type=int, default=28)
    parser.add_argument("--batch-size", type=int, default=256)
    parser.add_argument("--epochs", type=int, default=100)
    parser.add_argument("--lr", type=float, default=1e-3)
    parser.add_argument("--hidden-dim", type=int, default=64)
    parser.add_argument("--d-model", type=int, default=32)
    parser.add_argument("--heads", type=int, default=2)
    parser.add_argument("--lstm-hidden", type=int, default=32)
    parser.add_argument("--lstm-layers", type=int, default=1)
    parser.add_argument("--dropout", type=float, default=0.1)
    parser.add_argument("--quantiles", type=str, default="0.1,0.5,0.9")
    parser.add_argument("--stride", type=int, default=1)
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--early-stopping-patience", type=int, default=5,
                        help="Stop if no val loss improvement for N epochs")
    parser.add_argument("--early-stopping-min-delta", type=float, default=0.5,
                        help="Minimum val loss improvement to reset patience")
    
    parser.add_argument("--train-flag", type=bool, default=True)
    # args = parser.parse_args()
    args, unknown = parser.parse_known_args()

    set_seed(args.seed)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    os.makedirs(TFT_CHECKPOINTS_DIR, exist_ok=True)

    (
        train_loader,
        val_loader,
        test_loader,
        static_dims,
        train_len,
        val_len,
        test_len,
    ) = get_data_split(
        args.dec_len,
        args.enc_len,
        args.batch_size,
        args.stride
        )
    
    # Model
    past_input_dims = [1] * len(ENC_VARS)
    future_input_dims = [1] * len(DEC_VARS)
    static_input_dims = static_dims
    quantiles = [float(x) for x in args.quantiles.split(",")]

    model = TemporalFusionTransformer(
        static_input_dims=static_input_dims,
        past_input_dims=past_input_dims,
        future_input_dims=future_input_dims,
        d_model=args.d_model,
        hidden_dim=args.hidden_dim,
        n_heads=args.heads,
        lstm_hidden_size=args.lstm_hidden,
        lstm_layers=args.lstm_layers,
        dropout=args.dropout,
        num_quantiles=len(quantiles),
    ).to(device)

    # train
    if args.train_flag:
        train_model(model, quantiles, args,
                    train_loader, val_loader,
                    train_len, val_len
                    )

    test_metrics = eval_loader(model, test_loader, quantiles, test_len)
    print(f"Test matrics: {test_metrics}")



main_train()

Train samples: 2760318 | Val: 49896 | Test: 48114


Epoch 1/100 [train]: 100%|██████████| 10783/10783 [35:13<00:00,  5.10it/s, loss=21.4858]


Epoch 1 Train Loss: 103.893941,               Train Metrics: {'mae': 216.47283935546875, 'wape': 0.6030859351158142, 'smape': 1.0205391645431519}
Epoch 1 Validation Loss: 76.832314,               Validation Metrics: {'mae': 170.27674865722656, 'wape': 0.35288771986961365, 'smape': 0.5617756843566895}
Saved TFT model to tft_best.pt at 1                     epochs (val_loss=76.832314)


Epoch 2/100 [train]: 100%|██████████| 10783/10783 [34:40<00:00,  5.18it/s, loss=41.6034]


Epoch 2 Train Loss: 36.059187,               Train Metrics: {'mae': 90.56014251708984, 'wape': 0.2522982656955719, 'smape': 0.9193856120109558}
Epoch 2 Validation Loss: 33.964924,               Validation Metrics: {'mae': 90.49281311035156, 'wape': 0.18754059076309204, 'smape': 0.6436980962753296}
Saved TFT model to tft_best.pt at 2                     epochs (val_loss=33.964924)


Epoch 3/100 [train]: 100%|██████████| 10783/10783 [34:43<00:00,  5.18it/s, loss=17.5162]


Epoch 3 Train Loss: 24.003759,               Train Metrics: {'mae': 68.6312255859375, 'wape': 0.19120442867279053, 'smape': 0.9168866872787476}
Epoch 3 Validation Loss: 27.722538,               Validation Metrics: {'mae': 78.5736083984375, 'wape': 0.16283880174160004, 'smape': 0.5497755408287048}
Saved TFT model to tft_best.pt at 3                     epochs (val_loss=27.722538)


Epoch 4/100 [train]: 100%|██████████| 10783/10783 [35:14<00:00,  5.10it/s, loss=16.9944]


Epoch 4 Train Loss: 21.883278,               Train Metrics: {'mae': 63.77964401245117, 'wape': 0.1776883453130722, 'smape': 0.9105744361877441}
Epoch 4 Validation Loss: 26.441577,               Validation Metrics: {'mae': 75.13134765625, 'wape': 0.15570494532585144, 'smape': 0.6119471192359924}
Saved TFT model to tft_best.pt at 4                     epochs (val_loss=26.441577)


Epoch 5/100 [train]: 100%|██████████| 10783/10783 [35:26<00:00,  5.07it/s, loss=22.5254]


Epoch 5 Train Loss: 21.137168,               Train Metrics: {'mae': 61.76539611816406, 'wape': 0.17207714915275574, 'smape': 0.907207190990448}
Epoch 5 Validation Loss: 26.081718,               Validation Metrics: {'mae': 74.63662719726562, 'wape': 0.15467967092990875, 'smape': 0.59743332862854}


Epoch 6/100 [train]: 100%|██████████| 10783/10783 [35:10<00:00,  5.11it/s, loss=21.0937]


Epoch 6 Train Loss: 20.606297,               Train Metrics: {'mae': 60.315086364746094, 'wape': 0.16803622245788574, 'smape': 0.90375816822052}
Epoch 6 Validation Loss: 26.022258,               Validation Metrics: {'mae': 74.93453979492188, 'wape': 0.15529707074165344, 'smape': 0.5397894382476807}


Epoch 7/100 [train]: 100%|██████████| 10783/10783 [35:00<00:00,  5.13it/s, loss=20.9993]


Epoch 7 Train Loss: 20.274323,               Train Metrics: {'mae': 59.36318588256836, 'wape': 0.16538487374782562, 'smape': 0.9015188217163086}
Epoch 7 Validation Loss: 25.230773,               Validation Metrics: {'mae': 72.15264129638672, 'wape': 0.149531751871109, 'smape': 0.6680188179016113}
Saved TFT model to tft_best.pt at 7                     epochs (val_loss=25.230773)


Epoch 8/100 [train]: 100%|██████████| 10783/10783 [35:40<00:00,  5.04it/s, loss=10.8687]


Epoch 8 Train Loss: 19.929114,               Train Metrics: {'mae': 58.39336395263672, 'wape': 0.16268233954906464, 'smape': 0.8996798396110535}
Epoch 8 Validation Loss: 25.597094,               Validation Metrics: {'mae': 73.80294799804688, 'wape': 0.1529519110918045, 'smape': 0.5599561929702759}


Epoch 9/100 [train]: 100%|██████████| 10783/10783 [35:53<00:00,  5.01it/s, loss=21.5622]


Epoch 9 Train Loss: 19.695725,               Train Metrics: {'mae': 57.7153434753418, 'wape': 0.1607932299375534, 'smape': 0.8979678750038147}
Epoch 9 Validation Loss: 24.584762,               Validation Metrics: {'mae': 70.2391586303711, 'wape': 0.1455661952495575, 'smape': 0.6113264560699463}
Saved TFT model to tft_best.pt at 9                     epochs (val_loss=24.584762)


Epoch 10/100 [train]: 100%|██████████| 10783/10783 [34:59<00:00,  5.14it/s, loss=20.4659]


Epoch 10 Train Loss: 19.457836,               Train Metrics: {'mae': 57.02552032470703, 'wape': 0.15887156128883362, 'smape': 0.8971431851387024}
Epoch 10 Validation Loss: 24.907079,               Validation Metrics: {'mae': 71.48946380615234, 'wape': 0.14815737307071686, 'smape': 0.6432635188102722}


Epoch 11/100 [train]: 100%|██████████| 10783/10783 [35:20<00:00,  5.08it/s, loss=25.1889]


Epoch 11 Train Loss: 19.261267,               Train Metrics: {'mae': 56.45172882080078, 'wape': 0.1572730988264084, 'smape': 0.8967845439910889}
Epoch 11 Validation Loss: 24.661042,               Validation Metrics: {'mae': 70.7729721069336, 'wape': 0.14667248725891113, 'smape': 0.5381738543510437}


Epoch 12/100 [train]: 100%|██████████| 10783/10783 [35:03<00:00,  5.13it/s, loss=23.7477]


Epoch 12 Train Loss: 19.046882,               Train Metrics: {'mae': 55.85866928100586, 'wape': 0.1556205302476883, 'smape': 0.8956518173217773}
Epoch 12 Validation Loss: 24.750305,               Validation Metrics: {'mae': 71.186279296875, 'wape': 0.14752903580665588, 'smape': 0.6136394143104553}


Epoch 13/100 [train]: 100%|██████████| 10783/10783 [35:04<00:00,  5.12it/s, loss=29.1722]


Epoch 13 Train Loss: 18.861352,               Train Metrics: {'mae': 55.32265853881836, 'wape': 0.15412776172161102, 'smape': 0.8956164717674255}
Epoch 13 Validation Loss: 24.520793,               Validation Metrics: {'mae': 70.5121078491211, 'wape': 0.14613185822963715, 'smape': 0.6042813658714294}


Epoch 14/100 [train]: 100%|██████████| 10783/10783 [35:30<00:00,  5.06it/s, loss=14.1599]


Epoch 14 Train Loss: 18.732336,               Train Metrics: {'mae': 54.920928955078125, 'wape': 0.153008371591568, 'smape': 0.8948394060134888}
Epoch 14 Validation Loss: 24.620863,               Validation Metrics: {'mae': 71.01564025878906, 'wape': 0.1471754014492035, 'smape': 0.6221717596054077}
Early stopping at epoch 14 (patience=5,                    min_delta=0.5)
Loaded stored TFT model for evaluation tft_best.pt
Saved test forecasts CSV -> tft_test_forecasts.csv
Test Loss: 24.662834. Test Metrics: {'mae': 70.62957763671875, 'wape': 0.14813287556171417, 'smape': 0.6175957918167114}
Test matrics: {'mae': 70.62957763671875, 'wape': 0.14813287556171417, 'smape': 0.6175957918167114}


In [7]:
# !pip uninstall optuna -y
# !pip install --upgrade optuna

In [8]:
# import os
# import sys
# import subprocess
# import argparse
# import re
# import json
# import optuna
# optuna.logging.set_verbosity(optuna.logging.INFO)


# PROJECT_ROOT = "/kaggle/input/demandforecasting/demandForecasting"
# # os.path.abspath(os.path.join(os.getcwd(), ".."))
# # sys.path.insert(1, '/kaggle/input/demandforecasting/demandForecasting')

# def run_train_cli(hparams, fixed_epochs):
#     """
#     Launch TFT/train_tft.py as a module with mapped CLI args.
#     Parse WAPE from stdout; prefer Validation line, fallback to Test metrics.
#     """
#     cmd = [
#         sys.executable, "-m", "TFT.train_tft",
#         "--enc-len", str(56),
#         "--dec-len", str(28),
#         "--batch-size", str(512),
#         "--epochs", str(fixed_epochs),
#         "--lr", str(1e-3),
#         "--hidden-dim", str(hparams["hidden_dim"]),
#         "--d-model", str(hparams["d_model"]),
#         "--heads", str(hparams["heads"]),
#         "--dropout", str(0.1),
#         "--stride", "1",
#         "--seed", "42",
#         "--lstm-hidden", str(hparams["lstm_hidden"]),
#         "--lstm-layers", str(2),
#     ]
#     env = os.environ.copy()
#     env["PYTHONPATH"] = PROJECT_ROOT + (os.pathsep + env.get("PYTHONPATH", ""))
#     proc = subprocess.run(
#         cmd, cwd=PROJECT_ROOT, env=env, capture_output=True, text=True
#         )
#     if proc.returncode != 0:
#         raise RuntimeError(
#             f"Training failed:\nSTDERR:\n{proc.stderr}\nSTDOUT:\n{proc.stdout}"
#             )

#     wape = None
#     # Parse validation print (dict) or test metrics line
#     for line in proc.stdout.splitlines():
#         if "Validation" in line and "WAPE:" in line:
#             try:
#                 wape = float(line.split("WAPE:")[1].split("|")[0].strip())
#                 break
#             except Exception:
#                 pass
#         if "Test Metrics:" in line and "'wape':" in line:
#             m = re.search(r"'wape':\s*([0-9\.eE+-]+)", line)
#             if m:
#                 wape = float(m.group(1))
#     if wape is None:
#         # Last resort: scan any dict-looking line for wape
#         for line in proc.stdout.splitlines():
#             m = re.search(r"'wape':\s*([0-9\.eE+-]+)", line)
#             if m:
#                 wape = float(m.group(1))
#                 break
#     if wape is None:
#         raise RuntimeError(
#             "Could not parse WAPE from train output.\n" + proc.stdout
#             )
#     return wape


# def objective(trial: optuna.Trial):
#     # Simple search space (no epoch search)
#     hparams = {
#         "hidden_dim": trial.suggest_categorical("hidden_dim", [32, 64]),
#         "d_model": trial.suggest_categorical("d_model", [32, 128]),
#         "heads": trial.suggest_categorical("heads", [2, 8]),
#         "lstm_hidden": trial.suggest_categorical("lstm_hidden", [16, 32]),
#     }
#     fixed_epochs = 1
#     wape = run_train_cli(hparams, fixed_epochs=fixed_epochs)
#     return wape


# def main_tuning():
#     ap = argparse.ArgumentParser()
#     ap.add_argument("--trials", type=int, default=2)
#     ap.add_argument("--study_name", type=str, default="tft_tuning4")
#     ap.add_argument("--direction", type=str, default="minimize")
#     ap.add_argument("--storage", type=str, default=None,
#                     help="Optuna storage log file name")
#     # args = ap.parse_args()
#     args, unknown = ap.parse_known_args()

#     # Create study with optional storage for persistent tracking
#     # create a file if doesn't exist optuna_journal.log
#     if not os.path.exists("optuna_journal.log"):
#         open("optuna_journal.log", 'a').close()

#     study = optuna.create_study(
#         study_name=args.study_name,
#         direction=args.direction,
#         load_if_exists=False,
#     )

#     study.optimize(
#         objective,
#         n_trials=args.trials,
#         gc_after_trial=True,
#         n_jobs=5,
#         show_progress_bar=True
#     )

#     print("Best WAPE:", study.best_value)
#     print("Best params:", study.best_params)

#     out_json = os.path.join(
#         PROJECT_ROOT, "tft_best_params.json"
#         )
#     with open(out_json, "w") as f:
#         json.dump(
#             {"best_value": study.best_value, "best_params": study.best_params},
#             f,
#             indent=2
#             )
#     print(f"Saved best params -> {out_json}")

#     # Also export full trials dataframe if pandas is available
#     try:
#         df = study.trials_dataframe()
#         out_csv = os.path.join(
#             PROJECT_ROOT, "tuning_trials_full.csv"
#             )
#         df.to_csv(out_csv, index=False)
#         print(f"Saved all trial results -> {out_csv}")
#     except Exception as e:
#         print(f"Could not write full trials CSV: {e}")


# main_tuning()
