In [1]:
from __future__ import annotations

from dataclasses import dataclass, replace, field
from typing import List, Literal, Optional, Dict, Any

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

# ===== library imports (your refactored code) =====
import qresearch.signals as qsigs
from qresearch.data.utils import get_processed_dir
from qresearch.data.yfinance import *
from qresearch.backtest.portfolio import *
from qresearch.backtest.metrics import *
from qresearch.portfolio.weights import *

# -----------------------------
# Parquet -> MarketData
# -----------------------------
def _read_wide(path: str) -> pd.DataFrame:
    df = pd.read_parquet(path)
    # Ensure DatetimeIndex
    if not isinstance(df.index, pd.DatetimeIndex):
        df.index = pd.to_datetime(df.index, errors="coerce")
    df = df[~df.index.isna()].sort_index()
    # Ensure string tickers
    df.columns = df.columns.astype(str)
    return df


def load_marketdata_from_parquet(
    data_dir: str,
    tickers: list[str],
    start: str,
    end: Optional[str] = None,
    strict_calendar: bool = True,
) -> MarketData:
    """
    Load wide parquet panels and convert to MarketData.

    strict_calendar=True:
      - use close as master calendar
      - drop any dates where ANY selected ticker has missing close
    """
    tickers = [str(t).strip() for t in tickers]

    # Required
    close = _read_wide(f"{data_dir}/ohlcv_wide_close.parquet")

    # Optional panels (load if present)
    def maybe(panel_name: str) -> pd.DataFrame | None:
        p = f"{data_dir}/ohlcv_wide_{panel_name}.parquet"
        try:
            return _read_wide(p)
        except Exception:
            return None

    open_ = maybe("open")
    high = maybe("high")
    low = maybe("low")
    volume = maybe("volume")
    turnover = maybe("turnover")
    pct_chg = maybe("pct_chg")  # parquet name from Step 1–5 pipeline

    # Subset to tickers
    missing = [t for t in tickers if t not in close.columns]
    if missing:
        raise KeyError(f"These tickers are not in close parquet columns: {missing}")

    close = close[tickers].copy()

    # Date slicing
    idx = close.index
    start_dt = pd.to_datetime(start)
    end_dt = pd.to_datetime(end) if end is not None else idx.max()
    close = close.loc[(close.index >= start_dt) & (close.index <= end_dt)]

    # Master calendar
    cal = close.index

    def align(panel: pd.DataFrame | None) -> pd.DataFrame | None:
        if panel is None:
            return None
        # If the panel lacks some tickers, keep intersection only
        cols = [t for t in tickers if t in panel.columns]
        if not cols:
            return None
        out = panel[cols].reindex(cal)
        # If some tickers missing in panel (e.g., turnover not available), expand to full tickers with NaN cols
        if set(cols) != set(tickers):
            for t in tickers:
                if t not in out.columns:
                    out[t] = np.nan
            out = out[tickers]
        return out

    open_ = align(open_)
    high = align(high)
    low = align(low)
    volume = align(volume)
    turnover = align(turnover)
    pct_chg = align(pct_chg)

    # Strict calendar cleaning (close must be complete for all tickers)
    if strict_calendar:
        ok = close.notna().all(axis=1)
        close = close.loc[ok]
        if open_ is not None:
            open_ = open_.loc[ok]
        if high is not None:
            high = high.loc[ok]
        if low is not None:
            low = low.loc[ok]
        if volume is not None:
            volume = volume.loc[ok]
        if turnover is not None:
            turnover = turnover.loc[ok]
        if pct_chg is not None:
            pct_chg = pct_chg.loc[ok]

    md = MarketData(
        close=close,
        open=open_,
        high=high,
        low=low,
        volume=volume,
        turnover=turnover,
        pct_change=pct_chg,  # map parquet pct_chg -> MarketData.pct_change
    )
    return md

In [3]:
DATA_DIR = get_processed_dir() / "data_cn_etf_universe"  # change to your parquet directory

In [9]:
from __future__ import annotations

from dataclasses import dataclass
from typing import Dict, List, Tuple, Optional

import numpy as np
import pandas as pd

from sklearn.cluster import KMeans


# =========================
# Config
# =========================
@dataclass(frozen=True)
class UniverseDiagConfig:
    report_start: str = "2015-01-01"   # Availability 统计从这里开始
    report_end: Optional[str] = None   # None -> close.index.max()

    window: int = 240                 # corr/cluster 用最近 window 个交易日
    min_obs_window: int = 120         # 在 window 内至少有多少非空收益（太少不稳）

    universe_size: int = 10           # 目标推荐池大小（也用作 KMeans 的 K）

    corr_prune_thr: float = 0.90      # 二次去冗余阈值（对 reps 再做一次）

    # KMeans
    random_state: int = 42
    n_init: int = 20
    zscore_per_asset: bool = True     # 强烈建议 True


