In [None]:
import pandas as pd
import numpy as np
import yfinance as yf
import matplotlib.pyplot as plt

def calculate_zscore(series, window):
    """Calculate rolling Z-score for a given series."""
    mean = series.rolling(window=window).mean()
    std = series.rolling(window=window).std()
    return (series - mean) / std

class QuantileZScoreStrategy:
       
    def __init__(self, ticker1, ticker2, window=20, entry_z=1.5, exit_z=1):
        """
        Initialize strategy parameters.
        """
        self.ticker1 = ticker1
        self.ticker2 = ticker2
        self.window = window
        self.entry_z = entry_z
        self.exit_z = exit_z
        self.data = None
        self.signals = []

    def fetch_data(self, period='6mo', interval='1d'):
        """Download and align historical price data for both assets."""
        data1 = yf.download(self.ticker1, period=period, interval=interval)
        data2 = yf.download(self.ticker2, period=period, interval=interval)
        
        # Align data to common timestamps to avoid mismatches
        data1, data2 = data1.align(data2, join='inner')
        
        self.data = pd.DataFrame({
            'price1': data1['Close'],
            'price2': data2['Close'],
            'spread': data1['Close'] - data2['Close']
        }).dropna()

    def calculate_indicators(self):
        """Calculate technical indicators for trading signals."""
        # Rolling quantiles for spread
        self.data['spread_q10'] = self.data['spread'].rolling(self.window).quantile(0.10)
        self.data['spread_q90'] = self.data['spread'].rolling(self.window).quantile(0.90)
        self.data['spread_median'] = self.data['spread'].rolling(self.window).median()
        
        # Z-score of spread
        self.data['z_score'] = calculate_zscore(self.data['spread'], self.window)

    def generate_signals(self):
        """Generate long/short signals based on quantile and Z-score thresholds."""
        in_position = False
        
        for i in range(len(self.data)):
            # Skip initial window period and NaN values
            if i < self.window or np.isnan(self.data['z_score'].iloc[i]):
                continue

            current_spread = self.data['spread'].iloc[i]
            z_score = self.data['z_score'].iloc[i]

            if not in_position:
                # Short entry: Spread is high and Z-score confirms divergence
                if  current_spread > self.data['spread_q90'].iloc[i] and z_score > self.entry_z:
                    self.signals.append(('short', self.data.index[i]))
                    in_position = True
                # Long entry: Spread is low and Z-score confirms divergence
                elif current_spread < self.data['spread_q10'].iloc[i] and z_score < -self.entry_z:
                    self.signals.append(('long', self.data.index[i]))
                    in_position = True
            else:
                # Exit positions when Z-score returns towards mean
                last_signal = self.signals[-1][0]
                if last_signal == 'short' and z_score < self.exit_z:
                    self.signals.append(('exit_short', self.data.index[i]))
                    in_position = False
                elif last_signal == 'long' and z_score > -self.exit_z:
                    self.signals.append(('exit_long', self.data.index[i]))
                    in_position = False

    def backtest(self):
        """Backtest the strategy and calculate performance metrics."""
        if not self.signals:
            return None

        portfolio = 1000  # Initial capital
        position = None  # Current position details
        portfolio_history = [portfolio]  # Track portfolio value over time
        trade_history = []  # Track individual trade profits

        for signal in self.signals:
            sig_type, timestamp = signal
            price1 = self.data.loc[timestamp, 'price1']
            price2 = self.data.loc[timestamp, 'price2']

            if sig_type in ['long', 'short']:
                # Position sizing: Allocate 50% of portfolio to each leg
                units = portfolio * 0.5 / (price1 + price2)
                position = {
                    'type': sig_type,
                    'entry_price1': price1,
                    'entry_price2': price2,
                    'units': units,
                    'capital_used': units * (price1 + price2)
                }
            else:
                # Calculate profit/loss when exiting position
                exit_price1 = price1
                exit_price2 = price2
                
                if position['type'] == 'long':
                    profit = position['units'] * (
                        (exit_price1 - exit_price2) - 
                        (position['entry_price1'] - position['entry_price2'])
                    )
                else:
                    profit = position['units'] * (
                        (position['entry_price1'] - position['entry_price2']) - 
                        (exit_price1 - exit_price2)
                    )
                
                portfolio += profit
                portfolio_history.append(portfolio)
                trade_history.append(profit)
                position = None

        return {
            'initial_capital': 1000,
            'final_value': portfolio,
            'portfolio_history': portfolio_history
        }

    def plot_results(self):
        """Visualize price series, spread, and trading signals with enhanced labels."""
        plt.figure(figsize=(16, 12))
        plt.suptitle(f'{self.ticker1}-{self.ticker2} Strategy Results', y=0.95)
        
        # Price plot
        ax1 = plt.subplot(3, 1, 1)
        ax1.plot(self.data['price1'], label=self.ticker1, color='blue')
        ax1.plot(self.data['price2'], label=self.ticker2, color='orange')
        ax1.legend()
        ax1.set_title('Price Series')
        ax1.set_ylabel('Price ($)')
        
        # Spread plot with signals
        ax2 = plt.subplot(3, 1, 2, sharex=ax1)
        ax2.plot(self.data['spread'], label='Spread', color='blue')
        ax2.plot(self.data['spread_median'], 
                label=f'Rolling Median ({self.window}D)', 
                linestyle='--', 
                color='black')
        
        # Quantile band labeling
        ax2.fill_between(self.data.index, 
                        self.data['spread_q10'], 
                        self.data['spread_q90'], 
                        color='gray', 
                        alpha=0.3,
                        label=f'{self.window}-Day 10th/90th Percentile Band')
        
        # Signal labeling with legend deduplication
        handled_labels = set()
        for sig, t in self.signals:
            color = 'green' if 'long' in sig else 'red' if 'short' in sig else 'purple'
            marker = '^' if 'long' in sig else 'v' if 'short' in sig else 'X'
            
            label = {
                'long': 'Long Entry (Buy {} / Sell {})'.format(self.ticker1, self.ticker2),
                'short': 'Short Entry (Sell {} / Buy {})'.format(self.ticker1, self.ticker2),
                'exit_long': 'Exit Long',
                'exit_short': 'Exit Short'
            }.get(sig, 'Unknown Signal')
            
            # Only add label once to legend
            if label not in handled_labels:
                ax2.scatter(t, self.data.loc[t, 'spread'], 
                        color=color, 
                        marker=marker, 
                        s=100,
                        label=label)
                handled_labels.add(label)
            else:
                ax2.scatter(t, self.data.loc[t, 'spread'], 
                        color=color, 
                        marker=marker, 
                        s=100)
        
        ax2.legend()
        ax2.set_ylabel('Spread Value')
        ax2.set_title('Spread Analysis with Trading Signals')
        
        # Z-score plot
        ax3 = plt.subplot(3, 1, 3, sharex=ax1)
        ax3.plot(self.data['z_score'], label='Z-Score', color='purple')
        
        # Threshold labeling
        ax3.axhline(self.entry_z, 
                    color='red', 
                    linestyle='--', 
                    label=f'Entry Threshold (±{self.entry_z})')
        ax3.axhline(-self.entry_z, color='red', linestyle='--')
        ax3.axhline(self.exit_z, 
                    color='green', 
                    linestyle=':', 
                    label=f'Exit Threshold (±{self.exit_z})')
        ax3.axhline(-self.exit_z, color='green', linestyle=':')
        
        ax3.legend()
        ax3.set_xlabel('Date')
        ax3.set_ylabel('Z-Score')
        ax3.set_title('Standardized Spread Analysis')
        
        plt.tight_layout()
        plt.show()

if __name__ == "__main__":
    # Backtest
    strategy = QuantileZScoreStrategy('KO', 'PEP', window=20, entry_z=1.2, exit_z=0.5)
    strategy.fetch_data(period='6mo', interval='1d')  # Use daily data for more stable signals
    strategy.calculate_indicators()
    strategy.generate_signals()
    
    results = strategy.backtest()
    print("Backtest Results:")
    if results:
        print(f"Initial Capital: ${results['initial_capital']}")
        print(f"Final Value: ${results['final_value']:.2f}")
    else:
        print("No trades executed.")
    
    strategy.plot_results()