# 2026 Forward Test — Out-of-Sample Evaluation

Uses the **last trained models** (L1+L2 window 31, RL agent window 31) from the
2018-2025 backtest to trade on unseen 2026 data (Jan–Feb 2026).

**No retraining** — this is a pure out-of-sample test.

**Prerequisites:**
- Completed 2018-2025 backtest (`run_id=20260228_144143`) with models on Drive
- 2026 OHLCV data on Drive (or will be fetched via baostock)

**Runtime:** T4 GPU recommended (XGBoost inference is faster on GPU).

In [None]:
# Cell 1: Setup — mount Drive
from google.colab import drive
drive.mount('/content/drive')

import os
DRIVE_ROOT = '/content/drive/MyDrive/kronos'
os.makedirs(DRIVE_ROOT, exist_ok=True)
print(f'Drive root: {DRIVE_ROOT}')

In [None]:
# Cell 2: Clone/update repo + install deps
import sys

REPO_DIR = '/content/tradingagent'

if not os.path.exists(REPO_DIR):
    !git clone https://github.com/Yuxiaoliu12/tradingagent.git {REPO_DIR}
else:
    !cd {REPO_DIR} && git pull origin master

!pip install -q -r {REPO_DIR}/requirements.txt

sys.path.insert(0, REPO_DIR)
print(f'Repo at: {REPO_DIR}')

In [None]:
# Cell 3: Ensure 2026 OHLCV data exists on Drive
# If the existing pickle doesn't cover 2026, fetch fresh data via baostock.

import pickle
import pandas as pd

OHLCV_PATH = os.path.join(DRIVE_ROOT, 'data/ohlcv_all_a.pkl')
BENCH_PATH = os.path.join(DRIVE_ROOT, 'data/benchmark_000905.pkl')

# Check if 2026 data already exists
need_update = True
if os.path.exists(OHLCV_PATH):
    with open(OHLCV_PATH, 'rb') as f:
        ohlcv_check = pickle.load(f)
    sample_sym = next(iter(ohlcv_check))
    latest = ohlcv_check[sample_sym].index.max()
    print(f'Existing OHLCV: {len(ohlcv_check)} symbols, latest date: {latest.date()}')
    if latest >= pd.Timestamp('2026-01-15'):
        need_update = False
        print('2026 data already present — skipping download.')
    else:
        print('No 2026 data found — will fetch via baostock.')
    del ohlcv_check
else:
    print(f'No OHLCV file at {OHLCV_PATH} — will fetch via baostock.')

if need_update:
    !pip install -q baostock
    import baostock as bs
    lg = bs.login()
    print(f'baostock login: {lg.error_msg}')

    sys.path.insert(0, REPO_DIR)
    from screener.download_ohlcv import (
        get_all_a_shares, download_universe_ohlcv,
        download_benchmark, download_industry_mapping,
    )

    symbols = get_all_a_shares()
    # Filter out index codes (sz.399xxx bug)
    symbols = [s for s in symbols if not s.startswith('sz.399')]
    print(f'After filtering indices: {len(symbols)} stocks')

    download_universe_ohlcv(
        symbols, start_date='2015-01-01', end_date='2026-03-01',
        save_path=OHLCV_PATH,
    )
    download_benchmark(
        start_date='2015-01-01', end_date='2026-03-01',
        save_path=BENCH_PATH,
    )
    download_industry_mapping(
        save_path=os.path.join(DRIVE_ROOT, 'data/industry_mapping.pkl'),
    )
    bs.logout()
    print('Data download complete.')

In [None]:
# Cell 4: Config + load data
from screener.config import ScreenerConfig

cfg = ScreenerConfig(
    ohlcv_pickle_path=os.path.join(DRIVE_ROOT, 'data/ohlcv_all_a.pkl'),
    benchmark_pickle_path=os.path.join(DRIVE_ROOT, 'data/benchmark_000905.pkl'),
    industry_pickle_path=os.path.join(DRIVE_ROOT, 'data/industry_mapping.pkl'),
    drive_root=os.path.join(DRIVE_ROOT, 'output/screener'),
    # Keep pinned run_id to find trained models
    run_id='20260228_144143',
    # Extend backtest_end so load_data() includes 2026 in calendar + OHLCV
    backtest_end='2026-03-31',
)

