In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from scipy import interpolate
from arch import arch_model
import warnings
warnings.filterwarnings('ignore')

In [None]:
# 1. MISSING–DATA
def handle_missing_data(df, max_ffill_days: int = 3):
    """
    IV IMPUTATION using GARCH(1,1)
    Method: Generalized Autoregressive Conditional Heteroskedasticity
    - Standard approach for volatility imputation in quant finance
    - Preserves volatility clustering and mean reversion
    Reference: Engle (2001), "GARCH 101: The Use of ARCH/GARCH Models in Applied Econometrics"
    """
    full_idx = pd.date_range(df.index.min().floor('T'),
                             df.index.max().ceil('T'), 
                             freq='T', tz=df.index.tz)
    full_idx = full_idx[full_idx.dayofweek < 5]
    df_full = df.reindex(full_idx)
    
    raw_nan = df_full.isna().any(axis=1)
    
    mins_per_session = 375
    ffill_limit = max_ffill_days * mins_per_session
    df_tmp = df_full.fillna(method='ffill', limit=ffill_limit)
    
    df_filled = df_tmp.copy()
    
    for col in ['nifty', 'banknifty']:
        series = df_tmp[col]
        
        if series.isna().any():
            print(f"Applying GARCH imputation to {col.upper()}...")

            available_data = series.dropna()
            
            if len(available_data) < 100:
                df_filled[col] = series.interpolate(method='linear')
                continue
            
            returns = available_data.pct_change().dropna() * 100 
            
            # Fit GARCH(1,1) model
            try:
                garch_model = arch_model(returns, vol='GARCH', p=1, q=1, rescale=False)
                garch_fitted = garch_model.fit(disp='off')

                last_level = available_data.iloc[-1]
                last_return = returns.iloc[-1]

                missing_mask = series.isna()
                gap_labels = (missing_mask != missing_mask.shift()).cumsum()[missing_mask]
                
                for _, gap_idx in series[missing_mask].groupby(gap_labels).groups.items():
                    if len(gap_idx) <= ffill_limit:
                        continue  
                    
                    gap_length = len(gap_idx)
                
                    forecast = garch_fitted.forecast(horizon=min(gap_length, 1000), reindex=False)
                    
                    vol_forecast = np.sqrt(forecast.variance.values[-1])
                    
                    np.random.seed(42)  
 
                    omega = garch_fitted.params['omega']
                    alpha = garch_fitted.params['alpha[1]']
                    beta = garch_fitted.params['beta[1]']
                    
                    current_level = last_level
                    current_vol = vol_forecast[0] if len(vol_forecast) > 0 else np.std(returns)
                    
                    imputed_levels = []
                    
                    for i in range(gap_length):
                        if i > 0:
                            current_vol = np.sqrt(omega + alpha * (prev_return**2) + beta * (current_vol**2))
                        
                        random_return = np.random.normal(0, current_vol)
                                               
                        mean_reversion = 0.999 
                        long_term_mean = available_data.mean()
                        
                        current_level = current_level * mean_reversion + (1-mean_reversion) * long_term_mean
                        current_level = current_level * (1 + random_return/100)
                        
                        current_level = max(current_level, 0.01)
                        
                        imputed_levels.append(current_level)
                        prev_return = random_return
                    
                    df_filled.loc[gap_idx, col] = imputed_levels
                    
            except Exception as e:
                print(f"GARCH fitting failed for {col}, using linear interpolation: {e}")
                df_filled[col] = series.interpolate(method='linear')
    
    spread = df_filled['banknifty'] - df_filled['nifty']
    imputed_spread = spread[raw_nan.reindex(df_filled.index, fill_value=False)]
    
    print(f"\nGARCH Imputation Results:")
    print(f"Imputed {(raw_nan.reindex(df_filled.index, fill_value=False)).sum():,} rows")
    if len(imputed_spread) > 0:
        print(f"Imputed spread range: {imputed_spread.min():.4f} to {imputed_spread.max():.4f}")
        print(f"Full data spread range: {spread.min():.4f} to {spread.max():.4f}")
    
    df_out = df_filled.between_time('09:15', '15:30')
    df_out['is_interpolated'] = raw_nan.reindex(df_out.index, fill_value=False).astype(np.int8)
    
    return df_out

In [None]:
# 2. DATA LOADING
def load_and_preprocess_data(file_path='data.parquet'):
    df = pd.read_parquet(file_path)                       
    df = df.between_time('09:15:00', '15:30:00')          
    df = handle_missing_data(df)                         
    return df

In [None]:
# 3. SPREAD  &  Z-SCORE
def calculate_spread_and_zscore(df, lookback_window=200):
    df['spread']      = df['banknifty'] - df['nifty']

    df['spread_mean'] = df['spread'].rolling(
                            lookback_window, min_periods=lookback_window).mean()
    df['spread_std']  = df['spread'].rolling(
                            lookback_window, min_periods=lookback_window).std()

    z_raw            = (df['spread'] - df['spread_mean']) / df['spread_std'].replace(0, np.nan)
    df['z_score']    = z_raw.fillna(0.0)               

    return df

