In [None]:
from experiments.backtest import *
import matplotlib.pyplot as plt

In [None]:
import numpy as np
import pandas as pd
import yfinance as yf

START_DATE = "2015-01-01"
END_DATE   = "2025-01-15"
# 
# hsi_stocks = pd.read_excel("/Users/henrywzh/Desktop/hk_stock_ejfq.xlsx")
# 
# def code_int_to_hk(code: int) -> str:
#     return f"{int(code):04d}.HK"
# 
# UNIVERSE = hsi_stocks["代码"].apply(code_int_to_hk).tolist()
# 
# price_df = yf.download(UNIVERSE, start=START_DATE, end=END_DATE)
# hsi = yf.download("^HSI", START_DATE, END_DATE)
# 
# # Extract close matrices
# close = price_df["Close"].sort_index()
# benchmark_close = hsi["Close"].sort_index()
# 
# # Common calendar (optional but recommended)
# common_dates = close.index.intersection(benchmark_close.index)
# close = close.loc[common_dates]
# benchmark_close = benchmark_close.loc[common_dates]

In [None]:
sp500_stocks = pd.read_excel('/Users/henrywzh/Desktop/us_indices.xlsx', sheet_name='sp500')
sp400_stocks = pd.read_excel('/Users/henrywzh/Desktop/us_indices.xlsx', sheet_name='sp400')

In [None]:
price_df = yf.download(list(sp400_stocks['Ticker']), start=START_DATE, end=END_DATE)

In [None]:
ticker = '^GSPC'

benchmark = yf.download(ticker, start=START_DATE, end=END_DATE)
benchmark_close = benchmark['Close']

In [None]:
def build_ma_signal(price_df, windows=(5, 20, 50, 120)):
    close = price_df["Close"].sort_index()

    ma_sum = 0.0
    for w in windows:
        ma = close.rolling(w, min_periods=w).mean()
        ma_sum += (close / ma - 1.0)

    signal = ma_sum / len(windows)
    return signal

def build_rsi_signal(price_df: pd.DataFrame, period: int = 14) -> pd.DataFrame:
    """
    RSI computed per ticker from Close prices.
    Returns RSI in [0, 100]. Higher RSI = stronger recent gains.
    """
    close = price_df["Close"].sort_index()

    delta = close.diff()
    gain = delta.clip(lower=0.0)
    loss = (-delta).clip(lower=0.0)

    # Wilder's smoothing: EMA with alpha = 1/period
    avg_gain = gain.ewm(alpha=1/period, adjust=False, min_periods=period).mean()
    avg_loss = loss.ewm(alpha=1/period, adjust=False, min_periods=period).mean()

    rs = avg_gain / avg_loss
    rsi = 100.0 - (100.0 / (1.0 + rs))
    return rsi


def build_macd_signal(
    price_df: pd.DataFrame,
    fast: int = 12,
    slow: int = 26,
    signal: int = 9,
    output: str = "hist",
) -> pd.DataFrame:
    """
    MACD per ticker from Close prices.

    output:
      - "line": MACD line = EMA(fast) - EMA(slow)
      - "signal": signal line = EMA(MACD line, signal)
      - "hist": histogram = MACD line - signal line   (common trading signal)
    """
    close = price_df["Close"].sort_index()

    ema_fast = close.ewm(span=fast, adjust=False, min_periods=fast).mean()
    ema_slow = close.ewm(span=slow, adjust=False, min_periods=slow).mean()

    macd_line = ema_fast - ema_slow
    macd_sig = macd_line.ewm(span=signal, adjust=False, min_periods=signal).mean()
    macd_hist = macd_line - macd_sig

    if output == "line":
        return macd_line
    if output == "signal":
        return macd_sig
    if output == "hist":
        return macd_hist
    raise ValueError("output must be one of: 'line', 'signal', 'hist'")


