# Pair Selection - Statistical Arbitrage RL

This notebook performs grid search to identify stock pairs with minimum Empirical Mean Reversion Time (EMRT).

In [None]:
# Import libraries
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import yaml
import sys

sys.path.append('..')

from data_acquisition import DataAcquisition
from emrt_calculator import EMRTCalculator
from pair_selection import PairSelector

sns.set_style('darkgrid')
plt.rcParams['figure.figsize'] = (14, 8)

print("Libraries imported successfully")

## 1. Load Data

In [None]:
# Fetch data
data_acq = DataAcquisition('../config.yaml')
dataset = data_acq.fetch_full_dataset()

# Split into train/test
train_prices, test_prices = data_acq.split_train_test(dataset['prices'])

print(f"Training data: {len(train_prices)} days")
print(f"Testing data: {len(test_prices)} days")
print(f"Number of stocks: {train_prices.shape[1]}")

## 2. Correlation Analysis

In [None]:
# Calculate correlation matrix
selector = PairSelector('../config.yaml')
corr_matrix = selector.calculate_correlations(train_prices)

print(f"Correlation matrix shape: {corr_matrix.shape}")

In [None]:
# Visualize correlation heatmap (subset)
# Show top 20 stocks by market cap (proxy: price)
top_20_tickers = train_prices.mean().nlargest(20).index.tolist()
corr_subset = corr_matrix.loc[top_20_tickers, top_20_tickers]

plt.figure(figsize=(12, 10))
sns.heatmap(corr_subset, annot=False, cmap='RdYlGn', center=0, 
            vmin=-1, vmax=1, square=True, linewidths=0.5)
