In [1]:
from typing import List, Dict, Tuple, Any, Optional
from datetime import datetime
import pandas as pd

class TradeManager:
    """
    Manages the execution of trades and trade book maintenance.
    
    Responsible for:
    - Placing new trade orders
    - Closing open positions
    - Tracking active trades
    - Building and maintaining the trade book
    """
    
    def __init__(self) -> None:
        """Initialize the TradeManager with an empty tradebook."""
        print("TradeManager initialized")
        self.tradebook: List[Dict[str, Any]] = []
        self.tradebook_built: bool = False
        
    def place_order(self, entry_data: Dict[str, Any]) -> None:
        """
        Place a new trade order and add it to the tradebook.
        
        Args:
            entry_data: Dictionary containing trade entry details
        """
        trade = {
            'strategy_id': entry_data['strategy_id'],
            'position_id': entry_data['position_id'],
            'leg_id': entry_data['leg_id'],
            'symbol': entry_data['symbol'],
            'entry_date': entry_data['entry_date'],
            'entry_time': entry_data['entry_time'],
            'exit_date': None,
            'exit_time': None,
            'entry_price': entry_data['entry_price'],
            'exit_price': None,
            'qty': entry_data['qty'],
            'entry_type': entry_data['entry_type'],
            'entry_spot': entry_data['entry_spot'],
            'exit_spot': None,
            'stop_loss': entry_data['stop_loss'],
            'take_profit': entry_data['take_profit'],
            'entry_reason': entry_data['entry_reason'],
            'exit_reason': None,
            'status': 'open',
        }
        self.tradebook.append(trade)
        
    def square_off(self, trade_index: int, exit_data: Dict[str, Any]) -> None:
        """
        Close an open position.
        
        Args:
            trade_index: Index of the trade in tradebook
            exit_data: Dictionary containing trade exit details
        """
        self.tradebook[trade_index]['status'] = 'closed'
        self.tradebook[trade_index]['exit_date'] = exit_data["exit_date"]
        self.tradebook[trade_index]['exit_time'] = exit_data["exit_time"]
        self.tradebook[trade_index]['exit_price'] = exit_data["exit_price"]
        self.tradebook[trade_index]['exit_spot'] = exit_data["exit_spot"]
        self.tradebook[trade_index]['exit_reason'] = exit_data["exit_reason"]
    
    def active_trades(self) -> List[Tuple[int, Dict[str, Any]]]:
        """
        Return all active trades.
        
        Returns:
            List of tuples containing trade index and trade details
        """
        return [(index, trade) for index, trade in enumerate(self.tradebook) if trade['status'] == 'open']

    def build_tradebook(self) -> pd.DataFrame:
        """
        Convert tradebook to DataFrame and add calculated columns.
        
        Returns:
            DataFrame containing processed tradebook
        """
        # If tradebook is already built, return it
        if self.tradebook_built and isinstance(self.tradebook, pd.DataFrame) and not self.tradebook.empty:
            print("Returning existing tradebook DataFrame")
            return self.tradebook

        # Check if tradebook is empty (list or DataFrame)
        if isinstance(self.tradebook, list) and not self.tradebook:
            print("Tradebook is empty. No trades to build.")
            return pd.DataFrame()
        elif isinstance(self.tradebook, pd.DataFrame) and self.tradebook.empty:
            print("Tradebook DataFrame is empty. No trades to build.")
            return pd.DataFrame()

        # Convert list to DataFrame if necessary
        if isinstance(self.tradebook, list):
            df = pd.DataFrame(self.tradebook)
        else:
            df = self.tradebook.copy()  # Work on a copy to avoid modifying the original

        # Log raw tradebook for debugging
        expected_cols = ['symbol', 'entry_date', 'exit_date', 'entry_time', 'exit_time']
        if all(col in df.columns for col in expected_cols):
            print("Raw tradebook before processing:")
            print(df[expected_cols].head())
        else:
            print(f"Missing expected columns: {[col for col in expected_cols if col not in df.columns]}")
            print("Available columns:", df.columns.tolist())

        # Parse expiry from symbol
        def parse_expiry(symbol):
            if not isinstance(symbol, str) or len(symbol) < 14:
                print(f"Invalid symbol format: {symbol}")
                return pd.NaT
            try:
                return datetime.datetime.strptime(symbol[-14:-7], "%d%b%y")
            except (ValueError, TypeError) as e:
                print(f"Error parsing expiry from symbol {symbol}: {e}")
                return pd.NaT

        df["expiry"] = pd.to_datetime(df["symbol"].apply(parse_expiry), errors='coerce')
        df["instrument_type"] = df["symbol"].apply(lambda x: x[-2:] if isinstance(x, str) else None)
        df["strike"] = df["symbol"].apply(lambda x: x[-7:-2] if isinstance(x, str) else None)

        # Convert entry_date and exit_date to datetime
        df["entry_date"] = pd.to_datetime(df["entry_date"], errors='coerce')
        closed_trades = df['status'] == 'closed'
        df.loc[closed_trades, "exit_date"] = pd.to_datetime(df.loc[closed_trades, "exit_date"], errors='coerce')

        # Debug: Check dtypes and nulls
        print(f"entry_date null count: {df['entry_date'].isna().sum()}")
        print(f"exit_date null count (closed trades): {df.loc[closed_trades, 'exit_date'].isna().sum()}")
        print(f"expiry null count: {df['expiry'].isna().sum()}")

        if df['entry_date'].isna().any():
            print("Rows with null entry_date:")
            print(df[df['entry_date'].isna()][['symbol', 'entry_date', 'entry_time']])
        if df.loc[closed_trades, 'exit_date'].isna().any():
            print("Rows with null exit_date (closed trades):")
            print(df.loc[closed_trades & df['exit_date'].isna()][['symbol', 'exit_date', 'exit_time']])
        if df['expiry'].isna().any():
            print("Rows with null expiry:")
            print(df[df['expiry'].isna()][['symbol', 'expiry']])

        # Compute entry/exit datetimes
        df["entry_datetime"] = df["entry_date"] + pd.to_timedelta(df["entry_time"].astype(str), errors='coerce')
        df.loc[closed_trades, "exit_datetime"] = (
            df.loc[closed_trades, "exit_date"] + 
            pd.to_timedelta(df.loc[closed_trades, "exit_time"].astype(str), errors='coerce')
        )

        # Compute PnL for closed trades
        df["pnl"] = df.apply(
            lambda row: (
                row["exit_price"] - row["entry_price"] if row["entry_type"] == "BUY" 
                else row["entry_price"] - row["exit_price"]
            ) if row["status"] == "closed" and pd.notna(row["exit_price"]) and pd.notna(row["entry_price"]) else None,
            axis=1
        )

        self.tradebook = df
        self.tradebook_built = True
        return df