def cs_winsorize_zscore(
    sig: pd.DataFrame,
    lower_q: float = 0.01,
    upper_q: float = 0.99,
) -> pd.DataFrame:
    """
    Cross-sectional winsorize + z-score per date.
    """
    def _proc_row(x: pd.Series) -> pd.Series:
        x = x.astype(float)
        m = x.notna()
        if m.sum() < 5:
            return x  # too few values
        lo = x[m].quantile(lower_q)
        hi = x[m].quantile(upper_q)
        x2 = x.copy()
        x2[m] = x[m].clip(lo, hi)
        mu = x2[m].mean()
        sd = x2[m].std(ddof=1)
        if sd == 0 or not np.isfinite(sd):
            return x2
        x2[m] = (x2[m] - mu) / sd
        return x2

    return sig.apply(_proc_row, axis=1)


In [None]:
from typing import Dict, Optional, Tuple, Any, List


def perf_summary(r: pd.Series, freq: float) -> Dict[str, float]:
    return {
        "ann_return_geo": annualised_return_geo(r, freq=freq),
        "ann_vol": annualised_vol(r, freq=freq),
        "sharpe": sharpe_ratio(r, freq=freq),
        "max_dd": max_drawdown(r),
        "n_obs": int(r.dropna().shape[0]),
    }

# =======================================================
# Tear sheet
# =======================================================

def _compute_ic_spearman(
    signal: pd.DataFrame,
    ret_fwd: pd.DataFrame,
    dates: pd.Index,
    min_assets: int,
) -> pd.Series:
    """Daily (rebalance-date) Spearman IC, cross-sectional."""
    # Align on dates first
    signal = signal.reindex(index=ret_fwd.index)
    dates = pd.Index(dates).intersection(signal.index).intersection(ret_fwd.index)

    ic_vals = []
    for dt in dates:
        x = signal.loc[dt]
        y = ret_fwd.loc[dt]
        m = x.notna() & y.notna()
        if m.sum() < min_assets:
            ic_vals.append(np.nan)
        else:
            ic_vals.append(x[m].corr(y[m], method="spearman"))
    return pd.Series(ic_vals, index=dates, name="IC_spearman")


def _compute_coverage_and_nvalid(
    signal: pd.DataFrame,
    ret_fwd: pd.DataFrame,
    dates: pd.Index,
) -> Tuple[pd.Series, pd.Series]:
    """Coverage (%) and valid count on rebalance dates."""
    dates = pd.Index(dates).intersection(signal.index).intersection(ret_fwd.index)
    universe_size = signal.shape[1]

    cov = []
    nvalid = []
    for dt in dates:
        m = signal.loc[dt].notna() & ret_fwd.loc[dt].notna()
        n = int(m.sum())
        nvalid.append(n)
        cov.append(n / universe_size if universe_size > 0 else np.nan)

    return (
        pd.Series(cov, index=dates, name="coverage"),
        pd.Series(nvalid, index=dates, name="n_valid"),
    )


def _turnover_proxy_from_labels(
    bucket_lbl: pd.DataFrame,
    rebalance_dates: pd.Index,
    bucket_k: int,
) -> pd.Series:
    """
    Turnover proxy for a given bucket: 1 - overlap_ratio of bucket membership.
    Computed on rebalance dates using labels at those dates.
    """
    dts = pd.Index(rebalance_dates).intersection(bucket_lbl.index)
    prev_members = None
    vals = []

    for dt in dts:
        lbl = bucket_lbl.loc[dt]
        members = set(lbl.index[lbl == bucket_k])
        if prev_members is None:
            vals.append(np.nan)
        else:
            if len(members) == 0:
                vals.append(np.nan)
            else:
                overlap = len(members.intersection(prev_members))
                vals.append(1.0 - overlap / len(members))
        prev_members = members

    return pd.Series(vals, index=dts, name=f"turnover_bucket_{bucket_k}")


def _equity_curve_from_returns(r: pd.Series) -> pd.Series:
    """Equity curve on the same index as r (no resampling)."""
    r = r.dropna()
    if len(r) == 0:
        return pd.Series(dtype=float)
    return (1.0 + r).cumprod()

def _drawdown_from_equity(eq: pd.Series) -> pd.Series:
    """Drawdown series from equity curve."""
    if eq is None or len(eq) == 0:
        return pd.Series(dtype=float)
    peak = eq.cummax()
    return eq / peak - 1.0

