In [1]:
%pip install -U yfinance pandas



In [3]:
import numpy as np
import pandas as pd
from abc import ABC, abstractmethod
from dataclasses import dataclass
from enum import Enum
from typing import Dict, List, Optional, Tuple, Any, Union
from collections import defaultdict, deque
import itertools
import time
import warnings
warnings.filterwarnings('ignore')

import numpy as np
from abc import ABC, abstractmethod
from typing import Tuple, Dict, List
from sklearn.linear_model import LinearRegression
from scipy import stats

import importlib
from pathlib import Path
from typing import List, Tuple, Dict
import numpy as np
import pandas as pd
import io



import numpy as np
from dataclasses import dataclass
from typing import Dict, List, Tuple
from abc import ABC, abstractmethod
from enum import Enum
from sklearn.linear_model import LinearRegression




class OrderSide(Enum):
    BUY = "BUY"
    SELL = "SELL"

class OrderType(Enum):
    MARKET = "MARKET"

@dataclass
class Order:
    symbol: str
    side: OrderSide
    quantity: float
    order_type: OrderType = OrderType.MARKET

@dataclass
class Position:
    symbol: str
    quantity: float
    entry_price: float

class QuantStrategy(ABC):
    def __init__(self, symbol_pair: Tuple[str, str], **params):
        self.symbol_pair = symbol_pair
        self.symbol1, self.symbol2 = symbol_pair
        self.params = params
        self.position_state = 0  # 0: flat, 1: long spread, -1: short spread
        self.positions: Dict[str, Position] = {}
        self.entry_hedge_ratio: Optional[float] = None

    @abstractmethod
    def calculate_signals(self, prices_dict: Dict[str, np.ndarray],
                         current_time: int) -> List[Order]:
        pass

    @abstractmethod
    def get_strategy_name(self) -> str:
        pass

    def safe_divide(self, a, b):
        return a / (b + 1e-10)

    def safe_std(self, arr):
        return max(np.std(arr), 1e-8)

    def calculate_hedge_ratio(self, p1: np.ndarray, p2: np.ndarray) -> float:
        try:
            if len(p1) < 10 or len(p2) < 10:
                return 1.0
            X = p2.reshape(-1, 1)
            y = p1
            model = LinearRegression().fit(X, y)
            hedge_ratio = model.coef_[0]
            if np.isnan(hedge_ratio) or np.isinf(hedge_ratio):
                return 1.0
            return max(0.1, min(5.0, abs(hedge_ratio)))
        except Exception as e:
            return 1.0

    def is_flat(self) -> bool:
        return len(self.positions) == 0

class EnhancedMeanReversionStrategy(QuantStrategy):
    def __init__(self, symbol_pair: Tuple[str, str], **params):
        super().__init__(symbol_pair, **params)
        self.lookback = params.get('lookback', 60)
        self.entry_threshold = params.get('entry_threshold', 2.0)
        self.exit_threshold = params.get('exit_threshold', 0.5)
        self.stop_loss = params.get('stop_loss', 0.03)
        self.max_holding = params.get('max_holding', 200)
        self.entry_price = 0
        self.entry_time = 0
        self.min_correlation = params.get('min_correlation', 0.5)
        self.cointegration_check = params.get('cointegration_check', True)
        self.dynamic_hedge_ratio = params.get('dynamic_hedge_ratio', True)

    def check_cointegration(self, p1: np.ndarray, p2: np.ndarray, hedge_ratio: float) -> bool:
        try:
            if len(p1) < 30:
                return True
            residuals = p1 - hedge_ratio * p2
            adf_stat, p_value, _, _, critical_values, _ = stats.adfuller(residuals, maxlag=10)
            return adf_stat < critical_values.get('10%', -2.58)
        except:
            return True

    def calculate_signals(self, prices_dict: Dict[str, np.ndarray], current_time: int) -> List[Order]:
        p1 = prices_dict.get(self.symbol1, np.array([]))
        p2 = prices_dict.get(self.symbol2, np.array([]))

        if len(p1) < self.lookback or len(p2) < self.lookback:
            return []

        p1_window = p1[-self.lookback:]
        p2_window = p2[-self.lookback:]

        hedge_ratio = self.calculate_hedge_ratio(p1_window, p2_window)

        correlation = np.corrcoef(p1_window, p2_window)[0, 1]
        if abs(correlation) < self.min_correlation:
            if self.position_state != 0:
                return self._generate_exit_orders(hedge_ratio)
            return []

        if self.cointegration_check:
            if not self.check_cointegration(p1_window, p2_window, hedge_ratio):
                if self.position_state != 0:
                    return self._generate_exit_orders(hedge_ratio)
                return []

        spread = p1_window - hedge_ratio * p2_window
        if len(spread) > 20:
            stat_window = min(40, len(spread))
            mean_spread = np.mean(spread[-stat_window:])
            std_spread = self.safe_std(spread[-stat_window:])
        else:
            mean_spread = np.mean(spread)
            std_spread = self.safe_std(spread)

        current_spread = spread[-1]
        z_score = (current_spread - mean_spread) / std_spread

        orders = []
        if self.position_state == 0:
            if abs(z_score) > self.entry_threshold:
                self.entry_price = current_spread
                self.entry_time = current_time
                base_size = 100
                hedge_size = max(1, int(base_size * hedge_ratio))
                if z_score > 0:
                    orders = [
                        Order(self.symbol1, OrderSide.SELL, base_size),
                        Order(self.symbol2, OrderSide.BUY, hedge_size)
                    ]
                    self.position_state = -1
                else:
                    orders = [
                        Order(self.symbol1, OrderSide.BUY, base_size),
                        Order(self.symbol2, OrderSide.SELL, hedge_size)
                    ]
                    self.position_state = 1
        elif self.position_state != 0:
            current_spread_val = current_spread
            if abs(self.entry_price) > 1e-10:
                pnl_pct = (current_spread_val - self.entry_price) / abs(self.entry_price) * self.position_state
            else:
                pnl_pct = 0
            time_held = current_time - self.entry_time
            should_exit = False
            if abs(z_score) < self.exit_threshold:
                should_exit = True
            elif abs(pnl_pct) > self.stop_loss and pnl_pct * self.position_state < 0:
                should_exit = True
            elif time_held > self.max_holding:
                should_exit = True
            elif abs(z_score) > self.entry_threshold * 2:
                should_exit = True
            if should_exit:
                orders = self._generate_exit_orders(hedge_ratio)
                self.position_state = 0

        return orders

    def _generate_exit_orders(self, hedge_ratio: float) -> List[Order]:
        base_size = 100
        hedge_size = max(1, int(base_size * hedge_ratio))

        if self.position_state == 1:
            return [
                Order(self.symbol1, OrderSide.SELL, base_size),
                Order(self.symbol2, OrderSide.BUY, hedge_size)
            ]
        elif self.position_state == -1:
            return [
                Order(self.symbol1, OrderSide.BUY, base_size),
                Order(self.symbol2, OrderSide.SELL, hedge_size)
            ]
        return []

    def get_strategy_name(self) -> str:
        return f"MeanRev_{self.entry_threshold}_{self.exit_threshold}_{self.stop_loss}"

