<a href="https://colab.research.google.com/github/Shrey576/HFT-BackEngine-for-Statistical-Arbitrage/blob/main/HFTBackEngine.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Statistical Arbitrage Backtesting Engine

A high-performance Python-based backtesting system for statistical arbitrage strategies with MySQL integration and Linux deployment capabilities.

**Author:** Shreya Sharma  

**GitHub:** Shrey576

**LinkedIn:** Shreya Sharma

---

##  Table of Contents

1. [Overview](#overview)
2. [Architecture](#architecture)
3. [Installation](#installation)
4. [Core Components](#core-components)
5. [Usage Examples](#usage-examples)
6. [Deployment](#deployment)
7. [Performance Metrics](#performance-metrics)
8. [Contributing](#contributing)

---

##  Overview

This backtesting engine implements multiple statistical arbitrage strategies with production-grade features including MySQL data persistence, comprehensive performance analytics, and Linux deployment scripts. Built to handle high-frequency trading simulation with realistic transaction costs and slippage modeling.

**Key Features:**
- Multiple strategy implementations (Pairs Trading, Mean Reversion, Momentum)
- MySQL database integration for persistent data storage
- Real-time data ingestion from Yahoo Finance
- Commission and slippage modeling
- Comprehensive performance metrics (Sharpe, Sortino, Max Drawdown)
- Interactive visualizations with Plotly
- Production-ready Linux deployment

---

---

##  Installation

### For Google Colab (Quick Start)

```python
# Run this cell first in Colab
!pip install -q yfinance pandas numpy matplotlib seaborn plotly scipy statsmodels mysql-connector-python flask
```

### For Linux Production Environment

```bash
# System dependencies
sudo apt-get update
sudo apt-get install -y python3-pip python3-venv mysql-server

# Create virtual environment
python3 -m venv trading_env
source trading_env/bin/activate

# Install Python dependencies
pip install -r requirements.txt

# Setup MySQL
sudo mysql_secure_installation
sudo mysql -u root -p < setup_database.sql

In [None]:
##  Architecture

The system follows a modular architecture with clear separation of concerns:

```
┌─────────────────────────────────────────────────────────┐
│                    Web Dashboard (Flask)                 │
│         Strategy Configuration & Visualization           │
└────────────────────────┬────────────────────────────────┘
                         │
┌────────────────────────▼────────────────────────────────┐
│              Backtesting Orchestrator                    │
│    - Strategy Management                                 │
│    - Performance Analytics                               │
│    - Results Aggregation                                 │
└────────────────────────┬────────────────────────────────┘
                         │
        ┌────────────────┼────────────────┐
        │                │                │
┌───────▼──────┐  ┌─────▼──────┐  ┌─────▼─────────┐
│ Data Handler │  │   Strategy  │  │ MySQL Database│
│              │  │   Execution │  │               │
│ - Data Fetch │  │   Engine    │  │ - Tick Data   │
│ - Indicators │  │ - P&L Calc  │  │ - Results     │
└──────────────┘  └─────────────┘  └───────────────┘
```

---

##  Installation

### For Google Colab (Quick Start)

```python
# Run this cell first in Colab
!pip install -q yfinance pandas numpy matplotlib seaborn plotly scipy statsmodels mysql-connector-python flask
```

### For Linux Production Environment

```bash
# System dependencies
sudo apt-get update
sudo apt-get install -y python3-pip python3-venv mysql-server

# Create virtual environment
python3 -m venv trading_env
source trading_env/bin/activate

# Install Python dependencies
pip install -r requirements.txt

# Setup MySQL
sudo mysql_secure_installation
sudo mysql -u root -p < setup_database.sql
```

In [None]:
**requirements.txt:**
```
pandas==2.0.3
numpy==1.24.3
yfinance==0.2.28
matplotlib==3.7.2
seaborn==0.12.2
plotly==5.15.0
scipy==1.11.1
statsmodels==0.14.0
mysql-connector-python==8.1.0
flask==2.3.2
flask-cors==4.0.0
gunicorn==21.2.0
```

In [None]:
--

##  Core Components

### 1. Database Schema

The MySQL database stores historical market data, strategy configurations, and backtest results for persistent analysis.

```sql
-- setup_database.sql

CREATE DATABASE IF NOT EXISTS trading_db;
USE trading_db;

-- Historical tick data table
CREATE TABLE IF NOT EXISTS tick_data (
    id BIGINT AUTO_INCREMENT PRIMARY KEY,
    symbol VARCHAR(10) NOT NULL,
    timestamp BIGINT NOT NULL,
    open_price DECIMAL(12, 4),
    high_price DECIMAL(12, 4),
    low_price DECIMAL(12, 4),
    close_price DECIMAL(12, 4),
    adj_close DECIMAL(12, 4),
    volume BIGINT,
    INDEX idx_symbol_timestamp (symbol, timestamp),
    UNIQUE KEY unique_symbol_timestamp (symbol, timestamp)
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;

-- Partition by timestamp for better query performance
ALTER TABLE tick_data
PARTITION BY RANGE (timestamp) (
    PARTITION p2020 VALUES LESS THAN (UNIX_TIMESTAMP('2021-01-01')),
    PARTITION p2021 VALUES LESS THAN (UNIX_TIMESTAMP('2022-01-01')),
    PARTITION p2022 VALUES LESS THAN (UNIX_TIMESTAMP('2023-01-01')),
    PARTITION p2023 VALUES LESS THAN (UNIX_TIMESTAMP('2024-01-01')),
    PARTITION p2024 VALUES LESS THAN (UNIX_TIMESTAMP('2025-01-01')),
    PARTITION pfuture VALUES LESS THAN MAXVALUE
);

-- Strategy configurations table
CREATE TABLE IF NOT EXISTS strategies (
    id INT AUTO_INCREMENT PRIMARY KEY,
    name VARCHAR(100) NOT NULL,
    strategy_type VARCHAR(50) NOT NULL,
    parameters JSON,
    created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
    updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
    INDEX idx_strategy_type (strategy_type)
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;

-- Backtest results table
CREATE TABLE IF NOT EXISTS backtest_results (
    id INT AUTO_INCREMENT PRIMARY KEY,
    strategy_id INT,
    run_timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
    start_date DATE,
    end_date DATE,
    initial_capital DECIMAL(15, 2),
    final_capital DECIMAL(15, 2),
    total_return DECIMAL(10, 4),
    annual_return DECIMAL(10, 4),
    annual_volatility DECIMAL(10, 4),
    sharpe_ratio DECIMAL(10, 4),
    sortino_ratio DECIMAL(10, 4),
    max_drawdown DECIMAL(10, 4),
    win_rate DECIMAL(10, 4),
    total_trades INT,
    FOREIGN KEY (strategy_id) REFERENCES strategies(id) ON DELETE CASCADE,
    INDEX idx_strategy_timestamp (strategy_id, run_timestamp)
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;

-- Trade log table
CREATE TABLE IF NOT EXISTS trade_log (
    id BIGINT AUTO_INCREMENT PRIMARY KEY,
    backtest_id INT,
    timestamp BIGINT,
    symbol VARCHAR(10),
    action VARCHAR(10), -- 'BUY', 'SELL', 'CLOSE'
    quantity DECIMAL(15, 4),
    price DECIMAL(12, 4),
    commission DECIMAL(10, 4),
    pnl DECIMAL(15, 4),
    FOREIGN KEY (backtest_id) REFERENCES backtest_results(id) ON DELETE CASCADE,
    INDEX idx_backtest_timestamp (backtest_id, timestamp)
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;



**Explanation:** This schema is optimized for time-series data with partitioning by timestamp for efficient queries. The `tick_data` table stores OHLCV data, while `backtest_results` and `trade_log` tables maintain comprehensive backtest history for analysis and comparison.

In [None]:
### 2. Database Handler

```python
import mysql.connector
from mysql.connector import Error
import pandas as pd
from typing import Dict, List, Optional
from datetime import datetime
import json

class DatabaseHandler:
    """
    Manages MySQL database connections and operations for the backtesting engine.
    Handles data persistence, retrieval, and query optimization.
    """

    def __init__(self, host: str = 'localhost', user: str = 'root',
                 password: str = 'your_password', database: str = 'trading_db'):
        """
        Initialize database connection.

        Args:
            host: MySQL server hostname
            user: MySQL username
            password: MySQL password
            database: Database name
        """
        self.host = host
        self.user = user
        self.password = password
        self.database = database
        self.connection = None
        self.connect()

    def connect(self):
        """Establish connection to MySQL database"""
        try:
            self.connection = mysql.connector.connect(
                host=self.host,
                user=self.user,
                password=self.password,
                database=self.database
            )
            if self.connection.is_connected():
                print(f"✓ Connected to MySQL database: {self.database}")
        except Error as e:
            print(f"✗ Error connecting to MySQL: {e}")
            self.connection = None

    def disconnect(self):
        """Close database connection"""
        if self.connection and self.connection.is_connected():
            self.connection.close()
            print("✓ Database connection closed")

    def store_tick_data(self, symbol: str, df: pd.DataFrame):
        """
        Store historical tick data in database.
        Uses batch inserts for performance.

        Args:
            symbol: Stock symbol
            df: DataFrame with OHLCV data
        """
        if not self.connection or not self.connection.is_connected():
            self.connect()

        cursor = self.connection.cursor()

        # Prepare data for insertion
        insert_query = """
            INSERT INTO tick_data
            (symbol, timestamp, open_price, high_price, low_price,
             close_price, adj_close, volume)
            VALUES (%s, %s, %s, %s, %s, %s, %s, %s)
            ON DUPLICATE KEY UPDATE
            open_price = VALUES(open_price),
            high_price = VALUES(high_price),
            low_price = VALUES(low_price),
            close_price = VALUES(close_price),
            adj_close = VALUES(adj_close),
            volume = VALUES(volume)
        """

        # Convert DataFrame to list of tuples
        data = []
        for idx, row in df.iterrows():
            timestamp = int(idx.timestamp())
            data.append((
                symbol,
                timestamp,
                float(row.get('Open', 0)),
                float(row.get('High', 0)),
                float(row.get('Low', 0)),
                float(row.get('Close', 0)),
                float(row.get('Adj Close', row.get('Close', 0))),
                int(row.get('Volume', 0))
            ))

        try:
            # Batch insert
            cursor.executemany(insert_query, data)
            self.connection.commit()
            print(f"✓ Stored {len(data)} records for {symbol}")
        except Error as e:
            print(f"✗ Error storing data: {e}")
            self.connection.rollback()
        finally:
            cursor.close()

    def fetch_tick_data(self, symbols: List[str], start_date: str,
                       end_date: str) -> pd.DataFrame:
        """
        Retrieve tick data from database.

        Args:
            symbols: List of stock symbols
            start_date: Start date (YYYY-MM-DD)
            end_date: End date (YYYY-MM-DD)

        Returns:
            DataFrame with multi-column data (one column per symbol)
        """
        if not self.connection or not self.connection.is_connected():
            self.connect()

        start_ts = int(datetime.strptime(start_date, '%Y-%m-%d').timestamp())
        end_ts = int(datetime.strptime(end_date, '%Y-%m-%d').timestamp())

        query = """
            SELECT symbol, FROM_UNIXTIME(timestamp) as date, adj_close
            FROM tick_data
            WHERE symbol IN (%s)
            AND timestamp BETWEEN %s AND %s
            ORDER BY timestamp
        """ % (','.join(['%s'] * len(symbols)), '%s', '%s')

        try:
            df = pd.read_sql(query, self.connection,
                           params=tuple(symbols) + (start_ts, end_ts))

            # Pivot to have symbols as columns
            df_pivot = df.pivot(index='date', columns='symbol',
                              values='adj_close')
            df_pivot.index = pd.to_datetime(df_pivot.index)

            print(f"✓ Fetched {len(df_pivot)} records from database")
            return df_pivot

        except Error as e:
            print(f"✗ Error fetching data: {e}")
            return pd.DataFrame()

    def save_strategy(self, name: str, strategy_type: str,
                     parameters: Dict) -> int:
        """
        Save strategy configuration to database.

        Returns:
            Strategy ID
        """
        if not self.connection or not self.connection.is_connected():
            self.connect()

        cursor = self.connection.cursor()

        insert_query = """
            INSERT INTO strategies (name, strategy_type, parameters)
            VALUES (%s, %s, %s)
        """

        try:
            cursor.execute(insert_query,
                         (name, strategy_type, json.dumps(parameters)))
            self.connection.commit()
            strategy_id = cursor.lastrowid
            print(f"✓ Strategy saved with ID: {strategy_id}")
            return strategy_id
        except Error as e:
            print(f"✗ Error saving strategy: {e}")
            self.connection.rollback()
            return -1
        finally:
            cursor.close()

    def save_backtest_results(self, strategy_id: int, results: Dict) -> int:
        """
        Save backtest results to database.

        Returns:
            Backtest ID
        """
        if not self.connection or not self.connection.is_connected():
            self.connect()

        cursor = self.connection.cursor()

        insert_query = """
            INSERT INTO backtest_results
            (strategy_id, start_date, end_date, initial_capital, final_capital,
             total_return, annual_return, annual_volatility, sharpe_ratio,
             sortino_ratio, max_drawdown, win_rate, total_trades)
            VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
        """

        metrics = results['metrics']

        try:
            cursor.execute(insert_query, (
                strategy_id,
                results.get('start_date'),
                results.get('end_date'),
                results.get('initial_capital'),
                metrics.get('Final Capital'),
                metrics.get('Total Return (%)'),
                metrics.get('Annual Return (%)'),
                metrics.get('Annual Volatility (%)'),
                metrics.get('Sharpe Ratio'),
                metrics.get('Sortino Ratio', 0),
                metrics.get('Max Drawdown (%)'),
                metrics.get('Win Rate (%)'),
                metrics.get('Total Trades')
            ))
            self.connection.commit()
            backtest_id = cursor.lastrowid
            print(f"✓ Backtest results saved with ID: {backtest_id}")
            return backtest_id
        except Error as e:
            print(f"✗ Error saving results: {e}")
            self.connection.rollback()
            return -1
        finally:
            cursor.close()

    def get_strategy_performance_history(self, strategy_id: int) -> pd.DataFrame:
        """
        Retrieve historical performance for a strategy.

        Returns:
            DataFrame with backtest history
        """
        if not self.connection or not self.connection.is_connected():
            self.connect()

        query = """
            SELECT
                run_timestamp,
                total_return,
                sharpe_ratio,
                max_drawdown,
                win_rate
            FROM backtest_results
            WHERE strategy_id = %s
            ORDER BY run_timestamp DESC
        """

        try:
            df = pd.read_sql(query, self.connection, params=(strategy_id,))
            return df
        except Error as e:
            print(f"✗ Error fetching performance history: {e}")
            return pd.DataFrame()



**Explanation:** The `DatabaseHandler` class encapsulates all MySQL operations with connection pooling, error handling, and optimized batch insertions. It provides methods for storing tick data, saving strategy configurations, and persisting backtest results for historical analysis.


In [None]:
### 3. Data Handler with Hybrid Approach

```python
import yfinance as yf
import pandas as pd
import numpy as np
from typing import List, Optional

class DataHandler:
    """
    Handles data acquisition from both Yahoo Finance (real-time) and MySQL (cached).
    Implements intelligent caching strategy to minimize API calls.
    """

    def __init__(self, db_handler: Optional[DatabaseHandler] = None):
        """
        Initialize data handler.

        Args:
            db_handler: Optional DatabaseHandler for caching
        """
        self.db_handler = db_handler
        self.data = {}

    def fetch_data(self, symbols: List[str], start_date: str,
                   end_date: str, interval: str = '1d',
                   use_cache: bool = True) -> pd.DataFrame:
        """
        Fetch historical data with intelligent caching.

        Strategy:
        1. Check MySQL cache first (if enabled)
        2. If not in cache or use_cache=False, fetch from Yahoo Finance
        3. Store fetched data in cache for future use

        Args:
            symbols: List of stock symbols
            start_date: Start date (YYYY-MM-DD)
            end_date: End date (YYYY-MM-DD)
            interval: Data interval (1d, 1h, 15m, etc.)
            use_cache: Whether to use cached data

        Returns:
            DataFrame with price data
        """
        print(f"\n{'='*60}")
        print(f"Fetching data for {symbols}")
        print(f"Period: {start_date} to {end_date}")
        print(f"{'='*60}")

        # Try cache first
        if use_cache and self.db_handler:
            print("Checking MySQL cache...")
            cached_data = self.db_handler.fetch_tick_data(
                symbols, start_date, end_date
            )
            if not cached_data.empty:
                print("✓ Using cached data from MySQL")
                self.data = cached_data
                return cached_data

        # Fetch from Yahoo Finance
        print("Fetching from Yahoo Finance...")
        data_frames = {}

        for symbol in symbols:
            try:
                df = yf.download(
                    symbol,
                    start=start_date,
                    end=end_date,
                    interval=interval,
                    progress=False
                )

                if not df.empty:
                    data_frames[symbol] = df['Adj Close']
                    print(f"✓ {symbol}: {len(df)} bars")

                    # Cache in database
                    if self.db_handler:
                        self.db_handler.store_tick_data(symbol, df)
                else:
                    print(f"✗ {symbol}: No data available")

            except Exception as e:
                print(f"✗ {symbol}: {e}")

        # Combine into single DataFrame
        self.data = pd.DataFrame(data_frames)
        self.data.index.name = 'timestamp'

        print(f"\n✓ Data fetch complete: {len(self.data)} rows")
        return self.data

    def add_technical_indicators(self, df: pd.DataFrame) -> pd.DataFrame:
        """
        Add common technical indicators to the dataset.

        Indicators include:
        - Simple Moving Averages (SMA)
        - Exponential Moving Averages (EMA)
        - Bollinger Bands
        - RSI (Relative Strength Index)
        - MACD (Moving Average Convergence Divergence)
        """
        for col in df.columns:
            # Returns
            df[f'{col}_returns'] = df[col].pct_change()

            # Moving averages
            df[f'{col}_sma_20'] = df[col].rolling(20).mean()
            df[f'{col}_sma_50'] = df[col].rolling(50).mean()
            df[f'{col}_ema_12'] = df[col].ewm(span=12).mean()
            df[f'{col}_ema_26'] = df[col].ewm(span=26).mean()

            # Bollinger Bands
            sma_20 = df[col].rolling(20).mean()
            std_20 = df[col].rolling(20).std()
            df[f'{col}_bb_upper'] = sma_20 + 2 * std_20
            df[f'{col}_bb_middle'] = sma_20
            df[f'{col}_bb_lower'] = sma_20 - 2 * std_20

            # RSI
            delta = df[col].diff()
            gain = (delta.where(delta > 0, 0)).rolling(window=14).mean()
            loss = (-delta.where(delta < 0, 0)).rolling(window=14).mean()
            rs = gain / loss
            df[f'{col}_rsi'] = 100 - (100 / (1 + rs))

            # MACD
            ema_12 = df[col].ewm(span=12).mean()
            ema_26 = df[col].ewm(span=26).mean()
            df[f'{col}_macd'] = ema_12 - ema_26
            df[f'{col}_macd_signal'] = df[f'{col}_macd'].ewm(span=9).mean()

        return df

    def get_data_summary(self) -> Dict:
        """
        Get summary statistics of the loaded data.

        Returns:
            Dictionary with data statistics
        """
        if self.data.empty:
            return {}

        return {
            'symbols': list(self.data.columns),
            'start_date': self.data.index[0],
            'end_date': self.data.index[-1],
            'total_rows': len(self.data),
            'missing_values': self.data.isnull().sum().to_dict()
        }




**Explanation:** The `DataHandler` implements a hybrid caching strategy, first checking MySQL for existing data before making API calls to Yahoo Finance. This reduces latency and API rate limiting issues while ensuring data freshness. Technical indicators are calculated on-demand for strategy

In [None]:

### 4. Strategy Base Class and Implementations

```python
from abc import ABC, abstractmethod
import pandas as pd
import numpy as np
from scipy import stats
from statsmodels.tsa.stattools import coint

class Strategy(ABC):
    """
    Abstract base class for all trading strategies.
    Enforces consistent interface across different strategy types.
    """

    def __init__(self, name: str):
        self.name = name
        self.positions = {}

    @abstractmethod
    def generate_signals(self, data: pd.DataFrame) -> pd.DataFrame:
        """
        Generate trading signals based on market data.

        Must return DataFrame with 'position' column:
        - 1: Long position
        - 0: No position
        - -1: Short position
        """
        pass

    def __repr__(self):
        return f"{self.__class__.__name__}(name='{self.name}')"


class PairsTradingStrategy(Strategy):
    """
    Implements statistical arbitrage via pairs trading.

    Theory:
    - Identifies cointegrated stock pairs
    - Trades mean reversion of the spread
    - Uses z-score to determine entry/exit points

    Parameters:
    - lookback: Rolling window for calculating spread statistics
    - entry_z: Z-score threshold for entering positions
    - exit_z: Z-score threshold for exiting positions
    """

    def __init__(self, stock1: str, stock2: str,
                 lookback: int = 60, entry_z: float = 2.0,
                 exit_z: float = 0.5):
        super().__init__(f"PairsTrading_{stock1}_{stock2}")
        self.stock1 = stock1
        self.stock2 = stock2
        self.lookback = lookback
        self.entry_z = entry_z
        self.exit_z = exit_z

    def test_cointegration(self, series1: pd.Series,
                          series2: pd.Series) -> Dict:
        """
        Test for cointegration between two price series.

        Returns:
            Dictionary with test results
        """
        score, pvalue, _ = coint(series1, series2)
        return {
            'cointegrated': pvalue < 0.05,
            'p_value': pvalue,
            'test_statistic': score
        }

    def generate_signals(self, data: pd.DataFrame) -> pd.DataFrame:
        """Generate pairs trading signals"""
        df = data.copy()

        # Test cointegration
        coint_result = self.test_cointegration(
            df[self.stock1].dropna(),
            df[self.stock2].dropna()
        )

        print(f"\nCointegration Test Results:")
        print(f"  Cointegrated: {coint_result['cointegrated']}")
        print(f"  P-value: {coint_result['p_value']:.4f}")

        # Calculate spread (price difference)
        df['spread'] = df[self.stock1] - df[self.stock2]

        # Calculate rolling z-score
        df['spread_mean'] = df['spread'].rolling(self.lookback).mean()
        df['spread_std'] = df['spread'].rolling(self.lookback).std()
        df['z_score'] = (df['spread'] - df['spread_mean']) / df['spread_std']

        # Generate signals
        df['signal'] = 0

        # Entry signals
        # When spread is too high (z > entry_z), short the spread
        df.loc[df['z_score'] > self.entry_z, 'signal'] = -1
        # When spread is too low (z < -entry_z), long the spread
        df.loc[df['z_score'] < -self.entry_z, 'signal'] = 1

        # Exit signals when spread returns to mean
        df.loc[abs(df['z_score']) < self.exit_z, 'signal'] = 0

        # Forward fill positions (hold until exit signal)
        df['position'] = df['signal'].replace(0, np.nan).ffill().fillna(0)

        return df


class MeanReversionStrategy(Strategy):
    """
    Bollinger Bands mean reversion strategy.

    Theory:
    - Prices tend to revert to their mean over time
    - Bollinger Bands identify overbought/oversold conditions
    - Buy when price touches lower band, sell at upper band
    """

    def __init__(self, symbol: str, window: int = 20, num_std: float = 2.0):
        super().__init__(f"MeanReversion_{symbol}")
        self.symbol = symbol
        self.window = window
        self.num_std = num_std

    def generate_signals(self, data: pd.DataFrame) -> pd.DataFrame:
        """Generate mean reversion signals"""
        df = data[[self.symbol]].copy()

        # Calculate Bollinger Bands
        df['sma'] = df[self.symbol].rolling(self.window).mean()
        df['std'] = df[self.symbol].rolling(self.window).std()
        df['upper_band'] = df['sma'] + (self.num_std * df['std'])
        df['lower_band'] = df['sma'] - (self.num_std * df['std'])

        # Calculate distance from bands (normalized)
        df['upper_distance'] = (df[self.symbol] - df['upper_band']) / df['std']
        df['lower_distance'] = (df['lower_band'] - df[self.symbol]) / df['std']

        # Generate signals
        df['signal'] = 0

        # Buy when price touches lower band (oversold)
        df.loc[df[self.symbol] < df['lower_band'], 'signal'] = 1
        # Sell when price touches upper band (overbought)
        df.loc[df[self.symbol] > df['upper_band'], 'signal'] = -1
        # Exit when price returns to middle band
        df.loc[(df[self.symbol] >= df['lower_band']) &
               (df[self.symbol] <= df['upper_band']), 'signal'] = 0

        # Forward fill positions
        df['position'] = df['signal'].replace(0, np.nan).ffill().fillna(0)

        return df


class MomentumStrategy(Strategy):
    """
    Moving average crossover momentum strategy.

    Theory:
    - Trend following: buy strength, sell weakness
    - Golden cross (short MA > long MA): Buy signal
    - Death cross (short MA < long MA): Sell signal
    """

    def __init__(self, symbol: str, short_window: int = 20,
                 long_window: int = 50):
        super().__init__(f"Momentum_{symbol}")
        self.symbol = symbol
        self.short_window = short_window
        self.long_window = long_window

    def generate_signals(self, data: pd.DataFrame) -> pd.DataFrame:
        """Generate momentum signals"""
        df = data[[self.symbol]].copy()

        # Calculate moving averages
        df['sma_short'] = df[self.symbol].rolling(self.short_window).mean()
        df['sma_long'] = df[self.symbol].rolling(self.long_window).mean()

        # Calculate crossover signals
        df['signal'] = 0
        df.loc[df['sma_short'] > df['sma_long'], 'signal'] = 1   # Golden cross
        df.loc[df['sma_short'] < df['sma_long'], 'signal'] = -1  # Death cross

        df['position'] = df['signal']

        # Identify crossover points for visualization
        df['crossover'] = df['signal'].diff()

        return df




**Explanation:** These strategy classes implement specific trading logic while adhering to the `Strategy` interface. Each strategy encapsulates its parameters, signal generation logic, and statistical tests. The pairs trading strategy includes cointegration testing to validate the statistical relationship between stocks.




In [None]:
### 5. Backtesting Engine

```python
class BacktestEngine:
    """
    Core backtesting engine with realistic transaction modeling.

    Features:
    - Commission and slippage modeling
    - Position sizing
    - Risk management
    - Performance analytics
    - Results persistence
    """

    def __init__(self, initial_capital: float = 100000,
                 commission: float = 0.001,
                 slippage: float = 0.0005,
                 db_handler: Optional[DatabaseHandler] = None):
        """
        Initialize backtesting engine.

        Args:
            initial_capital: Starting capital in dollars
            commission: Commission rate (0.001 = 0.1%)
            slippage: Slippage rate (0.0005 = 0.05%)
            db_handler: Optional database handler for persisting results
        """
        self.initial_capital = initial_capital
        self.commission = commission
        self.slippage = slippage
        self.db_handler = db_handler
        self.results = {}

    def run(self, strategy: Strategy, data: pd.DataFrame,
            save_to_db: bool = True) -> Dict:
        """
        Execute backtest for a strategy.

        Process:
        1. Generate trading signals
        2. Calculate returns with realistic costs
        3. Compute performance metrics
        4. Persist results to database

        Args:
            strategy: Strategy instance to backtest
            data: Historical price data
            save_to_db: Whether to save results to database

        Returns:
            Dictionary with complete backtest results
        """
        print(f"\n{'='*60}")
        print(f"Running Backtest: {strategy.name}")
        print(f"{'='*60}")
        print(f"Initial Capital: ${self.initial_capital:,.2f}")
        print(f"Commission: {self.commission*100:.2f}%")
        print(f"Slippage: {self.slippage*100:.3f}%")

        # Generate signals
        signals = strategy.generate_signals(data)

        if 'position' not in signals.columns:
            print("✗ Error: Strategy must generate 'position' column")
            return {}

        # Calculate returns based on strategy type
        if isinstance(strategy, PairsTradingStrategy):
            # For pairs trading: return = position * (stock1_return - stock2_return)
            returns = (signals['position'] *
                      (data[strategy.stock1].pct_change() -
                       data[strategy.stock2].pct_change()))
        else:
            # For single-stock strategies: return = position * stock_return
            returns = (signals['position'].shift(1) *
                      data[strategy.symbol].pct_change())

        # Apply transaction costs
        # Commission charged on position changes
        position_changes = signals['position'].diff().abs()
        commission_costs = position_changes * self.commission

        # Slippage on all trades
        slippage_costs = position_changes * self.slippage

        # Net returns after costs
        returns = returns - commission_costs - slippage_costs

        # Calculate equity curve
        equity_curve = (1 + returns).cumprod() * self.initial_capital

        # Calculate comprehensive metrics
        metrics = self._calculate_metrics(returns, equity_curve, signals)

        # Store results
        result = {
            'strategy': strategy,
            'signals': signals,
            'returns': returns,
            'equity_curve': equity_curve,
            'metrics': metrics,
            'start_date': data.index[0].strftime('%Y-%m-%d'),
            'end_date': data.index[-1].strftime('%Y-%m-%d'),
            'initial_capital': self.initial_capital
        }

        self.results[strategy.name] = result

        # Print results
        self._print_metrics(metrics)

        # Save to database
        if save_to_db and self.db_handler:
            # Save strategy configuration
            strategy_params = {
                'commission': self.commission,
                'slippage': self.slippage
            }
            if isinstance(strategy, PairsTradingStrategy):
                strategy_params.update({
                    'stock1': strategy.stock1,
                    'stock2': strategy.stock2,
                    'lookback': strategy.lookback,
                    'entry_z': strategy.entry_z,
                    'exit_z': strategy.exit_z
                })

            strategy_id = self.db_handler.save_strategy(
                strategy.name,
                strategy.__class__.__name__,
                strategy_params
            )

            # Save backtest results
            if strategy_id > 0:
                self.db_handler.save_backtest_results(strategy_id, result)

        return result

    def _calculate_metrics(self, returns: pd.Series,
                          equity_curve: pd.Series,
                          signals: pd.DataFrame) -> Dict:
        """
        Calculate comprehensive performance metrics.

        Metrics include:
        - Returns (total, annual)
        - Volatility (annual)
        - Risk-adjusted returns (Sharpe, Sortino)
        - Drawdown analysis
        - Trade statistics
        """
        # Basic returns
        total_return = (equity_curve.iloc[-1] / self.initial_capital - 1) * 100

        # Annualized metrics
        trading_days = 252
        total_days = len(returns)
        years = total_days / trading_days

        annual_return = returns.mean() * trading_days * 100
        annual_std = returns.std() * np.sqrt(trading_days) * 100

        # Sharpe Ratio (assuming 2% risk-free rate)
        risk_free_rate = 2.0
        excess_return = annual_return - risk_free_rate
        sharpe_ratio = (excess_return / annual_std) if annual_std != 0 else 0

        # Sortino Ratio (downside deviation)
        downside_returns = returns[returns < 0]
        downside_std = downside_returns.std() * np.sqrt(trading_days) * 100
        sortino_ratio = (excess_return / downside_std) if downside_std != 0 else 0

        # Drawdown analysis
        running_max = equity_curve.expanding().max()
        drawdown = (equity_curve - running_max) / running_max * 100
        max_drawdown = drawdown.min()

        # Calculate average drawdown duration
        is_drawdown = drawdown < -1  # More than 1% drawdown
        drawdown_periods = is_drawdown.astype(int).diff()
        num_drawdowns = (drawdown_periods == 1).sum()
        avg_drawdown_duration = is_drawdown.sum() / num_drawdowns if num_drawdowns > 0 else 0

        # Trade statistics
        winning_days = (returns > 0).sum()
        losing_days = (returns < 0).sum()
        total_trades = (returns != 0).sum()
        win_rate = (winning_days / total_trades * 100) if total_trades > 0 else 0

        # Average win/loss
        avg_win = returns[returns > 0].mean() * 100 if winning_days > 0 else 0
        avg_loss = returns[returns < 0].mean() * 100 if losing_days > 0 else 0
        profit_factor = abs(avg_win * winning_days / (avg_loss * losing_days)) if avg_loss != 0 else 0

        # Calmar Ratio (Annual Return / Max Drawdown)
        calmar_ratio = abs(annual_return / max_drawdown) if max_drawdown != 0 else 0

        return {
            'Total Return (%)': total_return,
            'Annual Return (%)': annual_return,
            'Annual Volatility (%)': annual_std,
            'Sharpe Ratio': sharpe_ratio,
            'Sortino Ratio': sortino_ratio,
            'Calmar Ratio': calmar_ratio,
            'Max Drawdown (%)': max_drawdown,
            'Avg Drawdown Duration (days)': avg_drawdown_duration,
            'Win Rate (%)': win_rate,
            'Profit Factor': profit_factor,
            'Avg Win (%)': avg_win,
            'Avg Loss (%)': avg_loss,
            'Total Trades': total_trades,
            'Winning Days': winning_days,
            'Losing Days': losing_days,
            'Final Capital': equity_curve.iloc[-1],
            'Total Days': total_days
        }

    def _print_metrics(self, metrics: Dict):
        """Pretty print performance metrics"""
        print("\n" + "="*60)
        print("PERFORMANCE METRICS")
        print("="*60)

        # Returns section
        print("\n RETURNS")
        print("-" * 60)
        print(f"{'Total Return':.<40} {metrics['Total Return (%)']:>15.2f}%")
        print(f"{'Annual Return':.<40} {metrics['Annual Return (%)']:>15.2f}%")
        print(f"{'Annual Volatility':.<40} {metrics['Annual Volatility (%)']:>15.2f}%")

        # Risk-adjusted returns
        print("\n RISK-ADJUSTED RETURNS")
        print("-" * 60)
        print(f"{'Sharpe Ratio':.<40} {metrics['Sharpe Ratio']:>15.2f}")
        print(f"{'Sortino Ratio':.<40} {metrics['Sortino Ratio']:>15.2f}")
        print(f"{'Calmar Ratio':.<40} {metrics['Calmar Ratio']:>15.2f}")

        # Drawdown
        print("\n DRAWDOWN ANALYSIS")
        print("-" * 60)
        print(f"{'Max Drawdown':.<40} {metrics['Max Drawdown (%)']:>15.2f}%")
        print(f"{'Avg Drawdown Duration':.<40} {metrics['Avg Drawdown Duration (days)']:>15.0f} days")

        # Trade statistics
        print("\n TRADE STATISTICS")
        print("-" * 60)
        print(f"{'Total Trades':.<40} {metrics['Total Trades']:>15.0f}")
        print(f"{'Winning Days':.<40} {metrics['Winning Days']:>15.0f}")
        print(f"{'Losing Days':.<40} {metrics['Losing Days']:>15.0f}")
        print(f"{'Win Rate':.<40} {metrics['Win Rate (%)']:>15.2f}%")
        print(f"{'Profit Factor':.<40} {metrics['Profit Factor']:>15.2f}")
        print(f"{'Avg Win':.<40} {metrics['Avg Win (%)']:>15.2f}%")
        print(f"{'Avg Loss':.<40} {metrics['Avg Loss (%)']:>15.2f}%")

        # Capital
        print("\n CAPITAL")
        print("-" * 60)
        print(f"{'Final Capital':.<40} ${metrics['Final Capital']:>14,.2f}")

        print("="*60 + "\n")

    def compare_strategies(self, metric: str = 'Sharpe Ratio') -> pd.DataFrame:
        """
        Compare all backtested strategies by a specific metric.

        Args:
            metric: Metric to compare (default: Sharpe Ratio)

        Returns:
            DataFrame with strategy comparison
        """
        comparison = []

        for name, result in self.results.items():
            metrics = result['metrics']
            comparison.append({
                'Strategy': name,
                'Total Return (%)': metrics['Total Return (%)'],
                'Sharpe Ratio': metrics['Sharpe Ratio'],
                'Max Drawdown (%)': metrics['Max Drawdown (%)'],
                'Win Rate (%)': metrics['Win Rate (%)'],
                'Total Trades': metrics['Total Trades']
            })

        df = pd.DataFrame(comparison)
        df = df.sort_values(by=metric, ascending=False)

        return df



**Explanation:** The `BacktestEngine` is the core execution system that simulates trading with realistic costs. It implements proper commission and slippage modeling, calculates comprehensive metrics including Sharpe, Sortino, and Calmar ratios, and persists results to MySQL for historical comparison.



In [None]:
### 6. Visualization Module

```python
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import seaborn as sns
import matplotlib.pyplot as plt

class ResultsVisualizer:
    """
    Advanced visualization for backtesting results.
    Creates interactive Plotly charts for analysis.
    """

    @staticmethod
    def plot_equity_curves(results: Dict, benchmark_data: Optional[pd.DataFrame] = None):
        """
        Plot equity curves for multiple strategies with optional benchmark.

        Args:
            results: Dictionary of backtest results
            benchmark_data: Optional benchmark (e.g., SPY) for comparison
        """
        fig = go.Figure()

        # Plot strategy equity curves
        for name, result in results.items():
            fig.add_trace(go.Scatter(
                x=result['equity_curve'].index,
                y=result['equity_curve'],
                mode='lines',
                name=name,
                hovertemplate='%{y:$,.2f}<extra></extra>'
            ))

        # Add benchmark if provided e.g. RF, ARIMA dependent on data
        if benchmark_data is not None:
            benchmark_norm = (benchmark_data / benchmark_data.iloc[0]) * \
                           list(results.values())[0]['initial_capital']
            fig.add_trace(go.Scatter(
                x=benchmark_norm.index,
                y=benchmark_norm,
                mode='lines',
                name='Benchmark (SPY)',
                line=dict(dash='dash', color='gray'),
                hovertemplate='%{y:$,.2f}<extra></extra>'
            ))

        fig.update_layout(
            title='Equity Curves Comparison',
            xaxis_title='Date',
            yaxis_title='Portfolio Value ($)',
            hovermode='x unified',
            template='plotly_dark',
            height=600,
            legend=dict(
                orientation="h",
                yanchor="bottom",
                y=1.02,
                xanchor="right",
                x=1
            )
        )

        fig.show()

    @staticmethod
    def plot_drawdown(result: Dict):
        """
        Plot underwater (drawdown) chart.
        Shows periods when strategy was below peak equity.
        """
        equity = result['equity_curve']
        running_max = equity.expanding().max()
        drawdown = (equity - running_max) / running_max * 100

        fig = go.Figure()

        # Drawdown area
        fig.add_trace(go.Scatter(
            x=drawdown.index,
            y=drawdown,
            fill='tozeroy',
            name='Drawdown',
            line=dict(color='rgba(255,0,0,0.5)'),
            fillcolor='rgba(255,0,0,0.2)',
            hovertemplate='%{y:.2f}%<extra></extra>'
        ))

        # Mark maximum drawdown
        max_dd_idx = drawdown.idxmin()
        max_dd_value = drawdown.min()

        fig.add_trace(go.Scatter(
            x=[max_dd_idx],
            y=[max_dd_value],
            mode='markers',
            name='Max Drawdown',
            marker=dict(color='red', size=12, symbol='x'),
            hovertemplate=f'Max DD: {max_dd_value:.2f}%<extra></extra>'
        ))

        fig.update_layout(
            title=f"Drawdown Analysis - {result['strategy'].name}",
            xaxis_title='Date',
            yaxis_title='Drawdown (%)',
            template='plotly_dark',
            height=500,
            showlegend=True
        )

        fig.update_yaxes(range=[min(drawdown.min() * 1.1, -1), 1])

        fig.show()

    @staticmethod
    def plot_signals(result: Dict, data: pd.DataFrame):
        """
        Plot price action with trading signals overlaid.
        Different visualizations for different strategy types.
        """
        signals = result['signals']
        strategy = result['strategy']

        if isinstance(strategy, PairsTradingStrategy):
            # Pairs trading visualization: spread and z-score
            fig = make_subplots(
                rows=3, cols=1,
                subplot_titles=('Stock Prices', 'Spread', 'Z-Score & Signals'),
                vertical_spacing=0.08,
                row_heights=[0.35, 0.35, 0.30]
            )

            # Stock prices
            fig.add_trace(go.Scatter(
                x=data.index, y=data[strategy.stock1],
                name=strategy.stock1, line=dict(color='cyan')
            ), row=1, col=1)

            fig.add_trace(go.Scatter(
                x=data.index, y=data[strategy.stock2],
                name=strategy.stock2, line=dict(color='orange')
            ), row=1, col=1)

            # Spread
            fig.add_trace(go.Scatter(
                x=signals.index, y=signals['spread'],
                name='Spread', line=dict(color='white')
            ), row=2, col=1)

            fig.add_trace(go.Scatter(
                x=signals.index, y=signals['spread_mean'],
                name='Mean', line=dict(color='yellow', dash='dash')
            ), row=2, col=1)

            # Z-score with entry/exit levels
            fig.add_trace(go.Scatter(
                x=signals.index, y=signals['z_score'],
                name='Z-Score', line=dict(color='cyan')
            ), row=3, col=1)

            # Entry thresholds
            fig.add_hline(y=strategy.entry_z, line_dash="dash",
                         line_color="red", opacity=0.5, row=3, col=1)
            fig.add_hline(y=-strategy.entry_z, line_dash="dash",
                         line_color="green", opacity=0.5, row=3, col=1)
            fig.add_hline(y=strategy.exit_z, line_dash="dot",
                         line_color="gray", opacity=0.3, row=3, col=1)
            fig.add_hline(y=-strategy.exit_z, line_dash="dot",
                         line_color="gray", opacity=0.3, row=3, col=1)

            # Mark positions
            long_positions = signals[signals['position'] == 1]
            short_positions = signals[signals['position'] == -1]

            fig.add_trace(go.Scatter(
                x=long_positions.index,
                y=long_positions['z_score'],
                mode='markers',
                name='Long',
                marker=dict(color='green', size=8, symbol='triangle-up')
            ), row=3, col=1)

            fig.add_trace(go.Scatter(
                x=short_positions.index,
                y=short_positions['z_score'],
                mode='markers',
                name='Short',
                marker=dict(color='red', size=8, symbol='triangle-down')
            ), row=3, col=1)

        else:
            # Single-stock strategy visualization
            symbol = strategy.symbol

            fig = make_subplots(
                rows=2, cols=1,
                subplot_titles=('Price & Signals', 'Returns Distribution'),
                vertical_spacing=0.12,
                row_heights=[0.7, 0.3]
            )

            # Price chart
            fig.add_trace(go.Scatter(
                x=data.index, y=data[symbol],
                name='Price', line=dict(color='white', width=2)
            ), row=1, col=1)

            # Add moving averages if available
            if hasattr(strategy, 'window'):
                fig.add_trace(go.Scatter(
                    x=signals.index, y=signals['sma'],
                    name='SMA', line=dict(color='yellow', dash='dash')
                ), row=1, col=1)

                if 'upper_band' in signals.columns:
                    fig.add_trace(go.Scatter(
                        x=signals.index, y=signals['upper_band'],
                        name='Upper BB', line=dict(color='red', dash='dot')
                    ), row=1, col=1)

                    fig.add_trace(go.Scatter(
                        x=signals.index, y=signals['lower_band'],
                        name='Lower BB', line=dict(color='green', dash='dot')
                    ), row=1, col=1)

            # Buy/Sell signals
            buys = signals[signals['signal'] == 1]
            sells = signals[signals['signal'] == -1]

            if len(buys) > 0:
                fig.add_trace(go.Scatter(
                    x=buys.index, y=data.loc[buys.index, symbol],
                    mode='markers', name='Buy',
                    marker=dict(color='green', size=12, symbol='triangle-up')
                ), row=1, col=1)

            if len(sells) > 0:
                fig.add_trace(go.Scatter(
                    x=sells.index, y=data.loc[sells.index, symbol],
                    mode='markers', name='Sell',
                    marker=dict(color='red', size=12, symbol='triangle-down')
                ), row=1, col=1)

            # Returns histogram
            returns = result['returns'].dropna()
            fig.add_trace(go.Histogram(
                x=returns * 100,
                name='Returns Distribution',
                marker=dict(color='cyan'),
                nbinsx=50
            ), row=2, col=1)

        fig.update_layout(
            title=f"Trading Signals - {strategy.name}",
            template='plotly_dark',
            height=800,
            showlegend=True
        )

        fig.update_xaxes(title_text="Date", row=3 if isinstance(strategy, PairsTradingStrategy) else 2, col=1)

        fig.show()

    @staticmethod
    def plot_monthly_returns(result: Dict):
        """
        Create a heatmap of monthly returns.
        Useful for identifying seasonal patterns.
        """
        returns = result['returns']
        equity = result['equity_curve']

        # Calculate monthly returns
        monthly = equity.resample('M').last().pct_change() * 100
        monthly_df = monthly.to_frame()
        monthly_df['Year'] = monthly_df.index.year
        monthly_df['Month'] = monthly_df.index.month

        # Pivot for heatmap
        heatmap_data = monthly_df.pivot(index='Year', columns='Month',
                                        values=equity.name)

        # Create heatmap
        plt.figure(figsize=(12, 6))
        sns.heatmap(heatmap_data, annot=True, fmt='.1f', cmap='RdYlGn',
                   center=0, cbar_kws={'label': 'Return (%)'})
        plt.title(f'Monthly Returns Heatmap - {result["strategy"].name}')
        plt.xlabel('Month')
        plt.ylabel('Year')
        plt.tight_layout()
        plt.show()

    @staticmethod
    def plot_rolling_metrics(result: Dict, window: int = 60):
        """
        Plot rolling Sharpe ratio and other metrics over time.
        Shows strategy stability and regime changes.
        """
        returns = result['returns']

        # Calculate rolling metrics
        rolling_return = returns.rolling(window).mean() * 252 * 100
        rolling_vol = returns.rolling(window).std() * np.sqrt(252) * 100
        rolling_sharpe = rolling_return / rolling_vol

        fig = make_subplots(
            rows=3, cols=1,
            subplot_titles=('Rolling Annual Return', 'Rolling Volatility',
                          'Rolling Sharpe Ratio'),
            vertical_spacing=0.08
        )

        fig.add_trace(go.Scatter(
            x=rolling_return.index, y=rolling_return,
            name='Rolling Return', line=dict(color='cyan')
        ), row=1, col=1)

        fig.add_trace(go.Scatter(
            x=rolling_vol.index, y=rolling_vol,
            name='Rolling Volatility', line=dict(color='orange')
        ), row=2, col=1)

        fig.add_trace(go.Scatter(
            x=rolling_sharpe.index, y=rolling_sharpe,
            name='Rolling Sharpe', line=dict(color='green')
        ), row=3, col=1)

        # Add zero line for Sharpe
        fig.add_hline(y=0, line_dash="dash", line_color="gray",
                     opacity=0.5, row=3, col=1)

        fig.update_layout(
            title=f'Rolling Metrics ({window}-day window) - {result["strategy"].name}',
            template='plotly_dark',
            height=900,
            showlegend=False
        )

        fig.update_yaxes(title_text="Return (%)", row=1, col=1)
        fig.update_yaxes(title_text="Volatility (%)", row=2, col=1)
        fig.update_yaxes(title_text="Sharpe Ratio", row=3, col=1)
        fig.update_xaxes(title_text="Date", row=3, col=1)

        fig.show()



**Explanation:** The `ResultsVisualizer` creates publication-quality interactive charts using Plotly. It provides specialized visualizations for different strategy types and includes advanced analytics like rolling metrics and monthly return heatmaps for identifying performance patterns.




### 7. Complete Usage Example

```python
# ============================================================================
# COMPLETE EXAMPLE: From Data Fetching to Results
# ============================================================================

import warnings
warnings.filterwarnings('ignore')

def main():
    """
    Complete workflow demonstration:
    1. Setup database
    2. Fetch and cache data
    3. Run multiple strategies
    4. Analyze and compare results
    5. Visualize performance
    """
    
    print("="*60)
    print("STATISTICAL ARBITRAGE BACKTESTING ENGINE")
    print("="*60)
    

   

In [None]:
    # ========================================
    # STEP 1: Initialize Components
    # ========================================
    print("\n Initializing components...")

    # Database handler (optional - comment out if not using MySQL)
    try:
        db_handler = DatabaseHandler(
            host='localhost',
            user='root',
            password='your_password',  # Change this!
            database='trading_db'
        )
        print("✓ Database connected")
    except:
        print("⚠ Database not available, using memory only")
        db_handler = None

    # Data handler with caching
    data_handler = DataHandler(db_handler=db_handler)

    # Backtest engine
    engine = BacktestEngine(
        initial_capital=100000,
        commission=0.001,  # 0.1%
        slippage=0.0005,   # 0.05%
        db_handler=db_handler
    )

In [None]:
# ========================================
    # STEP 2: Fetch Data
    # ========================================
    print("\n Fetching market data...")

    symbols = ['AAPL', 'MSFT', 'GOOGL', 'AMZN', 'SPY']
    start_date = '2020-01-01'
    end_date = '2024-01-01'

    # Fetch data (will use cache if available)
    data = data_handler.fetch_data(
        symbols=symbols,
        start_date=start_date,
        end_date=end_date,
        use_cache=True
    )

    # Print data summary
    summary = data_handler.get_data_summary()
    print(f"\n Data loaded: {summary['total_rows']} rows")
    print(f"  Period: {summary['start_date']} to {summary['end_date']}")
    print(f"  Symbols: {', '.join(summary['symbols'])}")


In [None]:
 # ========================================
    # STEP 3: Define Strategies
    # ========================================
    print("\n Setting up strategies...")

    strategies = [
        # Pairs Trading: AAPL vs MSFT
        PairsTradingStrategy(
            stock1='AAPL',
            stock2='MSFT',
            lookback=60,
            entry_z=2.0,
            exit_z=0.5
        ),

        # Mean Reversion: GOOGL
        MeanReversionStrategy(
            symbol='GOOGL',
            window=20,
            num_std=2.0
        ),

        # Momentum: AMZN
        MomentumStrategy(
            symbol='AMZN',
            short_window=20,
            long_window=50
        ),

        # Additional pairs trade
        PairsTradingStrategy(
            stock1='GOOGL',
            stock2='AMZN',
            lookback=90,
            entry_z=1.5,
            exit_z=0.3
        )
    ]

    print(f"✓ {len(strategies)} strategies configured")


In [None]:
# ========================================
    # STEP 4: Run Backtests
    # ========================================
    print("\n Running backtests...")

    results = {}
    for strategy in strategies:
        try:
            result = engine.run(
                strategy=strategy,
                data=data,
                save_to_db=(db_handler is not None)
            )
            results[strategy.name] = result
        except Exception as e:
            print(f"✗ Error running {strategy.name}: {e}")

In [None]:
   # ========================================
    # STEP 5: Compare Strategies
    # ========================================
    print("\n Strategy Comparison")
    print("="*60)

    comparison = engine.compare_strategies(metric='Sharpe Ratio')
    print(comparison.to_string(index=False))


In [None]:
 # ========================================
    # STEP 6: Visualize Results
    # ========================================
    print("\n Generating visualizations...")

    viz = ResultsVisualizer()

    # Equity curves (with SPY benchmark)
    viz.plot_equity_curves(results, benchmark_data=data['SPY'])

    # Detailed analysis for each strategy
    for name, result in results.items():
        print(f"\nAnalyzing: {name}")
        viz.plot_signals(result, data)
        viz.plot_drawdown(result)
        viz.plot_rolling_metrics