In [1]:
import yfinance as yf

import os
import pickle
import joblib
import warnings
warnings.filterwarnings("ignore")
from tqdm import tqdm
from datetime import timedelta, datetime, date
from itertools import product
from tabulate import tabulate

import pandas as pd
pd.set_option('display.max_columns', None)
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

from sklearn.model_selection import RandomizedSearchCV, GridSearchCV, train_test_split, TimeSeriesSplit
from sklearn.base import BaseEstimator, TransformerMixin, clone, ClassifierMixin
from sklearn.impute import SimpleImputer
from sklearn.preprocessing import StandardScaler, OneHotEncoder, FunctionTransformer
from sklearn.compose import ColumnTransformer, make_column_transformer
from sklearn.pipeline import Pipeline, make_pipeline
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import (
    classification_report, confusion_matrix, ConfusionMatrixDisplay, 
    roc_curve, RocCurveDisplay, auc, roc_auc_score, f1_score,
    precision_recall_curve, PrecisionRecallDisplay, 
    precision_score, recall_score, average_precision_score
)
from sklearn.utils.class_weight import compute_class_weight
from xgboost import XGBClassifier, plot_importance


import ta
from ta.trend import MACD, ADXIndicator
from scipy.signal import argrelextrema
from arch import arch_model

In [20]:
# ===============================
# ===== serializing objects =====
# ===============================
def save_pkl(path, data):
    os.makedirs(os.path.dirname(path), exist_ok=True)

    with open(path, 'wb') as f:
        pickle.dump(data, f)
    print(f"Saved data to {path}.")

def read_pkl(path):
    if not os.path.exists(path):
        raise FileNotFoundError(f"{path} does not exist.")
    with open(path, 'rb') as f:
        data = pickle.load(f)
    print(f"Loaded data from {path}.")
    return data


# =======================
# ===== data loader =====
# =======================

def get_data(
        ticker, 
        start_date='1999-01-01', 
        end_date=None, 
        save_csv=False
):
    os.makedirs("data", exist_ok=True)
    filename = f"data/{ticker.lower()}.csv"

    # add one day to end_date because yf.download end is not inclusive
    if end_date is not None:
        end_date = pd.to_datetime(end_date).date()
        end_date = (end_date + timedelta(days=1)).isoformat()

    data = yf.download(
        ticker.upper(), 
        start=start_date, 
        end=end_date, 
        auto_adjust=True,
        progress=False
    )

    # Handle empty or failed download
    if data.empty:
        print(f"No data found for {ticker}")
        return pd.DataFrame()
    
    # Flatten column headers if it's a MultiIndex (e.g., from group_by='ticker')
    if isinstance(data.columns, pd.MultiIndex):
        data.columns = data.columns.get_level_values(0)
    
    data['Ticker'] = ticker.lower()
    data.index = pd.to_datetime(data.index)
    data.columns = [col.lower() for col in data.columns]

    if save_csv:
        data.to_csv(filename)
        print(f"Saved data for {ticker} to {filename}")
        
    return data

# get list of tickers in s&p500
def get_sp500(): # accurate as of 2025-07-31; may need update in the future
    url = "https://en.wikipedia.org/wiki/List_of_S%26P_500_companies"
    sp500_df = pd.read_html(url, header=0)[0]
    sp500_tickers = sp500_df['Symbol'].tolist()
    sp500_tickers = [t.replace('.', '-') for t in sp500_tickers] 
    # should exclude goog, foxa, nws (as of 2025-07-31)
    excluded_tickers = {'GOOG', 'FOXA', 'NWS'}
    sp500_tickers = [t for t in sp500_tickers if t not in excluded_tickers]
    # return lower case for consistency
    sp500_tickers = [t.lower() for t in sp500_tickers]
    return sp500_tickers # a list of string


# get metadata
def get_metadata(ticker, sp500_tickers=None, sector_only=False):
    try:
        info = yf.Ticker(ticker.upper()).info
        if sector_only:
            return {
            'ticker': ticker,
            'sector': yf.Ticker(ticker.upper()).info.get('sector'),
        }
        else:
            if sp500_tickers is None: 
                sp500_tickers = get_sp500()
            return {
                'ticker': ticker,
                'sector': info.get('sector'),
                # 'industry': info.get('industry'), 
                # # try not to ohe this for now because its a lot more granular
                'market_cap': info.get('marketCap'),
                'avg_volume': info.get('averageVolume'),
                'beta': info.get('beta'),
                'is_sp500': ticker in sp500_tickers
            }
    except Exception as e:
        print(f"Failed for {ticker}: {e}")
        return None
    

