In [1]:
# ==================================
# Cell 1: Notebook Setup and Imports
# ==================================

import sys
import os
from pathlib import Path

# Add parent directory to Python path so we can import config, signals, and backtest engine
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), "..")))

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import copy
import warnings

warnings.filterwarnings('ignore')

# Display options for easier debugging
pd.set_option('display.max_columns', 100)
pd.set_option('display.width', 120)

In [2]:
# ==================================
# Cell 2: Core Modules
# ==================================

import json
import pickle
from datetime import datetime

In [4]:
import dolphindb

In [5]:
# Connect to DolphinDB and run the specified query for Soybean Oil (Y) contracts

# Define DolphinDB server credentials (placeholder values, update as needed)
DDB_HOST = "192.168.91.91"
DDB_PORT = 8848
DDB_USER = "admin"
DDB_PASSWORD = "123456"

ddb = dolphindb.session()
ddb.connect(DDB_HOST, DDB_PORT, DDB_USER, DDB_PASSWORD)

# Compose the DolphinDB query
ddb_query = '''
select * from loadTable("dfs://rq_cn_futures_minute_Y", "k_minute") 
where string(maturity_month).split(".")[1] in ["01M","05M","09M"]
'''

# Run the query and get the result as a pandas DataFrame
ddb_df = ddb.run(ddb_query)

# Preview the data
print("DolphinDB Soybean Oil (Y) contracts (Jan, May, Sep maturity months):")
display(ddb_df.head())


RuntimeError: <Exception> in run: Server response: 'select * from loadTable("dfs://rq_cn_futures_minute_Y", "k_minute") where split(string(maturity_month), ".")[1] in ["01M","05M","09M"] => Out of memory' script: '
select * from loadTable("dfs://rq_cn_futures_minute_Y", "k_minute") 
where string(maturity_month).split(".")[1] in ["01M","05M","09M"]
'

In [3]:
# ==================================
# Cell 3: Load Raw 1-min Futures Data for Both Legs Separately
# ==================================

def load_leg_data(leg_name: str) -> pd.DataFrame:
    """
    Loads raw 1-minute futures data for a given leg from CSV in the ../data folder.

    Parameters:
        leg_name (str): Identifier for the futures leg (e.g., 'leg1', 'leg2' or 'Y', 'M')

    Returns:
        pd.DataFrame: Cleaned and time-sorted tick data for the leg
    """
    data_dir = Path("..") / "data"
    filename = f"data_{leg_name}_1min.csv"
    path = data_dir / filename

    df = pd.read_csv(path)
    df['datetime'] = pd.to_datetime(df['datetime'])
    df = df.sort_values(by=['order_book_id', 'datetime']).reset_index(drop=True)
    return df


# Load each leg separately
leg1_df = load_leg_data("Y")  # e.g. Soybean Oil or 'Y'
leg2_df = load_leg_data("M")  # e.g. Soybean Meal or 'M'

# Preview
print("Leg 1 Preview:")
display(leg1_df.head())
print("\nLeg 2 Preview:")
display(leg2_df.head())

FileNotFoundError: [Errno 2] No such file or directory: '../data/data_Y_1min.csv'

In [None]:
# ==================================
# Cell 4: Enrich Contract Metadata for Each Leg
# ==================================

import re

def parse_contract_id(contract_id: str):
    """
    Parses a contract ID of format 'PRODUCTYYYY' into product, year, and month.
    E.g., 'M2001' → ('M', 2020, 1)
    """
    match = re.match(r"([A-Z]+)(\d{4})", contract_id)
    if match:
        product = match.group(1)
        year = 2000 + int(match.group(2)[:2])   # '20' → 2020
        month = int(match.group(2)[2:])         # '01' → 1 (January)
        return product, year, month
    return None, None, None

def add_contract_metadata(df: pd.DataFrame) -> pd.DataFrame:
    """
    Adds parsed contract metadata columns: product, year, and month to a DataFrame
    containing 'order_book_id'.

    Parameters:
        df (pd.DataFrame): Raw leg data with 'order_book_id' column.

    Returns:
        pd.DataFrame: Enriched DataFrame with contract metadata columns.
    """
    parsed = df['order_book_id'].apply(parse_contract_id)
    df[['product', 'year', 'month']] = pd.DataFrame(parsed.tolist(), index=df.index)
    return df

In [None]:
# ==================================
# Cell 5: Enrich leg1_df and leg2_df with product, year, month
# ==================================

leg1_df = add_contract_metadata(leg1_df)
leg2_df = add_contract_metadata(leg2_df)

# Preview enriched data
print("Leg 1 Metadata Preview:")
display(leg1_df[['order_book_id', 'product', 'year', 'month']].drop_duplicates().head())