# GPU XGBoost (Colab T4)
cfg.layer1_xgb_params['device'] = 'cuda'
cfg.layer2_xgb_params['device'] = 'cuda'

print(f'Run ID: {cfg.run_id}')
print(f'Run dir: {cfg.run_dir}')
print(f'Cache dir: {cfg.cache_dir}')
print(f'Backtest end: {cfg.backtest_end}')

# Load data (OHLCV + Alpha158 + regime + calendar)
from screener.backtester import WalkForwardBacktester

bt = WalkForwardBacktester(cfg)
bt.load_data()

# Verify 2026 coverage
cal_2026 = bt._calendar[bt._calendar >= '2026-01-01']
print(f'\n2026 trading days in calendar: {len(cal_2026)}')
if len(cal_2026) > 0:
    print(f'  Range: {cal_2026[0].date()} → {cal_2026[-1].date()}')

In [None]:
# Cell 5: Load last trained L1+L2 models + generate 2026 signals
import time

# The 2018-2025 backtest produced 32 quarterly windows (0-31).
# Window 31 = Q4 2025 — the most recent trained models.
LAST_WINDOW = 31

# Verify model files exist
l1l2_path = os.path.join(cfg.cache_dir, 'l1l2_models', f'window_{LAST_WINDOW}.pkl')
rl_model_path = os.path.join(cfg.run_dir, f'rl_model_window_{LAST_WINDOW}')

print(f'L1+L2 model path: {l1l2_path}')
print(f'  Exists: {os.path.exists(l1l2_path)}')
print(f'RL model path: {rl_model_path}.zip')
print(f'  Exists: {os.path.exists(rl_model_path + ".zip")}')

if not os.path.exists(l1l2_path):
    raise FileNotFoundError(
        f'No L1+L2 models at {l1l2_path}. '
        f'Run the full 2018-2025 backtest first (train_colab.ipynb).'
    )
if not os.path.exists(rl_model_path + '.zip'):
    raise FileNotFoundError(
        f'No RL model at {rl_model_path}. '
        f'Run the full 2018-2025 backtest first (train_colab.ipynb).'
    )

# Load L1+L2 models into the backtester
bt._load_window_models(LAST_WINDOW)

# Ensure lagged IC covers 2026
bt.layer1.ensure_lagged_ic(range(2025, 2027))

# Generate L1+L2 signals for 2026
TEST_START = '2026-01-01'
TEST_END = '2026-02-28'  # adjust as data becomes available

t0 = time.time()
print(f'\nGenerating signals for {TEST_START} → {TEST_END}...')
signals_2026 = bt._generate_daily_signals(TEST_START, TEST_END, verbose=True)
elapsed = time.time() - t0

print(f'\nGenerated {len(signals_2026)} daily signals in {elapsed:.0f}s')

# Quick signal quality check
non_empty = sum(1 for s in signals_2026 if len(s['l2_ranking']) > 0)
print(f'Days with L2 candidates: {non_empty}/{len(signals_2026)}')
if signals_2026:
    avg_candidates = sum(len(s['l2_ranking']) for s in signals_2026) / len(signals_2026)
    print(f'Avg candidates per day: {avg_candidates:.0f}')

In [None]:
# Cell 6: RL inference on 2026
import numpy as np
from screener.portfolio_env import PortfolioEnv
from screener.rl_trader import RLTrader
from screener.data_pipeline import _get_benchmark_cache
from sb3_contrib.common.wrappers import ActionMasker
from sb3_contrib.common.maskable.utils import get_action_masks

benchmark_df = _get_benchmark_cache(cfg)

# Load the last trained RL model
rl_trader = RLTrader(cfg)
model = rl_trader.load(rl_model_path)
print(f'RL model loaded from {rl_model_path}')

# Create test environment
test_env = PortfolioEnv(
    cfg, signals_2026, bt._ohlcv,
    benchmark_df, training_mode=False,
    candidate_mode='top',
)
masked_env = ActionMasker(test_env, lambda e: e.action_masks())