def _yearly_returns_from_returns(r: pd.Series) -> pd.Series:
    """
    Calendar-year compounded returns from periodic returns r.
    Works for H-day returns too (compounds all observations within each year).
    """
    r = r.dropna()
    if len(r) == 0:
        return pd.Series(dtype=float)
    # group by calendar year and compound
    return (1.0 + r).groupby(r.index.year).prod() - 1.0


def _compute_benchmark_ret_fwd(
    bench_price: pd.DataFrame | pd.Series,
    H: int,
    entry_mode: str,
) -> pd.Series:
    """
    Return series aligned to formation date t:
      ret_fwd[t] = Px[t+H+1] / Px[t+1] - 1
    bench_price can be:
      - Series of Close (or Open) already chosen, OR
      - DataFrame with columns including 'Close' and/or 'Open'
    """
    if isinstance(bench_price, pd.Series):
        px = bench_price.sort_index()
    else:
        bench_price = bench_price.sort_index()
        if entry_mode in {"next_open"}:
            if "Open" not in bench_price.columns:
                raise ValueError("benchmark DataFrame must contain 'Open' for entry_mode='next_open'")
            px = bench_price["Open"]
        else:
            if "Close" not in bench_price.columns:
                raise ValueError("benchmark DataFrame must contain 'Close' for entry_mode='next_close' or 'open_to_close'")
            px = bench_price["Close"]

    # open_to_close benchmark: enter open[t+1], exit close[t+H+1]
    if entry_mode == "open_to_close":
        if isinstance(bench_price, pd.Series):
            raise ValueError("For entry_mode='open_to_close', benchmark must be a DataFrame with Open/Close.")
        entry_px = bench_price["Open"].shift(-1)
        exit_px  = bench_price["Close"].shift(-(H + 1))
        return (exit_px / entry_px - 1).rename("benchmark_ret_fwd")

    # next_close / next_open
    entry_px = px.shift(-1)
    exit_px  = px.shift(-(H + 1))
    return (exit_px / entry_px - 1).rename("benchmark_ret_fwd")