print("\nLeg 2 Metadata Preview:")
display(leg2_df[['order_book_id', 'product', 'year', 'month']].drop_duplicates().head())

Leg 1 Metadata Preview:


Unnamed: 0,order_book_id,product,year,month
0,Y2001,Y,2020,1
3330,Y2003,Y,2020,3
15480,Y2005,Y,2020,5
38385,Y2007,Y,2020,7
74280,Y2008,Y,2020,8



Leg 2 Metadata Preview:


Unnamed: 0,order_book_id,product,year,month
0,M2001,M,2020,1
3330,M2003,M,2020,3
15480,M2005,M,2020,5
38385,M2007,M,2020,7
74280,M2008,M,2020,8


In [None]:
# ==================================
# Cell 6: Retain Only Jan, May, Sep Contracts
# ==================================

# Define valid expiry months
valid_expiry_months = {1, 5, 9}

def filter_valid_expiries(df: pd.DataFrame) -> pd.DataFrame:
    """
    Filters the DataFrame to keep only contracts with expiry in Jan, May, or Sep.

    Parameters:
        df (pd.DataFrame): Enriched leg data with 'month' column.

    Returns:
        pd.DataFrame: Filtered DataFrame with only valid expiry contracts.
    """
    return df[df['month'].isin(valid_expiry_months)].copy()

# Apply to each leg
leg1_df = filter_valid_expiries(leg1_df)
leg2_df = filter_valid_expiries(leg2_df)

# Preview contract coverage
print("Leg 1 Contracts After Filtering:")
display(leg1_df[['order_book_id', 'year', 'month']].drop_duplicates().sort_values(['year', 'month']).head(10))

print("\nLeg 2 Contracts After Filtering:")
display(leg2_df[['order_book_id', 'year', 'month']].drop_duplicates().sort_values(['year', 'month']).head(10))

Leg 1 Contracts After Filtering:


Unnamed: 0,order_book_id,year,month
0,Y2001,2020,1
15480,Y2005,2020,5
118110,Y2009,2020,9
304560,Y2101,2021,1
459285,Y2105,2021,5
708975,Y2109,2021,9
958080,Y2201,2022,1
1124070,Y2205,2022,5
1371330,Y2209,2022,9
1619850,Y2301,2023,1



Leg 2 Contracts After Filtering:


Unnamed: 0,order_book_id,year,month
0,M2001,2020,1
15480,M2005,2020,5
118110,M2009,2020,9
304560,M2101,2021,1
459285,M2105,2021,5
708975,M2109,2021,9
958080,M2201,2022,1
1124070,M2205,2022,5
1371330,M2209,2022,9
1619850,M2301,2023,1


In [None]:
# ==================================
# Cell 7: Utility to Compute Roll Dates
# ==================================

from calendar import monthrange

def compute_scheduled_roll_date(year: int, month: int) -> pd.Timestamp:
    """
    Computes scheduled roll date as the last calendar day of the month two months before expiry.
    Used for defining the cut-off for stitched contract data, not the empirically observed last active date.
    """
    if month <= 2:
        roll_year = year - 1
        roll_month = month + 10
    else:
        roll_year = year
        roll_month = month - 2
    last_day = monthrange(roll_year, roll_month)[1]
    return pd.Timestamp(roll_year, roll_month, last_day)

In [None]:
# ==================================
# Cell 8: Generate Contract Schedule with Roll Dates
# ==================================

def get_contract_schedule(df: pd.DataFrame) -> pd.DataFrame:
    """
    Returns contract schedule for a given leg with unique contracts.
    Assumes df has already been filtered to desired expiry months.
    """
    schedule = df[["order_book_id", "product", "year", "month"]].drop_duplicates()
    schedule["scheduled_roll_date"] = schedule.apply(lambda row: compute_scheduled_roll_date(row.year, row.month), axis=1)
    return schedule.sort_values(["year", "month"]).reset_index(drop=True)

leg1_schedule = get_contract_schedule(leg1_df)
display(leg1_schedule.head())
leg2_schedule = get_contract_schedule(leg2_df)
display(leg2_schedule.head())

Unnamed: 0,order_book_id,product,year,month,scheduled_roll_date
0,Y2001,Y,2020,1,2019-11-30
1,Y2005,Y,2020,5,2020-03-31
2,Y2009,Y,2020,9,2020-07-31
3,Y2101,Y,2021,1,2020-11-30
4,Y2105,Y,2021,5,2021-03-31


