
# 🧭 因果 EDA 清单 + 模板函数（可运行 Notebook）
本 Notebook 提供：
- 一份**因果 EDA 检查清单**（时间线、变量角色、缺失、基线平衡、重叠、分布、共线性、生存特性等）；
- 一组**可复用的模板函数**（SMD 计算、平衡表、PS 重叠图、一键快速体检 `causal_eda_quickcheck`）；
- 一个**演示用数据**（DoWhy 合成数据）和一个**你的数据占位区**（按说明替换 DataFrame 与列名）。

> 使用方式：先运行“工具函数”单元格，然后选择“演示数据”或替换为你的数据，最后调用 `causal_eda_quickcheck(...)`。



## ✅ 因果 EDA 检查清单（打印出来贴显示器边上）
1) **时间顺序**：治疗/暴露先于结局；协变量来自治疗前窗口。  
2) **变量角色**：明确混杂 vs 中介 vs 碰撞；避免中介/碰撞进入 PS/回归调整集。  
3) **缺失机制**：MCAR/MAR/MNAR 初判；必要时制定插补策略（禁止用结局信息插补治疗前变量）。  
4) **基线平衡**：SMD 表 + Love plot；记录不平衡项，作为匹配/加权/调整的靶点。  
5) **重叠性**：PS 分布重叠足够；极端 PS（≈0 或 ≈1）处理策略（截尾/修剪/分层）。  
6) **分布与尺度**：必要时对数/标准化/温和截断；注意临床可解释性。  
7) **共线性**：相关矩阵 + VIF；必要时合并变量、降维或正则化。  
8) **抽样与外推**：纳入/排除标准、时间段/科室/医院差异（必要时分层分析）。  
9) **生存特性**（若为生存数据）：删失与竞争风险识别；对齐“治疗起始”和“随访起点”。



## 🧰 工具函数（SMD、平衡表、PS 重叠、快速体检）
先运行本单元，后续任意数据都能直接调用。


In [None]:

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

def smd_cont(x_t, x_c):
    """SMD for continuous variable"""
    x_t = pd.to_numeric(x_t, errors="coerce")
    x_c = pd.to_numeric(x_c, errors="coerce")
    mu_t, mu_c = np.nanmean(x_t), np.nanmean(x_c)
    sd_pool = np.sqrt((np.nanvar(x_t, ddof=1) + np.nanvar(x_c, ddof=1)) / 2)
    return (mu_t - mu_c) / (sd_pool + 1e-12)

def smd_binary(x_t, x_c):
    """SMD for binary/categorical variable"""
    x_t = pd.to_numeric(x_t, errors="coerce")
    x_c = pd.to_numeric(x_c, errors="coerce")
    p1, p0 = np.nanmean(x_t), np.nanmean(x_c)
    p = np.nanmean(np.r_[x_t, x_c])
    denom = np.sqrt(p * (1 - p) + 1e-12)
    return (p1 - p0) / denom

def balance_table(df, treat_col, covs, cat_covs=None):
    """Generate baseline balance table with SMD."""
    cat_covs = set(cat_covs or [])
    tmask = df[treat_col] == 1
    cmask = df[treat_col] == 0
    rows = []
    for v in covs:
        s = df[v]
        is_binary_like = s.dropna().isin([0,1]).all() or str(s.dtype) in ("category","bool") or (v in cat_covs)
        if is_binary_like:
            val = smd_binary(df.loc[tmask, v], df.loc[cmask, v])
        else:
            val = smd_cont(df.loc[tmask, v], df.loc[cmask, v])
        rows.append((v, val))
    out = pd.DataFrame(rows, columns=["variable","SMD"])
    out["|SMD|"] = out["SMD"].abs()
    return out.sort_values("|SMD|", ascending=False)

def love_plot(smd_df, cutoff_minor=0.1, cutoff_major=0.2):
    """Horizontal bar chart of |SMD| (Love plot)."""
    smd_sorted = smd_df.sort_values("|SMD|")
    plt.figure(figsize=(6, 0.4 * len(smd_sorted)))
    plt.barh(smd_sorted["variable"], smd_sorted["|SMD|"])
    plt.axvline(cutoff_minor, linestyle="--")
    plt.axvline(cutoff_major, linestyle="--")
    plt.xlabel("|SMD|"); plt.title("Baseline Imbalance (Pre-adjustment)")
    plt.tight_layout()
    plt.show()

