# 02 — Train StockFormer (direction + returns)

Loads dataset from `training_data/v1/dataset.parquet`, trains StockFormer with context_dim=29 and Kronos 512d embeddings, saves weights/config to `artifacts/v1/stockformer/`.

In [None]:
!pip -q install torch numpy pandas pyarrow scikit-learn tqdm

In [None]:
import os, sys, json, pathlib
REPO_URL = os.getenv("REPO_URL", "https://github.com/your-org/AI_TRADER.git")
REPO_DIR = os.getenv("REPO_DIR", "AI_TRADER")
if not pathlib.Path("backend").exists():
    if not pathlib.Path(REPO_DIR).exists():
        !git clone $REPO_URL $REPO_DIR
    %cd $REPO_DIR
sys.path.append(os.path.abspath("backend"))
os.makedirs("artifacts/v1/stockformer", exist_ok=True)

In [None]:
import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm

from app.ml.stockformer.model import StockFormer

torch.set_grad_enabled(True)
device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
DATASET_PATH = pathlib.Path("training_data/v1/dataset.parquet")
df = pd.read_parquet(DATASET_PATH).sort_values("asof").reset_index(drop=True)
print(df.shape, df.head())

split = int(0.8 * len(df))
train_df, val_df = df.iloc[:split], df.iloc[split:]
HORIZONS = [3, 5, 10]

In [None]:
class SwingDataset(Dataset):
    def __init__(self, frame: pd.DataFrame):
        self.df = frame
    def __len__(self):
        return len(self.df)
    def __getitem__(self, idx):
        r = self.df.iloc[idx]
        x_price = torch.tensor(np.array(r['ohlcv_norm'], dtype=np.float32), dtype=torch.float32)
        x_kron = torch.tensor(np.array(r['kronos_emb'], dtype=np.float32), dtype=torch.float32)
        x_ctx = torch.tensor(np.array(r['context'], dtype=np.float32), dtype=torch.float32)
        y_ret = torch.tensor(np.array(r['y_ret'], dtype=np.float32), dtype=torch.float32)
        y_up = torch.tensor(np.array(r['y_up'], dtype=np.float32), dtype=torch.float32)
        return x_price, x_kron, x_ctx, y_ret, y_up

train_dl = DataLoader(SwingDataset(train_df), batch_size=64, shuffle=True, drop_last=True)
val_dl = DataLoader(SwingDataset(val_df), batch_size=64, shuffle=False)


In [None]:
model = StockFormer(lookback=120, price_dim=5, kronos_dim=512, context_dim=29, d_model=128, n_heads=4, n_layers=4, ffn_dim=256, dropout=0.1)
model.to(device)
opt = torch.optim.AdamW(model.parameters(), lr=3e-4)
huber = torch.nn.SmoothL1Loss()
bce = torch.nn.BCEWithLogitsLoss()
patience = 5
best_val = 1e9
bad = 0

def evaluate():
    model.eval()
    losses = []
    with torch.no_grad():
        for x_price, x_kron, x_ctx, y_ret, y_up in val_dl:
            x_price = x_price.to(device)
            x_kron = x_kron.to(device)
            x_ctx = x_ctx.to(device)
            y_ret = y_ret.to(device)
            y_up = y_up.to(device)
            out = model(x_price, x_kron, x_ctx)
            loss = 0.6 * huber(out["ret"], y_ret) + 0.4 * bce(out["up_logits"], y_up)
            losses.append(loss.item())
    model.train()
    return float(np.mean(losses))

In [None]:
for epoch in range(30):
    for x_price, x_kron, x_ctx, y_ret, y_up in train_dl:
        x_price = x_price.to(device)
        x_kron = x_kron.to(device)
        x_ctx = x_ctx.to(device)
        y_ret = y_ret.to(device)
        y_up = y_up.to(device)

        out = model(x_price, x_kron, x_ctx)
        loss = 0.6 * huber(out["ret"], y_ret) + 0.4 * bce(out["up_logits"], y_up)
        opt.zero_grad()
        loss.backward()
        opt.step()

    val_loss = evaluate()
    print(f"epoch {epoch} val_loss {val_loss:.4f}")
    if val_loss < best_val:
        best_val = val_loss
        bad = 0
        torch.save(model.state_dict(), "artifacts/v1/stockformer/weights.pt")
    else:
        bad += 1
        if bad >= patience:
            print("Early stop")
            break

print("Best val:", best_val)

cfg = {
  "name": "stockformer_v1",
  "lookback": 120,
  "ohlcv_features": 5,
  "kronos_dim": 512,
  "context_dim": 29,
  "horizons": [3, 5, 10],
  "d_model": 128,
  "n_heads": 4,
  "n_layers": 4,
  "ffn_dim": 256,
  "dropout": 0.1
}
with open("artifacts/v1/stockformer/config.json", "w") as f:
    json.dump(cfg, f, indent=2)
print("Saved weights + config to artifacts/v1/stockformer")