In [None]:
import duckdb
import pandas as pd
import math
import time
from datetime import date
import re

iv = {}
time_arr = {}
data = None

def calculate_performance(trader):
    trades = trader.build_tradebook()
    if trades.empty:
        return 0.0
    # Filter for closed trades (non-null exit_price)
    closed_trades = trades[trades['exit_price'].notna()]
    if closed_trades.empty:
        return 0.0
    # Calculate per-trade profits
    profits = (closed_trades['pnl'] * closed_trades['qty']).to_numpy()
    returns = pd.Series(profits)
    if len(returns) < 2:
        return 0.0  # Need at least 2 trades for standard deviation
    # Calculate annualized Sharpe Ratio (assuming daily returns)
    sharpe_ratio = returns.mean() / returns.std() * (252 ** 0.5)
    return sharpe_ratio if not pd.isna(sharpe_ratio) else 0.0

def parse_table_name(table_name):
    """Convert table name like 'YYYY-MM-DD' or 'nifty_YYYY_MM_DD' to datetime.date."""
    if not isinstance(table_name, str):
        print(f"Invalid table name: {table_name} (type: {type(table_name)})")
        return None
    
    # Try parsing 'YYYY-MM-DD' format
    if re.match(r'\d{4}-\d{2}-\d{2}', table_name):
        try:
            return datetime.datetime.strptime(table_name, "%Y-%m-%d").date()
        except ValueError as e:
            print(f"Error parsing date from {table_name}: {e}")
            return None
    
    # Try parsing 'nifty_YYYY_MM_DD' format
    match = re.match(r'nifty_(\d{4})_(\d{2})_(\d{2})', table_name)
    if match:
        year, month, day = match.groups()
        print(f"Parsed {table_name}: year={year}, month={month}, day={day}")  # Debug
        try:
            year, month, day = int(year), int(month), int(day)
            # Validate date components
            if not (1 <= month <= 12):
                print(f"Invalid month in {table_name}: {month}")
                return None
            if not (1 <= day <= 31):  # Basic check; could use calendar.monthrange for precision
                print(f"Invalid day in {table_name}: {day}")
                return None
            return datetime.date(year, month, day)
        except ValueError as e:
            print(f"Error parsing date from {table_name}: {e}")
            return None
    
    print(f"Cannot parse table name: {table_name}")
    return None


