In [None]:
import sys
import os
sys.path.append('/app')

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from python.pair_discovery_client import PairDiscoveryClient

# Enable inline plotting
%matplotlib inline
plt.style.use('seaborn-v0_8-darkgrid')

## 1. Initialize gRPC Client

Connect to the Rust gRPC server (should be running in Docker).

In [None]:
# Create client (uses GRPC_HOST and GRPC_PORT env vars)
client = PairDiscoveryClient()
print(f"âœ… Connected to gRPC server at {client.host}:{client.port}")

## 2. Generate Synthetic Cointegrated Data

Let's create a pair of cointegrated series for testing.

In [None]:
np.random.seed(42)
n = 1000

# Generate cointegrated pair
x = np.cumsum(np.random.randn(n)) + 100
spread = np.random.randn(n) * 2  # Stationary spread
y = 1.5 * x + spread + 10  # y cointegrates with x

# Plot the series
fig, axes = plt.subplots(2, 1, figsize=(12, 6))
axes[0].plot(x, label='Asset X', alpha=0.7)
axes[0].plot(y, label='Asset Y', alpha=0.7)
axes[0].set_title('Cointegrated Price Series')
axes[0].legend()
axes[0].grid(True)

# Plot the spread
spread_series = y - 1.5 * x
axes[1].plot(spread_series, label='Spread', color='green')
axes[1].axhline(y=spread_series.mean(), color='r', linestyle='--', label='Mean')
axes[1].set_title('Spread (Should be Stationary)')
axes[1].legend()
axes[1].grid(True)

plt.tight_layout()
plt.show()

print(f"Generated {n} data points")

## 3. Test Single Pair

Use the `test_pair` method to get comprehensive statistics.

In [None]:
result = client.test_pair(x, y, pair_name="X-Y")

print("=" * 60)
print("PAIR TEST RESULTS")
print("=" * 60)
print(f"Pair Name:          {result['pair_name']}")
print(f"Is Cointegrated:    {result['is_cointegrated']}")
print(f"P-Value:            {result['p_value']:.6f}")
print(f"Hedge Ratio:        {result['hedge_ratio']:.4f}")
print()
print("OU Parameters:")
print(f"  Theta (speed):    {result['ou_params']['theta']:.6f}")
print(f"  Mu (mean):        {result['ou_params']['mu']:.6f}")
print(f"  Sigma (vol):      {result['ou_params']['sigma']:.6f}")
print()
print(f"Hurst Exponent:     {result['hurst_exponent']:.4f}")
print(f"Half-Life:          {result['half_life']:.2f} periods")
print()
print("Optimal Thresholds:")
print(f"  Entry:            {result['entry_threshold']:.4f}")
print(f"  Exit:             {result['exit_threshold']:.4f}")
print("=" * 60)

## 4. Backtest Strategy

Test the optimal strategy with the calculated thresholds.

In [None]:
# Calculate spread
spread = y - result['hedge_ratio'] * x

# Backtest
backtest_result = client.backtest_strategy(
    spread=spread,
    entry_threshold=result['entry_threshold'],
    exit_threshold=result['exit_threshold'],
    transaction_cost=0.001,  # 0.1%
    initial_capital=100000.0
)

print("=" * 60)
print("BACKTEST RESULTS")
print("=" * 60)
print(f"Total Return:       {backtest_result['total_return']:.2f}%")
print(f"Sharpe Ratio:       {backtest_result['sharpe_ratio']:.4f}")
print(f"Max Drawdown:       {backtest_result['max_drawdown']:.2f}%")
print(f"Number of Trades:   {backtest_result['num_trades']}")
print(f"Win Rate:           {backtest_result['win_rate']:.2f}%")
print("=" * 60)

# Plot PnL
plt.figure(figsize=(12, 4))
plt.plot(backtest_result['pnl'])
plt.title('Strategy Profit & Loss Over Time')
plt.xlabel('Time')
plt.ylabel('PnL ($)')
plt.grid(True)
plt.show()

## 5. Discover Multiple Pairs (Streaming)