# get a df that contains tickers from s&p500 and their sectors
def get_sp500_tickers_sectors_mindates(
        path='files/sp500_tickers_sectors_mindates.pkl', 
        sp500_tickers=None, 
        redo=False
):
    if os.path.exists(path) and not redo:
        return read_pkl(path) # a pd.DataFrame
    else: 
        if sp500_tickers is None:
            sp500_tickers = get_sp500()
        ls = []
        for i in tqdm(sp500_tickers, desc="Processing tickers"):
            d = get_metadata(i,sp500_tickers,sector_only=True)
            d['mindate'] = yf.Ticker(i.upper()).history(period='max').index.min().date()
            ls.append(d)
        df = pd.DataFrame(ls)
        save_pkl(path, df)
        return df
    

# =======================
# ===== feature eng =====
# =======================

def feature_engineering(df): # for trading at open tomorrow
    # bollinger bands and rsi
    df = df.copy()
    df['sma30'] = df['close'].rolling(30).mean() # wont be used 
    df['sma10'] = df['close'].rolling(10).mean() # wont be used
    df['sma_diff'] = df['sma10'] - df['sma30']
    df['sma_slope'] = df['sma10'].diff()
    df['std30'] = df['close'].rolling(30).std() # wont be used 
    df['bollinger_upper'] = df['sma30'] + 2 * df['std30'] # wont be used 
    df['bollinger_lower'] = df['sma30'] - 2 * df['std30'] # wont be used 
    df['percent_b'] = (df['close'] - df['bollinger_lower']) / (df['bollinger_upper'] - df['bollinger_lower'])
    df['bollinger_z'] = (df['close'] - df['sma30']) / df['std30']
    df['price_near_lower_bb'] = (df['close'] <= df['bollinger_lower'] * 1.01).astype(int)
    df['rsi14'] = ta.momentum.RSIIndicator(df['close'], window=14).rsi()    
    df['prod_bollingerz_rsi'] = df['percent_b'] * df['rsi14']

    # Detect local lows
    df['rsi_smooth'] = df['rsi14'].rolling(3).mean() # wont be used 
    rsi_vals = df['rsi_smooth'].values
    local_lows = argrelextrema(rsi_vals, np.less, order=5)[0]
    df['rsi_local_low'] = 0
    df.iloc[local_lows, df.columns.get_loc('rsi_local_low')] = 1

    # some other useful features  
    df['daily_return'] = df['open'].pct_change()
    df['rolling_volatility14'] = df['daily_return'].rolling(window=14).std()
    df['atr'] = ta.volatility.AverageTrueRange(high=df['high'], low=df['low'], close=df['close']).average_true_range()

    # GARCH(1,1) on returns
    returns = df['open'].pct_change().dropna() * 100 # in percent
    am = arch_model(returns, vol='Garch', p=1, q=1, mean='constant')
    res = am.fit(disp='off')
    df['garch_vol'] = res.conditional_volatility # in percent

    # time related features
    # df['year'] = df.index.year
    year_min = df.index.year.min()
    df['year_since'] = df.index.year-year_min
    
    df['month'] = df.index.month
    df['week'] = df.index.isocalendar().week
    df['dayofweek'] = df.index.dayofweek

    # trend following contextual features
    # sma_slope (already added)
    macd = MACD(close=df['close'], window_slow=26, window_fast=12, window_sign=9)
    # df['macd'] = macd.macd()                   # EMA12 - EMA26
    # df['macd_signal'] = macd.macd_signal()     # 9-day EMA of MACD
    df['macd_diff'] = macd.macd_diff()         # Histogram: MACD - Signal

    adx = ADXIndicator(high=df['high'], low=df['low'], close=df['close'], window=14)

    df['adx'] = adx.adx()              # Trend strength
    df['adx_pos'] = adx.adx_pos()      # +DI; wont be used
    df['adx_neg'] = adx.adx_neg()      # -DI; wont be used

    # df['macd_uptrend'] = (df['macd_diff'] > 0).astype(int)
    df['strong_trend'] = (df['adx'] > 25).astype(int)
    df['up_trend_context'] = ((df['adx'] > 25) & (df['adx_pos'] > df['adx_neg'])).astype(int)
    df['down_trend_context'] = ((df['adx'] > 25) & (df['adx_neg'] > df['adx_pos'])).astype(int)

    # df.dropna(inplace=True)
    return df


# useful for adding information from market trend e.g. spy, qqq
def market_trend(etf:str, end_date = None, load_csv = True, save_csv = False):
    csv_path = f"data/{etf}.csv"
    if load_csv and os.path.exists(csv_path):
            df_etf = pd.read_csv(csv_path, index_col=0, parse_dates=True)
    else:
        df_etf = get_data(etf, end_date=end_date)

    df_etf['sma10'] = df_etf['close'].rolling(10).mean()
    df_etf['sma20'] = df_etf['close'].rolling(20).mean()
    df_etf['sma50'] = df_etf['close'].rolling(50).mean()
    df_etf['sma200'] = df_etf['close'].rolling(200).mean()
    df_etf['trend_10_50'] = (df_etf['sma10'] > df_etf['sma50']).astype(int)
    df_etf['trend_20_50'] = (df_etf['sma20'] > df_etf['sma50']).astype(int)
    df_etf['trend_50_200'] = (df_etf['sma50'] > df_etf['sma200']).astype(int)
    df = df_etf[['trend_10_50','trend_20_50','trend_50_200']]
    if save_csv:
        df.to_csv(csv_path)
    return df