def backtest(legs, iv_slope_thresolds, duckdb, trader, dates):
    global iv, time_arr, data
    signal = 0
    position_id = 1
    counter = 1
    iv_slope = 0
    
    # Convert table names to dates
    parsed_dates = [parse_table_name(date_str) for date_str in dates]
    if any(date is None for date in parsed_dates):
        print("Warning: Some table names could not be parsed into dates")
    parsed_dates = [date for date in parsed_dates if date is not None]
    if not parsed_dates:
        print("Error: No valid dates parsed from table names")
        return 0.0

    for date in parsed_dates:
        date_str = f"nifty_{date.strftime('%Y_%m_%d')}"
        start_time = time.time()
        print(f"Processing table: {date_str} at {counter}")
        counter += 1
        try:
            data_df = duckdb.execute(f"SELECT * FROM {date_str} ORDER BY timestamp").fetchdf()
        except Exception as e:
            print(f"Error querying table {date_str}: {e}")
            continue
        time_to_expiry = sorted(data_df["Time_to_expiry"].unique())
        data_df.set_index("timestamp", inplace=True)
        ticker_map = {ticker: group for ticker, group in data_df.groupby("ticker", sort=False)}
        empty_df_template = data_df.iloc[0:0]

        for leg in legs.values():
            try:
                valid_tte = min(tte for tte in time_to_expiry if any(lower <= tte <= upper for lower, upper in [leg["expiry_range"]]))
                leg["expiry"] = data_df.loc[data_df["Time_to_expiry"] == valid_tte, "expiry_date"].iloc[0]
                if not isinstance(leg["expiry"], (datetime.datetime, pd.Timestamp, datetime.date)):
                    print(f"Invalid expiry {leg['expiry']} (type: {type(leg['expiry'])}) for leg {leg['type']}")
                    leg["expiry"] = pd.NaT
            except (ValueError, IndexError) as e:
                print(f"Error setting expiry for leg {leg['type']}: {e}")
                continue

        spot = data_df[["spot_price"]][~data_df.index.duplicated(keep="first")]
        for row in spot.itertuples():
            if row.spot_price is None or pd.isna(row.spot_price):
                continue
            atm = round(row.spot_price / 50) * 50

            for leg in legs.values():
                if leg["target_strike"] == "ATM":
                    leg["strike"] = float(atm)
                try:
                    if pd.isna(leg["expiry"]):
                        raise ValueError("Expiry is NaT")
                    contract = f"NIFTY{pd.Timestamp(leg['expiry']).strftime('%d%b%y').upper()}{int(leg['strike'])}{leg['type']}"
                except (ValueError, TypeError) as e:
                    print(f"Error forming contract for leg {leg['type']}: {e}")
                    continue
                leg["contract"] = contract
                subset_df = ticker_map.get(contract, empty_df_template)
                data = subset_df
                avl_time = subset_df.index.asof(row.Index) if not subset_df.empty else None
                leg["data"] = subset_df.loc[avl_time] if not pd.isna(avl_time) else None

            missing_legs = [leg["contract"] for leg in legs.values() if leg["data"] is None]
            if missing_legs:
                continue

            if (pd.Timestamp("15:29:00").time() <= pd.Timestamp(row.Index).time() <= pd.Timestamp("15:30:00").time()):
                try:
                    iv_slope = math.log((legs["leg1"]["data"]["iv"] + legs["leg2"]["data"]["iv"]) / (legs["leg3"]["data"]["iv"] + legs["leg4"]["data"]["iv"]), 10)
                    iv[row.Index] = (iv_slope, row.spot_price)
                except (ValueError, TypeError) as e:
                    print(f"Error calculating iv_slope: {e}")
                    continue

            new_signal = (
                (iv_slope > iv_slope_thresolds["upper_gamma"]) * 3 +
                (iv_slope_thresolds["upper_gamma"] >= iv_slope > iv_slope_thresolds["upper_buffer"]) * 2 +
                (iv_slope_thresolds["upper_buffer"] >= iv_slope > 0) * 1 +
                (0 >= iv_slope > iv_slope_thresolds["lower_buffer"]) * -1 +
                (iv_slope_thresolds["lower_buffer"] >= iv_slope > iv_slope_thresolds["lower_gamma"]) * -2 +
                (iv_slope_thresolds["lower_gamma"] >= iv_slope) * -3
            )

            active_trades = trader.active_trades()
            if (not active_trades) and (pd.Timestamp(row.Index).time() < pd.Timestamp("15:00:00").time()):
                if new_signal == -2 or new_signal == 2:
                    continue
                elif new_signal == 1:
                    entry_type_dict = {'weekly': 'BUY', 'monthly': 'SELL'}
                elif new_signal == -1:
                    entry_type_dict = {'weekly': 'SELL', 'monthly': 'BUY'}
                elif new_signal == -3 or new_signal == 3:
                    entry_type_dict = {'weekly': 'BUY', 'monthly': None}

                for leg_id, leg in legs.items():
                    entry_type = entry_type_dict.get(leg["expiry_type"])
                    if entry_type is None:
                        continue
                    entry_data = {
                        'strategy_id': 'strat1',
                        'position_id': position_id,
                        'leg_id': leg_id,
                        'symbol': leg["contract"],
                        'entry_date': pd.Timestamp(row.Index).date(),
                        'entry_time': pd.Timestamp(row.Index).time(),
                        'entry_price': leg["data"]["close"],
                        'qty': 1,
                        'entry_type': entry_type,
                        'entry_spot': row.spot_price,
                        'stop_loss': None,
                        'take_profit': None,
                        'entry_reason': f'{new_signal} signal entry',
                    }
                    trader.place_order(entry_data)
                position_id += 1
            else:
                near_expiry = None
                for index, trade in active_trades:
                    try:
                        expiry = datetime.datetime.strptime(trade["symbol"][-14:-7], "%d%b%y").date()
                        near_expiry = expiry if near_expiry is None else min(near_expiry, expiry)
                    except (ValueError, TypeError) as e:
                        print(f"Error parsing expiry from trade symbol {trade['symbol']}: {e}")
                        continue
                exit_reason = (
                    "End of Data reached" if ((date == parsed_dates[-1]) and (pd.Timestamp(row.Index).time() > pd.Timestamp("15:00:00").time())) else
                    "Near Expiry reached" if (pd.Timestamp(row.Index).date() == near_expiry) else
                    "Signal changed" if (signal != new_signal) else
                    None
                )
                if exit_reason:
                    for index, trade in active_trades:
                        contract = trade["symbol"]
                        subset_df = ticker_map.get(contract, empty_df_template)
                        subset_df = subset_df.loc[:row.Index]
                        if not subset_df.empty:
                            close_price = subset_df.iloc[-1]["close"]
                        else:
                            close_price = trade["entry_price"]
                        exit_data = {
                            'exit_date': pd.Timestamp(row.Index).date(),
                            'exit_time': pd.Timestamp(row.Index).time(),
                            'exit_price': close_price,
                            'exit_spot': row.spot_price,
                            'exit_reason': exit_reason,
                        }
                        trader.square_off(index, exit_data)

                if signal == new_signal:
                    leg_strike = legs["leg2"]["strike"]
                    if (row.spot_price * 0.99) <= leg_strike <= (row.spot_price * 1.01):
                        continue
                    else:
                        for leg_id, leg in legs.items():
                            entry_type = entry_type_dict.get(leg["expiry_type"])
                            if entry_type is None:
                                continue
                            entry_data = {
                                'strategy_id': 'strat1',
                                'position_id': position_id,
                                'leg_id': leg_id,
                                'symbol': leg["contract"],
                                'entry_date': pd.Timestamp(row.Index).date(),
                                'entry_time': pd.Timestamp(row.Index).time(),
                                'entry_price': leg["data"]["close"],
                                'qty': 1,
                                'entry_type': entry_type,
                                'entry_spot': row.spot_price,
                                'stop_loss': None,
                                'take_profit': None,
                                'entry_reason': f'Adjustment Calendar',
                            }
                            trader.place_order(entry_data)
                        position_id += 1

            signal = new_signal
        time_arr[date_str] = time.time() - start_time

    # Force-close any remaining open trades
    active_trades = trader.active_trades()
    if active_trades:
        print(f"Force-closing {len(active_trades)} open trades")
        last_date = parsed_dates[-1] if parsed_dates else datetime.date.today()
        for index, trade in active_trades:
            contract = trade["symbol"]
            subset_df = ticker_map.get(contract, empty_df_template)
            close_price = subset_df.iloc[-1]["close"] if not subset_df.empty else trade["entry_price"]
            exit_data = {
                'exit_date': last_date,
                'exit_time': pd.Timestamp("15:30:00").time(),
                'exit_price': close_price,
                'exit_spot': spot.iloc[-1]["spot_price"] if not spot.empty else trade["entry_spot"],
                'exit_reason': "End of backtest",
            }
            trader.square_off(index, exit_data)

    return calculate_performance(trader)