def make_tearsheet(
    price_df: pd.DataFrame,
    signal: pd.DataFrame,
    H: int = 5,
    n_buckets: int = 20,
    entry_mode: str = "next_close",
    min_assets_ic: int = 50,
    plot: bool = True,
    rolling_window_obs: Optional[int] = None,
    benchmark_price: Optional[pd.DataFrame | pd.Series] = None,
    benchmark_name: str = "Benchmark",
) -> Dict[str, Any]:
    """
    Generate a standardized tear sheet for a signal under your bucket backtest framework.

    Parameters
    ----------
    price_df : DataFrame
        yfinance multiindex columns, must contain 'Close' and (if next_open/open_to_close) 'Open'
    signal : DataFrame
        dates x tickers, computed at close[t]
    H : int
        holding horizon (in trading days after entry), and rebalance step
    n_buckets : int
        number of buckets
    entry_mode : str
        'next_close' | 'next_open' | 'open_to_close'
    min_assets_ic : int
        minimum valid assets required to compute IC on a rebalance date
    plot : bool
        whether to display plots
    rolling_window_obs : int | None
        rolling window in number of rebalance observations; if None, defaults to ~1 year = 252/H

    Returns
    -------
    report : dict
        Contains bucket_ret, ic series, summaries, coverage, turnover proxies, and summary tables.
    """
    if rolling_window_obs is None:
        rolling_window_obs = max(10, int(round(TRADING_DAYS / H)))

    # Run backtest (uses strict NaN logic, next-day entry alignment)
    bucket_ret, bucket_lbl, ret_fwd = bucket_backtest(
        price_df=price_df,
        signal=signal,
        H=H,
        n_buckets=n_buckets,
        entry_mode=entry_mode,
    )

    rebalance_dates = bucket_ret.index
    freq = TRADING_DAYS / H  # annualisation frequency for H-day non-overlapping returns

    # Coverage & valid count
    signal_aligned = signal.reindex(index=ret_fwd.index)
    coverage, n_valid = _compute_coverage_and_nvalid(signal_aligned, ret_fwd, rebalance_dates)

    # IC series + stats
    ic = _compute_ic_spearman(signal_aligned, ret_fwd, rebalance_dates, min_assets=min_assets_ic)
    ic_stats = {
        "ic_mean": float(ic.mean()),
        "ic_std": float(ic.std(ddof=1)),
        "icir": float(ic.mean() / ic.std(ddof=1)) if ic.std(ddof=1) and np.isfinite(ic.std(ddof=1)) else np.nan,
        "hit_rate": float((ic > 0).mean()),
        "n_obs": int(ic.dropna().shape[0]),
    }

    ic_roll = ic.rolling(rolling_window_obs).mean()

    # Bucket summary table
    bucket_summary_rows = []
    for col in bucket_ret.columns:
        r = bucket_ret[col].dropna()
        mu = float(r.mean()) if len(r) else np.nan
        sd = float(r.std(ddof=1)) if len(r) >= 2 else np.nan
        tstat = (mu / (sd / np.sqrt(len(r)))) if (len(r) >= 2 and sd > 0) else np.nan
        ps = perf_summary(bucket_ret[col], freq=freq)
        bucket_summary_rows.append({
            "bucket": col,
            "mean_ret_per_period": mu,
            "std_ret_per_period": sd,
            "t_stat_mean": tstat,
            **ps,
        })
    bucket_summary = pd.DataFrame(bucket_summary_rows).set_index("bucket")

    # Monotonicity diagnostic: correlation between bucket number and mean return
    # (Simple, fast sanity check)
    mean_by_bucket = bucket_summary["mean_ret_per_period"].copy()
    tmp = mean_by_bucket.copy()
    tmp.index = tmp.index.str.replace("bucket_", "", regex=False).astype(int)
    tmp = tmp.sort_index()  # 1..n_buckets
    monotonic_spearman = float(
        pd.Series(tmp.index, index=tmp.index).corr(tmp, method="spearman")
    )


    # Turnover proxies for bottom and top buckets
    turnover_bottom = _turnover_proxy_from_labels(bucket_lbl, rebalance_dates, bucket_k=1)
    turnover_top = _turnover_proxy_from_labels(bucket_lbl, rebalance_dates, bucket_k=n_buckets)
    
    
    bench_ret = None
    bench_eq = None
    bench_dd = None
    bench_yearly = None
    
    if benchmark_price is not None:
        bench_ret_full = _compute_benchmark_ret_fwd(benchmark_price, H=H, entry_mode=entry_mode)
    
        # Align to rebalance dates (because bucket_ret is on rebalance dates)
        bench_ret = bench_ret_full.reindex(rebalance_dates)
    
        # Equity / DD / yearly on the same rebalance grid
        bench_eq = _equity_curve_from_returns(bench_ret)
        bench_dd = _drawdown_from_equity(bench_eq)
        bench_yearly = _yearly_returns_from_returns(bench_ret).rename(benchmark_name)


    # Optional plots
    if plot:
        # # 1) Coverage
        # plt.figure()
        # coverage.plot()
        # plt.title(f"Coverage on Rebalance Dates (H={H}, entry={entry_mode})")
        # plt.ylabel("Coverage (valid fraction)")
        # plt.xlabel("Date")
        # plt.show()

        # 2) Cumulative IC + rolling IC
        plt.figure()
        ic.cumsum().plot()
        plt.title("Cumulative IC (Spearman)")
        plt.ylabel("Cumulative IC")
        plt.xlabel("Date")
        plt.show()

        plt.figure()
        ic_roll.plot()
        plt.title(f"Rolling Mean IC (window={rolling_window_obs} obs)")
        plt.ylabel("Rolling IC mean")
        plt.xlabel("Date")
        plt.show()

        # 3) Bucket mean returns bar chart
        plt.figure(figsize=(10, 6))
        mean_by_bucket.sort_index(key=lambda idx: idx.str.replace("bucket_", "", regex=False).astype(int)).plot(kind="bar")
        plt.title("Mean Return per Bucket (per holding period)")
        plt.ylabel("Mean H-day return")
        plt.xlabel("Bucket")
        plt.tight_layout()
        plt.show()

        # 4) Bucket cumulative curves (all buckets; if too busy, you can slice)
        plt.figure()
        # Build equity curves from bucket returns
        eq = (1.0 + bucket_ret.fillna(0.0)).cumprod()
        eq.plot(legend=True)
        plt.legend()
        plt.title("Bucket Cumulative Curves")
        plt.ylabel("Cumulative growth")
        plt.xlabel("Date")
        plt.tight_layout()
        plt.show()

        # # 5) Turnover proxy (top/bottom)
        # plt.figure()
        # pd.DataFrame({"turnover_bottom": turnover_bottom, "turnover_top": turnover_top}).plot()
        # plt.title("Turnover Proxy (Bucket Membership Change)")
        # plt.ylabel("1 - overlap ratio")
        # plt.xlabel("Date")
        # plt.tight_layout()
        # plt.show()
        
        # 6) Drawdown
        # Select series to plot drawdown/yearly returns for
        top_col = "bucket_1"
        bot_col = f"bucket_{n_buckets}"
        
        plt.figure()
        
        # top bucket
        eq_top = _equity_curve_from_returns(bucket_ret[top_col])
        dd_top = _drawdown_from_equity(eq_top)
        if len(dd_top):
            dd_top.plot(label=f"{top_col} (max_dd={dd_top.min():.2%})")
        
        # bottom bucket
        eq_bot = _equity_curve_from_returns(bucket_ret[bot_col])
        dd_bot = _drawdown_from_equity(eq_bot)
        if len(dd_bot):
            dd_bot.plot(label=f"{bot_col} (max_dd={dd_bot.min():.2%})")
        
        # benchmark
        if bench_dd is not None and len(bench_dd):
            bench_dd.plot(label=f"{benchmark_name} (max_dd={bench_dd.min():.2%})")
        
        plt.title("Drawdown Curves (Top / Bottom / Benchmark)")
        plt.ylabel("Drawdown")
        plt.xlabel("Date")
        plt.legend()
        plt.tight_layout()
        plt.show()
        
        # 7) Yearly Returns
        yr_top = _yearly_returns_from_returns(bucket_ret[top_col]).rename(top_col)
        yr_bot = _yearly_returns_from_returns(bucket_ret[bot_col]).rename(bot_col)
        
        parts = [yr_top, yr_bot]
        if bench_yearly is not None and len(bench_yearly):
            parts.append(bench_yearly)
        
        yr_tbl = pd.concat(parts, axis=1).sort_index()
        
        if len(yr_tbl):
            plt.figure(figsize=(10, 5))
            yr_tbl.plot(kind="bar")
            plt.title("Calendar-Year Returns (Top / Bottom / Benchmark)")
            plt.ylabel("Return")
            plt.xlabel("Year")
            plt.tight_layout()
            plt.show()

    report = {
        "meta": {
            "H": H,
            "n_buckets": n_buckets,
            "entry_mode": entry_mode,
            "freq": freq,
            "rolling_window_obs": rolling_window_obs,
            "universe_size": int(signal.shape[1]),
            "n_rebalance_obs": int(len(rebalance_dates)),
        },
        "bucket_ret": bucket_ret,
        "benchmark_ret": bench_ret,
        "bucket_labels": bucket_lbl,
        "ret_fwd": ret_fwd,
        "coverage": coverage,
        "n_valid": n_valid,
        "ic": ic,
        "ic_stats": ic_stats,
        "bucket_summary": bucket_summary,
        "monotonic_spearman_bucket_vs_mean": monotonic_spearman,
        "turnover_bottom": turnover_bottom,
        "turnover_top": turnover_top,
    }
    return report