# Run inference
obs, _ = masked_env.reset()
total_reward = 0.0
blocked_count = 0
sub_count = 0

print(f'\nRunning RL inference on {len(signals_2026)} trading days...')
for step_i in range(len(signals_2026) - 1):
    masks = get_action_masks(masked_env)
    action, _ = model.predict(obs, deterministic=True, action_masks=masks)
    obs, reward, terminated, truncated, info = masked_env.step(int(action))
    total_reward += reward
    blocked_count += len(info.get('blocked_trades', []))
    sub_count += len(info.get('substituted_trades', []))
    if terminated or truncated:
        break

print(f'\nInference complete: {step_i+1} steps')
print(f'  Total reward: {total_reward:.2f}')
print(f'  Blocked trades: {blocked_count}')
print(f'  Substituted trades: {sub_count}')

# Extract results
nav_history = test_env._nav_history
trade_log = test_env.trade_log
dates_2026 = [s['date'] for s in signals_2026]

# Build NAV series
nav_series = pd.Series(
    {dates_2026[i]: nav_history[i] for i in range(min(len(dates_2026), len(nav_history)))},
    name='nav',
).sort_index()

# Compute metrics
total_return = nav_series.iloc[-1] / nav_series.iloc[0] - 1
daily_rets = nav_series.pct_change().dropna()
sharpe = (daily_rets.mean() / daily_rets.std() * np.sqrt(252)
          if daily_rets.std() > 0 else 0.0)
drawdown = (nav_series / nav_series.cummax() - 1)
max_dd = drawdown.min()

print(f'\n{"="*60}')
print(f'2026 FORWARD TEST RESULTS')
print(f'{"="*60}')
print(f'  Period:       {nav_series.index[0].date()} → {nav_series.index[-1].date()}')
print(f'  Trading days: {len(nav_series)}')
print(f'  Total return: {total_return*100:+.2f}%')
print(f'  Sharpe ratio: {sharpe:.2f}')
print(f'  Max drawdown: {max_dd*100:.2f}%')
print(f'  Total trades: {len(trade_log)}')
print(f'  Final NAV:    {nav_series.iloc[-1]:,.0f}')

In [None]:
# Cell 7: Visualise results
import matplotlib.pyplot as plt

fig, axes = plt.subplots(3, 1, figsize=(14, 12), gridspec_kw={'height_ratios': [3, 1, 1]})

# --- Panel 1: NAV curve vs benchmark ---
ax = axes[0]
norm_nav = nav_series / nav_series.iloc[0]
norm_nav.plot(ax=ax, label='RL Agent', linewidth=2, color='#2196F3')

# Overlay benchmark (CSI 500)
bench_2026 = benchmark_df.loc[
    (benchmark_df.index >= nav_series.index[0]) &
    (benchmark_df.index <= nav_series.index[-1]),
    'close'
]
if len(bench_2026) > 0:
    norm_bench = bench_2026 / bench_2026.iloc[0]
    norm_bench.plot(ax=ax, label='CSI 500', linewidth=1.5, color='#FF9800', alpha=0.7)

ax.set_title('2026 Forward Test: RL Agent vs CSI 500', fontsize=14)
ax.set_ylabel('Growth of $1')
ax.legend(fontsize=11)
ax.grid(True, alpha=0.3)
ax.axhline(1.0, color='gray', linestyle='--', alpha=0.5)

# --- Panel 2: Daily returns ---
ax = axes[1]
colors = ['#4CAF50' if r >= 0 else '#F44336' for r in daily_rets]
ax.bar(daily_rets.index, daily_rets.values * 100, color=colors, width=1.5)
ax.set_ylabel('Daily Return (%)')
ax.set_title('Daily Returns', fontsize=11)
ax.grid(True, alpha=0.3)

# --- Panel 3: Drawdown ---
ax = axes[2]
drawdown_pct = drawdown * 100
ax.fill_between(drawdown_pct.index, drawdown_pct.values, 0, color='#F44336', alpha=0.3)
ax.plot(drawdown_pct.index, drawdown_pct.values, color='#F44336', linewidth=1)
ax.set_ylabel('Drawdown (%)')
ax.set_title('Drawdown', fontsize=11)
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