def feature_engineering_market(df, etf:str, etfs=None):
    df = df.copy()

    if etf not in etfs:
        df_etf = market_trend(etf)
    else:
        df_etf = etfs[etf][:df.index.max()]

    df = df.merge(df_etf['trend_10_50'].rename(f'{etf}_trend_10_50'), left_index=True, right_index=True, how='left')
    df = df.merge(df_etf['trend_20_50'].rename(f'{etf}_trend_20_50'), left_index=True, right_index=True, how='left')
    df = df.merge(df_etf['trend_50_200'].rename(f'{etf}_trend_50_200'), left_index=True, right_index=True, how='left')
    return df
        

def feature_engineering_metadata(df, sp500_tickers):
    df = df.copy()
    if 'ticker' not in df.columns:
        raise KeyError("df does not contain ticker info")
    if df['ticker'].nunique() != 1: 
        raise ValueError(f"Expected exactly 1 unique ticker, but got {df['ticker'].nunique()}")
    ticker = df['ticker'].iloc[0]
    md = get_metadata(ticker, sp500_tickers=sp500_tickers)
    for k, v in md.items():
        df[k] = v 
    return df

# ===========================
# ===== target creation =====
# ===========================

# trading at open tomorrow (t+1)
# based on today's data (ohlcv and more), decide whether to long tomorrow
def create_target_long(df_orig, timing = 'open', lookahead=5, strategy='static', **kwargs) -> pd.DataFrame:
    """
    entry: assume always enter at open
    exit: timing is "exit" timing, can either be 'open' or 'high'
    use 'high' when your platform performs auto exit for you when price exceeds your profit threshold 
    use 'open' for simplification (only trade at open)

    strategy: either 'static' or 'dynamic'
    static: use a set (expected return rate) threshold throughout, default is 0.01
    kwargs for static: threshold 
    dynamic: mimic the entry_exit function below, use a profit loss ratio scaled by volatility
    kwargs: vol (one of 'atr', 'garch_vol'), upper = 0.95, take_profit (default 1.2)
    upper is the quantile that cap the volatility (so that won't have extremely high threshold caused by high volatility)
    (no need stop loss here; everything not over profit threshold will be labeled 0)

    label as 1 if max return over the next 'lookahead' days is >= threshold, else label as 0.
    only returns a series (ie y), not a dataframe
    """
    df = df_orig.copy()
    if timing not in ['open','high']: 
        raise ValueError(f"Invalid timing value: {timing}. Expected 'open' or 'high'.") 
    if strategy == 'static':
        allowed = {'threshold'}
    elif strategy == 'dynamic':
        allowed = {'vol', 'upper', 'take_profit'}
    else: 
        raise ValueError(f"Invalid strategy value: {strategy}. Expected 'static' or 'dynamic'.") 

    for k in kwargs:
        if k not in allowed:
            raise ValueError(f"Unexpected keyword argument: '{k}' for strategy='{strategy}'")


    df['price_tmrw'] = df['open'].shift(-1) # trade at open next day

    df['seq_index'] = range(len(df))
    future_max_list = []
    for i in df['seq_index']:
        if i + 1 + lookahead <= len(df):
            window = df[timing].iloc[i+1 : i+1+lookahead]
            future_max_list.append(window.max())
        else:
            future_max_list.append(np.nan)

    df['future_max'] = future_max_list
    df['future_return'] = (df['future_max'] - df['price_tmrw']) / df['price_tmrw']

    if strategy == 'static':
        threshold = kwargs.get('threshold', 0.01)
        df['threshold'] = threshold 
    elif strategy == 'dynamic': 
        vol = kwargs.get('vol')
        if vol not in ['atr', 'garch_vol']:
            raise ValueError("Invalid or missing 'vol'. Must be one of: 'atr', 'garch_vol'")
        if vol not in df.columns:
            raise KeyError(f"Volatility column '{vol}' not found in DataFrame")
        upper = kwargs.get('upper', 0.95)
        vol_cap = df[vol].quantile(upper)
        take_profit = kwargs.get('take_profit', 1)
        df['effective_vol'] = df[vol].clip(upper=vol_cap)

        df['threshold'] = df['effective_vol']*take_profit/100 
    else: 
        raise ValueError(f"Invalid strategy value: {strategy}. Expected 'static' or 'dynamic'.") 
    
    df['target_long'] = (df['future_return'] >= df['threshold']) \
        .where(df['future_return'].notna()) \
        .astype('Int64')  
    
    return df_orig.merge(df[['target_long']], how='left', left_index=True, right_index=True)


