# 03 â€” Train TFT (returns/vol/bands)

In [None]:
!pip -q install torch numpy pandas pyarrow scikit-learn tqdm

In [None]:
import os, sys, json, pathlib
REPO_URL = os.getenv('REPO_URL', 'https://github.com/your-org/AI_TRADER.git')
REPO_DIR = os.getenv('REPO_DIR', 'AI_TRADER')
if not pathlib.Path('backend').exists():
    if not pathlib.Path(REPO_DIR).exists():
        !git clone $REPO_URL $REPO_DIR
    %cd $REPO_DIR
sys.path.append(os.path.abspath('backend'))


In [None]:
import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm

from app.ml.tft.model import TFT

torch.set_grad_enabled(True)
device = 'cuda' if torch.cuda.is_available() else 'cpu'

DATASET_PATH = pathlib.Path('training_data/v1/dataset.parquet')
df = pd.read_parquet(DATASET_PATH).sort_values('asof').reset_index(drop=True)
split = int(0.8 * len(df))
train_df, val_df = df.iloc[:split], df.iloc[split:]
print(df.shape)


In [None]:
# TFT dataset helper with robust array coercion
class TFTDataset(Dataset):
    def __init__(self, frame: pd.DataFrame):
        self.df = frame.reset_index(drop=True)
    def __len__(self):
        return len(self.df)
    def __getitem__(self, idx):
        r = self.df.iloc[idx]
        def _arr_price(x):
            try:
                a = np.array(x, dtype=np.float32)
            except Exception:
                a = np.array(list(x), dtype=object)
                a = np.stack([np.array(row, dtype=np.float32).reshape(-1) for row in a], axis=0)
            if a.size != 120 * 5:
                raise ValueError(f"Bad ohlcv_norm size for idx {idx}: shape {a.shape}")
            return a.reshape(120, 5)
        def _arr_flat(x, name):
            try:
                a = np.array(x, dtype=np.float32).reshape(-1)
            except Exception as exc:
                raise ValueError(f"Bad array for {name} at idx {idx}: {x}") from exc
            return a
        x_price = torch.tensor(_arr_price(r['ohlcv_norm']), dtype=torch.float32)
        x_kron = torch.tensor(_arr_flat(r['kronos_emb'], 'kronos_emb'), dtype=torch.float32)
        x_ctx = torch.tensor(_arr_flat(r['context'], 'context'), dtype=torch.float32)
        y_ret = torch.tensor(_arr_flat(r['y_ret'], 'y_ret'), dtype=torch.float32)
        return x_price, x_kron, x_ctx, y_ret

def make_loaders(train_df, val_df):
    train_dl = DataLoader(TFTDataset(train_df), batch_size=128, shuffle=True, drop_last=True)
    val_dl = DataLoader(TFTDataset(val_df), batch_size=128, shuffle=False)
    return train_dl, val_dl

train_dl, val_dl = make_loaders(train_df, val_df)


In [None]:
tft = TFT(lookback=120, price_dim=5, kronos_dim=512, context_dim=29, emb_dim=64, dropout=0.1).to(device)
opt = torch.optim.AdamW(tft.parameters(), lr=1e-3)
huber = torch.nn.SmoothL1Loss()
patience = 5
best_val = 1e9
bad = 0


def eval_tft():
    tft.eval()
    losses = []
    with torch.no_grad():
        for x_price, x_kron, x_ctx, y_ret in val_dl:
            x_price = x_price.to(device)
            x_kron = x_kron.to(device)
            x_ctx = x_ctx.to(device)
            y_ret = y_ret.to(device)
            out = tft(x_price, x_kron, x_ctx)
            loss = huber(out['ret'], y_ret)
            losses.append(loss.item())
    tft.train()
    return float(np.mean(losses))


In [None]:
for epoch in range(20):
    for x_price, x_kron, x_ctx, y_ret in train_dl:
        x_price = x_price.to(device)
        x_kron = x_kron.to(device)
        x_ctx = x_ctx.to(device)
        y_ret = y_ret.to(device)
        out = tft(x_price, x_kron, x_ctx)
        loss = huber(out['ret'], y_ret)
        opt.zero_grad()
        loss.backward()
        opt.step()
    v = eval_tft()
    print(f"epoch {epoch} val_loss {v:.4f}")
    if v < best_val:
        best_val = v
        bad = 0
        pathlib.Path('artifacts/v1/tft').mkdir(parents=True, exist_ok=True)
        torch.save(tft.state_dict(), 'artifacts/v1/tft/weights.pt')
    else:
        bad += 1
        if bad >= patience:
            print('Early stop')
            break

cfg = {
    'name': 'tft_v1',
    'lookback': 120,
    'ohlcv_features': 5,
    'kronos_dim': 512,
    'context_dim': 29,
    'horizons': [3,5,10],
    'hidden_size': 64,
    'attention_heads': 4,
    'dropout': 0.1,
}
with open('artifacts/v1/tft/config.json','w') as f:
    json.dump(cfg,f,indent=2)
print('Saved TFT weights + config')


In [None]:
# Artifact summary for TFT
from pathlib import Path
import json, os

w_path = Path('artifacts/v1/tft/weights.pt')
c_path = Path('artifacts/v1/tft/config.json')
print('Artifacts directory:', w_path.parent.resolve())
print('Weight exists:', w_path.exists(), 'size:', w_path.stat().st_size if w_path.exists() else 'n/a')
print('Config exists:', c_path.exists())
if c_path.exists():
    with open(c_path) as f:
        cfg = json.load(f)
    print('Config keys:', list(cfg.keys()))