Unnamed: 0,order_book_id,product,year,month,scheduled_roll_date
0,M2001,M,2020,1,2019-11-30
1,M2005,M,2020,5,2020-03-31
2,M2009,M,2020,9,2020-07-31
3,M2101,M,2021,1,2020-11-30
4,M2105,M,2021,5,2021-03-31


In [None]:
# ==================================
# Cell 9: Stitch Leg Based on Roll Schedule
# ==================================

def stitch_leg(df: pd.DataFrame, schedule: pd.DataFrame) -> pd.DataFrame:
    """
    Stitches a continuous intraday time series based on a roll schedule.

    Parameters:
        df (pd.DataFrame): Cleaned 1-minute data for a single product.
        schedule (pd.DataFrame): Roll schedule from get_contract_schedule().

    Returns:
        pd.DataFrame: Stitched time series with 'stitched_order_book_id' tagging.
    """
    stitched = []
    for i, row in schedule.iterrows():
        ob_id = row["order_book_id"]
        end_date = row["scheduled_roll_date"]
        start_date = schedule.iloc[i - 1]["scheduled_roll_date"] if i > 0 else None

        leg_df = df[df["order_book_id"] == ob_id].copy()

        # Clip time range based on start and end of contract window
        if start_date:
            leg_df = leg_df[leg_df["datetime"].dt.floor("D") > start_date]
        leg_df = leg_df[leg_df["datetime"].dt.floor("D") <= end_date]

        if leg_df.empty:
            continue

        leg_df = leg_df.sort_values("datetime")  # critical for any resampling/sync
        leg_df["scheduled_roll_date"] = end_date
        leg_df["stitched_order_book_id"] = ob_id
        stitched.append(leg_df)

    return pd.concat(stitched).sort_values("datetime").reset_index(drop=True)

stitched_leg1 = stitch_leg(leg1_df, leg1_schedule)
display(stitched_leg1.head())
stitched_leg2 = stitch_leg(leg2_df, leg2_schedule)
display(stitched_leg2.head())

Unnamed: 0,order_book_id,datetime,trading_date,open_interest,open,volume,total_turnover,low,close,high,parent_order_book_id,UTC_datetime,maturity_month,product,year,month,scheduled_roll_date,stitched_order_book_id
0,Y2005,2020-01-02 09:01:00,2020.01.02,662395,6722,17735,1193869000.0,6712,6746,6750,Y,2020.01.02T01:01:00,2020.05M,Y,2020,5,2020-03-31,Y2005
1,Y2005,2020-01-02 09:02:00,2020.01.02,662604,6746,6095,411496100.0,6746,6748,6756,Y,2020.01.02T01:02:00,2020.05M,Y,2020,5,2020-03-31,Y2005
2,Y2005,2020-01-02 09:03:00,2020.01.02,661900,6750,6516,439340000.0,6738,6746,6750,Y,2020.01.02T01:03:00,2020.05M,Y,2020,5,2020-03-31,Y2005
3,Y2005,2020-01-02 09:04:00,2020.01.02,661794,6746,3432,231744600.0,6744,6758,6758,Y,2020.01.02T01:04:00,2020.05M,Y,2020,5,2020-03-31,Y2005
4,Y2005,2020-01-02 09:05:00,2020.01.02,662319,6758,4491,303569000.0,6758,6760,6762,Y,2020.01.02T01:05:00,2020.05M,Y,2020,5,2020-03-31,Y2005


Unnamed: 0,order_book_id,datetime,trading_date,open_interest,open,volume,total_turnover,low,close,high,parent_order_book_id,UTC_datetime,maturity_month,product,year,month,scheduled_roll_date,stitched_order_book_id
0,M2005,2020-01-02 09:01:00,2020.01.02,1688619,2780,20247,562833510.0,2778,2780,2782,M,2020.01.02T01:01:00,2020.05M,M,2020,5,2020-03-31,M2005
1,M2005,2020-01-02 09:02:00,2020.01.02,1688123,2780,7854,218305110.0,2778,2779,2781,M,2020.01.02T01:02:00,2020.05M,M,2020,5,2020-03-31,M2005
2,M2005,2020-01-02 09:03:00,2020.01.02,1688507,2779,6526,181441130.0,2779,2782,2783,M,2020.01.02T01:03:00,2020.05M,M,2020,5,2020-03-31,M2005
3,M2005,2020-01-02 09:04:00,2020.01.02,1688491,2783,8699,242043320.0,2782,2784,2784,M,2020.01.02T01:04:00,2020.05M,M,2020,5,2020-03-31,M2005
4,M2005,2020-01-02 09:05:00,2020.01.02,1688722,2783,12681,353090420.0,2783,2785,2786,M,2020.01.02T01:05:00,2020.05M,M,2020,5,2020-03-31,M2005