# ===================================================
# ===== wrangling (for all features and target) =====
# ===================================================
def wrangling(
        df_ohlcv, 
        etfs=None, 
        sp500_tickers=None,
        **kwargs
):
    df = feature_engineering(df_ohlcv)
    df = feature_engineering_market(df, 'spy', etfs)
    df = feature_engineering_market(df, 'qqq', etfs)
    df = feature_engineering_metadata(df, sp500_tickers=sp500_tickers)
    # drop na first just to be safe for running target creation 
    # df.dropna(inplace=True)
    df = create_target_long(df, **kwargs)
    return df
    

# ============================================================
# ===== generate pkl file for data from multiple tickers =====
# ============================================================

def get_data_global_model(
        list_of_tickers, 
        path='files/data_global.pkl', 
        end_date=None, 
        redo=False, 
        **kwargs
):
    if os.path.exists(path) and not redo:
        d = read_pkl(path) # a dict
        print(f'loaded existing {path} with {len(d)} tickers')
    else: 
        d = {}
        print('no existing pkl found; starting fresh')

    for i in tqdm(list_of_tickers, desc="Processing tickers"):
        if i in d:
            continue
        try:
            ohlcv = get_data(i, end_date=end_date, save_csv=False)
            df = wrangling(ohlcv, **kwargs)
            d[i] = {
                'ohlcv': ohlcv,
                'prepped_data': df,
                'as_of': ohlcv.index.max().date()
            }
        except Exception as e:
            print(f"Failed for {i}: {e}", flush=True)

    save_pkl(path, d)
    return d


# =======================
# ===== XGB Wrapper =====
# =======================

class XGB(BaseEstimator, ClassifierMixin):
    def __init__(self, **kwargs):
        self.xgb_params = kwargs
        self.model = None

    def fit(self, X, y):
        y = np.asarray(y)
        pos = np.sum(y == 1)
        neg = np.sum(y == 0)
        scale_pos_weight = neg / pos if pos != 0 else 1.0

        self.model = XGBClassifier(
            scale_pos_weight=scale_pos_weight,
            use_label_encoder=False,
            eval_metric='logloss',
            objective='binary:logistic',
            **self.xgb_params
        )
        self.model.fit(X, y)
        return self

    def predict(self, X):
        return self.model.predict(X)

    def predict_proba(self, X):
        return self.model.predict_proba(X)

    def get_params(self, deep=True):
        return self.xgb_params

    def set_params(self, **params):
        self.xgb_params.update(params)
        return self


# =======================
# ===== backtesting =====
# =======================

def pred_proba_to_signal(y_proba, threshold=0.5):
    return (y_proba >= threshold).astype(int)

def entry_exit(df, use_vol=None, take_profit=1, stop_loss=1, min_return = 0.01): 
    '''
    FOR EACH INSTRUMENT (don't use the aggregated one that contains multiple tickers)
    takes in the X_test dataframe containing model signals,
    returns a df that contains entry and exit dates, prices, returns, and holding days
    use_vol: either 'atr' or 'garch_vol'
    take_profit and stop loss as percentage
    '''
    # if date as index
    if pd.api.types.is_datetime64_any_dtype(df.index): 
        df = df.reset_index().rename(columns={'Date': 'date'})

    trades = []
    i = 0
    n = len(df)
    if use_vol not in [None, 'atr', 'garch_vol']:
        raise ValueError("use_vol must be one of: None, 'atr' or 'garch_vol'")
    
    if use_vol == 'atr':
        multiplier = df['atr']/ df['open']
    elif use_vol == 'garch_vol':
        vol_cap = df['garch_vol'].quantile(0.9)
        multiplier = (df['garch_vol'].clip(upper=vol_cap)) / 100
    else:
        multiplier = pd.Series(0.01, index=df.index)
    effective_profit_threshold = (take_profit * multiplier).clip(lower=min_return)

    while i < n - 6:  # we need at least 5 days ahead, plus trading at open tomorrow
        if df['model_signal'].iloc[i] == 1:
            entry_date = df['date'].iloc[i+1]
            entry_price = df['open'].iloc[i+1]
            exit_price = None
            exit_date = None
            holding = None
            exit_reason = None

            for j in range(1, 6):  # check up to 5 days ahead
                if i + 1 + j >= n:
                    break

                next_price = df['open'].iloc[i + 1 + j]
                ret = (next_price - entry_price) / entry_price

                # Exit Conditions
                if ret >= effective_profit_threshold.iloc[i]:  # profit target
                    exit_price = next_price
                    exit_date = df['date'].iloc[i + 1 + j]
                    holding = j
                    exit_reason = 'profit_target'
                    break
                elif ret <= -stop_loss * multiplier.iloc[i]:  # stop loss
                    exit_price = next_price
                    exit_date = df['date'].iloc[i + 1 + j]
                    holding = j
                    exit_reason = 'stop_loss'
                    break
                elif df['open'].iloc[i + 1 + j] >= df['sma30'].iloc[i + 1 + j]:  # revert to SMA30
                    exit_price = next_price
                    exit_date = df['date'].iloc[i + 1 + j]
                    holding = j
                    exit_reason = 'revert_to_sma30'
                    break

            if exit_price is None:
                # Max holding (5th day)
                exit_price = df['open'].iloc[i + 1 + 5]
                exit_date = df['date'].iloc[i + 1 + 5]
                exit_reason = 'max_holding_expired'
                holding = 5

            trade_return = (exit_price - entry_price) / entry_price
            trades.append({
                'entry_date': entry_date,
                'exit_date': exit_date,
                'entry_price': entry_price,
                'exit_price': exit_price,
                'return': trade_return,
                'holding_days': holding,
                'exit_reason': exit_reason
            })

            i = i + holding  # skip to the day after exit
        else:
            i += 1

    if len(trades) > 0:
        return pd.DataFrame(trades)
    
    return pd.DataFrame(columns=[
        'entry_date',
        'exit_date',
        'entry_price',
        'exit_price',
        'return',
        'holding_days',
        'exit_reason'
    ])


