# 01 — Build training dataset + Kronos embeddings (512d)

Runs on Colab T4. Pulls yfinance daily bars, builds TF-align/SMC/TA vectors via in-repo preprocessors, encodes Kronos embeddings, and saves `training_data/v1/dataset.parquet`.

In [None]:
!pip -q install yfinance pandas numpy pyarrow duckdb torch huggingface_hub tqdm

In [None]:
import os, sys, json, pathlib

# Locate repo root (works from root or notebooks/)
REPO_URL = os.getenv("REPO_URL", "https://github.com/RishiKarthikeyan07/ai-trader-saas")
REPO_DIR_NAME = os.getenv("REPO_DIR", "AI_TRADER")

cwd = pathlib.Path().resolve()
repo_root = None

# Try to find repo root by looking for 'apps' or 'backend' directory
for p in [cwd, *cwd.parents]:
    if (p / "apps").exists() or (p / "backend").exists():
        repo_root = p
        break

# If not found, clone the repository
if repo_root is None:
    target = cwd / REPO_DIR_NAME
    if not target.exists():
        !git clone $REPO_URL $target.name
    repo_root = target
    os.chdir(repo_root)
else:
    os.chdir(repo_root)

print(f"Repo root: {pathlib.Path().resolve()}")

# Add apps/api to Python path so we can import 'app' module
api_path = repo_root / "apps" / "api"
if api_path.exists():
    sys.path.insert(0, str(api_path))
    print(f"✓ Added to Python path: {api_path}")
else:
    # Fallback to backend if apps/api doesn't exist
    backend_path = repo_root / "backend"
    if backend_path.exists():
        sys.path.insert(0, str(backend_path))
        print(f"✓ Added to Python path: {backend_path}")
    else:
        print(f"⚠ Warning: Neither {api_path} nor {backend_path} found")

# Create output directory
os.makedirs(repo_root / "training_data/v1", exist_ok=True)
print(f"✓ Output directory ready: {repo_root / 'training_data/v1'}")

In [None]:
# Verify Python path and import dependencies
print("Python path (first 3 entries):")
for i, p in enumerate(sys.path[:3]):
    print(f"  [{i}] {p}")

import numpy as np
import pandas as pd
import yfinance as yf
import torch
import requests
from tqdm import tqdm

# Import from local app module
from app.ml.preprocess.normalize import (
    normalize_ohlcv_120,
    build_tf_align_vec,
    build_smc_vec,
    build_ta_vec,
)
from app.services.kronos_loader import load_kronos_hf
from app.services.feature_engine import compute_ta_features, compute_smc_features

torch.set_grad_enabled(False)
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"\n✓ Using device: {device}")
print("✓ All imports successful!")

In [None]:
# Config
from pathlib import Path

LOOKBACK = 120
HORIZONS = [3, 5, 10]
START = os.getenv("DATA_START", "2015-01-01")  # 9y+ history
END = os.getenv("DATA_END", None)  # None == today
TICKER_FILE = Path(os.getenv("TICKER_FILE", repo_root / "config/nifty100_yfinance.txt"))

ticker_path = TICKER_FILE if isinstance(TICKER_FILE, Path) else Path(TICKER_FILE)
if ticker_path.exists():
    with open(ticker_path) as f:
        TICKERS = [t.strip() for t in f if t.strip()]
else:
    raise FileNotFoundError(
        f"Ticker file {ticker_path} not found. Provide a yfinance symbol list (one per line, e.g., RELIANCE.NS)."
    )

if not TICKERS:
    raise ValueError("No tickers loaded; check TICKER_FILE")

OUT_PATH = repo_root / "training_data/v1/dataset.parquet"
print(f"Using {len(TICKERS)} tickers; saving to {OUT_PATH}")

In [None]:
ALPHA_KEY = os.getenv('ALPHAVANTAGE_API_KEY')