In [None]:
# ==================================
# Cell 10: Merge Legs and Compute gap_days_to_next
# ==================================

def compute_gap_days_to_next(merged_df: pd.DataFrame) -> pd.DataFrame:
    """
    Compute calendar-day gap to the next trade date and merge into merged_df.
    Adds 'gap_days_to_next' column based on 'datetime' column normalized to date.
    """
    # Step 1: Extract unique trade dates
    df_day = merged_df[["datetime"]].copy()
    df_day["trade_date"] = df_day["datetime"].dt.normalize()
    df_day = df_day.drop_duplicates("trade_date").sort_values("trade_date").reset_index(drop=True)

    # Step 2: Compute gap to next trade date
    df_day["next_trade_date"] = df_day["trade_date"].shift(-1)
    df_day["gap_days_to_next"] = (df_day["next_trade_date"] - df_day["trade_date"]).dt.days
    df_day["gap_days_to_next"] = df_day["gap_days_to_next"].fillna(np.inf)

    # Step 3: Merge gap info back into full merged_df
    merged_df["trade_date"] = merged_df["datetime"].dt.normalize()
    merged_df = merged_df.merge(
        df_day[["trade_date", "gap_days_to_next"]],
        on="trade_date",
        how="left"
    )

    return merged_df


# Step 1: Merge both legs on timestamp using merge_asof with 30s tolerance
merged_df = pd.merge_asof(
    stitched_leg1.sort_values("datetime"),
    stitched_leg2.sort_values("datetime"),
    on="datetime",
    direction="nearest",
    tolerance=pd.Timedelta(seconds=30),
    suffixes=("_1", "_2")  # ✅ avoid _x/_y
)

# Step 2: Compute and attach gap_days_to_next
merged_df = compute_gap_days_to_next(merged_df)

# Optional preview
display(merged_df[["datetime", "trade_date", "gap_days_to_next"]].head())

Unnamed: 0,datetime,trade_date,gap_days_to_next
0,2020-01-02 09:01:00,2020-01-02,1.0
1,2020-01-02 09:02:00,2020-01-02,1.0
2,2020-01-02 09:03:00,2020-01-02,1.0
3,2020-01-02 09:04:00,2020-01-02,1.0
4,2020-01-02 09:05:00,2020-01-02,1.0


In [None]:
# ==================================
# Cell 11: Enrich merged_df with Contract Info and Roll Detection
# ==================================

# Step 1: Rename for clarity
merged_df = merged_df.rename(columns={
    "stitched_order_book_id_1": "contract_1",
    "stitched_order_book_id_2": "contract_2"
})

# Step 2: Compute actual roll date = last datetime seen per contract
actual_roll_1 = (
    merged_df.groupby("contract_1", observed=True)["datetime"]
    .max().dt.floor("D")
    .reset_index()
    .rename(columns={"datetime": "actual_roll_date_1"})
)

actual_roll_2 = (
    merged_df.groupby("contract_2", observed=True)["datetime"]
    .max().dt.floor("D")
    .reset_index()
    .rename(columns={"datetime": "actual_roll_date_2"})
)

# Step 3: Merge actual roll dates
merged_df = merged_df.merge(actual_roll_1, on="contract_1", how="left")
merged_df = merged_df.merge(actual_roll_2, on="contract_2", how="left")

# Step 4: Define is_roll_date flags per leg
merged_df["is_roll_date_1"] = merged_df["datetime"].dt.floor("D") == merged_df["actual_roll_date_1"]
merged_df["is_roll_date_2"] = merged_df["datetime"].dt.floor("D") == merged_df["actual_roll_date_2"]

# Step 5: Assign group id based on either contract changing (robust leg pairing)
merged_df["contract_pair_group"] = (
    merged_df[["contract_1", "contract_2"]] != merged_df[["contract_1", "contract_2"]].shift()
).any(axis=1).cumsum()

# Step 6: Flag roll mismatches based on is_roll_date_1 vs is_roll_date_2
roll_mismatch_mask = merged_df["is_roll_date_1"] ^ merged_df["is_roll_date_2"]
n_mismatches = roll_mismatch_mask.sum()

if n_mismatches > 0:
    print(f"⚠️ Warning: {n_mismatches} mismatched roll dates found (only one leg rolled).")
    print("🧠 These are handled via spread groupings (contract_pair_group).")

# Optional: Store for inspection
merged_df["roll_date_mismatch"] = roll_mismatch_mask

# Preview key fields
preview_cols = [
    "datetime", "contract_1", "contract_2",
    "actual_roll_date_1", "actual_roll_date_2", "roll_date_mismatch",
    "is_roll_date_1", "is_roll_date_2", "contract_pair_group"
]
display(merged_df[preview_cols].head(10))