In [None]:
n_buckets = 10

In [None]:
signal = build_ma_signal(price_df, windows=(5, 20, 50, 120))

rep = make_tearsheet(
    price_df=price_df,
    signal=signal,
    H=5,
    n_buckets=n_buckets,
    entry_mode="next_open",
    min_assets_ic=50,
    plot=True
)

# Key tables/series:
display(rep["ic_stats"])
display(rep["bucket_summary"])
display(rep["coverage"].describe())


In [None]:
rsi = build_rsi_signal(price_df, period=14)
macd = build_macd_signal(price_df, fast=12, slow=26, signal=9, output="hist")

# Hygiene
rsi_z = cs_winsorize_zscore(rsi)
macd_z = cs_winsorize_zscore(macd)

In [None]:
rep_rsi = make_tearsheet(
    price_df=price_df,
    signal=rsi_z,
    H=10,
    n_buckets=n_buckets,
    entry_mode="next_open",   # signal at close[t], enter close[t+1]
    min_assets_ic=50,
    plot=True                  # shows coverage, IC, bucket mean bar, bucket curves, turnover
)

display(rep_rsi["ic_stats"])
display(rep_rsi["bucket_summary"])
display(rep_rsi["coverage"].describe())


In [None]:
signal = build_ma_signal(price_df, windows=(5, 20))

