In [1]:
import pandas as pd
import matplotlib.pyplot as plt

# 读 public 版 mash-up 表
df = pd.read_csv("mashup_summary_public.csv")

# 只看 student/professional 的 depression_rate
d = df[(df["metric_readable"] == "depression_rate") &
       (df["source_dataset"].isin(["student", "professional"]))].copy()

# --------
# 加权平均（按 n 做权重）：
# 因为 mashup_summary 里每行是“某个分组”的 rate，n 是这个分组的样本量
# 我们汇总时用加权平均更合理
# --------
def weighted_rate(x):
    return (x["rate"] * x["n"]).sum() / x["n"].sum()

# 1) 先把所有 age_group / diet_group 等维度“压平”，
#    得到：dataset × family_history × financial_bucket 的总体 rate
agg = (d.groupby(["source_dataset", "family_history_flag", "financial_bucket"], as_index=False)
         .apply(lambda g: pd.Series({
             "rate": weighted_rate(g),
             "n": g["n"].sum()
         }))
         .reset_index(drop=True))

# 2) 透视：每个 dataset×family_history 一行，列是 financial_bucket（Low/Medium/High）
p = (agg.pivot_table(index=["source_dataset", "family_history_flag"],
                     columns="financial_bucket",
                     values="rate",
                     aggfunc="first")
       .reindex(columns=["Low", "Medium", "High"])
       .reset_index())

# 3) 计算 RD（High - Low）
p["RD_high_minus_low"] = p["High"] - p["Low"]

# 如果有 Low 或 High 缺失，会出现 NaN（正常，说明某档没数据）
# 你也可以选择 dropna 只画完整的
plot_df = p.dropna(subset=["RD_high_minus_low"]).copy()

# 4) 画图：x = dataset，分组 = family_history_flag，y = RD
#    为了更直观，我们把 family_history_flag 映射成标签
plot_df["family_history_label"] = plot_df["family_history_flag"].map({0: "No family history", 1: "Family history"})

# 转成“宽表”方便画并列柱状图
wide = plot_df.pivot(index="source_dataset", columns="family_history_label", values="RD_high_minus_low")
wide = wide.reindex(index=["student", "professional"])  # 控制显示顺序（可改）

ax = wide.plot(kind="bar")
ax.set_title("Risk Difference (RD) = High financial - Low financial\n(Depression rate; weighted by n)")
ax.set_xlabel("source_dataset")
ax.set_ylabel("RD (High - Low)")
plt.axhline(0, linewidth=1)  # 0 基线
plt.tight_layout()

plt.savefig("plot_RD_high_minus_low.png", dpi=200)
plt.close()

print("Done! Generated: plot_RD_high_minus_low.png")
print(plot_df[["source_dataset","family_history_flag","Low","High","RD_high_minus_low"]])


Done! Generated: plot_RD_high_minus_low.png
financial_bucket source_dataset  family_history_flag       Low      High  \
0                  professional                    0  0.003571  0.044534   
1                  professional                    1  0.007246  0.064039   
2                       student                    0  0.357063  0.732996   
3                       student                    1  0.392872  0.782651   

financial_bucket  RD_high_minus_low  
0                          0.040963  
1                          0.056793  
2                          0.375933  
3                          0.389779  


  .apply(lambda g: pd.Series({