def get_cagr(df_trades):
    if len(df_trades) > 0:
        start_date = df_trades['entry_date'].min()
        end_date = df_trades['exit_date'].max()
        n_years = (end_date - start_date).days / 365.25
        capital = 1
        for r in df_trades['return']:
            capital *= (1 + r)
        cagr = capital ** (1 / n_years) - 1
        return cagr
    return 0

def get_sharpe(df_trades, total_trading_days):
    if len(df_trades) > 0 and total_trading_days > 0:
        ret_mean = df_trades['return'].mean()
        ret_std = df_trades['return'].std()
        n_trades = len(df_trades)
        trades_per_year = (n_trades/total_trading_days) * 252
        return (ret_mean/ret_std) * np.sqrt(trades_per_year)
    return 0

def get_expectancy(df_trades):
    if len(df_trades) == 0:
        return 0

    wins = df_trades[df_trades['return'] > 0]
    losses = df_trades[df_trades['return'] <= 0]

    win_rate = len(wins) / len(df_trades)
    loss_rate = 1 - win_rate

    avg_win = wins['return'].mean() if not wins.empty else 0
    avg_loss = abs(losses['return'].mean()) if not losses.empty else 0

    expectancy = win_rate * avg_win - loss_rate * avg_loss
    return expectancy

def backtest(
        X_backtest_input, 
        model, 
        proba_threshold = 0.5, 
        use_vol=None, 
        take_profit=1, 
        stop_loss=1, 
        min_return = 0.01,
        X_model_input=None
):
    if X_model_input is None:
        X_model_input = X_backtest_input

    trade_stats = {} 

    y_proba = model.predict_proba(X_model_input)[:, 1]
    trade_stats['max_proba'] = y_proba.max()

    signals = pred_proba_to_signal(y_proba, threshold=proba_threshold)
    X_test_signal = X_backtest_input.copy()
    X_test_signal = X_test_signal.assign(model_signal=signals, y_proba=y_proba)

    df_trade = entry_exit(X_test_signal, use_vol, take_profit, stop_loss, min_return)
    if len(df_trade) > 0:
        total_holding_days = df_trade['holding_days'].sum()
        total_trading_days = len(X_test_signal)
        if total_trading_days > 0:
            holding_time_percentage = total_holding_days/total_trading_days
        else: 
            holding_time_percentage = 0

        n_trades = len(df_trade)
        n_wins = len(df_trade[df_trade['return'] > 0])
        if n_trades > 0:
            win_rate = n_wins/n_trades
        else: 
            win_rate = 0

        cagr = get_cagr(df_trade)
        sharpe = get_sharpe(df_trade, total_holding_days)
        expectancy = get_expectancy(df_trade)

        trade_stats = {
            'exit_reason_spread': df_trade['exit_reason'].value_counts(),
            'holding_days_spread': df_trade['holding_days'].value_counts(),
            'total_trading_days': total_trading_days,
            'total_holding_days': total_holding_days,
            'holding_time_percentage': holding_time_percentage,
            'n_trades': n_trades,
            'n_wins': n_wins,
            'win_rate': win_rate,
            'cagr': cagr,
            'sharpe': sharpe,
            'expectancy': expectancy,
        }
    
    else:
        trade_stats = {
            'exit_reason_spread': 'not available',
            'holding_days_spread': 'not available',
            'total_trading_days': 0,
            'total_holding_days': 0,
            'holding_time_percentage': 0,
            'n_trades': 0,
            'n_wins': 0,
            'win_rate': 0,
            'cagr': 0,
            'sharpe': 0,
            'expectancy': 0,
        }

    return df_trade, trade_stats


# =================================
# ===== plotting & evaluation =====
# =================================