In [None]:
# 4. SIGNALS
def generate_trading_signals(df, entry_threshold=2.0, exit_threshold=0.5):
    z = df['z_score']

    long_entry   =  (z < -entry_threshold)
    short_entry  =  (z >  entry_threshold)
    exit_signal  =  (z.abs() < exit_threshold)

    raw_dir = np.where(exit_signal, 0,
              np.where(long_entry,   1,
              np.where(short_entry, -1, np.nan)))

    df['position'] = (pd.Series(raw_dir, index=df.index)
                        .ffill()               
                        .fillna(0)             
                        .astype(np.int8))

    eod = df.index.time == pd.to_datetime('15:30:00').time()
    df.loc[eod, 'position'] = 0

    df['signal']          = df['position'].diff().fillna(df['position']).astype(np.int8)
    df['position_change'] = df['signal']                                         # kept for metrics
    return df

In [None]:
# 5. P/L  
def calculate_pnl(df, trade_cost_bps=0.0325):  
    df['trade_pnl'] = 0.0
    cumulative_pnl  = 0.0

    open_pos        = 0
    entry_spread    = entry_tte = entry_dir = None

    for idx, row in df.iterrows():
        curr_pos = row['position']

        # -------- ENTRY ------------------------------------------------------
        enter = (open_pos == 0 and curr_pos != 0) or \
                (open_pos != 0 and np.sign(curr_pos) != np.sign(open_pos))

        # -------- EXIT -------------------------------------------------------
        exit_  = (open_pos != 0 and curr_pos == 0) or \
                 (open_pos != 0 and np.sign(curr_pos) != np.sign(open_pos))

        if exit_:
            spread_change = row['spread'] - entry_spread
            avg_tte       = ((entry_tte + row['tte']) / 2.) ** 0.7
            pnl           = spread_change * avg_tte * entry_dir
            pnl          -= trade_cost_bps                   
            df.at[idx, 'trade_pnl'] = pnl
            cumulative_pnl         += pnl
            open_pos = 0                                     
        if enter:
            entry_spread = row['spread']
            entry_tte    = row['tte']
            entry_dir    = curr_pos
            open_pos     = curr_pos

    df['cumulative_pnl'] = df['trade_pnl'].cumsum()
    return df

In [None]:
# 6.  PERFORMANCE  METRICS
def calculate_performance_metrics(df):
    total_pnl   = df['cumulative_pnl'].iloc[-1]
    num_trades  = (df['trade_pnl'] != 0).sum()            

    daily_pnl   = df['cumulative_pnl'].resample('D').last().diff().dropna()
    sharpe      = (daily_pnl.mean() / daily_pnl.std() * np.sqrt(252)
                   if daily_pnl.std() != 0 else 0)

    equity      = df['cumulative_pnl']
    running_max = equity.cummax()
    drawdown    = equity - running_max
    drawdown_pct= drawdown / running_max.replace(0, np.nan)
    max_dd      = drawdown.min()
    max_dd_pct  = drawdown_pct.min()

    trade_pnls  = df.loc[df['trade_pnl'] != 0, 'trade_pnl']
    win_rate    = (trade_pnls > 0).mean() if len(trade_pnls) else 0

    return {
        'Total P/L'          : total_pnl,
        'Number of Trades'   : num_trades,
        'Sharpe Ratio'       : sharpe,
        'Max Drawdown (abs)' : max_dd,
        'Win Rate'           : win_rate
    }

In [None]:
# 7. PLOTS

def plot_results(df):
    """
    Plot key results
    """
    fig, axes = plt.subplots(3, 1, figsize=(15, 12))

    # Plot 1: Spread and Z-Score
    axes[0].plot(df.index, df['spread'], label='Spread', alpha=0.7)
    axes[0].set_ylabel('Spread')
    axes[0].legend()
    axes[0].set_title('Bank Nifty - Nifty IV Spread')

    ax0_twin = axes[0].twinx()
    ax0_twin.plot(df.index, df['z_score'], color='red', label='Z-Score', alpha=0.7)
    ax0_twin.axhline(y=2, color='r', linestyle='--', alpha=0.5)
    ax0_twin.axhline(y=-2, color='r', linestyle='--', alpha=0.5)
    ax0_twin.set_ylabel('Z-Score')
    ax0_twin.legend()

    # Plot 2: Positions
    axes[1].plot(df.index, df['position'], label='Position', color='orange')
    axes[1].set_ylabel('Position')
    axes[1].set_title('Trading Positions')
    axes[1].legend()

    # Plot 3: Cumulative P/L
    axes[2].plot(df.index, df['cumulative_pnl'], label='Cumulative P/L', color='green')
    axes[2].set_xlabel('Date')
    axes[2].set_ylabel('Cumulative P/L')
    axes[2].set_title('Strategy Performance')
    axes[2].legend()

    plt.tight_layout()
    plt.show()

In [None]:
# 8. MAIN
def main():
    df  = load_and_preprocess_data('data.parquet')
    df  = calculate_spread_and_zscore(df, lookback_window=100)
    df  = generate_trading_signals(df, entry_threshold=2.0, exit_threshold=0.5)
    df  = calculate_pnl(df)
    mts = calculate_performance_metrics(df)

    for k, v in mts.items():
        print(f'{k:22}: {v:,.4f}' if isinstance(v, float) else f'{k:22}: {v}')

    plot_results(df)
    return df, mts

if __name__ == '__main__':
    df_results, performance = main()