# 03 — Модель и бэктест (валидация + диагностика)

Порог и полярность подбираются только на ВАЛИДАЦИИ, тест — финальная честная оценка.
Диагностика trade_count и экспозиции помогает понять, почему сделок мало.

Полярность (pred vs -pred) рассматриваем как гиперпараметр,
поэтому выбираем её на валидации. Порог также фиксируется по валидации.
Это защищает от утечки информации из теста.

In [None]:
from pathlib import Path
import json
import sys
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

PROJECT_ROOT = Path.cwd().resolve()
if (PROJECT_ROOT / "src").exists():
    ROOT = PROJECT_ROOT
elif (PROJECT_ROOT.parent / "src").exists():
    ROOT = PROJECT_ROOT.parent
else:
    ROOT = PROJECT_ROOT

if str(ROOT) not in sys.path:
    sys.path.insert(0, str(ROOT))

from src.features import get_feature_columns, drop_na_for_training
from src.model import TrainConfig, StandardScaler, train_mlp_model, predict, compute_regression_metrics
from src.backtest import backtest_long_short, backtest_long_short_horizon

DATA_PATH = ROOT / "data" / "eurusd_features.parquet"
ARTIFACT_DIR = ROOT / "data" / "artifacts"
ARTIFACT_DIR.mkdir(parents=True, exist_ok=True)

df = pd.read_parquet(DATA_PATH)
df = drop_na_for_training(df)
df = df.sort_values("time").reset_index(drop=True)


ModuleNotFoundError: No module named 'src'

In [None]:
feature_cols = get_feature_columns()
X = df[feature_cols].values
y = df["target"].values

assert np.isfinite(X).all(), "NaNs or infs in features"
assert df["time"].is_monotonic_increasing, "Time must be sorted"


In [None]:
n = len(df)
train_end = int(n * 0.70)
val_end = int(n * 0.85)

X_train, y_train = X[:train_end], y[:train_end]
X_val, y_val = X[train_end:val_end], y[train_end:val_end]
X_test, y_test = X[val_end:], y[val_end:]

df_val = df.iloc[train_end:val_end].reset_index(drop=True)
df_test = df.iloc[val_end:].reset_index(drop=True)


In [None]:
scaler = StandardScaler()
X_train_s = scaler.fit_transform(X_train)
X_val_s = scaler.transform(X_val)
X_test_s = scaler.transform(X_test)


In [None]:
cfg = TrainConfig(epochs=200, batch_size=1024, lr=1e-3, weight_decay=1e-4, patience=5)
model, history = train_mlp_model(X_train_s, y_train, X_val_s, y_val, cfg)

pred_val = predict(model, X_val_s)
pred_test = predict(model, X_test_s)


In [None]:
bt_pos = backtest_long_short_horizon(df_val.assign(pred=pred_val), threshold=0.0, hold_bars=3)
bt_neg = backtest_long_short_horizon(df_val.assign(pred=-pred_val), threshold=0.0, hold_bars=3)

polarity = 1 if bt_pos.metrics["sharpe"] >= bt_neg.metrics["sharpe"] else -1
print("VAL Sharpe pos:", bt_pos.metrics["sharpe"], "neg:", bt_neg.metrics["sharpe"])
print("Chosen polarity:", polarity)

pred_val_final = polarity * pred_val
pred_test_final = polarity * pred_test


In [None]:
abs_pred = np.abs(pred_val_final)
quantiles = [0.50, 0.60, 0.70, 0.80, 0.90, 0.95]
thresholds = sorted(set([float(np.quantile(abs_pred, q)) for q in quantiles]))

regimes = [None, "adx", "h1_align", "adx_and_h1"]

val_rows = []
for th in thresholds:
    for reg in regimes:
        bt = backtest_long_short_horizon(
            df_val.assign(pred=pred_val_final),
            threshold=th,
            hold_bars=3,
            cost_bps=0.5,
            regime=reg,
        )
        m = bt.metrics
        d = bt.debug
        val_rows.append(
            {
                "threshold": th,
                "regime": reg,
                "sharpe": m["sharpe"],
                "total_return": m["total_return"],
                "max_drawdown": m["max_drawdown"],
                "trade_count": m["trade_count"],
                "signal_counts": d["signal_counts"],
            }
        )

val_table = pd.DataFrame(val_rows)
val_table


In [None]:
valid = val_table[val_table["trade_count"] >= 20]
if len(valid) == 0:
    print("WARNING: No valid thresholds with trade_count >= 20; using first row")
    best_row = val_table.iloc[0]
else:
    valid = valid.sort_values(["sharpe", "max_drawdown", "total_return"], ascending=[False, False, False])
    best_row = valid.iloc[0]

best_threshold = float(best_row["threshold"])
best_regime = best_row["regime"]
print("Selected (VAL):", best_threshold, best_regime)


In [None]:
metrics_val = compute_regression_metrics(pred_val_final, y_val)
metrics_test = compute_regression_metrics(pred_test_final, y_test)

bt_test = backtest_long_short_horizon(
    df_test.assign(pred=pred_test_final),
    threshold=best_threshold,
    hold_bars=3,
    cost_bps=0.5,
    regime=best_regime,
)

print("VAL metrics:", metrics_val)
print("TEST metrics:", metrics_test)
print("TEST backtest metrics:", bt_test.metrics)
print("TEST debug:", bt_test.debug)

bt_test.trades.head(10)


In [None]:
baseline_bt = backtest_long_short(df_test.assign(pred=0.0), threshold=0.0)
print("Baseline pred=0 metrics:", baseline_bt.metrics)


Печатаем критерии остановки и возможные причины провала.
Если trade_count низкий, вероятно порог слишком высокий или предсказания почти нулевые.

In [None]:
criteria = {
    "sharpe_ok": bt_test.metrics["sharpe"] >= 0.30,
    "maxdd_ok": bt_test.metrics["max_drawdown"] >= -0.12,
    "trades_ok": bt_test.metrics["trade_count"] >= 20,
    "diracc_ok": metrics_test["dir_acc"] >= 0.51,
}
print("PASS/FAIL:", criteria)

if not all(criteria.values()):
    print("Likely causes:")
    if bt_test.metrics["trade_count"] < 20:
        print("1) Too few trades: threshold too high or predictions near zero.")
    if bt_test.debug.get("pred_abs_p95", 0) < best_threshold:
        print("2) Threshold above prediction scale; decrease quantiles.")
    if bt_test.debug.get("signal_counts", {}).get("long", 0) == 0 or bt_test.debug.get("signal_counts", {}).get("short", 0) == 0:
        print("3) Model predicts mostly one sign; consider polarity or regularization.")

results_payload = {
    "polarity": polarity,
    "threshold": best_threshold,
    "regime": best_regime,
    "val_metrics": metrics_val,
    "test_metrics": metrics_test,
    "test_backtest_metrics": bt_test.metrics,
}
with (ARTIFACT_DIR / "results.json").open("w", encoding="utf-8") as f:
    json.dump(results_payload, f, ensure_ascii=False, indent=2)