In [None]:
import os
import re
import numpy as np
import matplotlib.pyplot as plt
from openpyxl import Workbook
from itertools import product
import itertools
import datetime
from openpyxl.drawing.image import Image
import json
import multiprocessing as mp
from functools import partial

class HyperParameterOptimizer:
    def __init__(self, legs: Dict, conn: duckdb.DuckDBPyConnection, dates: pd.Series):
        """
        Initialize the HyperParameterOptimizer.

        Args:
            legs: Dictionary of trade legs
            conn: DuckDB connection object
            dates: Series of dates for backtesting
        """
        self.legs = legs
        self.conn = conn
        self.dates = dates
        self.trader = TradeManager()

    def optimize(self, hyperparameter_grid: Dict[str, List[float]], maximize: str = 'Sharpe Ratio', 
                 method: str = 'grid', max_tries: Optional[int] = None, constraint: Optional[callable] = None) -> Tuple[Dict, float, pd.DataFrame]:
        """
        Optimize hyperparameters using grid search.

        Args:
            hyperparameter_grid: Dictionary of hyperparameters with their possible values
            maximize: Metric to maximize ('Sharpe Ratio')
            method: Optimization method ('grid' only for now)
            max_tries: Maximum number of parameter combinations to try
            constraint: Function to filter admissible parameter combinations

        Returns:
            Tuple of (best parameters, best Sharpe ratio, results DataFrame)
        """
        if method != 'grid':
            raise ValueError("Only 'grid' method is supported in this implementation")

        # Generate all parameter combinations
        param_keys = list(hyperparameter_grid.keys())
        param_values = [hyperparameter_grid[key] for key in param_keys]
        param_combinations = [dict(zip(param_keys, combo)) for combo in product(*param_values)]

        # Apply constraint if provided
        if constraint:
            param_combinations = [params for params in param_combinations if constraint(params)]
        
        # Limit combinations if max_tries is specified
        if max_tries and max_tries < len(param_combinations):
            param_combinations = param_combinations[:max_tries]

        if not param_combinations:
            raise ValueError("No admissible parameter combinations to test")

        # Initialize results storage
        results = []
        best_sharpe = -np.inf
        best_params = None

        # Run backtest for each parameter combination
        for params in param_combinations:
            print(f"Testing parameters: {params}")
            self.trader = TradeManager()  # Reset trader for each run
            sharpe_ratio = backtest(self.legs, params, self.conn, self.trader, self.dates)
            results.append({**params, 'Sharpe Ratio': sharpe_ratio})

            print(f"Sharpe ratio : {sharpe_ratio}")
            
            if sharpe_ratio > best_sharpe:
                best_sharpe = sharpe_ratio
                best_params = params

        # Convert results to DataFrame
        results_df = pd.DataFrame(results)
        
        # Generate heatmaps
        timestamp = datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
        output_directory_path = os.path.join(os.getcwd(), f'hyperparameter_optimizer_output_{timestamp}')
        os.makedirs(output_directory_path, exist_ok=True)
        sheet_path = os.path.join(output_directory_path, 'summary.xlsx')
        
        self.generate_heatmaps(results_df, sheet_path, output_directory_path)
        self.delete_png_files(output_directory_path)

        return best_params, best_sharpe, results_df

    def generate_heatmaps(self, results_df: pd.DataFrame, output_file: str, png_directory: str) -> None:
        """
        Generate heatmaps for the Sharpe Ratio and save to Excel.

        Args:
            results_df: DataFrame containing optimization results
            output_file: Path to save the Excel file
            png_directory: Directory to save temporary heatmap images
        """
        workbook = Workbook()
        sheet = workbook.create_sheet(title="Sharpe Ratio")

        # Write results to sheet
        for row_idx, row in enumerate(
            [results_df.columns.tolist()] + results_df.values.tolist(), start=1
        ):
            for col_idx, value in enumerate(row, start=1):
                sheet.cell(row=row_idx, column=col_idx, value=value)

        # Add heatmaps
        start_row = len(results_df) + 3
        parameter_columns = [col for col in results_df.columns if col != "Sharpe Ratio"]
        parameter_pairs = list(itertools.combinations(parameter_columns, 2))

        for param_x, param_y in parameter_pairs:
            grouped = (
                results_df.groupby([param_x, param_y])["Sharpe Ratio"]
                .mean()
                .unstack(fill_value=np.nan)
            )

            cmap = plt.cm.viridis
            cmap = cmap.copy()
            cmap.set_bad(color="white")

            plt.figure(figsize=(5, 4))
            ax = plt.gca()
            heatmap = ax.imshow(
                grouped.values, cmap=cmap, aspect="auto", interpolation="nearest"
            )
            ax.set_xticks(np.arange(len(grouped.columns)))
            ax.set_yticks(np.arange(len(grouped.index)))
            ax.set_xticklabels(grouped.columns, rotation=45)
            ax.set_yticklabels(grouped.index)
            ax.set_xlabel(param_y)
            ax.set_ylabel(param_x)
            ax.set_title(f"Heatmap of Sharpe Ratio: {param_x} vs {param_y}")
            plt.colorbar(heatmap, ax=ax)

            image_file = f"{param_x}_{param_y}_Sharpe_Ratio_heatmap.png"
            image_file = os.path.join(png_directory, image_file)
            plt.savefig(image_file, bbox_inches="tight", dpi=150)
            plt.close()

            img = Image(image_file)
            img.anchor = f"A{start_row}"
            sheet.add_image(img)
            start_row += 30

        if "Sheet" in workbook.sheetnames:
            workbook.remove(workbook["Sheet"])
        workbook.save(output_file)
        print(f"Heatmaps saved to {output_file}")

    def delete_png_files(self, png_directory: str) -> None:
        """
        Delete temporary PNG files.

        Args:
            png_directory: Directory containing PNG files
        """
        for filename in os.listdir(png_directory):
            if filename.endswith(".png"):
                file_path = os.path.join(png_directory, filename)
                os.remove(file_path)