def plot_model_metrics(y_true, y_proba):
    y_pred = (y_proba >= 0.5).astype(int)
    fig, axes = plt.subplots(1, 3, figsize=(14, 4))

    # 1. Confusion Matrix
    cm = confusion_matrix(y_true, y_pred)
    disp = ConfusionMatrixDisplay(confusion_matrix=cm)
    disp.plot(ax=axes[0], colorbar=False)

    # 2. ROC Curve
    fpr, tpr, _ = roc_curve(y_true, y_proba)
    roc_auc = auc(fpr, tpr)
    roc_disp = RocCurveDisplay(fpr=fpr, tpr=tpr, roc_auc=roc_auc)
    roc_disp.plot(ax=axes[1])

    # 3. Precision-Recall Curve
    precision, recall, _ = precision_recall_curve(y_true, y_proba)
    pr_disp = PrecisionRecallDisplay(precision=precision, recall=recall)
    pr_disp.plot(ax=axes[2])

    plt.tight_layout()
    plt.show()


def tscv_eval(X, y, pipe, n_split=5, predict_proba_threshold=0.5):
    '''
    for evaluations on train set, make sure X and y does not contain anything from the test set (future data), 
    and use n_split=5 for 5-fold time series cross-validation
    '''
    tscv = TimeSeriesSplit(n_splits=n_split)
    for i, (train_index, test_index) in enumerate(tscv.split(X)):
        X_train, X_test = X.iloc[train_index], X.iloc[test_index]
        y_train, y_test = y.iloc[train_index], y.iloc[test_index]

        pipe.fit(X_train, y_train)
        y_proba = pipe.predict_proba(X_test)[:, 1]
        y_pred = (y_proba >= predict_proba_threshold).astype(int)

        print(f'Fold {i+1}:\n{classification_report(y_test, y_pred)}')

        if i == 0:
            print('Fold 1 visualized below')
            plot_model_metrics(y_test, y_proba)

        plt.tight_layout()
        plt.show()

def plot_prc_with_thresholds(precision, recall, thresholds, avg_precision, threshold_markers=[0, 0.5, 0.6, 0.7], title='Precision-Recall Curve'):
    """
    Plots a Precision-Recall curve and marks specific threshold points.
    """
    plt.figure(figsize=(8, 6))
    plt.plot(recall, precision, label=f"PR Curve (AP = {avg_precision:.4f})", lw=2)

    for t_val in threshold_markers:
        if t_val <= np.max(thresholds) and t_val >= np.min(thresholds):
            idx = np.argmin(np.abs(thresholds - t_val))
            plt.plot(recall[idx], precision[idx], 'o', label=f'Threshold = {t_val}', markersize=8)

    plt.xlabel('Recall')
    plt.ylabel('Precision')
    plt.title(title)
    handles, labels = plt.gca().get_legend_handles_labels()
    handles, labels = handles[::-1], labels[::-1]  # Reverse legend order
    plt.legend(handles, labels, title="Legend")
    plt.grid(True)
    plt.show()

def plot_equity_curve(df_trade, ohlcv, start_date, end_date, title='Equity Curve'):
    df_trade['entry_date'] = pd.to_datetime(df_trade['entry_date'])
    df_trade['exit_date'] = pd.to_datetime(df_trade['exit_date'])
    
    # Prepare the index: start with first entry_date, followed by all exit_dates
    dates = [df_trade['entry_date'].min()] + df_trade['exit_date'].tolist()
    
    # Compute cumulative return step by step
    returns = (1 + df_trade['return']).cumprod()
    returns = pd.Series(data=returns.values, index=df_trade['exit_date'])

    # Insert starting value of 1.0 at the first entry date
    cumulative_series = pd.concat([
        pd.Series([1.0], index=[df_trade['entry_date'].min()]),
        returns
    ])
    
    # Sort index and return
    cumulative_series = cumulative_series.sort_index()

    df = ohlcv.copy()[start_date:end_date]
    df['market_cumret'] = (1 + df['close'].pct_change().fillna(0)).cumprod()

    # Merge strategy cumulative return
    df['strategy_cumret'] = cumulative_series
    df['strategy_cumret'] = df['strategy_cumret'].ffill()

    # Plot both
    plt.figure(figsize=(12, 6))
    plt.plot(df.index, df['strategy_cumret'], label='Strategy Equity Curve', marker='o', markersize=2)
    plt.plot(df.index, df['market_cumret'], label='Market (Close) Curve', alpha=0.7)
    plt.title(title)
    plt.xlabel('Date')
    plt.ylabel('Cumulative Return')
    plt.grid(True)
    plt.legend()
    plt.tight_layout()
    plt.show()


# ===============================
# ===== WRC & MC Simulation =====
# ===============================