def plot_ps_overlap(ps_t, ps_c, bins=30):
    """Propensity score overlap plot."""
    plt.figure()
    plt.hist(ps_t, bins=bins, alpha=0.6, label="Treatment")
    plt.hist(ps_c, bins=bins, alpha=0.6, label="Control")
    plt.xlabel("Propensity Score"); plt.ylabel("Count"); plt.legend(); plt.title("PS Overlap")
    plt.show()

def causal_eda_quickcheck(df, treat='treatment', outcome='outcome', pre_covs=None, show_plots=True):
    """Quick causal EDA: size/missing/num desc + SMD + PS overlap (if pre_covs provided)."""
    from sklearn.preprocessing import StandardScaler
    from sklearn.linear_model import LogisticRegression

    print("== Basic info ==")
    print("Shape:", df.shape)
    subset = [c for c in [treat, outcome] if c in df.columns]
    if subset:
        display(df[subset].head())
    print("\nMissing (Top 10):")
    print(df.isnull().sum().sort_values(ascending=False).head(10))

    print("\nNumeric describe (head):")
    numdesc = df.select_dtypes(include=[np.number]).describe().T
    display(numdesc.head(10))

    if pre_covs:
        print("\n== Baseline balance (SMD) ==")
        smd_df = balance_table(df, treat, pre_covs)
        display(smd_df.head(15))
        if show_plots:
            love_plot(smd_df)

        print("\n== Propensity score overlap ==")
        X = pd.get_dummies(df[pre_covs], drop_first=True)
        X = X.select_dtypes(include=[np.number])
        X = X.fillna(X.median(numeric_only=True))
        scaler = StandardScaler()
        Xs = scaler.fit_transform(X)
        y = df[treat].astype(int).values
        ps = LogisticRegression(max_iter=300).fit(Xs, y).predict_proba(Xs)[:,1]
        if show_plots:
            plot_ps_overlap(ps[df[treat]==1], ps[df[treat]==0])
        print("PS percentiles:", np.percentile(ps, [0,1,5,50,95,99,100]))
        return smd_df, ps
    else:
        print("No pre_covs provided; skip SMD & PS checks.")
        return None, None



## 🧪 演示数据（DoWhy 合成数据，安全练手）
说明：运行下方单元生成一份包含 `treatment`、`outcome` 与若干混杂变量的合成数据；
随后调用 `causal_eda_quickcheck(...)` 体验完整的因果 EDA 流程。


In [None]:

import pandas as pd

try:
    import dowhy.datasets as dwd
    data = dwd.linear_dataset(
        beta=8, num_common_causes=6, num_instruments=1, num_samples=1200,
        treatment_is_binary=True, stddev_treatment_noise=1.0, stddev_outcome_noise=1.0, seed=2025
    )
    df_demo = data["df"].copy()
    T = data["treatment_name"]
    Y = data["outcome_name"]
    C = data["common_causes_names"]
    print("Demo shape:", df_demo.shape)
    print("treatment:", T, "| outcome:", Y, "| covariates:", C)
    display(df_demo.head())
except Exception as e:
    print("DoWhy not installed or failed to generate:", e)
    df_demo, T, Y, C = None, None, None, None


In [None]:

# 一键体检（含 SMD 表 + Love plot + PS 重叠）
if df_demo is not None:
    smd_df, ps = causal_eda_quickcheck(df_demo, treat=T, outcome=Y, pre_covs=C, show_plots=True)
else:
    print("Skip demo quickcheck.")



## 🧩 你的数据占位区（请替换为你的 DataFrame 与列名）
把你的数据 DataFrame 命名为 `df`，指定：
- `treat`：治疗/暴露列名（0/1 或可转为 0/1）；
- `outcome`：结局列名（可留空，本 EDA 不强制使用）；
- `pre_covs`：**治疗前**协变量列表（用作 SMD 与 PS 估计）。


In [None]:

# === 示例模板 ===
# import pandas as pd
# df = pd.read_csv("/path/to/your/medical_dataset.csv")
# treat = "treat_col_name"
# outcome = "outcome_col_name"   # 可选
# pre_covs = ["age","sex","sofa","charlson","sbp","dbp"]  # 替换为你的治疗前协变量
# smd_df, ps = causal_eda_quickcheck(df, treat=treat, outcome=outcome, pre_covs=pre_covs, show_plots=True)
pass