In [None]:
class WalkForwardOptimizer:
    """
    Performs walk-forward optimization for trading strategy hyperparameters using HyperParameterOptimizer.

    Attributes:
        legs (dict): Trading legs configuration.
        hyperparameter_grid (dict): Grid of hyperparameters to optimize.
        duckdb: DuckDB connection object.
        dates (pd.Series): Series of dates for backtesting.
        in_sample_ratio (float): Proportion of data for in-sample testing.
        out_sample_ratio (float): Proportion of data for out-of-sample testing.
    """
    def __init__(self, legs: Dict, hyperparameter_grid: Dict[str, List[float]], duckdb: duckdb.DuckDBPyConnection, 
                 dates: pd.Series, in_sample_ratio: float = 0.6/4, out_sample_ratio: float = 0.2/4):
        """Initialize the WalkForwardOptimizer with necessary parameters."""
        self.legs = legs
        self.hyperparameter_grid = hyperparameter_grid
        self.duckdb = duckdb
        self.dates = dates.tolist()
        self.in_sample_ratio = in_sample_ratio
        self.out_sample_ratio = out_sample_ratio
        self.total_dates = len(self.dates)
        self.in_sample_size = int(self.total_dates * in_sample_ratio)
        self.out_sample_size = int(self.total_dates * out_sample_ratio)

        if self.in_sample_size + self.out_sample_size > self.total_dates:
            raise ValueError("In-sample + out-sample size exceeds total dates")

    def optimize(self) -> Tuple[Dict, float, List[Dict]]:
        """
        Perform walk-forward optimization using HyperParameterOptimizer.

        Returns:
            tuple: (best_params, avg_out_sample_performance, results)
                - best_params: Best hyperparameters from in-sample optimization.
                - avg_out_sample_performance: Average Sharpe ratio from out-sample tests.
                - results: List of dictionaries with window, in-sample, and out-sample results.
        """
        results = []
        out_sample_performances = []

        # Define the constraint function for hyperparameters
        def constraint(params: Dict) -> bool:
            return params["upper_gamma"] > params["upper_buffer"] > 0 > params["lower_buffer"] > params["lower_gamma"]

        for start in range(0, self.total_dates - self.in_sample_size - self.out_sample_size + 1, self.out_sample_size):
            in_sample_dates = self.dates[start:start + self.in_sample_size]
            out_sample_dates = self.dates[start + self.in_sample_size:start + self.in_sample_size + self.out_sample_size]

            print(f"Processing window: In-sample {in_sample_dates[0]} to {in_sample_dates[-1]}, "
                  f"Out-sample {out_sample_dates[0]} to {out_sample_dates[-1]}")

            # Initialize HyperParameterOptimizer for in-sample data
            optimizer = HyperParameterOptimizer(self.legs, self.duckdb, pd.Series(in_sample_dates))
            
            try:
                # Perform grid search on in-sample data
                best_in_sample_params, best_in_sample_performance, results_df = optimizer.optimize(
                    hyperparameter_grid=self.hyperparameter_grid,
                    maximize='Sharpe Ratio',
                    method='grid',
                    constraint=constraint
                )
                
                # Debug: Log in-sample results
                print(f"In-sample best params: {best_in_sample_params}, Sharpe Ratio: {best_in_sample_performance}")
                results_df.to_csv(f'in_sample_results_window_{start}.csv')

            except (duckdb.ConnectionException, ValueError) as e:
                print(f"Error during in-sample optimization: {e}")
                continue

            # Test best parameters on out-sample data
            trader = TradeManager()
            try:
                out_sample_performance = backtest(self.legs, best_in_sample_params, self.duckdb, trader, pd.Series(out_sample_dates))
                
                # Debug: Log out-sample tradebook
                tradebook = trader.build_tradebook()
                print(f"Out-sample: {len(tradebook)} trades, Sharpe Ratio: {out_sample_performance}")
                tradebook.to_csv(f'out_sample_tradebook_window_{start}.csv')

            except duckdb.ConnectionException as e:
                print(f"Connection error during out-sample testing: {e}")
                continue

            results.append({
                'window': (in_sample_dates[0], out_sample_dates[-1]),
                'in_sample_params': best_in_sample_params,
                'in_sample_performance': best_in_sample_performance,
                'out_sample_performance': out_sample_performance
            })
            out_sample_performances.append(out_sample_performance)

        if not results:
            raise ValueError("No valid results from walk-forward optimization")

        avg_out_sample_performance = sum(out_sample_performances) / len(out_sample_performances)
        best_result = max(results, key=lambda x: x['out_sample_performance'])
        best_params = best_result['in_sample_params']

        # Save results to JSON
        wfo_output = {
            'best_params': best_params,
            'avg_out_sample_sharpe_ratio': avg_out_sample_performance,
            'results': [
                {
                    'window': (str(result['window'][0]), str(result['window'][1])),
                    'in_sample_params': result['in_sample_params'],
                    'in_sample_sharpe_ratio': result['in_sample_performance'],
                    'out_sample_sharpe_ratio': result['out_sample_performance']
                } for result in results
            ]
        }
        with open('wfo_results.json', 'w') as f:
            json.dump(wfo_output, f, indent=4)

        return best_params, avg_out_sample_performance, results