Unnamed: 0,datetime,contract_1,contract_2,actual_roll_date_1,actual_roll_date_2,roll_date_mismatch,is_roll_date_1,is_roll_date_2,contract_pair_group
0,2020-01-02 09:01:00,Y2005,M2005,2020-03-31,2020-03-31,False,False,False,1
1,2020-01-02 09:02:00,Y2005,M2005,2020-03-31,2020-03-31,False,False,False,1
2,2020-01-02 09:03:00,Y2005,M2005,2020-03-31,2020-03-31,False,False,False,1
3,2020-01-02 09:04:00,Y2005,M2005,2020-03-31,2020-03-31,False,False,False,1
4,2020-01-02 09:05:00,Y2005,M2005,2020-03-31,2020-03-31,False,False,False,1
5,2020-01-02 09:06:00,Y2005,M2005,2020-03-31,2020-03-31,False,False,False,1
6,2020-01-02 09:07:00,Y2005,M2005,2020-03-31,2020-03-31,False,False,False,1
7,2020-01-02 09:08:00,Y2005,M2005,2020-03-31,2020-03-31,False,False,False,1
8,2020-01-02 09:09:00,Y2005,M2005,2020-03-31,2020-03-31,False,False,False,1
9,2020-01-02 09:10:00,Y2005,M2005,2020-03-31,2020-03-31,False,False,False,1


In [None]:
# ==================================
# Cell 12: Compute Price Ratio and Log Price Ratio
# ==================================

# Step 1: Basic validation (optional but safe)
if not {"close_1", "close_2"}.issubset(merged_df.columns):
    raise ValueError("Missing close_1 or close_2 columns for price ratio computation.")

# Step 2: Compute price ratio and log price ratio
merged_df["price_ratio"] = merged_df["close_1"] / merged_df["close_2"]
merged_df["log_price_ratio"] = np.log(merged_df["close_1"]) - np.log(merged_df["close_2"])

# Step 3: Preview
preview_cols = [
    "datetime", "contract_1", "contract_2",
    "close_1", "close_2",
    "price_ratio", "log_price_ratio",
    "contract_pair_group"
]
display(merged_df[preview_cols].head(10))

In [None]:
# ==================================
# Cell 13: Import Parameter Grid and Metrics
# ==================================

# Then import config
from config.config_params import param_grid, evaluation_metrics

print(f"Total parameter combinations: {len(param_grid)}")
from pprint import pprint
pprint(param_grid[:3])

# Inspect keys of one sample config
print("Sample keys:")
print(param_grid[0].keys())
print("Signal config:", param_grid[0]['signal'])
print("Execution config:", param_grid[0]['execution'])

In [None]:
# ==================================
# Cell 14: Define run_backtest_for_config Function
# ==================================
"""
This helper function takes a merged minute-level DataFrame and a configuration (config) dictionary containing signal and execution parameters,
runs the signal generation + backtest loop + metrics, and returns results.
"""

from spread_backtest_engine import (
    compute_rolling_zscore_grouped,
    run_backtest_loop,
    compute_trading_metrics
)

from signals.signal_generators import generate_trading_signals_directional

def run_backtest_for_config(merged_df, config):
    """
    Run full pipeline for a single parameter config.

    Parameters:
        merged_df (pd.DataFrame): Merged data with price columns, already preprocessed
        config (dict): Dict with 'signal' and 'execution' sub-dicts

    Returns:
        dict: Contains config, metrics, df_result, trade_data
    """
    signal_params = config["signal"]
    exec_params = config["execution"]
    zscore_window = signal_params["zscore_window"]

    # 1. Compute rolling z-score on log price ratio
    assert "log_price_ratio" in merged_df.columns, "Missing log_price_ratio column"
    merged_df["zscore"] = compute_rolling_zscore_grouped(
        merged_df,
        value_col="log_price_ratio",
        window=zscore_window,
        group_col=f"contract_pair_group" # groups based on generic pair change, not specific contract codes
    )

    # 2. Generate trading signals
    merged_df["position_signal"], merged_df["raw_signal"] = generate_trading_signals_directional(
        merged_df["zscore"],
        entry_threshold_long=signal_params["entry_threshold_long"],
        entry_threshold_short=signal_params["entry_threshold_short"],
        exit_threshold_long=signal_params["exit_threshold_long"],
        exit_threshold_short=signal_params["exit_threshold_short"]
    )

    # 3. Run backtest loop
    df_result, trade_data = run_backtest_loop(
        merged_df,
        signal_col="raw_signal",
        execution_config=exec_params
    )

    # 4. Derive trade directions for metrics
    if df_result["trade_id"].notna().any():
        trade_directions = (
            df_result[df_result["trade_id"].notna()]
            .groupby("trade_id", observed=True)["executed_position"]
            .first()
            .astype(int)
            .tolist()
        )
    else:
        trade_directions = []

    # 5. Compute trading metrics
    metrics = compute_trading_metrics(df_result, trade_data, trade_directions)

    return {
        "config": config,
        "metrics": metrics,
        "df_result": df_result,
        "trade_data": trade_data
    }

