In [None]:
# Standard library
import bisect
import datetime
import gc
import logging
import math
import pickle
from pathlib import Path
from textwrap import dedent
from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Union

# Third-party
import matplotlib.pyplot as plt
from numba import njit, prange
import numpy as np
import pandas as pd
from sklearn.linear_model import LinearRegression
from tqdm.auto import tqdm
import wrds

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
START_DATE = "1995-01-01" # TODO: CHANGE
END_DATE = '2024-12-31'
USERNAME = 'your_wrds_username'

# Data

## Data collection and saving

In [None]:
logger = logging.getLogger(__name__)
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s - %(levelname)s - %(message)s",
)


class WRDSDataCollector:
    """WRDS/CRSP data collector for monthly and daily pulls."""
    def __init__(
        self,
        username: str = USERNAME,
        data_dir: str = "./data",
        wrds_conn: Optional[wrds.Connection] = None,
    ):
        """Initialize the collector.

        Parameters
        ----------
        username : str
            WRDS user ID.
        data_dir : str
            Root folder where pulled CSV/Parquet files will be written.
        wrds_conn : Optional[wrds.Connection]
            If provided, reuse this WRDS connection.
        """
        self.conn = wrds_conn or wrds.Connection(wrds_username=username)
        self.data_dir = Path(data_dir)
        (self.data_dir / "monthly").mkdir(parents=True, exist_ok=True)
        (self.data_dir / "daily").mkdir(parents=True, exist_ok=True)

        logger.info(f"Connected to WRDS as {username}")
        logger.info(f"Data will be saved under {self.data_dir.resolve()}")

    def fetch_stock_universe(
        self,
        date: str,
        min_market_cap: float = 1e9,
        exchange_codes: List[int] = [1, 2, 3],
    ) -> pd.DataFrame:
        """Universe of common stocks on selected exchanges at a given date.

        Market cap is computed as ``shrout * abs(prc) * 1000``.

        Parameters
        ----------
        date : str
            Date in "YYYY-MM-DD" format.
        min_market_cap : float, default 1e9
            Minimum market cap in USD.
        exchange_codes : List[int]
            Exchanges to include (1=NYSE, 2=AMEX, 3=NASDAQ).

        Returns
        -------
        pd.DataFrame
            Columns: permno, ticker, comnam, market_cap, shrcd, exchcd, siccd.
        """
        query = f"""
        SELECT
          dsf.permno,
          sn.ticker,
          sn.comnam,
          (dsf.shrout * ABS(dsf.prc) * 1000) AS market_cap,
          sn.shrcd,
          sn.exchcd,
          sn.siccd
        FROM crsp.dsf AS dsf
        JOIN crsp.stocknames AS sn
          ON dsf.permno = sn.permno
         AND sn.namedt <= dsf.date
         AND (sn.nameenddt >= dsf.date OR sn.nameenddt IS NULL)
        WHERE dsf.date = '{date}'
          AND (dsf.shrout * ABS(dsf.prc) * 1000) >= {min_market_cap}
          AND sn.exchcd IN ({','.join(map(str, exchange_codes))})
          AND sn.shrcd IN (10, 11)
        ORDER BY market_cap DESC
        """
        return self.conn.raw_sql(query)

    def fetch_returns_matrix(
        self,
        start_date: str,
        end_date: str,
    ) -> pd.DataFrame:
        """Daily total return matrix (T x N) for common shares on NYSE/AMEX/NASDAQ.

        ``total_ret = ret + coalesce(dlret, 0)``
        """
        sql = f"""
            SELECT
            dsf.permno,
            dsf.date,
            dsf.ret + COALESCE(d.dlret, 0) AS total_ret
            FROM crsp.dsf AS dsf
            JOIN crsp.stocknames AS sn
              ON dsf.permno = sn.permno
             AND sn.namedt <= dsf.date
             AND (sn.nameenddt >= dsf.date OR sn.nameenddt IS NULL)
            LEFT JOIN crsp.dsedelist AS d
              ON dsf.permno = d.permno
             AND dsf.date = d.dlstdt  -- non-null only on a delisting day
            WHERE dsf.date BETWEEN '{start_date}' AND '{end_date}'
              AND sn.exchcd IN (1, 2, 3)
              AND sn.shrcd IN (10, 11)
        """
        # Pull into a long DataFrame.
        df = self.conn.raw_sql(sql, date_cols=['date'])

        # Pivot to wide: index=dates, columns=permnos, values=total_ret.
        ret_mat = (
            df.pivot(index='date', columns='permno', values='total_ret')
              .sort_index()
        )
        ret_mat.index.name = 'date'
        ret_mat.columns.name = 'permno'
        return ret_mat

    def fetch_mkt_cap(self, start_date: str, end_date: str):
        sql = f"""
            SELECT
                permno,
                date,
                (shrout * ABS(prc) * 1000) AS market_cap
            FROM crsp.dsf
            WHERE date BETWEEN '{start_date}' AND '{end_date}'
              AND shrout IS NOT NULL
              AND prc IS NOT NULL
        """
        mcap = self.conn.raw_sql(sql, date_cols=['date'])
        return mcap

    def fetch_split_adj_close(
        self,
        start_date: str,
        end_date: str,
    ) -> pd.DataFrame:
        sql = f"""
            SELECT
                dsf.date,
                dsf.permno,
                ABS(dsf.prc) / NULLIF(dsf.cfacpr, 0) AS adj_close
            FROM crsp.dsf AS dsf
            JOIN crsp.stocknames AS sn
              ON dsf.permno = sn.permno
             AND sn.namedt <= dsf.date
             AND (sn.nameenddt >= dsf.date OR sn.nameenddt IS NULL)
            WHERE dsf.date BETWEEN '{start_date}' AND '{end_date}'
              AND sn.exchcd IN (1, 2, 3)
              AND sn.shrcd IN (10, 11)
        """
        df = self.conn.raw_sql(sql, date_cols=['date'])
        price_mat = df.pivot(index='date', columns='permno', values='adj_close').sort_index()
        price_mat.index.name = 'date'
        price_mat.columns.name = 'permno'
        return price_mat

    def fetch_etf_time_series(
        self,
        etf_tickers: List[str],
        start_date: str,
        end_date: str,
    ) -> pd.DataFrame:
        """Daily total returns for a list of ETF tickers.

        If a ticker maps to multiple PERMNOs over time, the most recent name
        record per date is used.
        """
        if not self.conn:
            print("No WRDS connection available.")
            return pd.DataFrame()

        tickers_list_str = "','".join(etf_tickers)

        sql_query = f"""
        SELECT 
            a.date, 
            b.ticker,
            a.ret AS daily_ret,
            COALESCE(d.dlret, 0) AS delist_ret,
            CASE 
                WHEN d.dlret IS NOT NULL THEN d.dlret
                ELSE a.ret
            END AS total_ret
        FROM crsp.dsf AS a
        INNER JOIN crsp.stocknames AS b 
            ON a.permno = b.permno
        LEFT JOIN crsp.dsedelist AS d 
            ON a.permno = d.permno AND a.date = d.dlstdt
        WHERE 
            b.ticker IN ('{tickers_list_str}')
            AND a.date BETWEEN '{start_date}' AND '{end_date}'
            AND b.namedt <= a.date 
            AND a.date <= b.nameenddt
        """

        try:
            df = self.conn.raw_sql(sql_query, date_cols=['date'])

            # Pivot to get tickers as columns.
            pivot_df = df.pivot_table(
                index='date',
                columns='ticker',
                values='total_ret',
                aggfunc='first',  # use first to handle potential duplicates
            )

            # Reorder columns to match input list and sort by date.
            pivot_df = pivot_df.reindex(columns=etf_tickers).sort_index()

            return pivot_df

        except Exception as e:  # pragma: no cover - pass-through logging
            print(f"An error occurred while fetching ETF data: {e}")
            return pd.DataFrame()

    def fetch_dollar_turnover(
        self,
        start_date: str,
        end_date: str,
        universe_filters: bool = True,
    ) -> pd.DataFrame:
        """Daily split-consistent dollar turnover.

        Definitions
        -----------
        vol_adj = abs(prc) * (vol * cfacshr)
        mkt_cap = shrout * abs(prc) * 1000  # shrout is in thousands
        turnover = vol_adj / mkt_cap

        Parameters
        ----------
        start_date, end_date : str
            'YYYY-MM-DD'.
        universe_filters : bool, default True
            If True, restrict to shrcd 10/11 and exchcd 1/2/3 via stocknames.

        Returns
        -------
        pd.DataFrame
            Columns: permno, date, dollar_vol, market_cap, turnover.
        """
        if universe_filters:
            join_where = f"""
            JOIN crsp.stocknames sn
              ON dsf.permno = sn.permno
             AND sn.namedt <= dsf.date
             AND (sn.nameenddt >= dsf.date OR sn.nameenddt IS NULL)
            WHERE dsf.date BETWEEN '{start_date}' AND '{end_date}'
              AND sn.shrcd IN (10, 11)
              AND sn.exchcd IN (1, 2, 3)
            """
        else:
            join_where = f"""
            WHERE dsf.date BETWEEN '{start_date}' AND '{end_date}'
            """

        sql = dedent(
            f"""
        WITH base AS (
            SELECT
                dsf.permno,
                dsf.date,
                ABS(dsf.prc) AS prc,
                dsf.vol AS vol,
                dsf.shrout AS shrout,
                dsf.cfacshr AS cfacshr
            FROM crsp.dsf dsf
            {join_where}
              AND dsf.prc IS NOT NULL
              AND dsf.vol IS NOT NULL
              AND dsf.shrout IS NOT NULL
              AND dsf.cfacshr IS NOT NULL
        )
        SELECT
            permno,
            date,
            cfacshr,
            prc * (vol) AS dollar_vol,
            shrout * prc * 1000 AS market_cap,
            (prc * (vol)) / NULLIF(shrout * prc * 1000, 0) AS turnover
        FROM base
        ORDER BY date, permno
        """
        )

        return self.conn.raw_sql(sql, date_cols=['date'])

    def fetch_permco_map(self, start_date: str, end_date: str) -> pd.DataFrame:
        sql = f"""
            SELECT
                dsf.date,
                dsf.permno,
                sn.permco
            FROM crsp.dsf dsf
            JOIN crsp.stocknames sn
              ON dsf.permno = sn.permno
             AND sn.namedt <= dsf.date
             AND (sn.nameenddt >= dsf.date OR sn.nameenddt IS NULL)
            WHERE dsf.date BETWEEN '{start_date}' AND '{end_date}'
              AND sn.shrcd IN (10, 11)
              AND sn.exchcd IN (1, 2, 3)
        """
        return self.conn.raw_sql(sql, date_cols=['date'])

    def close(self):
        """Close the WRDS connection."""
        self.conn.close()
        logger.info("WRDS connection closed")