def white_reality_check(strategy_returns: pd.Series, all_strategy_returns: pd.DataFrame, n_bootstrap=1000, seed=42):
    """
    Perform White Reality Check.
    
    Parameters:
        strategy_returns: Series of returns from selected strategy (e.g., with model predictions).
        all_strategy_returns: DataFrame of returns from multiple strategies (same time index).
        n_bootstrap: Number of bootstrap resamples.
    
    Returns:
        p_value: White Reality Check p-value.
    """
    np.random.seed(seed)
    T = len(strategy_returns)
    observed_stat = np.mean(strategy_returns)

    max_statistics = []
    for _ in range(n_bootstrap):
        idx = np.random.randint(0, T, T)
        boot_sample = all_strategy_returns.iloc[idx]
        max_mean = boot_sample.mean(axis=0, skipna=True).max()
        max_statistics.append(max_mean)

    p_value = np.mean([s >= observed_stat for s in max_statistics])

    plt.hist(max_statistics, bins=50, alpha=0.7)
    plt.axvline(observed_stat, color='red', linestyle='--', label='Observed Mean')
    plt.title("White Reality Check")
    plt.legend()
    plt.show()

    return p_value


def monte_carlo_test(strategy_returns: pd.Series, n_sim=1000, seed=42):
    """
    Monte Carlo test by shuffling return series.
    
    Parameters:
        strategy_returns: Series of returns from selected strategy.
        n_sim: Number of simulations.
    
    Returns:
        p_value: Monte Carlo p-value.
    """
    np.random.seed(seed)
    strategy_returns = strategy_returns.dropna()
    observed_mean = np.mean(strategy_returns)

    sim_means = []
    for _ in range(n_sim):
        shuffled = np.random.choice(strategy_returns.values, size=len(strategy_returns), replace=True)
        sim_means.append(np.mean(shuffled))

    p_value = np.mean([s >= observed_mean for s in sim_means])

    plt.hist(sim_means, bins=50, alpha=0.7)
    sns.kdeplot(sim_means, fill=True)
    plt.axvline(observed_mean, color='red', linestyle='--', label='Observed Mean')
    plt.title("Monte Carlo Test")
    plt.legend()
    plt.show()

    # print(f"Observed mean: {observed_mean:.6f}")
    # print(f"Simulated means (first 10): {sim_means[:10]}")
    # print(f"Unique simulated means: {len(set(sim_means))}")
    # print(f"Min: {min(sim_means):.6f}, Max: {max(sim_means):.6f}")
    return p_value



In [3]:
# set up some constants
etfs = {
        'spy': market_trend('spy', end_date='2025-06-27', load_csv=False, save_csv=True),
        'qqq': market_trend('qqq', end_date='2025-06-27', load_csv=False, save_csv=True),
    }
# qqq first date is ~2 month later than spy but it won't matter for this project
# print(etfs.keys(), [len(i) for i in etfs.values()]) 

sp500_tickers = get_sp500()
df1 = get_sp500_tickers_sectors_mindates(sp500_tickers=sp500_tickers)

Loaded data from files/sp500_tickers_sectors_mindates.pkl.


In [4]:
print(df1['sector'].value_counts())
print(df1[df1['mindate'] < datetime.strptime("2016-01-01", "%Y-%m-%d").date()]['sector'].value_counts())

sector
Technology                82
Industrials               71
Financial Services        68
Healthcare                61
Consumer Cyclical         56
Consumer Defensive        37
Utilities                 31
Real Estate               31
Energy                    22
Communication Services    21
Basic Materials           20
Name: count, dtype: int64
sector
Technology                75
Financial Services        67
Industrials               65
Healthcare                58
Consumer Cyclical         54
Consumer Defensive        35
Utilities                 29
Real Estate               29
Energy                    21
Communication Services    19
Basic Materials           18
Name: count, dtype: int64


In [5]:
rddt = get_data('rddt')
rddt.tail()


Unnamed: 0_level_0,close,high,low,open,volume,ticker
Date,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
2025-08-14,243.470001,244.850006,229.955002,229.955002,7579700,rddt
2025-08-15,246.5,248.014999,235.789993,242.574997,6434000,rddt
2025-08-18,241.759995,253.139999,237.539993,245.630005,6933800,rddt
2025-08-19,228.179993,238.839996,225.5,238.220001,8655700,rddt
2025-08-20,218.369995,225.0,205.373993,221.5,12437823,rddt


In [7]:
rddtmd = get_metadata('rddt', sp500_tickers=sp500_tickers)
rddtmd


{'ticker': 'rddt',
 'sector': 'Communication Services',
 'market_cap': 40869253120,
 'avg_volume': 7892150,
 'beta': None,
 'is_sp500': False}

In [23]:
rddtwr = wrangling(
    rddt, 
    etfs=etfs, 
    sp500_tickers=sp500_tickers, 
)
print(rddtwr.isna().sum())
rddtwr.tail()

close                     0
high                      0
low                       0
open                      0
volume                    0
ticker                    0
sma30                    29
sma10                     9
sma_diff                 29
sma_slope                10
std30                    29
bollinger_upper          29
bollinger_lower          29
percent_b                29
bollinger_z              29
price_near_lower_bb       0
rsi14                    13
prod_bollingerz_rsi      29
rsi_smooth               15
rsi_local_low             0
daily_return              1
rolling_volatility14     14
atr                       0
garch_vol                 1
year_since                0
month                     0
week                      0
dayofweek                 0
macd_diff                33
adx                       0
adx_pos                   0
adx_neg                   0
strong_trend              0
up_trend_context          0
down_trend_context        0
spy_trend_10_50     