In [None]:
# =============================================
# Cell 15: Helper — Optimize Config on Training Group
# =============================================
"""
This function runs a grid search over the parameter grid on a given training group (train_df),
evaluates each config using specified trading metrics, ranks them, and returns the best configuration.

It ensures:
- All configs are attempted robustly with exception handling.
- Metric-based ranking that supports multiple objective metrics.
- Raw config and per-metric details are retained for full traceability.
"""

def optimize_config_on_group(train_df, param_grid, evaluation_metrics):
    # Run Grid Search Over Parameter Grid
    grid_results = []

    for config in param_grid:
        try:
            result = run_backtest_for_config(train_df.copy(), config)
            grid_results.append(result)
        except Exception as e:
            print(f"❌ Error with config: {config}")
            print(e)

    if not grid_results:
        raise ValueError("No successful backtest runs in training group.")

    # Flatten results into DataFrame
    results_records = []
    for res in grid_results:
        record = {
            "config": res["config"],  # Preserve raw config dict
            **res["config"]["signal"],  # Flatten signal params
            **res["config"]["execution"],  # Flatten execution params
            **res["metrics"]  # Add performance metrics
        }
        results_records.append(record)
    results_df = pd.DataFrame(results_records)

    # Rank-based voting across all specified metrics
    ranked_df = results_df.copy()
    for metric, higher_is_better in evaluation_metrics:
        ascending = not higher_is_better
        ranked_df[f"{metric}_rank"] = ranked_df[metric].rank(
            ascending=ascending, method="min"
        )

    # Sum of ranks = total rank score (lower is better)
    rank_cols = [f"{metric}_rank" for metric, _ in evaluation_metrics]
    ranked_df["total_rank_score"] = ranked_df[rank_cols].sum(axis=1)
    ranked_df = ranked_df.sort_values("total_rank_score").reset_index(drop=True)

    # Select best config and its metrics
    best_config = ranked_df.loc[0, "config"]
    best_metrics = ranked_df.loc[0, [m for m, _ in evaluation_metrics]].to_dict()

    return best_config, ranked_df, best_metrics

In [None]:
# ================================================================
# Cell 16: Walk-Forward Execution — Grouping, Snapshot, Evaluation Loop
# ================================================================

import os
import shutil
from datetime import datetime
from tqdm.notebook import tqdm
from pprint import pprint

# === Step 1: Identify Stitched Contract Groups (generic leg version) ===
group_ids = sorted(merged_df["contract_pair_group"].unique())
print(f"🔢 Found {len(group_ids)} stitched contract groups: {group_ids}")

# === Step 2: Create Timestamped Folder to Save Run Outputs ===
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
project_root = os.path.abspath(os.path.join(os.getcwd(), ".."))
run_folder = os.path.join(project_root, "backtest_runs", "walkforward_runs", timestamp)
os.makedirs(run_folder, exist_ok=True)
print(f"📂 Saving walk-forward outputs to: {run_folder}")

# === Step 3: Snapshot Config and Signal Generator Files ===
config_src = os.path.join(project_root, "config", "config_params.py")
config_dst = os.path.join(run_folder, "config_params_snapshot.py")
shutil.copy(config_src, config_dst)

signal_src = os.path.join(project_root, "signals", "signal_generators.py")
signal_dst = os.path.join(run_folder, "signal_generators_snapshot.py")
shutil.copy(signal_src, signal_dst)

# === Step 4: Walk-Forward Loop with Optimization and Backtest ===
walkforward_summary = []
all_walkforward_df_results = []
all_walkforward_trade_data = []