def fetch_daily(sym: str) -> pd.DataFrame:
    # Prefer AlphaVantage if key present; map NSE tickers like RELIANCE.NS -> NSE:RELIANCE
    bases = []
    if sym.endswith('.NS') or sym.endswith('.BSE'):
        base = sym.split('.')[0]
    else:
        base = sym
    bases = [f'NSE:{base}', f'BSE:{base}', base]

    if ALPHA_KEY:
        for av_sym in bases:
            for attempt in range(4):
                try:
                    params = {
                        'function': 'TIME_SERIES_DAILY_ADJUSTED',
                        'symbol': av_sym,
                        'outputsize': 'full',
                        'apikey': ALPHA_KEY,
                    }
                    resp = requests.get('https://www.alphavantage.co/query', params=params, timeout=30)
                    data = resp.json()
                    if 'Note' in data:
                        import time
                        time.sleep(15)
                        continue
                    series = data.get('Time Series (Daily)', {})
                    if series:
                        records = []
                        for date, vals in series.items():
                            records.append({
                                'date': pd.to_datetime(date),
                                'open': float(vals['1. open']),
                                'high': float(vals['2. high']),
                                'low': float(vals['3. low']),
                                'close': float(vals['4. close']),
                                'volume': float(vals['6. volume']),
                            })
                        df = pd.DataFrame(records).sort_values('date')
                        df.set_index('date', inplace=True)
                        if END:
                            df = df.loc[(df.index >= pd.to_datetime(START)) & (df.index <= pd.to_datetime(END))]
                        else:
                            df = df.loc[df.index >= pd.to_datetime(START)]
                        if not df.empty:
                            return df
                except Exception as exc:
                    print(f'[warn] AlphaVantage failed for {av_sym} attempt {attempt+1}: {exc}')
                import time
                time.sleep(15)

    # Fallback to Yahoo Finance
    last_exc = None
    for attempt in range(6):
        try:
            df = yf.download(
                sym,
                start=START,
                end=END,
                interval='1d',
                auto_adjust=False,
                progress=False,
                threads=False,
            )
            if not df.empty:
                break
        except Exception as exc:
            last_exc = exc
        import time
        time.sleep(2.0 * (attempt + 1))
    else:
        if last_exc:
            print(f'[warn] {sym} failed after retries: {last_exc}')
        return pd.DataFrame()

    df = df.rename(columns=str.lower)[['open', 'high', 'low', 'close', 'volume']].dropna()
    if df.empty:
        return df
    df.reset_index(inplace=True)
    df.rename(columns={'index': 'date', 'Date': 'date'}, inplace=True)
    df['date'] = pd.to_datetime(df['date'])
    df.set_index('date', inplace=True)
    return df


def add_labels(df: pd.DataFrame) -> pd.DataFrame:
    out = df.copy()
    for h in HORIZONS:
        out[f'ret_{h}'] = (out['close'].shift(-h) / out['close']) - 1.0
        out[f'up_{h}'] = (out[f'ret_{h}'] > 0).astype(np.int32)
    return out

In [None]:
# Feature extraction helpers aligned with backend
def _prep_window(window: pd.DataFrame) -> pd.DataFrame | None:
    if window.empty:
        return None
    df = window.copy()
    df.columns = [str(c).lower() for c in df.columns]
    if not isinstance(df.index, pd.DatetimeIndex):
        df.index = pd.to_datetime(df.index)
    df = df.sort_index()
    required = ['open', 'high', 'low', 'close', 'volume']
    missing = [c for c in required if c not in df.columns]
    if missing:
        return None
    return df


def compute_alignment(window: pd.DataFrame) -> dict:
    base = _prep_window(window)
    if base is None:
        return {'monthly_bias': 0.0, 'weekly_bias': 0.0, 'daily_bias': 0.0, 'h4_align': 0.0, 'h1_align': 0.0}
    core = base[['open', 'high', 'low', 'close', 'volume']]
    wk = core.resample('W-FRI').agg({'open': 'first', 'high': 'max', 'low': 'min', 'close': 'last', 'volume': 'sum'}).dropna()
    mo = core.resample('ME').agg({'open': 'first', 'high': 'max', 'low': 'min', 'close': 'last', 'volume': 'sum'}).dropna()
    h4 = core.resample('4h').agg({'open': 'first', 'high': 'max', 'low': 'min', 'close': 'last', 'volume': 'sum'}).dropna()
    h1 = core.copy()  # already 1H if provided; with daily data it's sparse but harmless

    def bias(df: pd.DataFrame) -> float:
        enriched = compute_ta_features(df)
        if enriched.empty:
            return 0.0
        latest = enriched.iloc[-1]
        return 1.0 if latest.get('ema_fast', 0) > latest.get('ema_slow', 0) else -1.0

    return {
        'monthly_bias': bias(mo),
        'weekly_bias': bias(wk),
        'daily_bias': bias(core),
        'h4_align': bias(h4),
        'h1_align': bias(h1),
    }


def compute_feature_dict(window: pd.DataFrame) -> dict:
    base = _prep_window(window)
    if base is None:
        return {}
    enriched = compute_ta_features(base)
    enriched = compute_smc_features(enriched)
    if enriched.empty:
        return {}
    latest = enriched.iloc[-1].to_dict()
    latest.update({
        'open': float(base.iloc[-1]['open']),
        'high': float(base.iloc[-1]['high']),
        'low': float(base.iloc[-1]['low']),
        'close': float(base.iloc[-1]['close']),
        'volume': float(base.iloc[-1]['volume']),
    })
    return latest