Unnamed: 0_level_0,close,high,low,open,volume,ticker,sma30,sma10,sma_diff,sma_slope,std30,bollinger_upper,bollinger_lower,percent_b,bollinger_z,price_near_lower_bb,rsi14,prod_bollingerz_rsi,rsi_smooth,rsi_local_low,daily_return,rolling_volatility14,atr,garch_vol,year_since,month,week,dayofweek,macd_diff,adx,adx_pos,adx_neg,strong_trend,up_trend_context,down_trend_context,spy_trend_10_50,spy_trend_20_50,spy_trend_50_200,qqq_trend_10_50,qqq_trend_20_50,qqq_trend_50_200,sector,market_cap,avg_volume,beta,is_sp500,target_long
Date,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1,Unnamed: 23_level_1,Unnamed: 24_level_1,Unnamed: 25_level_1,Unnamed: 26_level_1,Unnamed: 27_level_1,Unnamed: 28_level_1,Unnamed: 29_level_1,Unnamed: 30_level_1,Unnamed: 31_level_1,Unnamed: 32_level_1,Unnamed: 33_level_1,Unnamed: 34_level_1,Unnamed: 35_level_1,Unnamed: 36_level_1,Unnamed: 37_level_1,Unnamed: 38_level_1,Unnamed: 39_level_1,Unnamed: 40_level_1,Unnamed: 41_level_1,Unnamed: 42_level_1,Unnamed: 43_level_1,Unnamed: 44_level_1,Unnamed: 45_level_1,Unnamed: 46_level_1,Unnamed: 47_level_1
2025-08-14,243.470001,244.850006,229.955002,229.955002,7579700,rddt,170.392334,215.257001,44.864666,8.288,33.759198,237.910731,102.873938,1.041169,2.164674,0,87.984832,91.60704,86.407406,0,0.01786,0.046927,11.291711,4.642381,1,8,33,3,6.006136,44.456127,43.852888,6.177464,1,1,0,,,,,,,Communication Services,40869253120,7892150,,False,
2025-08-15,246.5,248.014999,235.789993,242.574997,6434000,rddt,173.374668,221.043001,47.668333,5.786,36.387671,246.150011,100.599325,1.002405,2.009618,0,88.42739,88.640021,87.677632,0,0.05488,0.046663,11.358375,4.518337,1,8,33,4,5.790437,46.732505,42.47189,5.70255,1,1,0,,,,,,,Communication Services,40869253120,7892150,,False,
2025-08-18,241.759995,253.139999,237.539993,245.630005,6933800,rddt,176.321001,225.043001,48.722,4.0,38.243212,252.807426,99.834576,0.927782,1.711127,0,83.260825,77.247873,86.557682,0,0.012594,0.046309,11.661348,4.674655,1,8,34,0,4.870647,48.959931,41.552725,5.15765,1,1,0,,,,,,,Communication Services,40869253120,7892150,,False,
2025-08-19,228.179993,238.839996,225.5,238.220001,8655700,rddt,179.061667,227.952,48.890333,2.909,38.932295,256.926256,101.197078,0.815409,1.261634,0,70.543946,57.522142,80.744053,0,-0.030167,0.045324,11.989823,4.532939,1,8,34,1,2.982513,49.181486,37.527603,11.830785,1,1,0,,,,,,,Communication Services,40869253120,7892150,,False,
2025-08-20,218.369995,225.0,205.373993,221.5,12437823,rddt,181.485667,228.508,47.022333,0.556,39.043969,259.573606,103.397729,0.736172,0.944687,0,63.052034,46.417125,72.285602,0,-0.070187,0.049152,12.762407,4.59851,1,8,34,2,0.820214,47.135,32.737561,21.58481,1,1,0,,,,,,,Communication Services,40869253120,7892150,,False,


In [17]:
rddtfe = feature_engineering(rddt)
print(len(rddtfe))
rddtfe.tail()
rddtfe.isna().sum()

355


close                    0
high                     0
low                      0
open                     0
volume                   0
ticker                   0
sma30                   29
sma10                    9
sma_diff                29
sma_slope               10
std30                   29
bollinger_upper         29
bollinger_lower         29
percent_b               29
bollinger_z             29
price_near_lower_bb      0
rsi14                   13
prod_bollingerz_rsi     29
rsi_smooth              15
rsi_local_low            0
daily_return             1
rolling_volatility14    14
atr                      0
garch_vol                1
year_since               0
month                    0
week                     0
dayofweek                0
macd_diff               33
adx                      0
adx_pos                  0
adx_neg                  0
strong_trend             0
up_trend_context         0
down_trend_context       0
dtype: int64