In [None]:
from __future__ import annotations

import json
from dataclasses import asdict
from datetime import datetime, timezone, timedelta
from pathlib import Path

import numpy as np

from trading.contexts.backtest.application.dto import RunBacktestTemplate
from trading.contexts.backtest.application.services.close_fill_scorer_v1 import (
    CloseFillBacktestStagedScorerV1,
)
from trading.contexts.backtest.application.services.staged_runner_v1 import BacktestStagedRunnerV1
from trading.contexts.backtest.adapters.outbound.config.backtest_runtime_config import (
    load_backtest_runtime_config,
)
from trading.contexts.indicators.application.dto import CandleArrays, ComputeRequest
from trading.contexts.indicators.domain.entities import AxisDef, IndicatorId, Layout
from trading.contexts.indicators.domain.specifications import ExplicitValuesSpec, GridSpec
from trading.contexts.indicators.domain.errors import GridValidationError, UnknownIndicatorError
from trading.shared_kernel.primitives import (
    InstrumentId,
    MarketId,
    Symbol,
    Timeframe,
    TimeRange,
    UtcTimestamp,
)

UPDATE_GOLDEN = False
GOLDEN_PATH = Path("notebooks/_goldens/backtest_smoke_ma_sma.json")


In [None]:
def make_synth_candles(*, n: int = 3000, seed: int = 7) -> CandleArrays:
    rng = np.random.default_rng(seed)

    start = datetime(2024, 1, 1, tzinfo=timezone.utc)
    ts_open = (np.arange(n, dtype=np.int64) * 60_000) + int(start.timestamp() * 1000)

    noise = rng.normal(loc=0.0, scale=0.15, size=n).astype(np.float32)
    drift = np.full(n, 0.01, dtype=np.float32)
    periodic = (0.8 * np.sin(np.linspace(0, 30, n, dtype=np.float32))).astype(np.float32)

    close = (100.0 + np.cumsum(drift + noise + 0.02 * periodic)).astype(np.float32)
    open_ = np.roll(close, 1)
    open_[0] = close[0]

    spread = (0.05 + 0.02 * rng.random(n, dtype=np.float32)).astype(np.float32)
    high = np.maximum(open_, close) + spread
    low = np.minimum(open_, close) - spread
    volume = (1000.0 + 50.0 * rng.random(n, dtype=np.float32)).astype(np.float32)

    end = start + timedelta(minutes=int(n))
    time_range = TimeRange(UtcTimestamp(start), UtcTimestamp(end))

    return CandleArrays(
        market_id=MarketId(1),
        symbol=Symbol("BTCUSDT"),
        time_range=time_range,
        timeframe=Timeframe("1m"),
        ts_open=ts_open.astype(np.int64),
        open=open_.astype(np.float32),
        high=high.astype(np.float32),
        low=low.astype(np.float32),
        close=close.astype(np.float32),
        volume=volume.astype(np.float32),
    )

candles = make_synth_candles()
candles


In [None]:
# Prefer the real Numba compute engine if available.
USE_NUMBA = True

class _NumpySmaOnlyCompute:
    """Fallback IndicatorCompute for ma.sma only (sanity mode).

    Golden checks are intended to run with the real Numba engine.
    """

    def estimate(self, grid: GridSpec, *, max_variants_guard: int):
        if grid.indicator_id.value != "ma.sma":
            raise UnknownIndicatorError(grid.indicator_id)
        windows = grid.params["window"].materialize()
        sources = ("close",) if grid.source is None else grid.source.materialize()
        variants = len(windows) * len(sources)
        if variants > max_variants_guard:
            raise GridValidationError("variants exceed guard")
        axes = (
            AxisDef(name="source", values_enum=tuple(str(s) for s in sources)),
            AxisDef(name="window", values_int=tuple(int(w) for w in windows)),
        )
        from trading.contexts.indicators.application.dto import EstimateResult

        return EstimateResult(
            indicator_id=grid.indicator_id,
            axes=axes,
            variants=variants,
            max_variants_guard=max_variants_guard,
        )

    def compute(self, req: ComputeRequest):
        grid = req.grid
        if grid.indicator_id.value != "ma.sma":
            raise UnknownIndicatorError(grid.indicator_id)

        windows = [int(x) for x in grid.params["window"].materialize()]
        sources = [str(x) for x in (("close",) if grid.source is None else grid.source.materialize())]

        if sources != ["close"]:
            raise GridValidationError("fallback only supports source=close")

        close = req.candles.close.astype(np.float32, copy=False)
        t = int(close.shape[0])

        # TIME_MAJOR: values[t, variants]
        out = np.empty((t, len(windows)), dtype=np.float32)
        for j, w in enumerate(windows):
            if w <= 0:
                raise GridValidationError("window must be > 0")
            # Simple SMA with NaN warmup behavior (match current nan policy: propagate).
            y = np.full(t, np.nan, dtype=np.float32)
            c = np.cumsum(close, dtype=np.float64)
            c[w:] = c[w:] - c[:-w]
            y[w - 1 :] = (c[w - 1 :] / float(w)).astype(np.float32)
            out[:, j] = y

        from trading.contexts.indicators.application.dto import IndicatorTensor, TensorMeta

        axes = (
            AxisDef(name="source", values_enum=("close",)),
            AxisDef(name="window", values_int=tuple(windows)),
        )
        meta = TensorMeta(t=t, variants=int(out.shape[1]), nan_policy="propagate")
        return IndicatorTensor(
            indicator_id=grid.indicator_id,
            layout=Layout.TIME_MAJOR,
            axes=axes,
            values=out,
            meta=meta,
        )

    def warmup(self) -> None:
        return