In [None]:
import traceback

if __name__ == "__main__":
    db_path = "nifty_1min_desiquant.duckdb"
    conn = None
    try:
        if not os.path.exists(db_path):
            raise FileNotFoundError(f"Database file '{db_path}' not found in {os.getcwd()}")
        conn = duckdb.connect(db_path)
        
        # Fetch table names
        table_names = conn.execute("SELECT table_name FROM information_schema.tables WHERE table_schema = 'main'").fetchdf()
        if table_names.empty:
            raise ValueError("No tables found in the database")
        
        # Debug: Print table names
        print("Table names DataFrame columns:", table_names.columns.tolist())
        print("First few rows of table_names:")
        print(table_names.head())
        
        # Filter table names for the specified range
        table_names = table_names[90:703]
        
        # Filter table names using regex pattern
        pattern = re.compile(r'nifty_\d{4}_\d{2}_\d{2}')
        table_names = table_names[table_names['table_name'].str.match(pattern)]
        if table_names.empty:
            raise ValueError("No dates available in the specified range [90:703]")
        
        # Convert table names to dates
        dates = pd.Series([parse_table_name(name) for name in table_names['table_name']], index=table_names.index)
        dates = dates.dropna()
        if dates.empty:
            raise ValueError("No valid dates parsed from table names")
        dates = dates.apply(lambda x: x.strftime('%Y-%m-%d'))  # Convert to string format for consistency
        
        legs = {
            'leg1': {'type': 'CE', 'expiry_type': 'weekly', 'expiry_range': [12, 20], 'target_strike': 'ATM', 'stop_loss': None, 'take_profit': None},
            'leg2': {'type': 'PE', 'expiry_type': 'weekly', 'expiry_range': [12, 20], 'target_strike': 'ATM', 'stop_loss': None, 'take_profit': None},
            'leg3': {'type': 'CE', 'expiry_type': 'monthly', 'expiry_range': [26, 34], 'target_strike': 'ATM', 'stop_loss': None, 'take_profit': None},
            'leg4': {'type': 'PE', 'expiry_type': 'monthly', 'expiry_range': [26, 34], 'target_strike': 'ATM', 'stop_loss': None, 'take_profit': None}
        }

        hyperparameter_grid = {
            "upper_gamma": [0.0, 0.02, 0.04],
            "upper_buffer": [-0.02, 0.00, 0.02],
            "lower_buffer": [-0.03, -0.06, -0.09],
            "lower_gamma": [-0.05, -0.08, -0.11]
        }
        
        # Perform Walk-Forward Optimization
        print("Running Walk-Forward Optimization...")
        optimizer = WalkForwardOptimizer(legs, hyperparameter_grid, conn, dates, in_sample_ratio=0.6/4, out_sample_ratio=0.2/4)
        best_params, avg_performance, wfo_results = optimizer.optimize()
        
        print(f"Best parameters: {best_params}")
        print(f"Average out-of-sample Sharpe Ratio: {avg_performance}")
        print("WFO results saved to 'wfo_results.json'")
        
    except FileNotFoundError as e:
        print(f"Error: {e}")
        print("Please ensure the database file exists in the current working directory or provide the correct path.")
    except PermissionError:
        print(f"Error: Permission denied accessing '{db_path}'")
        print("Please check file permissions and ensure you have read/write access.")
    except duckdb.IOException as e:
        print(f"Error: Failed to open database: {e}")
        print("The database file may be corrupted or locked. Try restoring from a backup or checking for open connections.")
    except duckdb.ConnectionException as e:
        print(f"Connection error: {e}")
        print("The database connection was closed unexpectedly. Check for connection timeouts or resource issues.")
    except Exception as e:
        print(f"Unexpected error: {e}")
        traceback.print_exc()
    finally:
        if conn is not None:
            try:
                conn.close()
                print("Database connection closed successfully.")
            except Exception as e:
                print(f"Error closing connection: {e}")