Fetching from WRDS

In [None]:
ETF_TICKERS = ['SPY', 'MTUM']

wrds_data_collector = WRDSDataCollector(username=USERNAME)
etf_returns_matrix = wrds_data_collector.fetch_etf_time_series(etf_tickers=ETF_TICKERS, start_date=START_DATE, end_date=END_DATE)
returns_matrix = wrds_data_collector.fetch_returns_matrix(start_date=START_DATE, end_date=END_DATE)
mcap = wrds_data_collector.fetch_mkt_cap(start_date=START_DATE, end_date=END_DATE)
split_adjusted_close = wrds_data_collector.fetch_split_adj_close(start_date=START_DATE, end_date=END_DATE)
turnover_df = wrds_data_collector.fetch_dollar_turnover(start_date=START_DATE, end_date=END_DATE, universe_filters=True)
wrds_data_collector.close()

Loading library list...


2025-07-23 20:26:34,501 - INFO - Connected to WRDS as benjaminlhr555
2025-07-23 20:26:34,502 - INFO - Data will be saved under C:\Users\benja\Desktop\oxford\Dissertation\cs_strats\data


Done


2025-07-23 20:57:39,600 - INFO - WRDS connection closed


Saving files locally

In [5]:
returns_matrix.to_csv('data/returns_matrix.csv', index_label='date')
mcap.to_csv('data/mcap.csv', index=False)
split_adjusted_close.to_csv('data/split_adjusted_close.csv', index_label='date')


In [9]:
turnover_df.astype({c: 'float64' for c in turnover_df.select_dtypes(include=['Int64','Float64']).columns}).to_hdf('data/turnover_df.h5', key='turnover_df', format='fixed')

In [17]:
etf_returns_matrix.astype({c: 'float64' for c in etf_returns_matrix.select_dtypes(include=['Int64','Float64']).columns}).to_hdf('data/etf_returns_matrix.h5', key='etf_returns_matrix', format='fixed')

## Data loading

In [12]:
returns_matrix = pd.read_csv(
    'data/returns_matrix.csv',
    index_col='date',
    parse_dates=['date']
)

mcap = pd.read_csv('data/mcap.csv')
mcap['date'] = pd.to_datetime(mcap['date'])

split_adjusted_close = pd.read_csv(
    'data/split_adjusted_close.csv',
    index_col='date',
    parse_dates=['date']
)

returns_matrix.columns = returns_matrix.columns.astype(int)
split_adjusted_close.columns = split_adjusted_close.columns.astype(int)

In [10]:
returns_matrix = pd.read_csv(
    'data/returns_matrix.csv',
    index_col='date',
    parse_dates=['date']
)


In [16]:
returns_matrix.columns = returns_matrix.columns.astype(int)

In [5]:
turnover_df = pd.read_hdf("data/turnover_df.h5", key="turnover_df")
turnover_df['permno'] = turnover_df['permno'].astype(int)

# Initializing Mask for Stocks to Trade
It will continue to be updated along the way

In [13]:
NUM_STOCKS_TO_TRADE = 1000

valid253 = (
    returns_matrix.notna()                         # bool matrix
                 .rolling(253, min_periods=253)    # look-back window
                 .sum() == 253                     # True ⇔ all 253 days present
)

mcap_wide = mcap.pivot(index='date', columns='permno', values='market_cap')
mcap_wide_aligned = mcap_wide.reindex_like(returns_matrix)
mcap_wide_where_valid253 = mcap_wide_aligned.where(valid253)
mask_for_stocks_to_trade = (mcap_wide_where_valid253.rank(axis=1, method='first', ascending=False) <= NUM_STOCKS_TO_TRADE)

In [None]:
# # Uncomment and run this chunk to save memory (if needed) since we don't need these variables anymore
# del mcap
# del mcap_wide

# Feature Engineering

## Volume Info

In [7]:
def add_turnover_rolling_avg(
    df: pd.DataFrame,
    lookbacks: Iterable[int] = (21, 63, 126, 252),
    coverage: float = 0.8,
    permno_col: str = "permno",
    date_col: str = "date",
    turnover_col: str = "turnover",
) -> pd.DataFrame:
    """
    Add rolling-average turnover columns for multiple lookback windows.

    Parameters
    ----------
    df : DataFrame
        Input frame containing at least [permno_col, date_col, turnover_col].
    lookbacks : iterable of int
        Window lengths (in trading days) to compute.
    coverage : float
        Minimum fraction of non-null observations required in the window.
    permno_col, date_col, turnover_col : str
        Column names.
    copy : bool
        If True, work on a copy; otherwise modify in place.

    Returns
    -------
    DataFrame
        DataFrame with new columns: `avg_turnover_of_past_{lb}_days`.
    """
    df = df.copy()

    # Ensure proper ordering
    df.sort_values([permno_col, date_col], inplace=True)

    grp = df.groupby(permno_col)[turnover_col]

    for lb in lookbacks:
        minp = int(np.floor(coverage * lb))
        colname = f"avg_turnover_of_past_{lb}_days"
        df[colname] = (
        turnover_df
            .groupby('permno')['turnover']
            .apply(lambda s: s.shift(1)         # skip “today”
                                .rolling(window=lb, min_periods=minp)
                                .mean())
            .reset_index(level=0, drop=True)
        )
        
    
    return df

