# 03 â€” Train TFT (risk/bands) + CatBoost veto

Uses dataset from `training_data/v1/dataset.parquet`. Trains TFT on returns (bands/vol heads stay auxiliary) and trains CatBoost veto on fused features from StockFormer + TFT + context. Saves artifacts to `artifacts/v1/tft/` and `artifacts/v1/veto/`.

In [None]:
!pip -q install torch numpy pandas pyarrow scikit-learn catboost tqdm

In [None]:
import os, sys, json, pathlib
REPO_URL = os.getenv("REPO_URL", "https://github.com/your-org/AI_TRADER.git")
REPO_DIR = os.getenv("REPO_DIR", "AI_TRADER")
if not pathlib.Path("backend").exists():
    if not pathlib.Path(REPO_DIR).exists():
        !git clone $REPO_URL $REPO_DIR
    %cd $REPO_DIR
sys.path.append(os.path.abspath("backend"))
os.makedirs("artifacts/v1/tft", exist_ok=True)
os.makedirs("artifacts/v1/veto", exist_ok=True)

In [None]:
import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
from catboost import CatBoostClassifier

from app.ml.tft.model import TFT
from app.ml.stockformer.model import StockFormer
from app.ml.preprocess.normalize import build_veto_vec

torch.set_grad_enabled(True)
device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
DATASET_PATH = pathlib.Path("training_data/v1/dataset.parquet")
df = pd.read_parquet(DATASET_PATH).sort_values("asof").reset_index(drop=True)
split = int(0.8 * len(df))
train_df, val_df = df.iloc[:split], df.iloc[split:]
HORIZONS = [3, 5, 10]
print(df.shape)

In [None]:
class TFTDataset(Dataset):
    def __init__(self, frame: pd.DataFrame):
        self.df = frame
    def __len__(self):
        return len(self.df)
    def __getitem__(self, idx):
        r = self.df.iloc[idx]
        x_price = torch.tensor(r["ohlcv_norm"], dtype=torch.float32)
        x_kron = torch.tensor(r["kronos_emb"], dtype=torch.float32)
        x_ctx = torch.tensor(r["context"], dtype=torch.float32)
        y_ret = torch.tensor(r["y_ret"], dtype=torch.float32)
        return x_price, x_kron, x_ctx, y_ret

train_dl = DataLoader(TFTDataset(train_df), batch_size=128, shuffle=True, drop_last=True)
val_dl = DataLoader(TFTDataset(val_df), batch_size=128, shuffle=False)

tft = TFT(lookback=120, price_dim=5, kronos_dim=512, context_dim=29, emb_dim=64, dropout=0.1).to(device)
opt = torch.optim.AdamW(tft.parameters(), lr=1e-3)
huber = torch.nn.SmoothL1Loss()
patience = 5
best_val = 1e9
bad = 0

def eval_tft():
    tft.eval()
    losses = []
    with torch.no_grad():
        for x_price, x_kron, x_ctx, y_ret in val_dl:
            x_price = x_price.to(device)
            x_kron = x_kron.to(device)
            x_ctx = x_ctx.to(device)
            y_ret = y_ret.to(device)
            out = tft(x_price, x_kron, x_ctx)
            loss = huber(out["ret"], y_ret)
            losses.append(loss.item())
    tft.train()
    return float(np.mean(losses))

In [None]:
for epoch in range(20):
    for x_price, x_kron, x_ctx, y_ret in train_dl:
        x_price = x_price.to(device)
        x_kron = x_kron.to(device)
        x_ctx = x_ctx.to(device)
        y_ret = y_ret.to(device)
        out = tft(x_price, x_kron, x_ctx)
        loss = huber(out["ret"], y_ret)
        opt.zero_grad()
        loss.backward()
        opt.step()
    v = eval_tft()
    print(f"epoch {epoch} val_loss {v:.4f}")
    if v < best_val:
        best_val = v
        bad = 0
        torch.save(tft.state_dict(), "artifacts/v1/tft/weights.pt")
    else:
        bad += 1
        if bad >= patience:
            print("Early stop")
            break

cfg = {
    "name": "tft_v1",
    "lookback": 120,
    "ohlcv_features": 5,
    "kronos_dim": 512,
    "context_dim": 29,
    "horizons": [3, 5, 10],
    "hidden_size": 64,
    "attention_heads": 4,
    "dropout": 0.1,
}
with open("artifacts/v1/tft/config.json", "w") as f:
    json.dump(cfg, f, indent=2)
print("Saved TFT weights + config")

In [None]:
# CatBoost veto
sf = StockFormer(lookback=120, price_dim=5, kronos_dim=512, context_dim=29, d_model=128, n_heads=4, n_layers=4, ffn_dim=256, dropout=0.1)
sf.load_state_dict(torch.load("artifacts/v1/stockformer/weights.pt", map_location="cpu"))
sf.eval()

tft_cpu = TFT(lookback=120, price_dim=5, kronos_dim=512, context_dim=29, emb_dim=64, dropout=0.1)
tft_cpu.load_state_dict(torch.load("artifacts/v1/tft/weights.pt", map_location="cpu"))
tft_cpu.eval()

def split_context(ctx: np.ndarray):
    tf_align = ctx[:5]
    smc_vec = ctx[5:17]
    ta_vec = ctx[17:]
    return tf_align, smc_vec, ta_vec

def infer_vectors(frame: pd.DataFrame, max_rows: int | None = None):
    X, y = [], []
    iterator = frame if max_rows is None else frame.sample(min(max_rows, len(frame)), random_state=42)
    for _, r in tqdm(iterator.iterrows(), total=len(iterator)):
        x_price = torch.tensor(r["ohlcv_norm"], dtype=torch.float32).unsqueeze(0)
        x_kron = torch.tensor(r["kronos_emb"], dtype=torch.float32).unsqueeze(0)
        x_ctx = torch.tensor(r["context"], dtype=torch.float32).unsqueeze(0)
        with torch.no_grad():
            sf_out = sf(x_price, x_kron, x_ctx)
            tft_out = tft_cpu(x_price, x_kron, x_ctx)
        tf_align, smc_vec, ta_vec = split_context(r["context"])
        veto_vec = build_veto_vec(
            sf_out={"prob": sf_out["up_prob"].numpy(), "ret": sf_out["ret"].numpy()},
            tft_out={k: v.numpy() for k, v in tft_out.items()},
            smc_vec=smc_vec,
            tf_align=tf_align,
            ta_vec=ta_vec,
            raw_features={}
        )
        X.append(veto_vec.squeeze(0))
        y.append(int(r["y_up"][1]))  # 5d up label
    return np.stack(X, axis=0), np.array(y)

X_train, y_train = infer_vectors(train_df, max_rows=50000)
X_val, y_val = infer_vectors(val_df, max_rows=10000)

cb = CatBoostClassifier(
    iterations=400,
    depth=6,
    learning_rate=0.05,
    loss_function="Logloss",
    verbose=100,
)
cb.fit(X_train, y_train, eval_set=(X_val, y_val))

cb.save_model("artifacts/v1/veto/catboost.cbm")
with open("artifacts/v1/veto/config.json", "w") as f:
    json.dump({"name": "catboost_veto_v1", "threshold_block": 0.65, "threshold_boost": 0.35}, f, indent=2)
print("Saved CatBoost veto")