# =========================
# Core utilities
# =========================
def _to_simple_returns(close: pd.DataFrame) -> pd.DataFrame:
    """日简单收益；不 pad 填充（避免未来 pandas 默认 pad 的 warning）。"""
    close = close.sort_index()
    ret = close.pct_change(fill_method=None)
    ret.iloc[0] = 0.0
    ret = ret.replace([np.inf, -np.inf], np.nan)
    return ret


def _availability_table(
    close: pd.DataFrame,
    start: pd.Timestamp,
    end: pd.Timestamp,
) -> pd.DataFrame:
    """
    对齐你给的 Availability 表定义：
    - 使用全局 [start, end] 的交易日历长度作为分母
    - ticker 若中间缺一两天，也算 missing
    - ticker 若晚出现，早期整段缺失也会计入 missing（这就是你表里后发 ETF missing_rate 大的原因）
    """
    px = close.loc[start:end].copy()
    total_days = len(px.index)

    rows = []
    for t in px.columns:
        s = px[t]
        n_obs = int(s.notna().sum())
        miss = int(s.isna().sum())
        missing_rate = (miss / total_days) if total_days > 0 else np.nan

        start_date = s.first_valid_index()
        end_date = s.last_valid_index()

        rows.append({
            "ticker": t,
            "start_date": pd.to_datetime(start_date).date() if start_date is not None else pd.NaT,
            "end_date": pd.to_datetime(end_date).date() if end_date is not None else pd.NaT,
            "n_obs": n_obs,
            "missing_rate": float(missing_rate) if pd.notna(missing_rate) else np.nan,
        })

    df = pd.DataFrame(rows).sort_values(
        ["missing_rate", "start_date", "ticker"],
        ascending=[True, True, True],
    ).reset_index(drop=True)

    return df


def _corr_summary(ret_win: pd.DataFrame) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
    """
    输出：
    - corr summary（mean/max/min）
    - top correlated pairs
    - corr matrix
    """
    corr = ret_win.corr()  # pairwise complete

    # summary（对角线不参与）
    corr2 = corr.copy()
    np.fill_diagonal(corr2.values, np.nan)

    summary = pd.DataFrame({
        "ticker": corr2.columns,
        "mean_corr": corr2.mean(axis=1, skipna=True).values,
        "max_corr": corr2.max(axis=1, skipna=True).values,
        "min_corr": corr2.min(axis=1, skipna=True).values,
        "n_pairs": corr2.notna().sum(axis=1).values,
    }).sort_values("mean_corr", ascending=False).reset_index(drop=True)

    # top pairs
    pairs = []
    cols = corr2.columns.tolist()
    for i in range(len(cols)):
        for j in range(i + 1, len(cols)):
            c = corr2.iat[i, j]
            if pd.notna(c):
                pairs.append((cols[i], cols[j], float(c)))

    top_pairs = (
        pd.DataFrame(pairs, columns=["a", "b", "corr"])
        .sort_values("corr", ascending=False)
        .head(50)
        .reset_index(drop=True)
    )
    return summary, top_pairs, corr


def _zscore_cols(x: pd.DataFrame) -> pd.DataFrame:
    mu = x.mean(axis=0)
    sd = x.std(axis=0, ddof=0).replace(0.0, np.nan)
    z = (x - mu) / sd
    z = z.replace([np.inf, -np.inf], np.nan).fillna(0.0)
    return z


def _cluster_kmeans(
    ret_win: pd.DataFrame,
    k: int,
    cfg: UniverseDiagConfig,
) -> pd.Series:
    """
    KMeans 聚类：
    - 特征：每只ETF最近 window 天的收益序列（长度=window）
    - 缺失：先 fillna(0)
    - 标准化：按资产 zscore（可选）
    输出：cluster label（从 1 开始，与你日志一致）
    """
    X = ret_win.copy().fillna(0.0)

    if cfg.zscore_per_asset:
        X = _zscore_cols(X)

    # KMeans 输入 shape=(n_assets, window)
    X_assets = X.T  # index=ticker

    km = KMeans(
        n_clusters=k,
        random_state=cfg.random_state,
        n_init=cfg.n_init,
    )
    labels0 = km.fit_predict(X_assets.values)  # 0..k-1
    labels = pd.Series(labels0 + 1, index=X_assets.index, name="cluster")  # 1..k
    return labels