In [None]:
def add_turnover_normalized_betas(
    df: pd.DataFrame,
    lookbacks: Iterable[int] = (21, 63, 126, 252),
    coverage: float = 0.8,
    permno_col: str = "permno",
    date_col: str = "date",
    turnover_col: str = "turnover",
) -> pd.DataFrame:
    """
    Add raw/normalized beta & R^2 of OLS regression: turnover ~ trading_day_index
    using the previous lookback trading days for each (permno, date).

    Normalization: beta / median(turnover over previous lookback days).

    For memory‑efficiency: iterates over contiguous slices of each permno (no pandas groupby).
    Uses float64 for numerical stability. Rolling median computed with a
    sliding sorted list (O(n * lookback)) which is acceptable for typical
    lookbacks (<=252).
    """
    # Sort once
    df.sort_values([permno_col, date_col], inplace=True)
    df.reset_index(drop=True, inplace=True)

    permno_vals = df[permno_col].to_numpy()
    y_all = df[turnover_col].to_numpy(dtype=np.float64)
    N = len(df)

    # Find group boundaries (assumes sorted by permno)
    change = np.flatnonzero(permno_vals[1:] != permno_vals[:-1]) + 1
    starts = np.concatenate(([0], change))
    ends = np.concatenate((change, [N]))

    for lb in lookbacks:
        beta_raw_arr = np.full(N, np.nan, dtype=np.float64)
        beta_norm_arr = np.full(N, np.nan, dtype=np.float64)
        r2_arr = np.full(N, np.nan, dtype=np.float64)
        minp = int(np.floor(lb * coverage))

        for start, end in zip(starts, ends):
            n = end - start
            if n <= lb:
                continue  # no row has enough history

            y = y_all[start:end]                 # view
            x = np.arange(n, dtype=np.float64)   # day index within permno
            valid = ~np.isnan(y)

            #  rolling median of previous lb days 
            median_arr = np.full(n, np.nan, dtype=np.float64)
            # initialize sorted window with first lb observations (previous days for j=lb)
            window_vals = [val for val in y[0:lb] if not np.isnan(val)]
            window_vals.sort()

            def current_median(vals):
                m = len(vals)
                if m == 0:
                    return np.nan
                if m % 2:
                    return vals[m // 2]
                else:
                    return 0.5 * (vals[m // 2 - 1] + vals[m // 2])

            # j runs from lb .. n-1
            for j in range(lb, n):
                if len(window_vals) >= minp:
                    median_arr[j] = current_median(window_vals)
                # slide window: remove outgoing y[j-lb], add incoming y[j]
                out_val = y[j - lb]
                if not np.isnan(out_val):
                    pos = bisect.bisect_left(window_vals, out_val)
                    # safety check (should exist)
                    if pos < len(window_vals) and window_vals[pos] == out_val:
                        window_vals.pop(pos)
                in_val = y[j]
                if not np.isnan(in_val):
                    bisect.insort(window_vals, in_val)

            # OLS slope/R^2 using prefix sums 
            y_valid = np.where(valid, y, 0.0)
            x_valid = np.where(valid, x, 0.0)

            cs_y  = np.empty(n + 1); cs_y[0] = 0;  cs_y[1:]  = np.cumsum(y_valid)
            cs_y2 = np.empty(n + 1); cs_y2[0] = 0; cs_y2[1:] = np.cumsum(y_valid * y_valid)
            cs_x  = np.empty(n + 1); cs_x[0] = 0;  cs_x[1:]  = np.cumsum(x_valid)
            cs_x2 = np.empty(n + 1); cs_x2[0] = 0; cs_x2[1:] = np.cumsum(np.where(valid, x * x, 0.0))
            cs_xy = np.empty(n + 1); cs_xy[0] = 0; cs_xy[1:] = np.cumsum(np.where(valid, x * y, 0.0))
            cs_n  = np.empty(n + 1, dtype=np.int64); cs_n[0] = 0; cs_n[1:]  = np.cumsum(valid.astype(np.int64))

            j = np.arange(lb, n, dtype=np.int64)
            S_y  = cs_y[j]  - cs_y[j - lb]
            S_y2 = cs_y2[j] - cs_y2[j - lb]
            S_x  = cs_x[j]  - cs_x[j - lb]
            S_x2 = cs_x2[j] - cs_x2[j - lb]
            S_xy = cs_xy[j] - cs_xy[j - lb]
            n_valid = cs_n[j] - cs_n[j - lb]

            numer_beta = n_valid * S_xy - S_x * S_y
            denom_beta = n_valid * S_x2 - S_x * S_x
            S_yc2 = n_valid * S_y2 - S_y * S_y

            with np.errstate(divide="ignore", invalid="ignore"):
                beta_raw = numer_beta / denom_beta
                r2 = (numer_beta * numer_beta) / (denom_beta * S_yc2)

            # Valid slope mask
            mask = (n_valid >= minp) & (denom_beta > 0) & (S_yc2 > 0)
            rows_global = start + j[mask]

            # Raw beta
            beta_raw_arr[rows_global] = beta_raw[mask]
            r2_arr[rows_global] = r2[mask]

            # Normalized beta: divide by median (must have median, nonzero)
            med_vals = median_arr[j[mask]]
            norm_mask = ~np.isnan(med_vals) & (med_vals != 0)
            if norm_mask.any():
                beta_norm_arr[rows_global[norm_mask]] = (
                    beta_raw[mask][norm_mask] / med_vals[norm_mask]
                )

        # Attach columns
        # df[f"beta_turnover_wrt_day_{lb}_days_raw"] = beta_raw_arr
        df[f"beta_turnover_wrt_day_{lb}_days"] = beta_norm_arr
        df[f"r2_turnover_wrt_day_{lb}_days"] = r2_arr

    return df


In [23]:
turnover_lookbacks = [21, 63, 126, 252]
turnover_df = turnover_df.sort_values(['permno', 'date'])

turnover_df = add_turnover_rolling_avg(turnover_df, lookbacks=turnover_lookbacks, 
                                       coverage=0.8)

In [24]:
turnover_df = add_turnover_normalized_betas(df=turnover_df, lookbacks=turnover_lookbacks, coverage=0.8)

### Convert to wide format (so that we can update mask_for_stocks_to_trade)

In [19]:
avg_turnover_dict = {}
avg_turnover_cols = [c for c in turnover_df.columns if 'avg_turnover_of_past' in c]
assert len(avg_turnover_cols) > 0
for col in avg_turnover_cols:
    avg_turnover_dict[col] = (turnover_df.groupby(['date', 'permno'])[col].mean()
        .unstack('permno')
        .sort_index()
        .sort_index(axis=1))

In [None]:
turnover_reg_dict = {}
turnover_reg_cols = [c for c in turnover_df.columns if "beta_turnover_wrt_day_" in c] + [c for c in turnover_df.columns if "r2_turnover_wrt_day_" in c]
assert len(turnover_reg_cols) > 0
for col in turnover_reg_cols:
    turnover_reg_dict[col] = (turnover_df.groupby(['date', 'permno'])[col].mean()
        .unstack('permno')
        .sort_index()
        .sort_index(axis=1))

Save turnover_reg_dict, avg_turnover_dict

In [90]:
import pickle
with open('features/turnover_reg_dict.pkl','wb') as f:
    pickle.dump(turnover_reg_dict, f, protocol=pickle.HIGHEST_PROTOCOL)

In [20]:
import pickle
with open('features/avg_turnover_dict.pkl','wb') as f:
    pickle.dump(avg_turnover_dict, f, protocol=pickle.HIGHEST_PROTOCOL)

Load turnover_reg_dict, avg_turnover_dict

In [None]:
with open('features/avg_turnover_dict.pkl','rb') as f:
    avg_turnover_dict = pickle.load(f)

In [None]:
with open('features/avg_turnover_dict.pkl','rb') as f:
    avg_turnover_dict = pickle.load(f)

## Raw returns and vol-scaled returns

In [46]:
def ewma_std_strict_window_size(
    returns_matrix: pd.DataFrame,
    span: int = 63,
    window: int = 253
) -> pd.DataFrame:
    """
    Compute an exponentially-weighted std over exactly the last `window` days.
    If any of the past `window` returns are NaN, the result is NaN.

    Parameters
    ----------
    returns_matrix : pd.DataFrame
        Daily total-return matrix (dates × permnos)
    span : int
        EW span for decay (alpha = 2/(span+1)).
    window : int
        Number of days to include (exactly) in each vol estimate.

    Returns
    -------
    vol : DataFrame
        Same shape as `mat`, but the first `window-1` rows (and
        any window containing NaNs) will be NaN.
    """
    α = 2.0 / (span + 1)
    # raw weights, newest day last
    raw_w = (1 - α) ** np.arange(window)[::-1]
    S1 = raw_w.sum()
    S2 = (raw_w**2).sum()
    # bias-correction factor for variance (ddof=1 analogue)
    bias_var = S1**2 / (S1**2 - S2)
    sqrt_bias = np.sqrt(bias_var)

    # normalized weights for the population formula
    w = raw_w / S1

    def _wstd(x: np.ndarray) -> float:
        # x is a length-`window` slice of returns
        if x.shape[0] < window or np.isnan(x).any():
            return np.nan
        μ = np.dot(w, x)
        pop_std = np.sqrt(np.dot(w, (x - μ) ** 2))
        # apply the same bias correction pandas uses:
        return sqrt_bias * pop_std

    return (
        returns_matrix
        .rolling(window=window, min_periods=window)
        .apply(_wstd, raw=True)
    )
    
def calc_vol_and_time_scaled_kday_returns(
    ret_dict: Dict[int, pd.DataFrame],
    vol_mat: pd.DataFrame
) -> Dict[int, pd.DataFrame]:
    """
    Normalize each k-day return matrix by the ex-ante volatility.

    Parameters
    ----------
    ret_dict : dict[int, DataFrame]
        Mapping lookback k → raw k-day return matrix (dates × assets).
    vol_mat : DataFrame
        Ex-ante volatility σ_t (dates × assets), e.g. from `calc_ewm_vol`.

    Returns
    -------
    scaled_dict : dict[int, DataFrame]
        Same keys as `ret_dict`, but each DataFrame is elementwise
        divided by `vol_mat`, giving vol-scaled returns, then further divided by sqrt(k) to make the values comparable against
        different horizons
    """
    scaled = {}
    for k, ret in ret_dict.items():
        length = None
        if isinstance(k, int):
            length = k
        elif isinstance(k, tuple) and len(k) == 2:
            length = k[0] - k[1]
        else:
            raise Exception('Key for ret_dict should be either an integer or a length 2 tuple')
        # align on dates/assets in case of mismatches
        ret_aligned = ret.reindex_like(vol_mat)
        scaled[k] = ret_aligned.div(vol_mat).div(np.sqrt(length))
    return scaled

In [44]:
WindowSpec = Union[int, Tuple[int, int]]


def calc_kday_returns(
    returns_matrix: pd.DataFrame,
    windows: Iterable[WindowSpec]
) -> Dict[WindowSpec, pd.DataFrame]:
    """
    Compute total return matrices for multiple look-back specifications.

    For an integer window k:
        r_t^(k) = exp( sum_{i=0..k-1} log(1 + r_{t-i}) ) - 1
      i.e. k-day total return ending at date t (includes day t).

    For a tuple window (a, b) with a > b >= 0:
        Return from *a days ago* up to and including *b days ago*, skipping
        the most recent b days. Length of interval = a - b trading days.
        Formally (aligned to date t):
            r_t^(a,b) = exp( sum_{i=b .. a-1} log(1 + r_{t-i}) ) - 1
        (Note: i=0 corresponds to day t.)

        Example: (252, 21) sums log-returns for days t-252+1 ... t-21 (length 231),
        i.e. from 252 days ago through 21 days ago inclusive, excluding the last
        21 days.

    Parameters
    ----------
    returns_matrix : pd.DataFrame
        Daily simple returns (dates × assets), each entry r_t.
    windows : iterable of int or (int, int)
        Look-back specs.

    Returns
    -------
    Dict[WindowSpec, pd.DataFrame]
        Maps each window spec to a DataFrame of total returns.
        Keys are the same objects (ints or tuples) provided in `windows`.
    """
    # Sanitize returns to avoid log(<=0) issues: r <= -1 ⇒ NaN
    safe = returns_matrix.where(returns_matrix > -1)
    log_ret = np.log1p(safe)

    out: Dict[WindowSpec, pd.DataFrame] = {}

    # (Optional) could precompute different rolling lengths only once.
    # We'll cache rolling results by length for efficiency.
    rolling_cache: Dict[int, pd.DataFrame] = {}

    for spec in windows:
        if isinstance(spec, int):
            k = spec
            if k <= 0:
                raise ValueError(f"Integer window must be positive, got {k}.")
            if k not in rolling_cache:
                rolling_cache[k] = log_ret.rolling(window=k, min_periods=k).sum()
            summed = rolling_cache[k]
            out[spec] = np.expm1(summed)
        else:
            # Tuple case
            if not (isinstance(spec, tuple) and len(spec) == 2):
                raise AssertionError(f"Tuple window must have length 2, got {spec}.")
            a, b = spec
            assert isinstance(a, int) and isinstance(b, int), \
                f"Tuple elements must be ints, got {spec}."
            assert a > b >= 0, \
                f"For tuple (a,b) require a > b >= 0; got (a={a}, b={b})."

            length = a - b
            if length not in rolling_cache:
                rolling_cache[length] = log_ret.rolling(window=length, min_periods=length).sum()

            # Rolling sum over the interval length, then shift forward by b days
            # so that the window that *ended* at t-b is aligned to date t.
            summed = rolling_cache[length].shift(b)

            out[spec] = np.expm1(summed)

    return out


In [None]:
RET_LOOKBACK_WINDOWS = [21, 63, 126, 252, (252, 21), (252, 42), (126, 21)]
kday_returns_dict = calc_kday_returns(returns_matrix=returns_matrix, windows=RET_LOOKBACK_WINDOWS)
ewm_vol_matrix = ewma_std_strict_window_size(returns_matrix=returns_matrix, span=63, window=253)

In [20]:
vol_time_scaled_kday_returns_dict = calc_vol_and_time_scaled_kday_returns(kday_returns_dict, ewm_vol_matrix)

Save kday_returns_dict, vol_time_scaled_kday_returns_dict

In [None]:
with open('features/kday_returns_dict.pkl','wb') as f:
    pickle.dump(kday_returns_dict, f, protocol=pickle.HIGHEST_PROTOCOL)
with open('features/vol_time_scaled_kday_returns_dict.pkl','wb') as f:
    pickle.dump(vol_time_scaled_kday_returns_dict, f, protocol=pickle.HIGHEST_PROTOCOL)


Load kday_returns_dict, vol_time_scaled_kday_returns_dict

In [None]:

with open('features/kday_returns_dict.pkl','rb') as f:
    kday_returns_dict = pickle.load(f)
with open('features/vol_time_scaled_kday_returns_dict.pkl','rb') as f:
    vol_time_scaled_kday_returns_dict = pickle.load(f)

## MACD

In [None]:
def _ew_weights(span=63, window=253, oldest_first=False):
    α = 2.0 / (span + 1)                     # pandas/EMA definition
    w = (1 - α) ** np.arange(window)         # newest-first
    w /= w.sum()                             # renormalise over the truncated window
    return w[::-1] if oldest_first else w

def ewma_mean_strict_window_size(mat: pd.DataFrame, span: int, window: int = 253) -> pd.DataFrame:
    """
    Finite-window EWMA (length = `window`) that writes NaN unless the
    entire look-back window is free of NaNs.

    Uses pandasʼ convention  alpha = 2 / (span + 1).

    Parameters
    ----------
    mat : DataFrame T x N
    span   : EWMA span S (same interpretation as pd.Series.ewm(span=...)).
    window : Truncation length, default 253.

    Returns
    -------
    DataFrame of the same shape as `prices`.
    """
    if span <= 1:
        raise ValueError("span must be > 1")

    # -- Convert once to a contiguous NumPy array for speed -------------
    x   = mat.to_numpy(dtype=float, copy=False)
    T, N = x.shape
    out = np.full_like(x, np.nan)

    # -- EW coefficients ------------------------------------------------
    alpha = 2.0 / (span + 1.0)     # pandas definition
    d     = 1.0 - alpha            # decay factor
    a     = alpha                  # keep variable names consistent
    drop_coef = a * d ** window    # weight leaving the window each step
    denom     = 1.0 - d ** window  # normaliser for truncated kernel

    num  = np.zeros(N, dtype=float)     # running numerator
    run  = np.zeros(N, dtype=np.int32)  # consecutive non-NaN count

    for t in range(T):
        row  = x[t]
        mask_nan = np.isnan(row)

        # reset counters where current value is NaN
        run[mask_nan] = 0
        run[~mask_nan] += 1

        # standard EW recursion (treat NaNs as zero contribution here)
        num *= d
        num += a * np.where(mask_nan, 0.0, row)

        # subtract value that just rolled out of the finite window
        if t >= window:
            old = x[t - window]
            num -= drop_coef * np.where(np.isnan(old), 0.0, old)

        # output only where we have `window` consecutive valid obs
        full = run >= window
        if full.any():
            out[t, full] = num[full] / denom

    return pd.DataFrame(out, index=mat.index, columns=mat.columns)

def rolling_std_strict(mat: pd.DataFrame, window: int, valid_frac: float = 0.9) -> pd.DataFrame:
    if not (0 <= valid_frac <= 1):
        raise ValueError("valid_frac must be between 0 and 1")

    min_required = int(np.ceil(window * valid_frac))

    std = mat.rolling(window=window, min_periods=min_required).std()
    cnt = mat.notna().rolling(window=window, min_periods=window).sum()

    # Mask out where count < min_required
    std[cnt < min_required] = np.nan
    return std


def baz_n_to_span(n: int | float) -> float:
    """
    Convert Baz ‘n’ (λ = (n-1)/n, α = 1/n) to the equivalent
    pandas EWMA span so that α = 2/(span+1) matches 1/n.
    """
    return 2.0 * n - 1.0

def baz_response(z: pd.DataFrame) -> pd.DataFrame:
    """
    Nonlinear response φ(z) = z·exp(–z²/4) / 0.89
    """
    return z * np.exp(-z.pow(2) / 4) / 0.89

def calc_macd_scores(
    prices: pd.DataFrame,
    pairs=((8, 24), (16, 48), (32, 96)),
    ew_window: int = 252,
    price_vol_window: int = 63,
    zscore_window: int = 252,
    lags_for_lagged_macd: Dict[str, int] = {"curr": 0, "lag_1m": 21, "lag_3m": 63,
                            "lag_6m": 126, "lag_12m": 252},
) -> dict[tuple[int, int], pd.DataFrame]:
    """
    Baz-style intermediate signals Ẏ̃ (Eq. 6–8).

    Parameters
    ----------
    prices : DataFrame (T × N) of close prices.
    pairs  : iterable of Baz ‘n’ numbers (S, L).
    ew_window        : truncation length for EWMAs (default 252).
    price_vol_window : window for σ_price in Eq. (7) (default 63).
    zscore_window    : window for σ_ξ    in Eq. (6) (default 252).

    Returns
    -------
    dict { (S, L) : DataFrame } – each is Ẏ̃_t^(i)(S,L).
    """
    # --- pre-compute rolling volatilities ---------------------------------
    price_std = rolling_std_strict(prices, price_vol_window)

    # --- container for results --------------------------------------------
    macd_scores = {}

    current_y_tildes = []

    for S, L in pairs:
        # truncated EWMAs
        span_s, span_l = baz_n_to_span(S), baz_n_to_span(L)
        m_short = ewma_mean_strict_window_size(prices, span=span_s, window=ew_window)
        m_long  = ewma_mean_strict_window_size(prices, span=span_l, window=ew_window)
        

        macd  = m_short - m_long # Eq (8)
        xi     = macd / price_std # Eq (7)
        
        sigma_xi = rolling_std_strict(xi, zscore_window)

        y_curr = xi / sigma_xi # Eq (6)
        
        for label, d in lags_for_lagged_macd.items():
            macd_scores[(S, L, label)] = y_curr.shift(d) if d else y_curr
        current_y_tildes.append(y_curr)
        
    # Composite final signal (Eq. 9)
    #    Y_t = (1/3) * sum_k φ(Ẏ̃_k)
    # responses = [baz_response(df) for df in macd_scores.values()]
    composite = sum(baz_response(df) for df in current_y_tildes) / len(pairs)
    macd_scores["baz_comp"] = composite

    return macd_scores

In [None]:
lags_for_lagged_macd = {"curr": 0, "lag_1m": 21, "lag_3m": 63,
                            "lag_6m": 126, "lag_12m": 252}
macd_dict = calc_macd_scores(prices=split_adjusted_close, 
                             pairs=((8, 24), (16, 48), (32, 96)), 
                             ew_window=252, price_vol_window=63, 
                             zscore_window=252,
                             lags_for_lagged_macd=lags_for_lagged_macd)

Save and load macd_dict

In [None]:
# save
with open('features/macd_dict.pkl','wb') as f:
    pickle.dump(macd_dict, f, protocol=pickle.HIGHEST_PROTOCOL)

In [None]:
with open('features/macd_dict.pkl','rb') as f:
    macd_dict = pickle.load(f)


## Rolling betas

In [None]:
spy_vix_df = pd.read_csv('features/spy_vix_changes.csv')
spy_vix_df.index = pd.to_datetime(spy_vix_df['date'])
spy_ret_df = spy_vix_df[['SPY']]

In [None]:
@njit(parallel=True)
def _ewm_betas_strict_numba(X: np.ndarray,
                            m: np.ndarray,
                            span: int,
                            window: int) -> np.ndarray:
    """
    Compute EW-weighted rolling betas of each column in X vs market m,
    using exactly the most recent `window` points with EWM decay `span`.
    Any NaN in the window (either series) -> beta is NaN for that date.

    Parameters
    ----------
    X : float64 array, shape (T, N)
        Asset returns (aligned to market).
    m : float64 array, shape (T,)
        Market (SPY) returns (aligned to X).
    span : int
        EW span (alpha = 2/(span+1)).
    window : int
        Exact window length.

    Returns
    -------
    betas : float64 array, shape (T, N)
        Rolling betas; first `window-1` rows are NaN.
    """
    T, N = X.shape

    alpha = 2.0 / (span + 1.0)
    gamma = 1.0 - alpha               # decay
    g_pow_w = gamma ** window         # gamma^window
    # Sum of raw weights over the finite window (constant):
    S1 = (1.0 - g_pow_w) / (1.0 - gamma)

    betas = np.empty((T, N), dtype=np.float64)
    betas[:] = np.nan

    # Running (unnormalized) weighted sums; newest obs weight = 1
    Sx = np.zeros(N, dtype=np.float64)     # sum gamma^i * x
    Sxm = np.zeros(N, dtype=np.float64)    # sum gamma^i * x*m
    Sm = 0.0                               # sum gamma^i * m
    Smm = 0.0                              # sum gamma^i * m^2

    # Track how many valid (non-NaN pair) points are in the last window
    counts = np.zeros(N, dtype=np.int32)
    valid_buf = np.zeros((window, N), dtype=np.uint8)  # ring buffer of 0/1

    for t in range(T):
        m_new = m[t]
        m_new_val = 0.0 if np.isnan(m_new) else m_new

        m_old_val = 0.0
        if t >= window:
            m_old = m[t - window]
            m_old_val = 0.0 if np.isnan(m_old) else m_old

        # Update market sums (shared across all assets)
        Sm = gamma * Sm + m_new_val - g_pow_w * m_old_val
        Smm = gamma * Smm + m_new_val * m_new_val - g_pow_w * (m_old_val * m_old_val)

        # Denominator (unnormalized weighted var of market after demeaning)
        denom_unscaled = Smm - (Sm * Sm) / S1

        pos = t % window  # ring-buffer position for this step

        # Per-asset updates
        for j in prange(N):
            x_new = X[t, j]
            x_new_val = 0.0 if np.isnan(x_new) else x_new

            x_old_val = 0.0
            if t >= window:
                x_old = X[t - window, j]
                x_old_val = 0.0 if np.isnan(x_old) else x_old

            # Update asset sums
            Sx[j] = gamma * Sx[j] + x_new_val - g_pow_w * x_old_val
            Sxm[j] = gamma * Sxm[j] + (x_new_val * m_new_val) - g_pow_w * (x_old_val * m_old_val)

            # Maintain strict-window validity count (both series must be finite)
            new_valid = 1 if (not np.isnan(x_new) and not np.isnan(m_new)) else 0
            old_valid = 0
            if t >= window:
                old_valid = valid_buf[pos, j]
            valid_buf[pos, j] = new_valid
            counts[j] += (new_valid - old_valid)

            # Emit beta only when the last `window` pairs are all valid
            if t >= window - 1 and counts[j] == window and denom_unscaled > 1e-14:
                # Numerator is (cov * S1); denominator is (var_m * S1); S1 cancels in the ratio
                num_unscaled = Sxm[j] - (Sx[j] * Sm) / S1
                betas[t, j] = num_unscaled / denom_unscaled
            else:
                betas[t, j] = np.nan

    return betas


def rolling_beta_ewm_strict(
    returns_matrix: pd.DataFrame,
    spy_returns: pd.Series | pd.DataFrame,
    span: int = 63,
    window: int = 253,
) -> pd.DataFrame:
    """
    Rolling CAPM beta (with intercept) vs SPY using a strict, finite EW window.

    - Uses EW weights with `span` (alpha = 2/(span+1)), applied over exactly
      the last `window` days (finite horizon).
    - If *any* missing value appears in the `window` for either series,
      the beta for that date is NaN.

    Parameters
    ----------
    returns_matrix : DataFrame (dates × assets)
        Daily asset returns.
    spy_returns : Series or 1-col DataFrame
        Daily SPY returns.
    span : int
        EWM span.
    window : int
        Exact window length.

    Returns
    -------
    DataFrame
        Betas with the same index/columns as `returns_matrix`.
    """
    if isinstance(spy_returns, pd.DataFrame):
        if spy_returns.shape[1] != 1:
            raise ValueError("spy_returns must be a Series or a 1-column DataFrame.")
        spy_returns = spy_returns.iloc[:, 0]

    # Align on common dates (strictly)
    idx = returns_matrix.index.intersection(spy_returns.index)
    idx = idx.sort_values()

    X = returns_matrix.loc[idx].to_numpy(dtype=np.float64, copy=True)
    m = spy_returns.loc[idx].to_numpy(dtype=np.float64, copy=True)

    betas = _ewm_betas_strict_numba(X, m, span, window)
    return pd.DataFrame(betas, index=idx, columns=returns_matrix.columns)


In [None]:
beta_df = rolling_beta_ewm_strict(returns_matrix, spy_ret_df["SPY"], span=63, window=253)

In [28]:
beta_df.columns = beta_df.columns.astype(int)

In [34]:
beta_df.to_hdf('features/beta_df.h5', key='beta_df')

## Assembling raw return, vol-scaled return and MACD

### Filter by
#### a) All features on the day of trading are present (we need close to 2 years of burn-in period)
#### b) Get the top 1000 by mkt cap

In [9]:
def update_mask_for_tradeable_stocks(
    feature_dicts: Sequence[Dict[Any, pd.DataFrame]],
    existing_mask_for_tradeable_stocks: pd.DataFrame,
) -> pd.DataFrame:
    """
    Returns a boolean DataFrame (`True` = the stock-date survives **all**
    feature-NaN checks + the external tradability filter).

    Parameters
    ----------
    feature_dicts : sequence of the dicts that hold your feature matrices
                    (e.g. [kday_returns_dict, vol_scaled_dict, macd_dict]).
    existing_mask : initial boolean mask of tradeable stocks, same shape/idx.
    """
    mask = existing_mask_for_tradeable_stocks.copy()
    for d in feature_dicts:
        for mat in d.values():
            mask &= ~(mat.isna())
    return mask

def build_stocks_to_trade_mask_using_mcap(
    trade_mask: pd.DataFrame,
    mcap_wide: pd.DataFrame,
    num_stocks_to_trade: int,
) -> pd.DataFrame:
    """
    Adds the market-cap ranking filter on top of `trade_mask`.
    """
    return (
        mcap_wide.where(trade_mask)
                 .rank(axis=1, method="first", ascending=False)
                 <= num_stocks_to_trade
    )


In [None]:
mask_for_tradeable_stocks = update_mask_for_tradeable_stocks(
    [kday_returns_dict, avg_turnover_dict],
    # [kday_returns_dict, vol_time_scaled_kday_returns_dict, macd_dict],
    valid253
)
mask_for_stocks_to_trade = build_stocks_to_trade_mask_using_mcap(
    mask_for_tradeable_stocks,
    mcap_wide_aligned,
    NUM_STOCKS_TO_TRADE
)   
# del mask_for_tradeable_stocks

In [38]:
mask_for_tradeable_stocks.to_hdf("features/mask_for_tradeable_stocks.h5",
                                 key="mask_for_tradeable_stocks",
                                 format="fixed")

In [55]:
def generate_feat_mat_for_ltr(
    kday_returns_dict: Dict[int, pd.DataFrame],
    vol_time_scaled_kday_returns_dict: Dict[int, pd.DataFrame],
    macd_dict: Dict[Tuple[int, int], pd.DataFrame],
    avg_turnover_dict: Dict[str, pd.DataFrame],
    mask_for_stocks_to_trade: pd.DataFrame,
    drop_rows_all_nan: bool = True,
) -> pd.DataFrame:
    """
    Drop-in replacement. Produces the same result as the original implementation:
      - same rows/cols/order
      - same dtypes/NaNs
      - same sorting and column names
    """

    def _series(name: str, mat: pd.DataFrame) -> pd.Series:
        # exactly the same expression as original
        return mat.where(mask_for_stocks_to_trade).stack().rename(name)

    X = None
    col_order = []

    # 1) kday_returns_dict
    for k, mat in kday_returns_dict.items():
        name = f"{k}_day_raw_ret"
        s = _series(name, mat)
        col_order.append(name)
        if X is None:
            X = s.to_frame()
        else:
            X[name] = s  # align on index

    # 2) vol_time_scaled_kday_returns_dict
    for k, mat in vol_time_scaled_kday_returns_dict.items():
        name = f"{k}_day_vol_time_scaled_ret"
        s = _series(name, mat)
        col_order.append(name)
        X[name] = s

    # 3) macd_dict
    for k, mat in macd_dict.items():
        col_name = "_".join(map(str, k)) if isinstance(k, tuple) else k
        s = _series(col_name, mat)
        col_order.append(col_name)
        X[col_name] = s

    # 4) avg_turnover_dict
    for k, mat in avg_turnover_dict.items():
        s = _series(k, mat)
        col_order.append(k)
        X[k] = s

    # Ensure identical column order to the original concat
    X = X[col_order]

    final_df = (
        X.reset_index()                               # index → cols
         .rename(columns={'level_0': 'date', 'level_1': 'permno'})
         .sort_values(['date', 'permno'])
    )

    if drop_rows_all_nan:
        final_df = final_df.dropna(how='all', subset=final_df.columns[2:])

    return final_df

In [52]:
def compute_forward_deciles(
    *,
    returns_matrix: pd.DataFrame,
    mask_for_stocks_to_trade: pd.DataFrame,
    feature_index_df: pd.DataFrame,          # feat_mat[["date","permno"]]
    target_horizon: int = 21,
    num_deciles: int = 10,
) -> pd.DataFrame:
    import numpy as np
    import pandas as pd

    # --- 1) trailing 21-day cumulative return ------------------------
    trailing_ret = (
        (1 + returns_matrix)
        .rolling(target_horizon, min_periods=target_horizon)
        .apply(np.prod, raw=True)
        .sub(1)
    )

    # --- 2) shift back by 21 to make it "forward" --------------------
    fwd_return = trailing_ret.shift(-target_horizon)
    
    # --- 3) rank into (approximately) equal-sized *per-date* deciles --
    # Mask out untradeable stocks first
    fwd_masked = fwd_return.where(mask_for_stocks_to_trade)

    def _row_deciles(row: pd.Series) -> pd.Series:
        """
        Assign deciles 1..num_deciles within a single date (row),
        using qcut for (near) equal-sized buckets. Keeps NaNs for masked entries.
        """
        valid = row.dropna()
        n = len(valid)
        if n == 0:
            return pd.Series(index=row.index, dtype="Int8")

        # If fewer valid observations than desired deciles, fall back to rank-based
        # buckets so we still return integers starting at 1 (some top deciles unused).
        if n < num_deciles:
            # dense rank 1..n, then map linearly into 1..num_deciles using ceil
            r = valid.rank(method="first")
            dec = np.ceil(r / n * num_deciles).astype(int)
            return dec.reindex(row.index).astype("Int8")

        # Use qcut; duplicates='drop' handles many ties at boundaries (may yield < num_deciles bins)
        bins = pd.qcut(
            valid,
            q=num_deciles,
            labels=range(1, num_deciles + 1),
            duplicates="drop",
        )

        # If we lost some bins due to ties (e.g., only 8 unique bins), re-label consecutively
        # so that result still starts at 1 and is compact.
        u = pd.unique(bins)
        if len(u) < num_deciles:
            remap = {old: i + 1 for i, old in enumerate(sorted(u))}
            bins = bins.map(remap)

        return bins.reindex(row.index).astype("Int8")

    # Apply per row (date)
    decile_df = fwd_masked.apply(_row_deciles, axis=1)

    # --- 4) same reshaping / joining as before -----------------------
    y_ret_decile = (
        decile_df.stack(dropna=False)
                 .rename("decile")
                 .to_frame()
                 .join(
                     fwd_return.stack(dropna=False).rename("fwd_ret")
                 )
                 .reset_index()
                 .rename(columns={"level_0": "date",
                                  "level_1": "permno"})
                 .dropna(subset=["decile"])
                 .merge(feature_index_df, on=["date", "permno"], how="inner")
                 .sort_values(["date", "permno"])
                 .reset_index(drop=True)
    )

    y_ret_decile["decile"] = y_ret_decile["decile"].astype("UInt8")
    y_ret_decile["fwd_ret"] = y_ret_decile["fwd_ret"].astype("float32")

    return y_ret_decile


In [None]:
def build_ltr_xy(
    *,
    kday_returns_dict,                       # same signatures as before
    vol_time_scaled_kday_returns_dict,
    macd_dict,
    avg_turnover_dict,
    mask_for_stocks_to_trade,
    returns_matrix,
    target_horizon: int = 21,
    num_deciles: int = 10,
    drop_rows_all_nan: bool = True,
):
    X = generate_feat_mat_for_ltr(
        kday_returns_dict=kday_returns_dict,
        vol_time_scaled_kday_returns_dict=vol_time_scaled_kday_returns_dict,
        macd_dict=macd_dict,
        avg_turnover_dict=avg_turnover_dict,
        mask_for_stocks_to_trade=mask_for_stocks_to_trade,
        drop_rows_all_nan=drop_rows_all_nan,
    )
    

    y = compute_forward_deciles(
        returns_matrix=returns_matrix,
        mask_for_stocks_to_trade=mask_for_stocks_to_trade,
        feature_index_df=X[['date', 'permno']],
        target_horizon=target_horizon,
        num_deciles=num_deciles
    )
    
    Xy_df = (
        X.merge(y, on=['date', 'permno'], how='inner')
        .sort_values(['date', 'permno'])
        .reset_index(drop=True)
    )
    
    return Xy_df

ltr_df = build_ltr_xy(
    kday_returns_dict=kday_returns_dict,
    vol_time_scaled_kday_returns_dict=vol_time_scaled_kday_returns_dict,
    macd_dict=macd_dict,
    avg_turnover_dict=avg_turnover_dict,
    mask_for_stocks_to_trade=mask_for_stocks_to_trade,
    returns_matrix=returns_matrix,
)


In [None]:
ltr_df.to_csv('features/ltr.csv', index=False)

In [None]:
ltr_df.to_hdf(
    'features/ltr.h5',
    key='ltr',
    mode='w',
    complib='zlib',
    complevel=9
)

# Heavy batching (use this to generate ltr_df if the above cells cause memory issues)

In [None]:
def _feature_name_list(
    kday_returns_dict,
    vol_time_scaled_kday_returns_dict,
    macd_dict,
    avg_turnover_dict
):
    names = []
    for k in kday_returns_dict:                     # 1
        names.append(f"{k}_day_raw_ret")
    for k in vol_time_scaled_kday_returns_dict:     # 2
        names.append(f"{k}_day_vol_time_scaled_ret")
    for k in macd_dict:                             # 3
        names.append("_".join(map(str, k)) if isinstance(k, tuple) else k)
    for k in avg_turnover_dict:                     # 4
        names.append(k)
    return names


def generate_feat_mat_for_ltr(
    kday_returns_dict: Dict[int, pd.DataFrame],
    vol_time_scaled_kday_returns_dict: Dict[int, pd.DataFrame],
    macd_dict: Dict[Tuple[int, int], pd.DataFrame],
    avg_turnover_dict: Dict[str, pd.DataFrame],
    mask_for_stocks_to_trade: pd.DataFrame,
    drop_rows_all_nan: bool = True,
    *,
    hdf_path: str = "X_features.h5",     # where to store intermediate/final
    key: str = "X",                      # HDF5 group/key
    date_batch_size: int = 250,          # adjust to your RAM
    return_in_memory: bool = False       # False → just write to HDF
) -> Optional[pd.DataFrame]:
    """
    Writes the full wide feature matrix to an HDF5 table in row batches.
    If return_in_memory is True, it also loads & returns the whole DataFrame
    (identical to the original function's output).
    """
    # Precompute final column order (matches original concat order)
    col_order = _feature_name_list(
        kday_returns_dict,
        vol_time_scaled_kday_returns_dict,
        macd_dict,
        avg_turnover_dict
    )

    # Ensure file is clean
    try:
        with pd.HDFStore(hdf_path, mode="w") as _:
            pass
    except Exception:
        pass

    dates = mask_for_stocks_to_trade.index
    n_dates = len(dates)

    with pd.HDFStore(hdf_path, mode="a") as store:
        for start in range(0, n_dates, date_batch_size):
            end = min(start + date_batch_size, n_dates)
            d_slice = dates[start:end]

            mask_chunk = mask_for_stocks_to_trade.loc[d_slice]

            series_list = []

            # 1) kday_returns_dict
            for k, mat in kday_returns_dict.items():
                name = f"{k}_day_raw_ret"
                s = mat.loc[d_slice].where(mask_chunk).stack().rename(name)
                series_list.append(s)

            # 2) vol_time_scaled_kday_returns_dict
            for k, mat in vol_time_scaled_kday_returns_dict.items():
                name = f"{k}_day_vol_time_scaled_ret"
                s = mat.loc[d_slice].where(mask_chunk).stack().rename(name)
                series_list.append(s)

            # 3) macd_dict
            for k, mat in macd_dict.items():
                col_name = "_".join(map(str, k)) if isinstance(k, tuple) else k
                s = mat.loc[d_slice].where(mask_chunk).stack().rename(col_name)
                series_list.append(s)

            # 4) avg_turnover_dict
            for k, mat in avg_turnover_dict.items():
                s = mat.loc[d_slice].where(mask_chunk).stack().rename(k)
                series_list.append(s)

            chunk_df = pd.concat(series_list, axis=1)
            chunk_df = (
                chunk_df.reset_index()
                        .rename(columns={"level_0": "date", "level_1": "permno"})
                        .sort_values(["date", "permno"])
            )

            if drop_rows_all_nan:
                chunk_df = chunk_df.dropna(how="all", subset=col_order)

            # reorder cols precisely once
            chunk_df = chunk_df[["date", "permno"] + col_order]

            # Append rows to HDF5 table (same schema for each chunk)
            store.append(key, chunk_df, format="table", data_columns=["date", "permno"])

            # free memory
            del series_list, chunk_df, mask_chunk
            gc.collect()

    if not return_in_memory:
        return None

    # Load the full thing back (identical to original output)
    with pd.HDFStore(hdf_path, mode="r") as store:
        full_df = store.select(key)
    return full_df


In [59]:
def compute_forward_deciles_batched(
    *,
    returns_matrix: pd.DataFrame,
    mask_for_stocks_to_trade: pd.DataFrame,
    feature_index_df: pd.DataFrame,     # X[['date','permno']]
    target_horizon: int = 21,
    num_deciles: int = 10,
    hdf_path: str = "y_targets.h5",
    key: str = "y",
    date_batch_size: int = 250,
    return_in_memory: bool = False
) -> Optional[pd.DataFrame]:
    import numpy as np
    import pandas as pd
    import gc

    # 1) trailing & forward returns (full, needed to compute deciles consistently)
    trailing_ret = (
        (1 + returns_matrix)
        .rolling(target_horizon, min_periods=target_horizon)
        .apply(np.prod, raw=True)
        .sub(1)
    )
    fwd_return = trailing_ret.shift(-target_horizon)
    fwd_masked = fwd_return.where(mask_for_stocks_to_trade)

    def _row_deciles(row: pd.Series) -> pd.Series:
        valid = row.dropna()
        n = len(valid)
        if n == 0:
            return pd.Series(index=row.index, dtype="Int8")
        if n < num_deciles:
            r = valid.rank(method="first")
            dec = np.ceil(r / n * num_deciles).astype(int)
            return dec.reindex(row.index).astype("Int8")
        bins = pd.qcut(valid, q=num_deciles,
                       labels=range(1, num_deciles + 1),
                       duplicates="drop")
        u = pd.unique(bins)
        if len(u) < num_deciles:
            remap = {old: i + 1 for i, old in enumerate(sorted(u))}
            bins = bins.map(remap)
        return bins.reindex(row.index).astype("Int8")

    dates = fwd_masked.index
    with pd.HDFStore(hdf_path, mode="w") as store:
        for start in range(0, len(dates), date_batch_size):
            end = min(start + date_batch_size, len(dates))
            d_slice = dates[start:end]

            decile_chunk = fwd_masked.loc[d_slice].apply(_row_deciles, axis=1)
            # stack both
            y_chunk = (
                decile_chunk.stack(dropna=False).rename("decile").to_frame()
                .join(fwd_return.loc[d_slice].stack(dropna=False).rename("fwd_ret"))
                .reset_index()
                .rename(columns={"level_0": "date", "level_1": "permno"})
                .dropna(subset=["decile"])
            )

            # merge with feature_index_df (only rows present in the features)
            idx_chunk = feature_index_df[
                (feature_index_df["date"] >= d_slice[0]) &
                (feature_index_df["date"] <= d_slice[-1])
            ]
            y_chunk = (
                y_chunk.merge(idx_chunk, on=["date", "permno"], how="inner")
                       .sort_values(["date", "permno"])
                       .reset_index(drop=True)
            )

            # cast to compact types
            y_chunk["decile"] = y_chunk["decile"].astype("UInt8")
            y_chunk["fwd_ret"] = y_chunk["fwd_ret"].astype("float32")

            store.append(key, y_chunk, format="table", data_columns=["date", "permno"])

            del decile_chunk, y_chunk, idx_chunk
            gc.collect()

    if not return_in_memory:
        return None
    with pd.HDFStore(hdf_path, mode="r") as store:
        return store.select(key)


In [None]:
def _row_deciles(row: pd.Series, num_deciles: int) -> pd.Series:
    """Assign per-row deciles with NA handling."""
    valid = row.dropna()
    n = len(valid)
    if n == 0:
        return pd.Series(index=row.index, dtype="Int8")
    if n < num_deciles:
        r = valid.rank(method="first")
        dec = np.ceil(r / n * num_deciles).astype(int)
        return dec.reindex(row.index).astype("Int8")
    bins = pd.qcut(
        valid,
        q=num_deciles,
        labels=range(1, num_deciles + 1),
        duplicates="drop",
    )
    u = pd.unique(bins)
    if len(u) < num_deciles:
        remap = {old: i + 1 for i, old in enumerate(sorted(u))}
        bins = bins.map(remap)
    return bins.reindex(row.index).astype("Int8")


def build_ltr_xy_batched(
    *,
    kday_returns_dict: Dict[int, pd.DataFrame],
    vol_time_scaled_kday_returns_dict: Dict[int, pd.DataFrame],
    macd_dict: Dict[Tuple[int, int], pd.DataFrame],
    avg_turnover_dict: Dict[str, pd.DataFrame],
    mask_for_stocks_to_trade: pd.DataFrame,
    returns_matrix: pd.DataFrame,
    target_horizon: int = 21,
    num_deciles: int = 10,
    date_batch_size: int = 250,
    hdf_path: str = "ltr.h5",
    key: str = "ltr",
    return_in_memory: bool = False,
    use_existing_X: bool = False,
    use_existing_Y: bool = False,
    X_path: str = "X_features.h5",
    X_key: str = "X",
    Y_path: str = "y_targets.h5",
    Y_key: str = "y",
) -> Optional[pd.DataFrame]:
    """Build or reuse X and Y, then merge to LTR in date batches."""
    # Dates to iterate over come from the mask
    all_dates = mask_for_stocks_to_trade.index.to_numpy()

    # X
    if not use_existing_X:
        generate_feat_mat_for_ltr(
            kday_returns_dict=kday_returns_dict,
            vol_time_scaled_kday_returns_dict=vol_time_scaled_kday_returns_dict,
            macd_dict=macd_dict,
            avg_turnover_dict=avg_turnover_dict,
            mask_for_stocks_to_trade=mask_for_stocks_to_trade,
            drop_rows_all_nan=True,
            hdf_path=X_path,
            key=X_key,
            date_batch_size=date_batch_size,
            return_in_memory=False,
        )

    # Y
    if not use_existing_Y:
        trailing_ret = (
            (1 + returns_matrix)
            .rolling(target_horizon, min_periods=target_horizon)
            .apply(np.prod, raw=True)
            .sub(1)
        )
        fwd_return = trailing_ret.shift(-target_horizon)
        fwd_masked = fwd_return.where(mask_for_stocks_to_trade)

        with pd.HDFStore(Y_path, "w") as Y_store:
            for start in range(0, len(all_dates), date_batch_size):
                end = min(start + date_batch_size, len(all_dates))
                d0, d1 = all_dates[start], all_dates[end - 1]

                decile_chunk = fwd_masked.loc[d0:d1].apply(
                    _row_deciles, axis=1, num_deciles=num_deciles
                )

                y_chunk = (
                    decile_chunk.stack(dropna=False)
                    .rename("decile")
                    .to_frame()
                    .join(
                        fwd_return.loc[d0:d1]
                        .stack(dropna=False)
                        .rename("fwd_ret")
                    )
                    .reset_index()
                    .rename(columns={"level_0": "date", "level_1": "permno"})
                    .dropna(subset=["decile"])
                )

                # compact dtypes
                y_chunk["decile"] = y_chunk["decile"].astype("int8")
                y_chunk["fwd_ret"] = y_chunk["fwd_ret"].astype("float32")

                Y_store.append(
                    Y_key, y_chunk, format="table", data_columns=["date", "permno"]
                )

                del decile_chunk, y_chunk
                gc.collect()

        del trailing_ret, fwd_return, fwd_masked
        gc.collect()

    # Merge X & Y -> LTR
    with pd.HDFStore(hdf_path, "w") as out_store, \
         pd.HDFStore(X_path, "r") as X_store, \
         pd.HDFStore(Y_path, "r") as Y_store:

        for start in range(0, len(all_dates), date_batch_size):
            end = min(start + date_batch_size, len(all_dates))
            d0, d1 = all_dates[start], all_dates[end - 1]

            # Interpolate datetimes into the where clause as ISO strings
            where_dates = [
                f"date >= '{pd.Timestamp(d0)}'",
                f"date <= '{pd.Timestamp(d1)}'",
            ]

            X_chunk = X_store.select(X_key, where=where_dates)
            if X_chunk.empty:
                continue

            Y_chunk = Y_store.select(Y_key, where=where_dates)

            chunk = (
                X_chunk.merge(Y_chunk, on=["date", "permno"], how="inner")
                .sort_values(["date", "permno"])
                .reset_index(drop=True)
            )

            out_store.append(
                key, chunk, format="table", data_columns=["date", "permno"]
            )

            del X_chunk, Y_chunk, chunk
            gc.collect()

    if not return_in_memory:
        return None

    with pd.HDFStore(hdf_path, "r") as store:
        return store.select(key)


In [None]:
# Build directly to disk, no huge objects kept:
ltr_df = build_ltr_xy_batched(
    kday_returns_dict=kday_returns_dict,
    vol_time_scaled_kday_returns_dict=vol_time_scaled_kday_returns_dict,
    macd_dict=macd_dict,
    avg_turnover_dict=avg_turnover_dict,
    mask_for_stocks_to_trade=mask_for_stocks_to_trade,
    returns_matrix=returns_matrix,
    date_batch_size=200,
    return_in_memory=False,  # keep False if RAM is tight
    hdf_path="ltr.h5",
    key="ltr",
    use_existing_X=True
)

ltr_df.to_hdf(
    'features/ltr.h5',
    key='ltr',
    mode='w',
    complib='zlib',
    complevel=9
)

# with pd.HDFStore("ltr.h5", "r") as store:
#     ltr_df = store.select("ltr")


# Generate Top-1000 Stock Mask from ltr_df (need for BOCD feat generation code)

In [None]:
def build_top1000_mask_from_ltr(
    ltr_long_df: pd.DataFrame,
    returns_wide: pd.DataFrame | None = None
) -> pd.DataFrame:
    """
    Create a boolean mask indicating whether (date, permno) appears in ltr_df.

    If `returns_wide` is provided, the mask is reindexed to precisely match
    its index/columns and any missing (date, permno) pairs are filled as False.

    Parameters
    ----------
    ltr_long_df : DataFrame with at least ['date', 'permno'] columns.
    returns_wide : Optional wide returns matrix to align the mask (index=date, columns=permno).

    Returns
    -------
    DataFrame[bool] indexed by date with PERMNO columns.
    """
    # Ensure 'date' is datetime; avoid re-reading anything from disk.
    df = ltr_long_df[['date', 'permno']].copy()
    if not pd.api.types.is_datetime64_any_dtype(df['date']):
        df['date'] = pd.to_datetime(df['date'], errors='coerce')

    # Cross-tabulate presence: True iff that (date, permno) appears in ltr_df
    mask = pd.crosstab(df['date'], df['permno']).gt(0)

    if returns_wide is not None:
        # Align to returns matrix shape to keep downstream matrices perfectly conformable
        mask = mask.reindex(index=returns_wide.index, columns=returns_wide.columns).fillna(False)

    # Enforce boolean dtype consistently
    return mask.astype('bool')


# Build mask from the in-memory ltr_df (and align to returns_matrix if available)
try:
    is_top_1000_mask = build_top1000_mask_from_ltr(ltr_df, returns_matrix)
except NameError:
    # Fallback if returns_matrix is not in scope; still builds the mask from ltr_df only
    is_top_1000_mask = build_top1000_mask_from_ltr(ltr_df, None)


if 'mask_for_stocks_to_trade' in globals():
    try:
        aligned_ref = mask_for_stocks_to_trade.reindex_like(is_top_1000_mask).fillna(False).astype('bool')
        diff = (is_top_1000_mask != aligned_ref)
        n_diff = int(diff.to_numpy().sum())
        if n_diff > 0:
            # Print a short diagnostic so you can decide whether to harmonize the definitions.
            n_total = is_top_1000_mask.size
            frac = n_diff / n_total
            print(f"[Top-1000 mask] Warning: {n_diff:,} cells differ from mask_for_stocks_to_trade "
                  f"({frac:.6%} of the matrix). This can happen if ltr_df applies extra row screens.")
    except Exception as e:
        print(f"[Top-1000 mask] Comparison to mask_for_stocks_to_trade skipped: {e}")

is_top_1000_mask.to_hdf(
    "features/is_top_1000_mask.h5",
    key="is_top_1000_mask",
    mode="w",
    complib="zlib",
    complevel=9
)

print(is_top_1000_mask.shape)