try:
    if not USE_NUMBA:
        raise ImportError("forced fallback")
    from trading.contexts.indicators.adapters.outbound.compute_numba.engine import NumbaIndicatorCompute
    from trading.contexts.indicators.domain.definitions import all_defs
    from trading.platform.config.indicators_compute_numba import IndicatorsComputeNumbaConfig

    numba_cfg = IndicatorsComputeNumbaConfig(
        numba_num_threads=1,
        numba_cache_dir=".cache/numba/notebooks",
        max_compute_bytes_total=512 * 1024**2,
        max_variants_per_compute=600_000,
    )
    indicator_compute = NumbaIndicatorCompute(defs=all_defs(), config=numba_cfg)
    USING_NUMBA = True
except Exception as e:
    print(f"Numba compute not available ({e!r}); using fallback SMA-only compute.")
    indicator_compute = _NumpySmaOnlyCompute()
    numba_cfg = None
    USING_NUMBA = False

USING_NUMBA


In [None]:
indicator_grid = GridSpec(
    indicator_id=IndicatorId("ma.sma"),
    source=ExplicitValuesSpec(name="source", values=("close",)),
    params={
        "window": ExplicitValuesSpec(name="window", values=(5, 10, 20, 40, 80)),
    },
)

template = RunBacktestTemplate(
    instrument_id=InstrumentId(candles.market_id, candles.symbol),
    timeframe=candles.timeframe,
    indicator_grids=(indicator_grid,),
    indicator_selections=(),
    signal_grids=None,
    risk_grid=None,
    direction_mode="long-short",
    sizing_mode="all_in",
    risk_params=None,
    execution_params=None,
)

rt = load_backtest_runtime_config("configs/dev/backtest.yaml")

scorer = CloseFillBacktestStagedScorerV1(
    indicator_compute=indicator_compute,
    direction_mode=template.direction_mode,
    sizing_mode=template.sizing_mode,
    execution_params=template.execution_params or {},
    market_id=candles.market_id.value,
    target_slice=slice(200, candles.close.shape[0]),
    init_cash_quote_default=rt.execution.init_cash_quote_default,
    fixed_quote_default=rt.execution.fixed_quote_default,
    safe_profit_percent_default=rt.execution.safe_profit_percent_default,
    slippage_pct_default=rt.execution.slippage_pct_default,
    fee_pct_default_by_market_id=rt.execution.fee_pct_default_by_market_id,
    max_variants_guard=(600_000 if numba_cfg is None else numba_cfg.max_variants_per_compute),
)

runner = BacktestStagedRunnerV1(parallel_workers=1)
res1 = runner.run(
    template=template,
    candles=candles,
    preselect=10,
    top_k=5,
    indicator_compute=indicator_compute,
    scorer=scorer,
    defaults_provider=None,
    max_variants_per_compute=(600_000 if numba_cfg is None else numba_cfg.max_variants_per_compute),
    max_compute_bytes_total=(512 * 1024**2 if numba_cfg is None else numba_cfg.max_compute_bytes_total),
    requested_time_range=candles.time_range,
    top_trades_n=3,
)

res2 = runner.run(
    template=template,
    candles=candles,
    preselect=10,
    top_k=5,
    indicator_compute=indicator_compute,
    scorer=scorer,
    defaults_provider=None,
    max_variants_per_compute=(600_000 if numba_cfg is None else numba_cfg.max_variants_per_compute),
    max_compute_bytes_total=(512 * 1024**2 if numba_cfg is None else numba_cfg.max_compute_bytes_total),
    requested_time_range=candles.time_range,
    top_trades_n=3,
)

rows1 = [(v.variant_key, v.indicator_variant_key, float(v.total_return_pct)) for v in res1.variants]
rows2 = [(v.variant_key, v.indicator_variant_key, float(v.total_return_pct)) for v in res2.variants]

assert rows1 == rows2, "non-deterministic results"
assert len(rows1) == 5
assert len({r[0] for r in rows1}) == 5

rows1


In [None]:
payload = {
    "meta": {
        "indicator_id": "ma.sma",
        "windows": [5, 10, 20, 40, 80],
        "bars": int(candles.close.shape[0]),
        "target_slice": [200, int(candles.close.shape[0])],
        "using_numba": USING_NUMBA,
    },
    "variants": [
        {
            "variant_key": vk,
            "indicator_variant_key": ik,
            "total_return_pct": tr,
        }
        for (vk, ik, tr) in rows1
    ],
}

if not GOLDEN_PATH.exists():
    if not USING_NUMBA:
        raise RuntimeError("Golden generation is intended to run with Numba compute.")
    if UPDATE_GOLDEN:
        GOLDEN_PATH.parent.mkdir(parents=True, exist_ok=True)
        GOLDEN_PATH.write_text(json.dumps(payload, indent=2, sort_keys=True) + "\n", encoding="utf-8")
        print(f"Wrote golden: {GOLDEN_PATH}")
    else:
        raise FileNotFoundError(
            f"Golden not found: {GOLDEN_PATH}. Set UPDATE_GOLDEN=True to generate."
        )

golden = json.loads(GOLDEN_PATH.read_text(encoding="utf-8"))
golden_rows = [
    (v["variant_key"], v["indicator_variant_key"], float(v["total_return_pct"]))
    for v in golden["variants"]
]

# Exact match is expected for keys; returns use a tight tolerance.
assert [r[:2] for r in rows1] == [r[:2] for r in golden_rows]
for (a, b) in zip([r[2] for r in rows1], [r[2] for r in golden_rows], strict=True):
    if not np.isfinite(a) or not np.isfinite(b):
        raise AssertionError("non-finite total_return_pct")
    if abs(a - b) > 1e-6:
        raise AssertionError(f"total_return_pct mismatch: {a} vs {b}")

"OK"