def _pick_cluster_representatives(
    clusters: pd.Series,
    availability: pd.DataFrame,
    size: int,
) -> List[str]:
    """
    每个 cluster 选一个代表：start_date 最早（与你参考策略一致）。
    如果 cluster 数 != size：
      - 若 cluster > size：取“成员数最多”的前 size 个 cluster
      - 若 cluster < size：先取每簇代表，再按 start_date 早的补足
    """
    # availability indexed by ticker
    av = availability.set_index("ticker")

    # cluster sizes
    cluster_sizes = clusters.value_counts().sort_values(ascending=False)

    # 目标 cluster 集合
    cluster_ids = cluster_sizes.index.tolist()
    if len(cluster_ids) > size:
        cluster_ids = cluster_ids[:size]

    reps = []
    for cid in cluster_ids:
        members = clusters[clusters == cid].index.tolist()
        # 按 start_date 最早选
        sub = av.loc[members].copy()
        # start_date 可能 NaT
        sub["start_date_ts"] = pd.to_datetime(sub["start_date"], errors="coerce")
        sub = sub.sort_values(["start_date_ts", "ticker"], ascending=[True, True])
        reps.append(sub.index[0])

    # 如果 reps 不够 size，补：在全体里按 start_date 早补齐（避免重复）
    if len(reps) < size:
        remaining = [t for t in clusters.index if t not in reps]
        sub = av.loc[remaining].copy()
        sub["start_date_ts"] = pd.to_datetime(sub["start_date"], errors="coerce")
        sub = sub.sort_values(["start_date_ts", "ticker"], ascending=[True, True])
        need = size - len(reps)
        reps.extend(sub.index[:need].tolist())

    return reps


def _corr_prune_reps(
    reps: List[str],
    corr: pd.DataFrame,
    availability: pd.DataFrame,
    thr: float,
) -> List[str]:
    """
    在 reps 上再做一次去冗余：
    若 corr(a,b) > thr，保留 start_date 更早者。
    """
    av = availability.set_index("ticker")
    sd = pd.to_datetime(av["start_date"], errors="coerce")

    keep = set(reps)
    reps_list = list(reps)

    pairs = []
    for i in range(len(reps_list)):
        for j in range(i + 1, len(reps_list)):
            a, b = reps_list[i], reps_list[j]
            c = corr.loc[a, b]
            if pd.notna(c) and c > thr:
                pairs.append((a, b, float(c)))
    pairs.sort(key=lambda x: x[2], reverse=True)

    for a, b, c in pairs:
        if a not in keep or b not in keep:
            continue
        sda = sd.get(a, pd.NaT)
        sdb = sd.get(b, pd.NaT)

        # 保留更早 start_date；缺失则保留有日期者
        if pd.isna(sda) and pd.isna(sdb):
            drop = b
        elif pd.isna(sda):
            drop = a
        elif pd.isna(sdb):
            drop = b
        else:
            drop = b if sda <= sdb else a

        keep.remove(drop)

    # 保持原先顺序输出
    out = [t for t in reps if t in keep]
    return out


def _rolling_26w_corr_for_pair(
    close: pd.DataFrame,
    a: str,
    b: str,
    freq: str = "W-FRI",
    window: int = 26,
) -> pd.Series:
    """
    对齐你输出的“Rolling 26W corr”：
    - 用周频收益（周复合）计算
    - rolling(window) 相关
    """
    px = close[[a, b]].copy().sort_index()
    ret_d = px.pct_change(fill_method=None).replace([np.inf, -np.inf], np.nan).fillna(0.0)

    # 周复合收益
    def _compound(x: pd.Series) -> float:
        return float((1.0 + x).prod() - 1.0)

    ret_w = ret_d.resample(freq).apply(_compound)
    s = ret_w[a].rolling(window).corr(ret_w[b])
    return s.dropna()