class AdaptiveMeanReversionStrategy(EnhancedMeanReversionStrategy):
    def __init__(self, symbol_pair: Tuple[str, str], **params):
        super().__init__(symbol_pair, **params)
        self.volatility_lookback = params.get('volatility_lookback', 20)
        self.adaptive_thresholds = params.get('adaptive_thresholds', True)

    def calculate_adaptive_thresholds(self, spread: np.ndarray) -> Tuple[float, float]:
        if len(spread) < self.volatility_lookback:
            return self.entry_threshold, self.exit_threshold
        recent_volatility = np.std(spread[-self.volatility_lookback:])
        long_term_volatility = np.std(spread)
        if long_term_volatility == 0:
            vol_ratio = 1.0
        else:
            vol_ratio = recent_volatility / long_term_volatility
        if vol_ratio > 1.5:
            entry_thresh = self.entry_threshold * 0.8
            exit_thresh = self.exit_threshold * 1.2
        elif vol_ratio < 0.7:
            entry_thresh = self.entry_threshold * 1.2
            exit_thresh = self.exit_threshold * 0.8
        else:
            entry_thresh = self.entry_threshold
            exit_thresh = self.exit_threshold
        return entry_thresh, exit_thresh

    def calculate_signals(self, prices_dict: Dict[str, np.ndarray], current_time: int) -> List[Order]:
        if self.adaptive_thresholds:
            p1 = prices_dict.get(self.symbol1, np.array([]))
            p2 = prices_dict.get(self.symbol2, np.array([]))

            if len(p1) >= self.lookback and len(p2) >= self.lookback:
                hedge_ratio = self.calculate_hedge_ratio(p1[-self.lookback:], p2[-self.lookback:])
                spread = p1[-self.lookback:] - hedge_ratio * p2[-self.lookback:]
                original_entry = self.entry_threshold
                original_exit = self.exit_threshold
                self.entry_threshold, self.exit_threshold = self.calculate_adaptive_thresholds(spread)
                orders = super().calculate_signals(prices_dict, current_time)
                self.entry_threshold = original_entry
                self.exit_threshold = original_exit
                return orders
        return super().calculate_signals(prices_dict, current_time)

    def get_strategy_name(self) -> str:
        return f"AdaptiveMeanRev_{self.entry_threshold}_{self.exit_threshold}_{self.stop_loss}"

class MomentumStrategy(QuantStrategy):
    def __init__(self, symbol_pair: Tuple[str, str], **params):
        super().__init__(symbol_pair, **params)
        self.lookback = params.get('lookback', 40)
        self.momentum_window = params.get('momentum_window', 10)
        self.entry_threshold = params.get('entry_threshold', 0.02)
        self.exit_threshold = params.get('exit_threshold', 0.005)
        self.base_quantity = params.get('base_quantity', 100)

    def calculate_signals(self, prices_dict: Dict[str, np.ndarray],
                         current_time: int) -> List[Order]:
        p1 = prices_dict.get(self.symbol1, np.array([]))
        p2 = prices_dict.get(self.symbol2, np.array([]))
        if len(p1) < self.lookback or len(p2) < self.lookback:
            return []
        current_hedge_ratio = self.calculate_hedge_ratio(p1[-self.lookback:], p2[-self.lookback:])
        spread = p1[-self.lookback:] - current_hedge_ratio * p2[-self.lookback:]
        if len(spread) < self.momentum_window + 1:
            return []
        recent_returns = np.diff(spread[-self.momentum_window-1:])
        if len(recent_returns) == 0:
            return []
        momentum_mean = np.mean(recent_returns)
        momentum_std = max(np.std(recent_returns), 1e-8)
        momentum_score = momentum_mean / momentum_std
        orders = []
        if self.is_flat() and abs(momentum_score) > self.entry_threshold:
            self.entry_hedge_ratio = current_hedge_ratio
            hedge_quantity = int(self.base_quantity * self.entry_hedge_ratio)

            if momentum_score > 0:
                orders = [
                    Order(self.symbol1, OrderSide.BUY, self.base_quantity),
                    Order(self.symbol2, OrderSide.SELL, hedge_quantity)
                ]
                self.positions[self.symbol1] = Position(self.symbol1, self.base_quantity, p1[-1])
                self.positions[self.symbol2] = Position(self.symbol2, -hedge_quantity, p2[-1])

            else:
                orders = [
                    Order(self.symbol1, OrderSide.SELL, self.base_quantity),
                    Order(self.symbol2, OrderSide.BUY, hedge_quantity)
                ]
                self.positions[self.symbol1] = Position(self.symbol1, -self.base_quantity, p1[-1])
                self.positions[self.symbol2] = Position(self.symbol2, hedge_quantity, p2[-1])
        elif not self.is_flat() and abs(momentum_score) < self.exit_threshold:
            if self.entry_hedge_ratio is not None:
                pos1 = self.positions.get(self.symbol1)
                pos2 = self.positions.get(self.symbol2)

                if pos1 and pos2:
                    if pos1.quantity > 0:
                        orders = [
                            Order(self.symbol1, OrderSide.SELL, abs(pos1.quantity)),
                            Order(self.symbol2, OrderSide.BUY, abs(pos2.quantity))
                        ]
                    else:
                        orders = [
                            Order(self.symbol1, OrderSide.BUY, abs(pos1.quantity)),
                            Order(self.symbol2, OrderSide.SELL, abs(pos2.quantity))
                        ]
                    self.positions.clear()
                    self.entry_hedge_ratio = None
        return orders

    def get_strategy_name(self) -> str:
        return f"FixedMomentum_{self.momentum_window}_{self.entry_threshold}"

