In [None]:
import pandas as pd
import matplotlib.pyplot as plt

# read mashup_public
df = pd.read_csv("mashup_summary_public.csv")

# select only student/professional depression_rate
d = df[(df["metric_readable"] == "depression_rate") &
       (df["source_dataset"].isin(["student", "professional"]))].copy()

# --------
# weighted average (weighted by n)
# --------
def weighted_rate(x):
    return (x["rate"] * x["n"]).sum() / x["n"].sum()

# 1) let age_group / diet_group and other dimensions flat，
#    get：dataset × family_history × financial_bucket rate
agg = (d.groupby(["source_dataset", "family_history_flag", "financial_bucket"], as_index=False)
         .apply(lambda g: pd.Series({
             "rate": weighted_rate(g),
             "n": g["n"].sum()
         }))
         .reset_index(drop=True))

# 2) pivot：every dataset×family_history row，the column is financial_bucket（Low/Medium/High）
p = (agg.pivot_table(index=["source_dataset", "family_history_flag"],
                     columns="financial_bucket",
                     values="rate",
                     aggfunc="first")
       .reindex(columns=["Low", "Medium", "High"])
       .reset_index())

# 3) calculate RD（High - Low）
p["RD_high_minus_low"] = p["High"] - p["Low"]

# if loss Low or High，return NaN
plot_df = p.dropna(subset=["RD_high_minus_low"]).copy()

# 4) draw：x = dataset，groupping = family_history_flag，y = RD
plot_df["family_history_label"] = plot_df["family_history_flag"].map({0: "No family history", 1: "Family history"})

wide = plot_df.pivot(index="source_dataset", columns="family_history_label", values="RD_high_minus_low")
wide = wide.reindex(index=["student", "professional"]) 

ax = wide.plot(kind="bar")
ax.set_title("Risk Difference (RD) = High financial - Low financial\n(Depression rate; weighted by n)")
ax.set_xlabel("source_dataset")
ax.set_ylabel("RD (High - Low)")
plt.axhline(0, linewidth=1)  # 0 baseline
plt.tight_layout()

plt.savefig("plot_RD_high_minus_low.png", dpi=200)
plt.close()

print("Done! Generated: plot_RD_high_minus_low.png")
print(plot_df[["source_dataset","family_history_flag","Low","High","RD_high_minus_low"]])


Done! Generated: plot_RD_high_minus_low.png
financial_bucket source_dataset  family_history_flag       Low      High  \
0                  professional                    0  0.003571  0.044534   
1                  professional                    1  0.007246  0.064039   
2                       student                    0  0.357063  0.732996   
3                       student                    1  0.392872  0.782651   

financial_bucket  RD_high_minus_low  
0                          0.040963  
1                          0.056793  
2                          0.375933  
3                          0.389779  


  .apply(lambda g: pd.Series({