plt.title('Correlation Matrix (Top 20 Stocks)', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

In [None]:
# Distribution of pairwise correlations
correlations = corr_matrix.values[np.triu_indices_from(corr_matrix.values, k=1)]

plt.figure(figsize=(12, 6))
plt.hist(correlations, bins=50, edgecolor='black', alpha=0.7, color='steelblue')
plt.axvline(x=0.7, color='red', linestyle='--', linewidth=2, label='Threshold (0.7)')
plt.xlabel('Correlation Coefficient', fontsize=12)
plt.ylabel('Frequency', fontsize=12)
plt.title('Distribution of Pairwise Correlations', fontsize=14, fontweight='bold')
plt.legend()
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

print(f"Mean correlation: {correlations.mean():.3f}")
print(f"Median correlation: {np.median(correlations):.3f}")
print(f"Pairs above 0.7 threshold: {(correlations > 0.7).sum():,}")

## 3. EMRT Calculation - Sample Pairs

In [None]:
# Test EMRT on sample pairs
emrt_calc = EMRTCalculator('../config.yaml')

sample_pairs = [
    ('MSFT', 'GOOGL'),
    ('CVS', 'JNJ'),
    ('PG', 'KO')
]

sample_results = []

for ticker1, ticker2 in sample_pairs:
    if ticker1 in train_prices.columns and ticker2 in train_prices.columns:
        emrt, details = emrt_calc.calculate_emrt(
            train_prices[ticker1],
            train_prices[ticker2]
        )
        
        sample_results.append({
            'pair': f"{ticker1}-{ticker2}",
            'emrt': emrt,
            'num_events': details['num_events'],
            'std_reversion_time': details.get('std_reversion_time', np.nan)
        })

sample_df = pd.DataFrame(sample_results)
print("\n=== Sample EMRT Results ===")
print(sample_df)

In [None]:
# Visualize spread and z-score for best sample pair
best_pair_idx = sample_df['emrt'].idxmin()
best_pair = sample_pairs[best_pair_idx]

ticker1, ticker2 = best_pair
spread = emrt_calc.calculate_spread(train_prices[ticker1], train_prices[ticker2])
zscore = emrt_calc.calculate_zscore(spread)

fig, axes = plt.subplots(2, 1, figsize=(14, 10))

# Spread
axes[0].plot(spread.index, spread, linewidth=1.5, color='navy')
axes[0].set_title(f'Price Spread: {ticker1} vs {ticker2}', fontsize=12, fontweight='bold')
axes[0].set_ylabel('Log Price Ratio')
axes[0].grid(True, alpha=0.3)

# Z-score with thresholds
axes[1].plot(zscore.index, zscore, linewidth=1.5, color='darkgreen')
axes[1].axhline(y=2, color='red', linestyle='--', linewidth=2, label='Upper Threshold (+2σ)')
axes[1].axhline(y=-2, color='red', linestyle='--', linewidth=2, label='Lower Threshold (-2σ)')
axes[1].axhline(y=0, color='black', linestyle='-', linewidth=1, alpha=0.5)
axes[1].fill_between(zscore.index, -2, 2, alpha=0.2, color='green', label='Mean Reversion Zone')
axes[1].set_title('Spread Z-Score', fontsize=12, fontweight='bold')
axes[1].set_xlabel('Date')
axes[1].set_ylabel('Z-Score')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print(f"\nEMRT for {ticker1}-{ticker2}: {sample_df.iloc[best_pair_idx]['emrt']:.2f} days")

## 4. Grid Search - Full Pair Selection

In [None]:
# Run complete pair selection pipeline
print("Running grid search...")
selection_results = selector.run_selection(train_prices, dataset['constituents'])

selected_pairs = selection_results['selected_pairs']
all_metrics = selection_results['all_metrics']

print(f"\n=== Selection Complete ===")
print(f"Candidate pairs evaluated: {selection_results['num_candidates']}")
print(f"Pairs selected: {len(selected_pairs)}")

In [None]:
# Display selected pairs
print("\n=== Selected Pairs ===")
print(selected_pairs[['pair_id', 'sector', 'emrt', 'correlation', 'num_events']])

## 5. EMRT Distribution Analysis

In [None]:
# Filter valid EMRT values
valid_metrics = all_metrics[all_metrics['emrt'] < np.inf].copy()

fig, axes = plt.subplots(1, 2, figsize=(16, 6))

# EMRT distribution
axes[0].hist(valid_metrics['emrt'], bins=50, edgecolor='black', alpha=0.7, color='steelblue')
axes[0].axvline(x=valid_metrics['emrt'].quantile(0.1), color='red', 
               linestyle='--', linewidth=2, label='10th Percentile (Selection Threshold)')
axes[0].set_xlabel('EMRT (days)', fontsize=12)
axes[0].set_ylabel('Frequency', fontsize=12)
axes[0].set_title('EMRT Distribution (All Candidate Pairs)', fontsize=14, fontweight='bold')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Selected pairs EMRT
axes[1].bar(range(len(selected_pairs)), selected_pairs['emrt'], color='coral', edgecolor='black')
axes[1].set_xlabel('Pair Index', fontsize=12)
axes[1].set_ylabel('EMRT (days)', fontsize=12)
axes[1].set_title('EMRT of Selected Pairs', fontsize=14, fontweight='bold')
axes[1].grid(True, alpha=0.3, axis='y')

plt.tight_layout()
plt.show()

print(f"\nEMRT Statistics (Valid Pairs):")
print(f"  Mean: {valid_metrics['emrt'].mean():.2f} days")
print(f"  Median: {valid_metrics['emrt'].median():.2f} days")
print(f"  10th percentile: {valid_metrics['emrt'].quantile(0.1):.2f} days")
print(f"  Min: {valid_metrics['emrt'].min():.2f} days")

## 6. Sector Analysis of Selected Pairs

In [None]:
# Sector distribution of selected pairs
sector_dist = selected_pairs['sector'].value_counts()

fig, axes = plt.subplots(1, 2, figsize=(16, 6))

# Bar chart
sector_dist.plot(kind='bar', ax=axes[0], color='teal', edgecolor='black')
axes[0].set_title('Selected Pairs by Sector', fontsize=14, fontweight='bold')
axes[0].set_xlabel('Sector')
axes[0].set_ylabel('Number of Pairs')
axes[0].tick_params(axis='x', rotation=45)
axes[0].grid(True, alpha=0.3, axis='y')

# Box plot: EMRT by sector
selected_pairs.boxplot(column='emrt', by='sector', ax=axes[1])
axes[1].set_title('EMRT Distribution by Sector', fontsize=14, fontweight='bold')
axes[1].set_xlabel('Sector')
axes[1].set_ylabel('EMRT (days)')
axes[1].get_figure().suptitle('')  # Remove automatic title

plt.tight_layout()
plt.show()

print("\nPairs per sector:")
for sector, count in sector_dist.items():
    avg_emrt = selected_pairs[selected_pairs['sector'] == sector]['emrt'].mean()
    print(f"  {sector}: {count} pairs (Avg EMRT: {avg_emrt:.2f} days)")

## 7. Correlation vs EMRT Relationship

In [None]:
# Scatter plot: correlation vs EMRT for selected pairs
plt.figure(figsize=(12, 8))

for sector in selected_pairs['sector'].unique():
    sector_data = selected_pairs[selected_pairs['sector'] == sector]
    plt.scatter(sector_data['correlation'], sector_data['emrt'],
               label=sector, alpha=0.7, s=150, edgecolors='black')

plt.xlabel('Correlation', fontsize=12)
plt.ylabel('EMRT (days)', fontsize=12)
plt.title('Correlation vs EMRT for Selected Pairs', fontsize=14, fontweight='bold')
plt.legend()
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

# Calculate correlation between these metrics
corr_emrt_corr = selected_pairs[['correlation', 'emrt']].corr().iloc[0, 1]
print(f"\nCorrelation between Correlation and EMRT: {corr_emrt_corr:.3f}")
print("(Negative suggests higher correlation pairs tend to have faster mean reversion)")

## 8. Save Selected Pairs

In [None]:
# Save to CSV
selected_pairs.to_csv('selected_pairs.csv', index=False)
print("Selected pairs saved to: selected_pairs.csv")

# Summary statistics
print("\n=== Final Selection Summary ===")
print(f"Total pairs selected: {len(selected_pairs)}")
print(f"Average EMRT: {selected_pairs['emrt'].mean():.2f} days")
print(f"Average correlation: {selected_pairs['correlation'].mean():.3f}")
print(f"Average deviation events: {selected_pairs['num_events'].mean():.1f}")

## Summary

This notebook performed comprehensive pair selection:

- **Correlation Analysis**: Identified candidate pairs with correlation > 0.7
- **EMRT Calculation**: Computed empirical mean reversion time for all candidates
- **Grid Search**: Selected top 10 pairs with minimum EMRT
- **Sector Distribution**: Pairs span Technology, Healthcare, Consumer Goods, Financials
- **Key Finding**: Selected pairs have EMRT in bottom 10th percentile (fastest mean reversion)

**Next**: Train RL agent on selected pairs to learn optimal trading strategies.