rep = make_tearsheet(
    price_df=price_df,
    signal=signal,
    H=2,
    n_buckets=n_buckets,
    entry_mode="next_open",
    min_assets_ic=50,
    plot=True
)

# Key tables/series:
display(rep["ic_stats"])
display(rep["bucket_summary"])
display(rep["coverage"].describe())


In [None]:
bucket_returns = rep['bucket_ret'].copy()
bucket_returns = bucket_returns.merge(benchmark_close.pct_change(), how='inner', left_index=True, right_index=True)
bucket_returns.corr()

In [None]:
signal = build_ma_signal(price_df, windows=(5,))

rep = make_tearsheet(
    price_df=price_df,
    signal=signal,
    H=1,
    n_buckets=n_buckets,
    entry_mode="next_open",
    min_assets_ic=50,
    plot=True
)

# Key tables/series:
display(rep["ic_stats"])
display(rep["bucket_summary"])
display(rep["coverage"].describe())


In [None]:
def build_past_vol_signal(
    price_df: pd.DataFrame,
    window: int = 20,
    use_log_returns: bool = True,
) -> pd.DataFrame:
    """
    Past volatility signal: rolling std of daily returns.
    Output is a DataFrame (dates x tickers) of *past* vol at each date t.

    - If use_log_returns=True: vol of log returns (more stable)
    - window=20: ~1 month
    """
    close = price_df["Close"].sort_index()

    if use_log_returns:
        ret = np.log(close).diff()
    else:
        ret = close.pct_change(fill_method=None)

    vol = ret.rolling(window, min_periods=window).std(ddof=1)
    return vol

def cs_winsorize_zscore(sig: pd.DataFrame, lower_q: float = 0.01, upper_q: float = 0.99) -> pd.DataFrame:
    def _proc_row(x: pd.Series) -> pd.Series:
        m = x.notna()
        if m.sum() < 10:
            return x
        lo = x[m].quantile(lower_q)
        hi = x[m].quantile(upper_q)
        x2 = x.copy()
        x2[m] = x[m].clip(lo, hi)
        mu = x2[m].mean()
        sd = x2[m].std(ddof=1)
        if sd == 0 or not np.isfinite(sd):
            return x2
        x2[m] = (x2[m] - mu) / sd
        return x2
    return sig.apply(_proc_row, axis=1)

vol = build_past_vol_signal(price_df, window=40, use_log_returns=True)
vol_z = cs_winsorize_zscore(vol)

signal_lowvol = -vol_z  # low vol => high score

rep = make_tearsheet(
    price_df=price_df,
    signal=signal_lowvol,
    H=5,
    n_buckets=n_buckets,
    entry_mode="next_open",
    min_assets_ic=50,
    plot=True
)

print(rep["ic_stats"])
print(rep["bucket_summary"].tail())  # top buckets if monotonic

In [None]:
def bucket_constituents(
    bucket_lbl: pd.DataFrame,
    date,
    k: int,
) -> list[str]:
    """
    Returns list of tickers in bucket k on a given date.
    bucket_lbl is typically forward-filled in your backtest.
    """
    d = pd.to_datetime(date)
    if d not in bucket_lbl.index:
        # nearest previous date
        d = bucket_lbl.index[bucket_lbl.index.get_indexer([d], method="ffill")[0]]
    members = bucket_lbl.loc[d]
    return members.index[(members == k).fillna(False)].tolist()