class VolatilityMeanReversionStrategy(QuantStrategy):
    def __init__(self, symbol_pair: Tuple[str, str], **params):
        super().__init__(symbol_pair, **params)
        self.lookback = params.get('lookback', 100)
        self.vol_window_short = params.get('vol_window_short', 10)
        self.vol_window_long = params.get('vol_window_long', 30)
        self.vol_threshold = params.get('vol_threshold', 1.8)
        self.price_threshold = params.get('price_threshold', 2.0)
        self.exit_vol_threshold = params.get('exit_vol_threshold', 1.2)
        self.base_quantity = params.get('base_quantity', 100)

    def calculate_realized_volatility(self, prices: np.ndarray, window: int) -> float:
        if len(prices) < window + 1:
            return 0.0
        log_returns = np.diff(np.log(prices[-window-1:]))
        if len(log_returns) == 0:
            return 0.0
        return np.std(log_returns) * np.sqrt(252)

    def calculate_ewma_volatility(self, returns: np.ndarray, lambda_param: float = 0.94) -> float:
        if len(returns) < 5:
            return np.std(returns) if len(returns) > 1 else 0.0
        weights = np.array([(1-lambda_param) * (lambda_param**i)
                           for i in range(len(returns))][::-1])
        weights = weights / weights.sum()
        weighted_var = np.sum(weights * returns**2)
        return np.sqrt(weighted_var * 252)

    def calculate_signals(self, prices_dict: Dict[str, np.ndarray],
                         current_time: int) -> List[Order]:
        p1 = prices_dict.get(self.symbol1, np.array([]))
        p2 = prices_dict.get(self.symbol2, np.array([]))

        if len(p1) < self.lookback or len(p2) < self.lookback:
            return []
        current_hedge_ratio = self.calculate_hedge_ratio(p1[-self.lookback:], p2[-self.lookback:])
        spread = p1[-self.lookback:] - current_hedge_ratio * p2[-self.lookback:]

        if len(spread) < max(self.vol_window_long, self.vol_window_short) + 10:
            return []
        spread_returns = np.diff(np.log(np.abs(spread) + 1e-8))
        recent_vol = self.calculate_ewma_volatility(spread_returns[-self.vol_window_short:])
        historical_vol = np.std(spread_returns[-self.vol_window_long:-self.vol_window_short]) * np.sqrt(252)
        if historical_vol < 1e-8:
            return []

        vol_ratio = recent_vol / historical_vol
        rolling_window = min(50, len(spread) // 2)
        rolling_mean = np.mean(spread[-rolling_window:])
        rolling_std = max(np.std(spread[-rolling_window:]), 1e-8)
        z_score = (spread[-1] - rolling_mean) / rolling_std
        orders = []
        if self.is_flat() and  vol_ratio > self.vol_threshold and abs(z_score) > self.price_threshold:
            self.entry_hedge_ratio = current_hedge_ratio
            hedge_quantity = int(self.base_quantity * self.entry_hedge_ratio)
            if z_score > 0:
                orders = [
                    Order(self.symbol1, OrderSide.SELL, self.base_quantity),
                    Order(self.symbol2, OrderSide.BUY, hedge_quantity)
                ]
                self.positions[self.symbol1] = Position(self.symbol1, -self.base_quantity, p1[-1])
                self.positions[self.symbol2] = Position(self.symbol2, hedge_quantity, p2[-1])
            else:
                orders = [
                    Order(self.symbol1, OrderSide.BUY, self.base_quantity),
                    Order(self.symbol2, OrderSide.SELL, hedge_quantity)
                ]
                self.positions[self.symbol1] = Position(self.symbol1, self.base_quantity, p1[-1])
                self.positions[self.symbol2] = Position(self.symbol2, -hedge_quantity, p2[-1])
        elif not self.is_flat() and (
            vol_ratio < self.exit_vol_threshold or
            abs(z_score) < 0.5
        ):
            if self.entry_hedge_ratio is not None:
                pos1 = self.positions.get(self.symbol1)
                pos2 = self.positions.get(self.symbol2)

                if pos1 and pos2:
                    if pos1.quantity > 0:
                        orders = [
                            Order(self.symbol1, OrderSide.SELL, abs(pos1.quantity)),
                            Order(self.symbol2, OrderSide.BUY, abs(pos2.quantity))
                        ]
                    else:
                        orders = [
                            Order(self.symbol1, OrderSide.BUY, abs(pos1.quantity)),
                            Order(self.symbol2, OrderSide.SELL, abs(pos2.quantity))
                        ]
                    self.positions.clear()
                    self.entry_hedge_ratio = None
        return orders

    def get_strategy_name(self) -> str:
        return f"VolMeanRev_{self.vol_threshold}_{self.price_threshold}"

class VolatilityBreakoutStrategy(QuantStrategy):
    def __init__(self, symbol_pair: Tuple[str, str], **params):
        super().__init__(symbol_pair, **params)
        self.lookback = params.get('lookback', 100)
        self.vol_window = params.get('vol_window', 15)
        self.vol_threshold = params.get('vol_threshold', 1.3)
        self.momentum_threshold = params.get('momentum_threshold', 0.01)
        self.base_quantity = params.get('base_quantity', 100)

    def calculate_signals(self, prices_dict: Dict[str, np.ndarray],
                         current_time: int) -> List[Order]:
        p1 = prices_dict.get(self.symbol1, np.array([]))
        p2 = prices_dict.get(self.symbol2, np.array([]))

        if len(p1) < self.lookback or len(p2) < self.lookback:
            return []

        current_hedge_ratio = self.calculate_hedge_ratio(p1[-self.lookback:], p2[-self.lookback:])
        spread = p1[-self.lookback:] - current_hedge_ratio * p2[-self.lookback:]

        if len(spread) < self.vol_window * 2:
            return []

        spread_returns = np.diff(spread)
        recent_vol = np.std(spread_returns[-self.vol_window:])
        historical_vol = np.std(spread_returns[-self.vol_window*2:-self.vol_window])

        if historical_vol < 1e-8:
            return []

        vol_ratio = recent_vol / historical_vol
        momentum = np.mean(spread_returns[-5:])

        orders = []
        if self.is_flat() and vol_ratio > self.vol_threshold and abs(momentum) > self.momentum_threshold:
            self.entry_hedge_ratio = current_hedge_ratio
            hedge_quantity = int(self.base_quantity * self.entry_hedge_ratio)
            if momentum > 0:
                orders = [
                    Order(self.symbol1, OrderSide.BUY, self.base_quantity),
                    Order(self.symbol2, OrderSide.SELL, hedge_quantity)
                ]
                self.positions[self.symbol1] = Position(self.symbol1, self.base_quantity, p1[-1])
                self.positions[self.symbol2] = Position(self.symbol2, -hedge_quantity, p2[-1])
            else:
                orders = [
                    Order(self.symbol1, OrderSide.SELL, self.base_quantity),
                    Order(self.symbol2, OrderSide.BUY, hedge_quantity)
                ]
                self.positions[self.symbol1] = Position(self.symbol1, -self.base_quantity, p1[-1])
                self.positions[self.symbol2] = Position(self.symbol2, hedge_quantity, p2[-1])
        elif not self.is_flat() and vol_ratio < 1.1:
            pos1 = self.positions.get(self.symbol1)
            pos2 = self.positions.get(self.symbol2)

            if pos1 and pos2:
                if pos1.quantity > 0:
                    orders = [
                        Order(self.symbol1, OrderSide.SELL, abs(pos1.quantity)),
                        Order(self.symbol2, OrderSide.BUY, abs(pos2.quantity))
                    ]
                else:
                    orders = [
                        Order(self.symbol1, OrderSide.BUY, abs(pos1.quantity)),
                        Order(self.symbol2, OrderSide.SELL, abs(pos2.quantity))
                    ]
                self.positions.clear()
                self.entry_hedge_ratio = None

        return orders

    def get_strategy_name(self) -> str:
        return f"VolBreakout_{self.vol_threshold}_{self.momentum_threshold}"


class FastPortfolio:
    def __init__(self, initial_capital: float = 1000000):
        self.initial_capital = initial_capital
        self.cash = initial_capital
        self.positions = defaultdict(float)
        self.trades = []
        self.equity_history = []

    def execute_order(self, order: Order, price: float, timestamp: int):
        """Execute order with transaction costs"""
        cost_rate = 0.001  # 10 bps

        if order.side == OrderSide.BUY:
            total_cost = order.quantity * price * (1 + cost_rate)
            if self.cash >= total_cost:
                self.cash -= total_cost
                self.positions[order.symbol] += order.quantity
                self.trades.append({
                    'time': timestamp,
                    'symbol': order.symbol,
                    'side': 'BUY',
                    'qty': order.quantity,
                    'price': price
                })
        else:  # SELL
            if self.positions[order.symbol] >= order.quantity:
                proceeds = order.quantity * price * (1 - cost_rate)
                self.cash += proceeds
                self.positions[order.symbol] -= order.quantity
                self.trades.append({
                    'time': timestamp,
                    'symbol': order.symbol,
                    'side': 'SELL',
                    'qty': order.quantity,
                    'price': price
                })

    def update_value(self, prices: Dict[str, float], timestamp: int):
        """Update portfolio value"""
        total_value = self.cash
        for symbol, position in self.positions.items():
            if symbol in prices and position != 0:
                total_value += position * prices[symbol]

        self.equity_history.append((timestamp, total_value))

    def get_metrics(self) -> Dict[str, float]:
        """Calculate performance metrics"""
        if len(self.equity_history) < 10:
            return {
                'total_return': 0.0,
                'sharpe_ratio': 0.0,
                'max_drawdown': 0.0,
                'final_value': self.initial_capital,
                'num_trades': len(self.trades)
            }

        values = [v[1] for v in self.equity_history]
        returns = [(values[i] - values[i-1]) / values[i-1] for i in range(1, len(values))]

        # Performance calculations
        total_return = (values[-1] - values[0]) / values[0]

        if len(returns) > 1:
            mean_ret = np.mean(returns)
            std_ret = np.std(returns) + 1e-8
            sharpe_ratio = mean_ret / std_ret * np.sqrt(252)
        else:
            sharpe_ratio = 0.0

        # Max drawdown calculation
        peak = values[0]
        max_dd = 0.0
        for value in values:
            peak = max(peak, value)
            drawdown = (value - peak) / peak
            max_dd = min(max_dd, drawdown)

        return {
            'total_return': total_return,
            'sharpe_ratio': sharpe_ratio,
            'max_drawdown': max_dd,
            'final_value': values[-1],
            'num_trades': len(self.trades)
        }

def run_backtest(strategy: QuantStrategy, data: Dict[str, np.ndarray],
                timestamps: np.ndarray) -> Dict[str, Any]:
    """Run backtest for a single strategy"""
    portfolio = FastPortfolio()
    lookback_buffer = max(200, strategy.lookback * 2)
    logged_errors = 0
    max_log_errors = 5

    for i in range(lookback_buffer, len(timestamps), 5):  # Skip every 5 points for speed
        # Current prices
        current_prices = {symbol: float(data[symbol][i]) for symbol in strategy.symbol_pair if symbol in data}

        # Price windows for strategy
        price_windows = {
            symbol: np.asarray(data[symbol][max(0, i-lookback_buffer):i+1], dtype=float)
            for symbol in strategy.symbol_pair if symbol in data
        }

        # Generate signals
        try:
            orders = strategy.calculate_signals(price_windows, i)

            # Execute orders
            for order in orders:
                if order.symbol in current_prices:
                    portfolio.execute_order(order, current_prices[order.symbol], i)
        except Exception as e:
            # Log a few errors (first few only) for debugging instead of silent skipping
            if logged_errors < max_log_errors:
                print(f"[run_backtest] Exception in calculate_signals for {strategy.get_strategy_name()} "
                      f"pair={strategy.symbol_pair} i={i}: {e}")
                logged_errors += 1
            continue  # keep running other time-steps / strategies

        # Update portfolio value
        portfolio.update_value(current_prices, i)

    # Get final metrics
    metrics = portfolio.get_metrics()
    metrics['strategy_name'] = strategy.get_strategy_name()
    metrics['params'] = strategy.params

    return metrics



# Optimization Engine
class StrategyOptimizer:
    def __init__(self, data: Dict[str, np.ndarray], timestamps: np.ndarray):
        self.data = data
        self.timestamps = timestamps
        self.symbols = list(data.keys())

    def run_optimization(self) -> pd.DataFrame:
        print(f"Optimizing strategies for {len(self.symbols)} assets...")

        pairs = list(itertools.combinations(self.symbols, 2))
        print(f"Testing {len(pairs)} pairs...")
        param_grids = self._get_parameter_grids()

        results = []
        total_tests = sum(len(params) * len(pairs) for params in param_grids.values())
        print(f"Total tests: {total_tests}")

        test_count = 0
        start_time = time.time()

        for strategy_class, param_list in param_grids.items():
            for params in param_list:
                for pair in pairs:
                    test_count += 1

                    if test_count % 50 == 0:
                        elapsed = time.time() - start_time
                        progress = test_count / total_tests
                        eta = elapsed / progress * (1 - progress) if progress > 0 else 0
                        print(f"Progress: {test_count}/{total_tests} ({progress:.1%}) - ETA: {eta/60:.1f}min")

                    try:
                        strategy = strategy_class(pair, **params)
                        result = run_backtest(strategy, self.data, self.timestamps)

                        result['strategy_class'] = strategy_class.__name__
                        result['symbol_pair'] = f"{pair[0]}_{pair[1]}"
                        result['pair_symbols'] = pair

                        result['combined_score'] = (
                            result['sharpe_ratio'] * 0.6 +
                            result['total_return'] * 0.4 +
                            max(0, result['max_drawdown']) * 0.2
                        )

                        results.append(result)

                    except Exception as e:
                        results.append({
                            'strategy_class': strategy_class.__name__,
                            'symbol_pair': f"{pair[0]}_{pair[1]}",
                            'total_return': 0.0,
                            'sharpe_ratio': 0.0,
                            'max_drawdown': 0.0,
                            'combined_score': -10.0,
                            'final_value': 1000000.0,
                            'num_trades': 0,
                            'error': str(e)
                        })
        df = pd.DataFrame(results)
        df = df.sort_values('combined_score', ascending=False)

        elapsed_total = time.time() - start_time
        print(f"\nOptimization completed in {elapsed_total/60:.1f} minutes")

        return df

    def _get_parameter_grids(self) -> Dict[type, List[Dict]]:
        mean_reversion_params = []
        for entry in [1.8, 2.2, 2.6]:
            for exit_thresh in [0.4, 0.6, 0.8]:
                for stop_loss in [0.02, 0.04]:
                    mean_reversion_params.append({
                        'lookback': 60,
                        'entry_threshold': entry,
                        'exit_threshold': exit_thresh,
                        'stop_loss': stop_loss,
                        'max_holding': 200,
                        'cointegration_check': True,
                        'min_correlation': 0.6
                    })

        adaptive_params = []
        for entry in [1.6, 2.0, 2.4]:
            for exit_thresh in [0.5, 0.7]:
                for vol_lookback in [15, 25]:
                    adaptive_params.append({
                        'lookback': 80,
                        'entry_threshold': entry,
                        'exit_threshold': exit_thresh,
                        'stop_loss': 0.03,
                        'max_holding': 150,
                        'volatility_lookback': vol_lookback,
                        'adaptive_thresholds': True,
                        'cointegration_check': True,
                        'min_correlation': 0.5
                    })
        momentum_params = []
        for mom_window in [8, 12, 16]:
            for entry in [0.015, 0.025, 0.035]:
                momentum_params.append({
                    'lookback': 50,
                    'momentum_window': mom_window,
                    'entry_threshold': entry,
                    'exit_threshold': entry * 0.3,
                    'base_quantity': 100
                })
        vol_mean_reversion_params = []
        for vol_thresh in [1.5, 1.8, 2.2]:
            for price_thresh in [1.8, 2.2]:
                vol_mean_reversion_params.append({
                    'lookback': 100,
                    'vol_window_short': 10,
                    'vol_window_long': 30,
                    'vol_threshold': vol_thresh,
                    'price_threshold': price_thresh,
                    'exit_vol_threshold': 1.2,
                    'base_quantity': 100
                })
        vol_breakout_params = []
        for vol_thresh in [1.2, 1.4, 1.6]:
            for mom_thresh in [0.008, 0.015, 0.025]:
                vol_breakout_params.append({
                    'lookback': 100,
                    'vol_window': 15,
                    'vol_threshold': vol_thresh,
                    'momentum_threshold': mom_thresh,
                    'base_quantity': 100
                })

        return {
            EnhancedMeanReversionStrategy: mean_reversion_params[:12],
            AdaptiveMeanReversionStrategy: adaptive_params[:8],
            MomentumStrategy: momentum_params[:9],
            VolatilityMeanReversionStrategy: vol_mean_reversion_params[:8],
            VolatilityBreakoutStrategy: vol_breakout_params[:6]
        }

# Analysis and Reporting
def analyze_results(df: pd.DataFrame) -> str:
    """Generate comprehensive results analysis"""
    report = []
    report.append("=" * 80)
    report.append("QUANTITATIVE STRATEGY OPTIMIZATION RESULTS")
    report.append("=" * 80)

    # Summary statistics
    total_strategies = len(df)
    profitable_strategies = len(df[df['total_return'] > 0])
    high_sharpe_strategies = len(df[df['sharpe_ratio'] > 1.0])

    report.append(f"\nSUMMARY:")
    report.append(f"Total strategies tested: {total_strategies}")
    report.append(f"Profitable strategies: {profitable_strategies} ({profitable_strategies/total_strategies:.1%})")
    report.append(f"High Sharpe (>1.0) strategies: {high_sharpe_strategies} ({high_sharpe_strategies/total_strategies:.1%})")

    # Top 10 strategies
    report.append(f"\nTOP 10 STRATEGIES:")
    report.append("-" * 80)
    top_10 = df.head(10)

    for idx, (_, row) in enumerate(top_10.iterrows()):
        report.append(f"\n#{idx + 1}. {row['strategy_class']} | {row['symbol_pair']}")
        report.append(f"    Return: {row['total_return']:.2%} | Sharpe: {row['sharpe_ratio']:.2f} | MaxDD: {row['max_drawdown']:.2%}")
        report.append(f"    Trades: {row['num_trades']} | Score: {row['combined_score']:.2f}")

        # Show key parameters
        if 'params' in row and isinstance(row['params'], dict):
            key_params = {k: v for k, v in list(row['params'].items())[:3]}  # Show first 3 params
            report.append(f"    Key Params: {key_params}")

    # Best by strategy type
    report.append(f"\nBEST STRATEGY OF EACH TYPE:")
    report.append("-" * 50)

    for strategy_type in df['strategy_class'].unique():
        best_of_type = df[df['strategy_class'] == strategy_type].iloc[0]
        report.append(f"\n{strategy_type}:")
        report.append(f"  Pair: {best_of_type['symbol_pair']}")
        report.append(f"  Return: {best_of_type['total_return']:.2%} | Sharpe: {best_of_type['sharpe_ratio']:.2f}")
        report.append(f"  MaxDD: {best_of_type['max_drawdown']:.2%} | Trades: {best_of_type['num_trades']}")

    # Performance distribution
    report.append(f"\nPERFORMANCE DISTRIBUTION:")
    report.append("-" * 30)

    returns = df['total_return']
    sharpes = df['sharpe_ratio']

    report.append(f"Returns - Mean: {returns.mean():.2%}, Std: {returns.std():.2%}")
    report.append(f"         Best: {returns.max():.2%}, Worst: {returns.min():.2%}")
    report.append(f"Sharpe  - Mean: {sharpes.mean():.2f}, Std: {sharpes.std():.2f}")
    report.append(f"         Best: {sharpes.max():.2f}, Worst: {sharpes.min():.2f}")

    # Best pairs
    report.append(f"\nBEST PERFORMING PAIRS:")
    report.append("-" * 25)

    pair_performance = df.groupby('symbol_pair').agg({
        'combined_score': 'max',
        'total_return': 'max',
        'sharpe_ratio': 'max'
    }).sort_values('combined_score', ascending=False).head(5)

    for pair, row in pair_performance.iterrows():
        report.append(f"{pair}: Score={row['combined_score']:.2f}, Return={row['total_return']:.2%}, Sharpe={row['sharpe_ratio']:.2f}")

    return "\n".join(report)



def load_market_data_strict_v2(symbols: List[str],
                               data_dir: str = "data",
                               start: str = "2018-01-01",
                               end: str = None,
                               use_adj_close: bool = True,
                               min_rows: int = 250
                              ) -> Tuple[Dict[str, np.ndarray], np.ndarray]:
    """
    Robust strict loader. Tries CSVs in data_dir, otherwise downloads via yfinance.
    On CSV parse issues it prints the CSV header / first rows to help debug.
    Raises RuntimeError when any requested symbol cannot be loaded with >= min_rows.
    """
    if end is None:
        end = pd.Timestamp.today().strftime("%Y-%m-%d")

    data_dir = Path(data_dir)
    data_dir.mkdir(parents=True, exist_ok=True)

    # Import yfinance safely
    try:
        yfinance = importlib.import_module("yfinance")
    except Exception as e:
        raise RuntimeError("yfinance is required to download missing symbols. Install with 'pip install yfinance'. "
                           f"Import error: {e}")

    loaded_series: Dict[str, pd.Series] = {}
    failures = []

    def inspect_csv(path: Path, nrows: int = 6) -> str:
        try:
            raw = pd.read_csv(path, dtype=str, nrows=nrows, header=0)
            buf = io.StringIO()
            raw.to_csv(buf, index=False)
            return buf.getvalue()
        except Exception as e:
            return f"<failed to read preview: {e}>"

    for sym in symbols:
        csv_path = data_dir / f"{sym}.csv"
        success = False

        # Try CSV first (if present)
        if csv_path.exists():
            # Attempt to read with parse_dates on index
            try:
                df = pd.read_csv(csv_path, index_col=0, parse_dates=True)
            except Exception:
                # Try without parse_dates to inspect structure
                preview = inspect_csv(csv_path)
                failures.append((sym, f"CSV parse error; preview:\n{preview}"))
                df = None
            if isinstance(df, pd.DataFrame):
                # Pick best price column
                col = None
                if use_adj_close and "Adj Close" in df.columns:
                    col = "Adj Close"
                elif "Close" in df.columns:
                    col = "Close"
                else:
                    # choose last numeric column
                    numeric_cols = [c for c in df.columns if pd.api.types.is_numeric_dtype(df[c])]
                    if numeric_cols:
                        col = numeric_cols[-1]
                if col is None:
                    preview = inspect_csv(csv_path)
                    failures.append((sym, f"No numeric price column found in CSV; preview:\n{preview}"))
                else:
                    series = df[col].dropna()
                    # Force index -> datetime safely
                    coerced_idx = pd.to_datetime(series.index, errors="coerce")
                    valid_mask = ~coerced_idx.isna()
                    if valid_mask.sum() < min_rows:
                        preview = inspect_csv(csv_path)
                        failures.append((sym, f"CSV parsed but only {valid_mask.sum()} parseable date rows (<{min_rows}); preview:\n{preview}"))
                    else:
                        series.index = coerced_idx
                        series = series.sort_index()
                        loaded_series[sym] = series.rename(sym)
                        success = True

        # If CSV either missing or failed, try download
        if not success:
            try:
                print(f"Downloading {sym} from {start} to {end}...")
                df_down = yfinance.download(sym, start=start, end=end, progress=False)
                if not isinstance(df_down, pd.DataFrame) or df_down.empty:
                    failures.append((sym, "download returned empty dataframe"))
                else:
                    # Save a cached CSV for inspection
                    try:
                        df_down.to_csv(csv_path)
                    except Exception:
                        pass
                    # choose column
                    if use_adj_close and "Adj Close" in df_down.columns:
                        col = "Adj Close"
                    elif "Close" in df_down.columns:
                        col = "Close"
                    else:
                        numeric_cols = [c for c in df_down.columns if pd.api.types.is_numeric_dtype(df_down[c])]
                        if not numeric_cols:
                            failures.append((sym, "downloaded but no numeric price column"))
                            continue
                        col = numeric_cols[-1]
                    series = df_down[col].dropna()
                    coerced_idx = pd.to_datetime(series.index, errors="coerce")
                    valid_mask = ~coerced_idx.isna()
                    if valid_mask.sum() < min_rows:
                        failures.append((sym, f"downloaded but only {valid_mask.sum()} parseable date rows (<{min_rows})"))
                    else:
                        series.index = coerced_idx
                        series = series.sort_index()
                        loaded_series[sym] = series.rename(sym)
                        success = True
            except Exception as e:
                failures.append((sym, f"download error: {repr(e)}"))

    # If any failures, raise a helpful error listing which symbols failed and why
    if failures:
        msgs = "\n".join([f" - {s}: {reason}" for s, reason in failures])
        raise RuntimeError(
            "One or more symbols failed to load. Details:\n" + msgs +
            "\n\nFix the CSV files in the data/ folder or ensure yfinance can download the tickers. "
            "Each CSV must have a date-like index (first column) and a 'Close' or 'Adj Close' numeric column."
        )

    # Align via inner join over the datetime index
    merged = pd.concat(list(loaded_series.values()), axis=1, join="inner").dropna(how="any")
    merged.index = pd.to_datetime(merged.index, errors="coerce")
    merged = merged[~merged.index.isna()]  # drop any rows that still couldn't parse

    if merged.shape[0] < min_rows:
        raise RuntimeError(f"Aligned data has only {merged.shape[0]} rows after intersection; need >= {min_rows} rows.")

    # Final return format
    data = {col: merged[col].to_numpy() for col in merged.columns}
    timestamps = np.arange(merged.shape[0])
    print(f"Loaded {len(data)} symbols with {merged.shape[0]} aligned rows (from {merged.index[0].date()} to {merged.index[-1].date()})")
    return data, timestamps



# ----------------- Example usage: replace your earlier data generation block -----------------

if __name__ == "__main__":
    print("=" * 80)
    print("ADVANCED QUANTITATIVE STRATEGY OPTIMIZATION (with real-data loader)")
    print("=" * 80)
    print("Running without JIT compilation for maximum stability")

    tickers = ["AAPL","MSFT","GOOG","AMZN","TSLA","NVDA"]
    data, timestamps = load_market_data_strict_v2(tickers, data_dir="data", start="2018-01-01", end=None, min_rows=250)
    # print(data)

    optimizer = StrategyOptimizer(data, timestamps)
    results_df = optimizer.run_optimization()

    # Generate and print report
    report = analyze_results(results_df)
    print("\n" + report)

    # Additional insights
    print("\n" + "=" * 80)
    print("KEY INSIGHTS")
    print("=" * 80)

    best_strategy = results_df.iloc[0]
    print(f"\n🏆 BEST OVERALL STRATEGY:")
    print(f"   Strategy: {best_strategy['strategy_class']}")
    print(f"   Pair: {best_strategy['symbol_pair']}")
    print(f"   Annual Return: {best_strategy['total_return']:.2%}")
    print(f"   Sharpe Ratio: {best_strategy['sharpe_ratio']:.2f}")
    print(f"   Max Drawdown: {best_strategy['max_drawdown']:.2%}")
    print(f"   Number of Trades: {best_strategy['num_trades']}")

    if 'params' in best_strategy:
        print(f"   Optimal Parameters: {best_strategy['params']}")

    # Strategy type analysis
    print(f"\n📊 STRATEGY TYPE PERFORMANCE:")
    strategy_summary = results_df.groupby('strategy_class').agg({
        'total_return': ['mean', 'max', 'std'],
        'sharpe_ratio': ['mean', 'max', 'std'],
        'combined_score': ['mean', 'max']
    }).round(3)

    for strategy in results_df['strategy_class'].unique():
        subset = results_df[results_df['strategy_class'] == strategy]
        avg_return = subset['total_return'].mean()
        avg_sharpe = subset['sharpe_ratio'].mean()
        best_score = subset['combined_score'].max()

        print(f"   {strategy}:")
        print(f"     Avg Return: {avg_return:.2%} | Avg Sharpe: {avg_sharpe:.2f} | Best Score: {best_score:.2f}")

    # Save results
    results_df.to_csv('optimization_results.csv', index=False)
    print(f"\n💾 Results saved to 'optimization_results.csv'")

    # Final recommendations
    print(f"\n🎯 TRADING RECOMMENDATIONS:")

    # Get top 3 strategies with different types
    top_diverse = []
    seen_types = set()

    for _, row in results_df.iterrows():
        if row['strategy_class'] not in seen_types and len(top_diverse) < 3:
            top_diverse.append(row)
            seen_types.add(row['strategy_class'])

    for i, strategy in enumerate(top_diverse):
        print(f"\n   Strategy #{i+1}: {strategy['strategy_class']}")
        print(f"   📈 Trade: {strategy['symbol_pair']} pair")
        print(f"   💰 Expected Return: {strategy['total_return']:.2%}")
        print(f"   ⚖️  Risk-Adjusted Return (Sharpe): {strategy['sharpe_ratio']:.2f}")
        print(f"   📉 Maximum Drawdown: {strategy['max_drawdown']:.2%}")

        if strategy['sharpe_ratio'] > 2.0:
            print(f"   ⭐ EXCELLENT: Sharpe > 2.0 indicates outstanding risk-adjusted returns")
        elif strategy['sharpe_ratio'] > 1.5:
            print(f"   ✅ GOOD: Sharpe > 1.5 indicates strong risk-adjusted returns")
        elif strategy['sharpe_ratio'] > 1.0:
            print(f"   👍 DECENT: Sharpe > 1.0 indicates acceptable risk-adjusted returns")

    print("\n" + "=" * 80)
    print("OPTIMIZATION COMPLETE! 🚀")
    print("=" * 80)

ADVANCED QUANTITATIVE STRATEGY OPTIMIZATION (with real-data loader)
Running without JIT compilation for maximum stability
Loaded 6 symbols with 1930 aligned rows (from 2018-01-02 to 2025-09-05)
Optimizing strategies for 6 assets...
Testing 15 pairs...
Total tests: 615
Progress: 50/615 (8.1%) - ETA: 3.5min
Progress: 100/615 (16.3%) - ETA: 3.2min
Progress: 150/615 (24.4%) - ETA: 3.4min
Progress: 200/615 (32.5%) - ETA: 3.3min
Progress: 250/615 (40.7%) - ETA: 3.5min
Progress: 300/615 (48.8%) - ETA: 3.2min
Progress: 350/615 (56.9%) - ETA: 2.5min
Progress: 400/615 (65.0%) - ETA: 2.0min
Progress: 450/615 (73.2%) - ETA: 1.5min
Progress: 500/615 (81.3%) - ETA: 1.1min
Progress: 550/615 (89.4%) - ETA: 0.6min
Progress: 600/615 (97.6%) - ETA: 0.1min

Optimization completed in 5.6 minutes

QUANTITATIVE STRATEGY OPTIMIZATION RESULTS

SUMMARY:
Total strategies tested: 615
Profitable strategies: 614 (99.8%)
High Sharpe (>1.0) strategies: 464 (75.4%)

TOP 10 STRATEGIES:
---------------------------------