In [4]:
import numpy as np
import pandas as pd
import yfinance as yf
import matplotlib.pyplot as plt
import pymc as pm
import arviz as az



def inference_gbm(prices, time_window, plot =False):
    dt = time_window/252
    biweek_log_returns = np.log(prices / prices.shift(time_window)).dropna().values

    with pm.Model() as gbm_model:
        mu     = pm.Normal("mu", mu=0.0, sigma=10)           # drift  (per year)
        sigma  = pm.HalfNormal("sigma", sigma=10)            # volatility (per √year)

        pm.Normal(
            "returns",
            mu     = mu    * dt,
            sigma  = sigma * np.sqrt(dt),
            observed = biweek_log_returns
        )

        trace = pm.sample(
            draws         = 2000,
            tune          = 2000,
            target_accept = 0.95,
            return_inferencedata=True,
            progressbar   = plot,
        )

    if plot:
        az.plot_trace(trace, var_names=["mu", "sigma"])
        plt.tight_layout(); plt.show()

    summary = az.summary(trace, var_names=["mu", "sigma"], round_to=4)

    mu_mean    = summary.loc["mu", "mean"]
    mu_std     = summary.loc["mu", "sd"]
    sigma_mean = summary.loc["sigma", "mean"]
    sigma_std  = summary.loc["sigma", "sd"]
    
    #print(f"mu = {mu_mean:.4f} ± {mu_std:.4f}")
    #print(f"sigma = {sigma_mean:.4f} ± {sigma_std:.4f}")



    return mu_mean, mu_std, sigma_mean, sigma_std

def load_price_data(ticker, start="2013-01-01", end="2023-01-01"):
    df = yf.download(ticker, start=start, end=end)["Close"].dropna()
    return df

def dynamic_hedge_ratio(volatility_A, volatility_B):
    """ Hedge ratio as ratio of volatilities """
    return volatility_A / volatility_B

def calculate_spread(price_A, price_B, hedge_ratio):
    return price_A - hedge_ratio * price_B

def zscore(series, window=20):
    rolling_mean = series.rolling(window).mean()
    rolling_std = series.rolling(window).std()
    return (series - rolling_mean) / rolling_std

def backtest_pairs_strategy(price_A, price_B, window_vol=20, z_entry=2.0, z_exit=0.5):
    """
    Backtest volatility-aware pairs trading with Bayesian volatilities and dynamic hedge ratio.
    
    Parameters:
        price_A, price_B: pd.Series of stock prices
        window_vol: int, rolling window for volatility inference (days)
        z_entry: float, z-score threshold to enter a position
        z_exit: float, z-score threshold to exit position
        
    Returns:
        DataFrame with PnL and positions
    """

    # Step 1: Estimate volatility for each stock over rolling windows using your Bayesian GBM inference
    vol_A_list = []
    vol_B_list = []

    for i in range(window_vol, len(price_A)):
        mu_A, _, sigma_A, _ = inference_gbm(price_A.iloc[i-window_vol:i], time_window=1, plot=False)
        mu_B, _, sigma_B, _ = inference_gbm(price_B.iloc[i-window_vol:i], time_window=1, plot=False)

        vol_A_list.append(sigma_A)
        vol_B_list.append(sigma_B)

    vol_A = pd.Series(vol_A_list, index=price_A.index[window_vol:])
    vol_B = pd.Series(vol_B_list, index=price_B.index[window_vol:])

    # Step 2: Calculate dynamic hedge ratio based on volatilities
    hedge_ratios = vol_A / vol_B

    # Align prices to hedge_ratios index (start from window_vol)
    price_A = price_A.loc[hedge_ratios.index]
    price_B = price_B.loc[hedge_ratios.index]

    # Step 3: Calculate spread and z-score
    spread = calculate_spread(price_A, price_B, hedge_ratios)
    spread_z = zscore(spread, window=20)

    # Step 4: Generate trading signals
    position = 0  # +1 = long spread, -1 = short spread, 0 = flat
    positions = []
    pnl = []
    entry_price = 0

    for z, pA, pB, h_ratio in zip(spread_z, price_A, price_B, hedge_ratios):
        if position == 0:
            if z > z_entry:
                position = -1  # Short spread: sell A, buy B * hedge_ratio
                entry_price = pA - h_ratio * pB
            elif z < -z_entry:
                position = 1   # Long spread: buy A, sell B * hedge_ratio
                entry_price = pA - h_ratio * pB
        elif position == 1:
            if z >= -z_exit:
                # Exit long spread
                pnl.append((pA - h_ratio * pB) - entry_price)
                position = 0
            else:
                pnl.append(0)
        elif position == -1:
            if z <= z_exit:
                # Exit short spread
                pnl.append(entry_price - (pA - h_ratio * pB))
                position = 0
            else:
                pnl.append(0)
        positions.append(position)

    # Append zeros for initial points without pnl
    pnl = pd.Series(pnl, index=spread_z.index[:len(pnl)])
    positions = pd.Series(positions, index=spread_z.index)

    # Calculate cumulative PnL
    cumulative_pnl = pnl.cumsum()

    # Plot results
    plt.figure(figsize=(12,6))
    plt.subplot(2,1,1)
    plt.plot(spread_z.index, spread_z, label="Spread z-score")
    plt.axhline(z_entry, color='red', linestyle='--')
    plt.axhline(-z_entry, color='green', linestyle='--')
    plt.axhline(z_exit, color='orange', linestyle='--')
    plt.axhline(-z_exit, color='orange', linestyle='--')
    plt.legend()
    plt.title("Spread Z-Score and Thresholds")

    plt.subplot(2,1,2)
    plt.plot(cumulative_pnl.index, cumulative_pnl, label="Cumulative PnL")
    plt.legend()
    plt.title("Strategy Cumulative PnL")

    plt.tight_layout()
    plt.show()

    return positions, pnl, cumulative_pnl

if __name__ == "__main__":
    ticker_A = "MSFT"
    ticker_B = "AAPL"

    price_A = load_price_data(ticker_A)
    price_B = load_price_data(ticker_B)

    positions, pnl, cumulative_pnl = backtest_pairs_strategy(price_A, price_B)


[*********************100%***********************]  1 of 1 completed
[*********************100%***********************]  1 of 1 completed
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [mu, sigma]
Sampling 4 chains for 2_000 tune and 2_000 draw iterations (8_000 + 8_000 draws total) took 3 seconds.
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [mu, sigma]
Sampling 4 chains for 2_000 tune and 2_000 draw iterations (8_000 + 8_000 draws total) took 3 seconds.
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [mu, sigma]
Sampling 4 chains for 2_000 tune and 2_000 draw iterations (8_000 + 8_000 draws total) took 2 seconds.
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [mu, sigma]
Sampling 4 chains for 2_000 tune and 2_000 draw iterations (8_000 + 8_000 draws total) took 2 seconds.
Initializing NUTS using ji