# --- Trade summary ---
if trade_log:
    trades_df = pd.DataFrame(trade_log)
    print(f'\n{"="*60}')
    print('TRADE SUMMARY')
    print(f'{"="*60}')
    buys = trades_df[trades_df['action'] == 'buy']
    sells = trades_df[trades_df['action'] == 'sell']
    print(f'  Buy trades:  {len(buys)}')
    print(f'  Sell trades: {len(sells)}')
    if len(sells) > 0 and 'pnl' in sells.columns:
        winners = (sells['pnl'] > 0).sum()
        losers = (sells['pnl'] <= 0).sum()
        print(f'  Win rate:    {winners/(winners+losers)*100:.1f}%')
        print(f'  Avg P&L:     {sells["pnl"].mean():+.2f}')
        print(f'  Avg hold:    {sells["hold_days"].mean():.1f} days')
    print(f'\n  Unique stocks traded: {trades_df["symbol"].nunique()}')
    print(f'  Most traded: {trades_df["symbol"].value_counts().head(5).to_string()}')

In [None]:
# Cell 8: Position history — daily holdings with stock names

# Reconstruct daily positions from trade log
positions = {}  # date -> {symbol: {shares, entry_price, hold_days}}
current_holdings = {}

if trade_log:
    trades_df = pd.DataFrame(trade_log)
    for date in sorted(trades_df['date'].unique()):
        day_trades = trades_df[trades_df['date'] == date]
        for _, t in day_trades.iterrows():
            sym = t['symbol']
            if t['action'] == 'buy':
                current_holdings[sym] = {
                    'shares': t['shares'],
                    'price': t['price'],
                }
            elif t['action'] == 'sell':
                current_holdings.pop(sym, None)
        positions[date] = dict(current_holdings)

# Look up stock names via baostock
all_syms = set()
for pos in positions.values():
    all_syms.update(pos.keys())

sym_names = {}
if all_syms:
    try:
        import baostock as bs
        lg = bs.login()
        for sym in sorted(all_syms):
            rs = bs.query_stock_basic(code=sym)
            row = rs.get_data()
            if not row.empty:
                name = row.iloc[0].get('code_name', sym)
                sym_names[sym] = name
            else:
                sym_names[sym] = sym
        bs.logout()
    except Exception as e:
        print(f'Could not fetch stock names: {e}')
        sym_names = {s: s for s in all_syms}

# Display daily holdings
print(f'\n{"="*70}')
print('DAILY POSITION HISTORY')
print(f'{"="*70}')

for date in sorted(positions.keys()):
    pos = positions[date]
    if pos:
        holdings_str = ', '.join(
            f'{sym_names.get(s, s)} ({s}) x{p["shares"]:.0f}'
            for s, p in pos.items()
        )
    else:
        holdings_str = '(cash)'
    print(f'{date.date() if hasattr(date, "date") else date}: {holdings_str}')

In [None]:
# Cell 9: Save results to Drive
import json

results_dir = os.path.join(cfg.run_dir, 'forward_test_2026')
os.makedirs(results_dir, exist_ok=True)

# Save full results
results = {
    'nav_series': nav_series,
    'trade_log': trade_log,
    'metrics': {
        'total_return': float(total_return),
        'sharpe': float(sharpe),
        'max_drawdown': float(max_dd),
        'n_days': len(nav_series),
        'n_trades': len(trade_log),
        'test_start': TEST_START,
        'test_end': TEST_END,
        'model_window': LAST_WINDOW,
    },
}

pkl_path = os.path.join(results_dir, 'forward_test_results.pkl')
with open(pkl_path, 'wb') as f:
    pickle.dump(results, f)

# Also save a human-readable JSON summary
json_path = os.path.join(results_dir, 'forward_test_summary.json')
with open(json_path, 'w') as f:
    json.dump(results['metrics'], f, indent=2)

print(f'Results saved to:')
print(f'  {pkl_path}')
print(f'  {json_path}')