for i in tqdm(range(1, len(group_ids)), desc="Walk-Forward Iterations"):
    train_group = group_ids[i - 1]
    test_group = group_ids[i]

    print(f"\n▶️ Walk-forward iteration: Train on group {train_group}, test on group {test_group}")

    train_df = merged_df[merged_df["contract_pair_group"] == train_group].copy()
    test_df = merged_df[merged_df["contract_pair_group"] == test_group].copy()

    if train_df.empty or test_df.empty:
        print(f"⚠️ Skipping group {test_group} due to empty train/test split.")
        continue

    # === Step 4.1: Optimize on training group ===
    best_config, ranked_df_train, best_train_metrics = optimize_config_on_group(
        train_df, param_grid, evaluation_metrics
    )
    print("✅ Selected config:")
    pprint(best_config)

    # === Step 4.2: Run backtest on test group with best config ===
    test_result = run_backtest_for_config(test_df.copy(), best_config)

    # Collect results and trade data for stitching
    all_walkforward_df_results.append(test_result["df_result"])
    all_walkforward_trade_data.append(test_result["trade_data"])

    # Store summary metrics
    walkforward_summary.append({
        "train_group": train_group,
        "test_group": test_group,
        "best_config": best_config,
        "train_metrics": best_train_metrics,
        "test_metrics": test_result["metrics"],
    })

In [None]:
# =====================================================
# Cell 17: Ensure Global Trade ID Uniqueness Across Groups
# =====================================================

trade_id_offset = 0

# Patch all_walkforward_df_results to ensure unique trade_id
for df_result in all_walkforward_df_results:
    if "trade_id" in df_result.columns:
        mask = df_result["trade_id"].notna()
        df_result.loc[mask, "trade_id"] += trade_id_offset
        trade_id_offset += df_result.loc[mask, "trade_id"].nunique()

In [None]:
# ====================================================================
# Cell 18: Aggregate Walk-Forward Results and Compute Final Metrics
# ====================================================================

# === Step 1: Concatenate all walkforward group-level df_results ===
combined_df_result = pd.concat(all_walkforward_df_results, ignore_index=True)
combined_df_result = combined_df_result.sort_values("datetime").reset_index(drop=True)

# === Step 2: Concatenate all trade-level data ===
combined_trade_data = {
    "entry_times": [],
    "exit_times": [],
    "trade_real_returns": [],
    "holding_durations": [],
    # Add more fields here if necessary
}

for td in all_walkforward_trade_data:
    for key in combined_trade_data:
        if key in td:
            combined_trade_data[key].extend(td[key])

# === Step 3: Compute trade directions based on first executed_position per trade_id ===
if "trade_id" in combined_df_result.columns and combined_df_result["trade_id"].notna().any():
    trade_directions = (
        combined_df_result[combined_df_result["trade_id"].notna()]
        .groupby("trade_id", observed=True)["executed_position"]
        .first()
        .astype(int)
        .tolist()
    )
else:
    trade_directions = []

# === Step 4: Compute final metrics ===
final_metrics = compute_trading_metrics(combined_df_result, combined_trade_data, trade_directions)

# === Step 5: Print Summary ===
print("\n✅ Aggregated Walk-Forward Metrics:")
for k, v in final_metrics.items():
    if isinstance(v, float):
        print(f"{k:<30}: {v:,.4f}")
    else:
        print(f"{k:<30}: {v}")

In [None]:
# ================================================================
# Cell 19: Plot Aggregated Equity Curve with Key Annotations
# ================================================================

import matplotlib.pyplot as plt
import matplotlib.dates as mdates

# === Step 1: Compute equity curve using real returns ===
combined_df_result["equity_curve"] = (1 + combined_df_result["strategy_real_return"]).cumprod()

# === Step 2: Extract max drawdown and largest single-bar drop points ===
max_drawdown_time = final_metrics["date_of_max_drawdown"]
max_drop_idx = combined_df_result["strategy_log_return"].idxmin()
max_drop_time = combined_df_result.loc[max_drop_idx, "datetime"]

# === Step 3: Identify contract group boundary transitions ===
contract_boundaries = (
    combined_df_result[["datetime", "contract_pair_group"]]
    .drop_duplicates("contract_pair_group")
    .sort_values("datetime")
)

# === Step 4: Plot ===
fig, ax = plt.subplots(figsize=(16, 6))

# Plot equity curve
ax.plot(
    combined_df_result["datetime"],
    combined_df_result["equity_curve"],
    label="Equity Curve",
    color="royalblue",
    linewidth=1.5
)

# Annotate max drawdown
ax.axvline(max_drawdown_time, color="firebrick", linestyle="--", alpha=0.8, linewidth=1)
ax.annotate("Max Drawdown",
            xy=(max_drawdown_time, combined_df_result.loc[combined_df_result["datetime"] == max_drawdown_time, "equity_curve"]),
            xytext=(0, -30), textcoords="offset points",
            arrowprops=dict(arrowstyle="->", color="firebrick"),
            ha='center', color="firebrick", fontsize=9)

