# 04 — Train LightGBM veto


In [None]:
!pip -q install torch numpy pandas pyarrow scikit-learn lightgbm tqdm


In [None]:
import os, sys, json, pathlib

REPO_URL = os.getenv('REPO_URL', 'https://github.com/RishiKarthikeyan07/ai-trader-saas')
REPO_DIR = os.getenv('REPO_DIR', 'AI_TRADER')

if not pathlib.Path('backend').exists():
    if not pathlib.Path(REPO_DIR).exists():
        !git clone $REPO_URL $REPO_DIR
    %cd $REPO_DIR
sys.path.append(str(pathlib.Path('backend').resolve()))

DATASET_PATH = pathlib.Path('training_data/v1/dataset.parquet')

train_df = None
val_df = None

assert DATASET_PATH.exists(), f"Dataset not found at {DATASET_PATH}"
print(f'Using dataset: {DATASET_PATH}')


In [None]:
import numpy as np
import pandas as pd
import torch
from tqdm import tqdm
import lightgbm as lgb

from app.ml.preprocess.normalize import build_veto_vec
from app.ml.stockformer.model import StockFormer
from app.ml.tft.model import TFT

DATASET_PATH = pathlib.Path('training_data/v1/dataset.parquet')
df = pd.read_parquet(DATASET_PATH).sort_values('asof').reset_index(drop=True)
split = int(0.8 * len(df))
train_df, val_df = df.iloc[:split], df.iloc[split:]
print(f'Train: {len(train_df)}, Val: {len(val_df)}')


In [None]:
# Load trained SF/TFT
sf = StockFormer(lookback=120, price_dim=5, kronos_dim=512, context_dim=29, d_model=128, n_heads=4, n_layers=4, ffn_dim=256, dropout=0.1)
sf.load_state_dict(torch.load('artifacts/v1/stockformer/weights.pt', map_location='cpu'))
sf.eval()

tft = TFT(lookback=120, price_dim=5, kronos_dim=512, context_dim=29, emb_dim=64, dropout=0.1)
tft.load_state_dict(torch.load('artifacts/v1/tft/weights.pt', map_location='cpu'))
tft.eval()

def split_ctx(ctx: np.ndarray):
    a = np.asarray(ctx, dtype=np.float32).reshape(-1)
    tf_align = a[:5]
    smc_vec = a[5:17]
    ta_vec = a[17:]
    return tf_align, smc_vec, ta_vec


In [None]:
def infer_vectors(frame: pd.DataFrame, max_rows: int | None = None):
    X, y = [], []
    iterator = frame if max_rows is None else frame.sample(min(max_rows, len(frame)), random_state=42)

    def _arr_price(x, idx=None):
        try:
            a = np.array(x, dtype=np.float32)
        except Exception:
            a = np.array(list(x), dtype=object)
            a = np.stack([np.array(row, dtype=np.float32).reshape(-1) for row in a], axis=0)
        if a.size != 120 * 5:
            raise ValueError(f"Bad ohlcv_norm size for idx {idx}: shape {a.shape}")
        return a.reshape(1, 120, 5)

    def _arr_flat(x, name, idx=None):
        try:
            a = np.array(x, dtype=np.float32).reshape(1, -1)
        except Exception as exc:
            raise ValueError(f"Bad array for {name} at idx {idx}: {x}") from exc
        return a

    for i, (_, r) in enumerate(iterator.iterrows()):
        x_price = torch.tensor(_arr_price(r['ohlcv_norm'], idx=i))
        x_kron = torch.tensor(_arr_flat(r['kronos_emb'], 'kronos_emb', idx=i))
        x_ctx = torch.tensor(_arr_flat(r['context'], 'context', idx=i))
        with torch.no_grad():
            sf_out = sf(x_price, x_kron, x_ctx)
            tft_out = tft(x_price, x_kron, x_ctx)
        tf_align, smc_vec, ta_vec = split_ctx(r['context'])
        veto_vec = build_veto_vec(
            sf_out={'prob': sf_out['up_prob'].numpy(), 'ret': sf_out['ret'].numpy()},
            tft_out={k: v.numpy() for k, v in tft_out.items()},
            smc_vec=smc_vec,
            tf_align=tf_align,
            ta_vec=ta_vec,
            raw_features={}
        )
        X.append(veto_vec.squeeze(0))
        y.append(int(np.array(r['y_up'])[1]))
    return np.stack(X, axis=0), np.array(y)


In [None]:
X_train, y_train = infer_vectors(train_df, max_rows=50000)
X_val, y_val = infer_vectors(val_df, max_rows=10000)

lgbm = lgb.LGBMClassifier(
    n_estimators=400,
    learning_rate=0.05,
    num_leaves=63,
    subsample=0.9,
    colsample_bytree=0.9,
    objective='binary',
    random_state=42,
)

lgbm.fit(
    X_train,
    y_train,
    eval_set=[(X_val, y_val)],
    eval_metric='logloss',
    verbose=50,
)

booster = lgbm.booster_
booster.save_model('artifacts/v1/veto/lightgbm.txt')

from pathlib import Path
import json

Path('artifacts/v1/veto').mkdir(parents=True, exist_ok=True)
with open('artifacts/v1/veto/config.json','w') as f:
    json.dump({'name':'lightgbm_veto_v1','threshold_block':0.65,'threshold_boost':0.35}, f, indent=2)
print('Saved LightGBM veto')


In [None]:
# Artifact summary for LightGBM veto
from pathlib import Path
import json, os

w_path = Path('artifacts/v1/veto/lightgbm.txt')
c_path = Path('artifacts/v1/veto/config.json')
print('Artifacts directory exists:', w_path.parent.exists())
print('LightGBM model exists:', w_path.exists(), w_path)
print('Config exists:', c_path.exists(), c_path)
if c_path.exists():
    print(c_path.read_text())
