<a href="https://colab.research.google.com/github/NeilMitra/2WD-ObstacleAvoidingRobot/blob/master/Screener.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
# =============================================================================
# 📦 CELL 1 – Install / upgrade required packages
# =============================================================================
!pip install --quiet yfinance statsmodels scikit-learn seaborn scipy


# =============================================================================
# 📚 CELL 2 – Imports, global style, utilities
# =============================================================================
import os, warnings, pickle
from datetime import datetime, timedelta

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import yfinance as yf
import statsmodels.api as sm
from statsmodels.tsa.stattools import coint
from statsmodels.stats.diagnostic import het_breuschpagan
from scipy.stats import pearsonr, spearmanr                       # ← NEW
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import TimeSeriesSplit
from sklearn.ensemble import RandomForestClassifier

warnings.filterwarnings("ignore")
plt.style.use("seaborn-v0_8")

In [3]:
# =============================================================================
# ⚙️ CELL 3 – PairsTradingBacktester class
# =============================================================================
class PairsTradingBacktester:
    """
    Machine-learning-augmented statistical-arbitrage back-tester
    (cointegration + z-score + optional RandomForest filter)
    """

    # ---------- init ----------
    def __init__(
        self,
        symbols,
        start_date,
        end_date,
        formation_period=252,
        trading_period=63,
        entry_threshold=2.0,
        exit_threshold=0.5,
        stop_loss_threshold=3.0,
        transaction_cost=0.001,
        use_ml=True,
    ):
        self.symbols = symbols
        self.start_date = start_date
        self.end_date = end_date
        self.formation_period = formation_period
        self.trading_period = trading_period
        self.entry_threshold = entry_threshold
        self.exit_threshold = exit_threshold
        self.stop_loss_threshold = stop_loss_threshold
        self.transaction_cost = transaction_cost
        self.use_ml = use_ml
        self.data = None
        self.returns = None
        self.pairs = []
        self.portfolio_log = []           # ← start as list, convert to DF later

    # ---------- data ----------
    def download_data(self):
        print(f"Downloading {len(self.symbols)} tickers …")
        try:
            data = yf.download(self.symbols, start=self.start_date, end=self.end_date, progress=False)
            if "Adj Close" in data.columns:
                self.data = data["Adj Close"]
            elif "Close" in data.columns:
                self.data = data["Close"]
            else:
                print("No price column found!"); return False

            # Forward-fill gaps, drop columns (tickers) that are entirely NaN
            self.data = self.data.ffill().dropna(axis=1, how="all")
            if self.data.empty:
                print("No usable data downloaded."); return False

            # Keep only symbols that survived cleaning
            self.symbols = self.data.columns.tolist()
            print("Data OK:", self.data.shape)
            return True
        except Exception as e:
            print("Download error:", e)
            return False

    # ---------- pair search ----------
    def find_cointegrated_pairs(self, data_period):
        n = data_period.shape[1]
        pairs, tested = [], set()
        keys = data_period.columns
        for i in range(n):
            for j in range(i + 1, n):
                pair = tuple(sorted((keys[i], keys[j])))
                if pair in tested: continue
                tested.add(pair)

                s1, s2 = np.log(data_period[pair[0]]), np.log(data_period[pair[1]])
                if len(s1) < 20 or len(s2) < 20: continue
                try:
                    score, pval, _ = coint(s1, s2)
                    if pval < 0.05: pairs.append((*pair, pval))
                except: pass
        pairs.sort(key=lambda x: x[2])
        return [(p[0], p[1]) for p in pairs]

    # ---------- feature engineering ----------
    def calculate_spread_and_features(self, pair_data):
        s1, s2 = pair_data.columns
        spread = np.log(pair_data[s1]) - np.log(pair_data[s2])
        mean = spread.rolling(window=self.formation_period // 4, min_periods=20).mean()
        std = spread.rolling(window=self.formation_period // 4, min_periods=20).std()
        z = (spread - mean) / std

        feats = pd.DataFrame(index=pair_data.index)
        feats["z_score"] = z
        feats["spread_volatility"] = std
        feats["pair_correlation"] = (
            pair_data[s1]
            .rolling(self.formation_period // 4, min_periods=20)
            .corr(pair_data[s2])
        )
        feats["spread_lag1"] = spread.diff(1)
        feats["spread_lag5"] = spread.diff(5)

        fwd_change = spread.shift(-5) - spread
        feats["target"] = np.where(
            z > 0.5, (fwd_change < 0).astype(int),
            np.where(z < -0.5, (fwd_change > 0).astype(int), 0)
        )
        feats.dropna(inplace=True)
        return spread, z, feats

    # ---------- ML ----------
    def train_ml_model(self, feats):
        if feats.empty or len(feats) < 50:
            return None, None
        X, y = feats.drop("target", axis=1), feats["target"]
        scaler = StandardScaler().fit(X)
        Xs = scaler.transform(X)
        rf = RandomForestClassifier(n_estimators=100, random_state=42, class_weight="balanced")
        try:
            rf.fit(Xs, y)
            return rf, scaler
        except ValueError as e:
            print("RF fit error:", e); return None, None

    # ---------- back-test ----------
    def run_backtest(self):
        if not self.download_data(): return None
        all_rets = pd.Series(0.0, index=self.data.index)

        for t in range(self.formation_period,
                       len(self.data) - self.trading_period,
                       self.trading_period):

            f_start, f_end = self.data.index[t - self.formation_period], self.data.index[t - 1]
            tr_start, tr_end = self.data.index[t], self.data.index[t + self.trading_period - 1]
            print(f"\nWindow  {f_start.date()} → {f_end.date()}  |  trade {tr_start.date()} → {tr_end.date()}")

            f_data = self.data.loc[f_start:f_end]
            tr_data = self.data.loc[tr_start:tr_end]
            current_pairs = self.find_cointegrated_pairs(f_data)
            if not current_pairs: continue

            period_rets = pd.Series(0.0, index=tr_data.index)
            active_pairs = 0

            for pair in current_pairs:
                s1, s2 = pair; key = tuple(sorted(pair))
                combo = self.data.loc[f_start:tr_end, [s1, s2]]
                if combo.isnull().values.any() or len(combo) < self.formation_period: continue

                spread, z, feats = self.calculate_spread_and_features(combo)

                z_tr = z.reindex(tr_data.index).ffill()
                feats_tr = feats.reindex(tr_data.index).ffill()

                ml, scaler = (None, None)
                if self.use_ml:
                    ml, scaler = self.train_ml_model(feats.loc[:f_end])

                pos = 0
                daily_pnl = pd.Series(0.0, index=tr_data.index)

                for i, date in enumerate(tr_data.index):
                    if pd.isna(z_tr[date]): continue
                    z_now = z_tr[date]; ml_pred = 1
                    if self.use_ml and ml is not None and date in feats_tr.index:
                        feat = feats_tr.loc[[date]].drop("target", axis=1, errors="ignore")
                        if not feat.isnull().values.any():
                            ml_pred = ml.predict(scaler.transform(feat))[0]
                        else:
                            ml_pred = 0

                    # ----- entry -----
                    if pos == 0:
                        if z_now < -self.entry_threshold and ml_pred == 1:
                            pos = 1
                            daily_pnl[date] -= self.transaction_cost * 2
                            self.portfolio_log.append(
                                dict(Date=date, Pair=key, Action="Enter Long", Z=z_now, ML=ml_pred)
                            )
                        elif z_now > self.entry_threshold and ml_pred == 1:
                            pos = -1
                            daily_pnl[date] -= self.transaction_cost * 2
                            self.portfolio_log.append(
                                dict(Date=date, Pair=key, Action="Enter Short", Z=z_now, ML=ml_pred)
                            )

                    # ----- manage / exit -----
                    else:
                        p_s1, p_s2 = tr_data.loc[date, s1], tr_data.loc[date, s2]
                        if i > 0:
                            prev = tr_data.index[i - 1]
                            ret = pos * (
                                (p_s1 / tr_data.loc[prev, s1] - 1) -
                                (p_s2 / tr_data.loc[prev, s2] - 1)
                            ) / 2.0
                            daily_pnl[date] += ret

                        exit_sig, reason = False, ""
                        if pos == 1 and z_now >= -self.exit_threshold:
                            exit_sig, reason = True, "Exit Long"
                        elif pos == -1 and z_now <= self.exit_threshold:
                            exit_sig, reason = True, "Exit Short"
                        elif pos == 1 and z_now < -self.stop_loss_threshold:
                            exit_sig, reason = True, "Stop-loss Long"
                        elif pos == -1 and z_now > self.stop_loss_threshold:
                            exit_sig, reason = True, "Stop-loss Short"
                        elif date == tr_data.index[-1]:
                            exit_sig, reason = True, "End of period"

                        if exit_sig:
                            daily_pnl[date] -= self.transaction_cost * 2
                            self.portfolio_log.append(
                                dict(Date=date, Pair=key, Action=reason, Z=z_now, PnL=daily_pnl[date])
                            )
                            pos = 0

                if not daily_pnl.eq(0).all():
                    period_rets = period_rets.add(daily_pnl, fill_value=0)
                    active_pairs += 1

            if active_pairs:
                all_rets.loc[tr_start:tr_end] = all_rets.loc[tr_start:tr_end].add(
                    period_rets / active_pairs, fill_value=0
                )
            print("Active pairs:", active_pairs)

        self.returns = all_rets.loc[self.data.index[self.formation_period]:]
        self.portfolio_log = pd.DataFrame(self.portfolio_log)
        print("Back-test complete.")
        return self.returns

    # ---------- simple plots & metrics ----------
    def plot_performance(self, rets, title="Cumulative Returns", fname=None):
        if rets is None or rets.empty: return
        plt.figure(figsize=(12, 6))
        ((1 + rets).cumprod() - 1).plot()
        plt.title(title); plt.ylabel("Cumulative return"); plt.grid(True)
        if fname: plt.savefig(fname); plt.close()

    def calculate_performance_metrics(self, rets):
        if rets is None or rets.empty: return {}
        cum_ret = (1 + rets).prod() - 1
        ann_ret = (1 + rets.mean()) ** 252 - 1
        ann_vol = rets.std() * np.sqrt(252)
        sharpe = ann_ret / ann_vol if ann_vol else 0
        draw = (1 + rets).cumprod()
        max_dd = (draw / draw.cummax() - 1).min()
        neg = rets[rets < 0]
        sortino = ann_ret / (neg.std() * np.sqrt(252)) if len(neg) else 0
        print(f"Sharpe {sharpe:.2f} | Sortino {sortino:.2f} | Max-DD {max_dd:.2%}")
        return dict(CumRet=cum_ret, AnnRet=ann_ret, AnnVol=ann_vol,
                    Sharpe=sharpe, Sortino=sortino, MaxDD=max_dd)