# Annotate largest single-period drop
ax.axvline(max_drop_time, color="darkorange", linestyle="--", alpha=0.8, linewidth=1)
ax.annotate("Largest Drop",
            xy=(max_drop_time, combined_df_result.loc[max_drop_idx, "equity_curve"]),
            xytext=(0, 30), textcoords="offset points",
            arrowprops=dict(arrowstyle="->", color="darkorange"),
            ha='center', color="darkorange", fontsize=9)

# Vertical lines for each contract group boundary
for _, row in contract_boundaries.iterrows():
    ax.axvline(row["datetime"], color="slategray", linestyle="--", linewidth=0.8, alpha=0.6)

# Plot aesthetics
ax.set_title("Aggregated Walk-Forward Equity Curve with Key Markers", fontsize=14)
ax.set_ylabel("Equity Curve", fontsize=12)
ax.set_xlabel("Datetime", fontsize=12)
ax.legend(loc="upper left")

# Improve x-axis formatting
ax.xaxis.set_major_locator(mdates.MonthLocator(interval=2))
ax.xaxis.set_major_formatter(mdates.DateFormatter("%Y-%m"))
fig.autofmt_xdate()

plt.grid(True, linestyle='--', linewidth=0.3, alpha=0.7)
plt.tight_layout()
plt.show()

# Save to run folder
fig.savefig(os.path.join(run_folder, "equity_curve.png"), dpi=300)

In [None]:
# ==================================
# Cell 20: Construct walkforward_df from walkforward_summary
# ==================================

walkforward_records = []

for entry in walkforward_summary:
    flat_record = {
        "train_group": entry["train_group"],
        "test_group": entry["test_group"],
    }

    # Flatten training metrics
    for k, v in entry["train_metrics"].items():
        flat_record[f"train_{k}"] = v

    # Flatten test metrics
    for k, v in entry["test_metrics"].items():
        flat_record[f"test_{k}"] = v

    # Optional: Flatten config for diagnostics
    for k, v in entry["best_config"]["signal"].items():
        flat_record[f"config_{k}"] = v
    for k, v in entry["best_config"]["execution"].items():
        flat_record[f"config_{k}"] = v

    walkforward_records.append(flat_record)

# Convert to DataFrame
walkforward_df = pd.DataFrame(walkforward_records)

print(f"\n✅ Constructed walkforward_df with shape: {walkforward_df.shape}")
display(walkforward_df.head())

In [None]:
# ==================================
# Cell 21: Save walk-forward outputs to run_folder
# ==================================

# 1. Save combined result DataFrame
combined_df_result.to_csv(os.path.join(run_folder, "combined_df_result.csv"), index=False)

# 2. Save combined trade data as pickle
with open(os.path.join(run_folder, "combined_trade_data.pkl"), "wb") as f:
    pickle.dump(combined_trade_data, f)

# 3. Save final aggregated metrics as JSON (datetime-safe)
with open(os.path.join(run_folder, "final_metrics.json"), "w") as f:
    json.dump(final_metrics, f, indent=2, default=str)

# 4. Save raw walkforward summary
with open(os.path.join(run_folder, "walkforward_summary.pkl"), "wb") as f:
    pickle.dump(walkforward_summary, f)

# 5. Save flattened summary table
walkforward_df.to_csv(os.path.join(run_folder, "walkforward_df.csv"), index=False)

print("\n✅ Saved all aggregated outputs to disk.")

In [None]:
# ==================================
# Cell 22: Plot Train vs Test Metric Stability across groups
# ==================================

import seaborn as sns
from config.config_params import base_metrics

plt.style.use("seaborn-v0_8-whitegrid")
n_rows = len(base_metrics)
fig, axes = plt.subplots(n_rows, 1, figsize=(10, 3 * n_rows), sharex=True)

x = range(len(walkforward_df))

for i, metric in enumerate(base_metrics):
    ax = axes[i]
    
    # Plot both train and test curves
    sns.lineplot(x=x, y=walkforward_df[f"train_{metric}"], marker="o", label="Train", ax=ax)
    sns.lineplot(x=x, y=walkforward_df[f"test_{metric}"], marker="o", label="Test", ax=ax)
    
    ax.set_ylabel(metric.replace("_", " ").title())
    ax.set_xticks(x)
    ax.set_xticklabels(walkforward_df["test_group"], rotation=0)
    ax.set_title(f"Train vs Test {metric.replace('_', ' ').title()}")
    ax.axhline(0, color="black", linestyle="--", linewidth=0.8)
    ax.legend()

plt.tight_layout()
plt.suptitle("Train vs Test Metric Stability", fontsize=14, y=1.02)
plt.show()

# Save the plot
fig.savefig(os.path.join(run_folder, "train_vs_test_metrics.png"), dpi=300)