In [None]:
# Kronos 512d encoder
kronos = load_kronos_hf(device=device, max_context=512)

def kronos_embed(batch_norm: np.ndarray) -> np.ndarray:
    # batch_norm: (B,120,5)
    x = torch.tensor(batch_norm, dtype=torch.float32, device=device)
    if x.shape[-1] == 5:  # pad amount channel if tokenizer expects 6
        amt = torch.zeros(x.shape[0], x.shape[1], 1, device=device)
        x = torch.cat([x, amt], dim=-1)
    z = kronos.tokenizer.embed(x)
    if isinstance(z, tuple):
        z = z[0]
    emb = z.mean(dim=1).detach().cpu().numpy().astype(np.float32)
    if emb.shape[1] < 512:
        pad = np.zeros((emb.shape[0], 512 - emb.shape[1]), dtype=np.float32)
        emb = np.concatenate([emb, pad], axis=1)
    elif emb.shape[1] > 512:
        emb = emb[:, :512]
    return emb

In [None]:
import time

rows = []

for sym in tqdm(TICKERS):
    # Throttle API calls for Yahoo Finance
    time.sleep(2.0)

    df = fetch_daily(sym)
    if df.empty or len(df) < LOOKBACK + max(HORIZONS) + 10:
        continue
    df = add_labels(df)

    batch_ohlcv = []
    batch_meta = []

    for i in range(LOOKBACK - 1, len(df) - max(HORIZONS)):
        window = df.iloc[i - LOOKBACK + 1 : i + 1]
        ohlcv = window[["open", "high", "low", "close", "volume"]].values.astype(np.float32)
        if ohlcv.shape[0] != LOOKBACK:
            continue
        norm = normalize_ohlcv_120(ohlcv)
        alignment = build_tf_align_vec(compute_alignment(window))
        feat_dict = compute_feature_dict(window)
        smc_vec = build_smc_vec(feat_dict)
        ta_vec = build_ta_vec(feat_dict)
        context = np.concatenate([alignment, smc_vec, ta_vec]).astype(np.float32)
        y_ret = np.array([df.iloc[i][f"ret_{h}"] for h in HORIZONS], dtype=np.float32)
        y_up = np.array([df.iloc[i][f"up_{h}"] for h in HORIZONS], dtype=np.float32)
        if np.any(np.isnan(y_ret)):
            continue
        batch_ohlcv.append(norm)
        batch_meta.append((sym, df.index[i], context, y_ret, y_up))

        # Process in batches of 64 for GPU efficiency
        if len(batch_ohlcv) >= 64:
            kron = kronos_embed(np.stack(batch_ohlcv, axis=0))
            for (m_sym, m_date, m_ctx, m_ret, m_up), m_emb, m_ohlcv in zip(batch_meta, kron, batch_ohlcv):
                rows.append(
                    {"symbol": m_sym, "asof": m_date, "ohlcv_norm": m_ohlcv, "kronos_emb": m_emb, "context": m_ctx, "y_ret": m_ret, "y_up": m_up}
                )
            batch_ohlcv, batch_meta = [], []

    # Process remaining samples
    if batch_ohlcv:
        kron = kronos_embed(np.stack(batch_ohlcv, axis=0))
        for (m_sym, m_date, m_ctx, m_ret, m_up), m_emb, m_ohlcv in zip(batch_meta, kron, batch_ohlcv):
            rows.append(
                {"symbol": m_sym, "asof": m_date, "ohlcv_norm": m_ohlcv, "kronos_emb": m_emb, "context": m_ctx, "y_ret": m_ret, "y_up": m_up}
            )

print(f"\nTotal samples: {len(rows)}")
df_out = pd.DataFrame(rows)
df_out.to_parquet(OUT_PATH, index=False)
print(f"Saved to {OUT_PATH}")

# Display summary statistics
print(f"\n=== Dataset Summary ===")
print(f"Total samples: {len(rows)}")
print(f"Unique symbols: {df_out['symbol'].nunique()}")
print(f"Date range: {df_out['asof'].min()} to {df_out['asof'].max()}")
print(f"OHLCV shape: {rows[0]['ohlcv_norm'].shape if rows else 'N/A'}")
print(f"Kronos embedding shape: {rows[0]['kronos_emb'].shape if rows else 'N/A'}")
print(f"Context vector shape: {rows[0]['context'].shape if rows else 'N/A'}")
print(f"Target horizons: {HORIZONS}")