Test multiple pairs in parallel with streaming results.

In [None]:
# Generate multiple synthetic series
n_assets = 10
n_points = 500

price_matrix = np.zeros((n_assets, n_points))
pair_names = [f"ASSET_{i}" for i in range(n_assets)]

# Create some cointegrated pairs
for i in range(n_assets):
    if i % 2 == 0 and i < n_assets - 1:
        # Cointegrated pair
        base = np.cumsum(np.random.randn(n_points)) + 100
        price_matrix[i] = base
        price_matrix[i+1] = 1.2 * base + np.random.randn(n_points) * 0.5
    else:
        # Random walk (not cointegrated)
        price_matrix[i] = np.cumsum(np.random.randn(n_points)) + 100

print(f"Testing {n_assets} assets ({n_assets * (n_assets - 1) // 2} pairs)...")
print("Streaming results as they arrive:\n")

discovered_pairs = []
for result in client.discover_pairs(
    price_matrix,
    pair_names,
    max_p_value=0.05,
    min_half_life=1.0,
    max_half_life=100.0
):
    discovered_pairs.append(result)
    print(f"âœ… {result['pair_name']:20s} | p={result['p_value']:.4f} | H={result['hurst_exponent']:.3f} | HL={result['half_life']:.1f}")

print(f"\nðŸŽ¯ Discovered {len(discovered_pairs)} cointegrated pairs")

## 6. Individual Components

Test individual methods for more control.

In [None]:
# Test cointegration only
coint_result = client.test_cointegration(x, y)
print("Cointegration Test:")
print(f"  Is Cointegrated: {coint_result['is_cointegrated']}")
print(f"  P-Value:         {coint_result['p_value']:.6f}")
print(f"  Hedge Ratio:     {coint_result['hedge_ratio']:.4f}")
print()

# Estimate OU parameters
ou_result = client.estimate_ou_params(spread)
print("OU Parameters:")
print(f"  Theta:    {ou_result['theta']:.6f}")
print(f"  Mu:       {ou_result['mu']:.6f}")
print(f"  Sigma:    {ou_result['sigma']:.6f}")
print(f"  Half-Life: {ou_result['half_life']:.2f}")
print()

# Calculate Hurst exponent
hurst_result = client.calculate_hurst(spread)
print("Hurst Exponent:")
print(f"  H:               {hurst_result['hurst_exponent']:.4f}")
print(f"  Mean Reverting:  {hurst_result['is_mean_reverting']}")
print(f"  95% CI:          {hurst_result['confidence_interval']}")

## 7. Solve HJB PDE for Optimal Control

Use the generic optimal control module to solve the Hamilton-Jacobi-Bellman equation.

In [None]:
hjb_result = client.solve_hjb(
    theta=ou_result['theta'],
    mu=ou_result['mu'],
    sigma=ou_result['sigma'],
    x_min=-3.0,
    x_max=3.0,
    n_points=201,
    dt=0.01,
    max_iterations=10000
)

print("HJB Solution:")
print(f"  Converged:        {hjb_result['converged']}")
print(f"  Iterations:       {hjb_result['iterations']}")
print(f"  Entry Threshold:  {hjb_result['policy']['entry_threshold']:.4f}")
print(f"  Exit Threshold:   {hjb_result['policy']['exit_threshold']:.4f}")

# Plot value function
x_grid = np.linspace(-3.0, 3.0, len(hjb_result['value_function']))
plt.figure(figsize=(10, 5))
plt.plot(x_grid, hjb_result['value_function'])
plt.axvline(hjb_result['policy']['entry_threshold'], color='g', linestyle='--', label='Entry')
plt.axvline(hjb_result['policy']['exit_threshold'], color='r', linestyle='--', label='Exit')
plt.title('Optimal Value Function V(x)')
plt.xlabel('State (spread deviation)')
plt.ylabel('Value')
plt.legend()
plt.grid(True)
plt.show()

## 8. Clean Up

Close the gRPC connection.

In [None]:
client.close()
print("âœ… gRPC connection closed")