def bucket_snapshot_table(
    bucket_lbl: pd.DataFrame,
    signal: pd.DataFrame,
    ret_fwd: pd.DataFrame,
    date,
    k: int,
    sort_by: str = "signal",
    ascending: bool = False,
) -> pd.DataFrame:
    """
    Creates a table for bucket k on date with:
      - signal[t, i]
      - forward return ret_fwd[t, i] (as used by backtest)
      - bucket id
    """
    d = pd.to_datetime(date)
    if d not in bucket_lbl.index:
        d = bucket_lbl.index[bucket_lbl.index.get_indexer([d], method="ffill")[0]]

    members = bucket_lbl.loc[d]
    tickers = members.index[(members == k).fillna(False)]

    df = pd.DataFrame({
        "bucket": members.loc[tickers].astype(int),
        "signal": signal.loc[d, tickers],
        "ret_fwd": ret_fwd.loc[d, tickers],
    }, index=tickers)

    if sort_by in df.columns:
        df = df.sort_values(sort_by, ascending=ascending)
    return df

def stock_bucket_history(
    bucket_lbl: pd.DataFrame,
    signal: pd.DataFrame,
    ret_fwd: pd.DataFrame,
    ticker: str,
    start=None,
    end=None,
) -> pd.DataFrame:
    s = bucket_lbl[ticker].rename("bucket")
    out = pd.concat([
        s,
        signal[ticker].rename("signal"),
        ret_fwd[ticker].rename("ret_fwd"),
    ], axis=1)

    if start is not None:
        out = out.loc[pd.to_datetime(start):]
    if end is not None:
        out = out.loc[:pd.to_datetime(end)]
    return out

def stock_conditional_performance(
    bucket_lbl: pd.DataFrame,
    ret_fwd: pd.DataFrame,
    ticker: str,
    k: int,
    start=None,
    end=None,
) -> dict:
    s_bucket = bucket_lbl[ticker]
    s_ret = ret_fwd[ticker]

    df = pd.concat([s_bucket.rename("bucket"), s_ret.rename("ret_fwd")], axis=1)
    if start is not None:
        df = df.loc[pd.to_datetime(start):]
    if end is not None:
        df = df.loc[:pd.to_datetime(end)]

    cond = df.loc[df["bucket"] == k, "ret_fwd"].dropna()

    if len(cond) == 0:
        return {"ticker": ticker, "bucket": k, "n": 0, "mean": np.nan, "winrate": np.nan}

    return {
        "ticker": ticker,
        "bucket": k,
        "n": int(cond.shape[0]),
        "mean": float(cond.mean()),
        "winrate": float((cond > 0).mean()),
        "std": float(cond.std(ddof=1)) if cond.shape[0] > 1 else np.nan,
    }

def bucket_membership_frequency(
    bucket_lbl: pd.DataFrame,
    k: int,
    start=None,
    end=None,
) -> pd.Series:
    df = bucket_lbl.copy()
    if start is not None:
        df = df.loc[pd.to_datetime(start):]
    if end is not None:
        df = df.loc[:pd.to_datetime(end)]

    freq = (df == k).sum(axis=0)  # counts of days in bucket k
    freq = freq.sort_values(ascending=False)
    return freq

In [None]:
bucket_lbl = rep['bucket_labels']
ret_next = rep['ret_fwd']

d = "2023-06-30"
names = bucket_constituents(bucket_lbl, d, k=10)
names

In [None]:
snap = bucket_snapshot_table(bucket_lbl, signal, ret_next, d, k=1)
snap

In [None]:
bucket_snapshot_table(bucket_lbl, signal, ret_next, d, k=10)

In [None]:
signal_lowvol.tail()

In [None]:
t = rep["bucket_ret"].index[0]  # first rebalance date
lbl_t = rep["bucket_labels"].loc[t]
sig_t = signal_lowvol.loc[t]

# Compare bucket 1 vs bucket 10 signal distributions
b1 = sig_t[lbl_t == 1].dropna()
b10 = sig_t[lbl_t == 10].dropna()

print("bucket1 median signal:", b1.median(), "n=", len(b1))
print("bucket10 median signal:", b10.median(), "n=", len(b10))