# =========================
# Main: universe diagnosis
# =========================
def diagnose_universe(close: pd.DataFrame, cfg: UniverseDiagConfig) -> Dict[str, object]:
    close = close.sort_index()

    report_start = pd.to_datetime(cfg.report_start)
    report_end = pd.to_datetime(cfg.report_end) if cfg.report_end else close.index.max()

    # 1) Availability
    availability = _availability_table(close, report_start, report_end)

    # 2) returns + window slice
    ret = _to_simple_returns(close)
    ret_win = ret.iloc[-cfg.window:].copy()

    # 3) min obs in window filter (避免样本太少导致 corr/cluster 不稳)
    obs = ret_win.notna().sum(axis=0)
    keep_cols = obs[obs >= cfg.min_obs_window].index.tolist()
    ret_win = ret_win[keep_cols]

    # 4) corr summary + top pairs
    corr_summary, top_pairs, corr = _corr_summary(ret_win)

    # 5) clustering (K = universe_size)
    k = min(cfg.universe_size, ret_win.shape[1])  # 资产不足时保护
    clusters = _cluster_kmeans(ret_win, k=k, cfg=cfg)

    # 6) reps by earliest start_date in each cluster
    availability_kept = availability[availability["ticker"].isin(ret_win.columns)].copy()
    reps_raw = _pick_cluster_representatives(clusters, availability_kept, size=k)

    # 7) optional corr prune on reps
    reps_pruned = _corr_prune_reps(
        reps=reps_raw,
        corr=corr,
        availability=availability_kept,
        thr=cfg.corr_prune_thr,
    )

    # 8) rolling 26W corr for top pair (from top_pairs)
    rolling_corr = None
    top_pair = None
    if len(top_pairs) > 0:
        a, b = top_pairs.loc[0, "a"], top_pairs.loc[0, "b"]
        top_pair = (a, b)
        rolling_corr = _rolling_26w_corr_for_pair(close, a, b, freq="W-FRI", window=26)

    return {
        "availability": availability,
        "corr_summary": corr_summary,
        "top_pairs": top_pairs,
        "clusters": clusters.sort_values(),
        "suggested_universe_raw": reps_raw,
        "suggested_universe": reps_pruned,
        "top_pair": top_pair,
        "rolling_26w_corr": rolling_corr,
    }


# =========================
# Pretty print (exactly align your log blocks)
# =========================
def print_universe_report(out: Dict[str, object], universe_size: int = 10) -> None:
    availability: pd.DataFrame = out["availability"]
    corr_summary: pd.DataFrame = out["corr_summary"]
    top_pairs: pd.DataFrame = out["top_pairs"]
    clusters: pd.Series = out["clusters"]
    suggested: List[str] = out["suggested_universe"]

    print(f"\n=== Availability (top 15) ===")
    print(availability.head(15).to_string(index=True))

    print(f"\n=== Corr summary (most redundant first) ===")
    print(corr_summary.head(15).to_string(index=True))

    print(f"\n=== Top correlated pairs ===")
    print(top_pairs.head(20).to_string(index=True))

    print(f"\n=== Cluster assignment ===")
    print(clusters.to_string())

    # 输出对齐你现在：['...', '...']
    sug = suggested[:universe_size]
    print(f"\n=== Suggested universe (size={len(sug)}) ===")
    print(sug)

    # rolling corr
    if out.get("top_pair") and out.get("rolling_26w_corr") is not None:
        a, b = out["top_pair"]
        print(f"\nRolling 26W corr for top pair {a}-{b}:")
        print(out["rolling_26w_corr"].tail(10).to_string())


# =========================
# Example usage
# =========================
if __name__ == "__main__":
    # 你需要提供 close:
    # close: pd.DataFrame
    # - index: trading dates
    # - columns: tickers like '510880', '159928', ...
    #
    # 例如：
    # close = md.close  # 你的 MarketData close
    #
    # 这里仅展示调用方式（不会实际运行）
    cfg = UniverseDiagConfig(
        report_start="2015-01-01",
        report_end=None,
        window=240,
        min_obs_window=120,
        universe_size=10,
        corr_prune_thr=0.90,
        random_state=42,
        n_init=20,
        zscore_per_asset=True,
    )

    # out = diagnose_universe(close, cfg)
    # print_universe_report(out, universe_size=cfg.universe_size)
    pass


PortfolioBacktestResult(gross_ret=date
2019-01-02    0.000000
2019-01-03    0.000000
2019-01-04    0.000000
2019-01-07    0.000000
2019-01-08    0.000000
                ...   
2026-01-15   -0.021169
2026-01-16    0.004831
2026-01-19   -0.003523
2026-01-20   -0.008081
2026-01-21    0.047352
Length: 1712, dtype: float64, net_ret=date
2019-01-02    0.000000
2019-01-03    0.000000
2019-01-04    0.000000
2019-01-07    0.000000
2019-01-08    0.000000
                ...   
2026-01-15   -0.021569
2026-01-16    0.004431
2026-01-19   -0.003923
2026-01-20   -0.008081
2026-01-21    0.047352
Length: 1712, dtype: float64, equity_gross=date
2019-01-02    1.000000
2019-01-03    1.000000
2019-01-04    1.000000
2019-01-07    1.000000
2019-01-08    1.000000
                ...   
2026-01-15    1.075664
2026-01-16    1.080861
2026-01-19    1.077053
2026-01-20    1.068349
2026-01-21    1.118938
Length: 1712, dtype: float64, equity_net=date
2019-01-02    1.000000
2019-01-03    1.000000
2019-01-04    1.000