# Propensity score stratifiering 

Denna notebook implementerar propensity score (PS)-baserad stratifiering för att balansera grupperna
CCI ≤ 5 (kontroll) och CCI > 5 (behandling). PS skattas med logistisk regression baserat på ålder,
kön och laboratorievärden (glukos, natrium samt log1p-transformerade njurmarkörer BUN och kreatinin).
Balans utvärderas med SMD före och efter stratifiering. Mortalitet rapporteras som rå risk och
strata-standardiserad risk för 30, 90 och 365 dagar.


In [None]:
import sys, os
from pathlib import Path
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn.linear_model import LogisticRegression
import statsmodels.api as sm 

In [None]:
PROJECT_ROOT = Path.cwd().resolve()
if PROJECT_ROOT.name == "notebooks":
    PROJECT_ROOT = PROJECT_ROOT.parent
sys.path.insert(0, str(PROJECT_ROOT))

print("PROJECT_ROOT =", PROJECT_ROOT)
from src.connect_db import run_query, close_tunnel

# --- Läs in basdata (kohort) ---
sql_path = PROJECT_ROOT / "sql" / "stroke_cci_ps_cohort.sql"
with open(sql_path, encoding="utf-8") as f:
    cohort_sql = f.read()

df_base = run_query(cohort_sql)
print("df_base kolumner:", df_base.columns.tolist())
print("df_base antal rader:", len(df_base))

# --- Läs in kovariater (första 24h) ---
sql_path = PROJECT_ROOT / "sql" / "stroke_covariates_first24h.sql"
with open(sql_path, encoding="utf-8") as f:
    cov_sql = f.read()

df_cov = run_query(cov_sql)
print("df_cov kolumner:", df_cov.columns.tolist())
print("df_cov antal rader:", len(df_cov))

### Baslinje: gruppstorlek och mortalitet före balansering

In [None]:
# Säkerställ att outcome-kolumnerna är 0/1 (ibland kommer bool)
outcome_cols = ["died", "died_30d", "died_90d", "died_1y"]
for c in outcome_cols:
    if c in df_base.columns and df_base[c].dtype == "bool":
        df_base[c] = df_base[c].astype(int)

# Antal döda per CCI-grupp (0=CCI≤5, 1=CCI>5)
counts = df_base.groupby("high_cci")[outcome_cols].sum().astype(int)
counts.index = counts.index.map({0: "CCI ≤ 5", 1: "CCI > 5"})

# Totalt antal patienter per grupp
counts["n_patients"] = df_base.groupby("high_cci").size().astype(int).values

# Snygg ordning
counts = counts[["n_patients", "died", "died_30d", "died_90d", "died_1y"]]
display(counts)

## Datakvalitet: duplikat och saknade värden


In [None]:
# Duplikatkontroll
dup_base = df_base.duplicated(subset=["hadm_id"]).sum() if "hadm_id" in df_base.columns else np.nan
dup_cov  = df_cov.duplicated(subset=["hadm_id"]).sum()  if "hadm_id" in df_cov.columns else np.nan
print("Duplikat hadm_id i df_base:", dup_base)
print("Duplikat hadm_id i df_cov:", dup_cov)

# Missingness i labb
lab_cols = ["creatinine_first","glucose_first","sodium_first","bun_first"]
print("\nMissingness (antal NaN) i df_cov:")
display(df_cov[lab_cols].isna().sum())

print("\nSanity check describe (df_cov):")
show_cols = ["anchor_age"] + lab_cols
show_cols = [c for c in show_cols if c in df_cov.columns]
display(df_cov[show_cols].describe())


## Sammanfogning av basdata och kovariater


In [None]:
# Undvik dubbla demografikolumner genom att bara ta labb från df_cov
lab_keep = ["hadm_id","creatinine_first","glucose_first","sodium_first","bun_first"]
missing_keep = [c for c in lab_keep if c not in df_cov.columns]
if missing_keep:
    raise ValueError(f"Saknar kolumner i df_cov: {missing_keep}")

df_ps = df_base.merge(df_cov[lab_keep], on="hadm_id", how="left")

if "gender" not in df_ps.columns:
    raise ValueError("Saknar 'gender' i df_ps efter merge. Kontrollera df_base.")

df_ps["is_male"] = (df_ps["gender"] == "M").astype(int)

print("df_ps antal rader:", len(df_ps))
print("\nNaN i labb efter merge:")
display(df_ps[["creatinine_first","glucose_first","sodium_first","bun_first"]].isna().sum())


## Hantering av saknade labbvärden och log-transform

Saknade labbvärden hanteras med medianimputering per variabel innan PS skattas, för att undvika
bortfall av observationer. Kreatinin och BUN log-transformeras med log1p för att minska påverkan
av snedfördelning och extremvärden.


In [None]:
lab_cols = ["creatinine_first","glucose_first","sodium_first","bun_first"]

# Medianimputering per labbvariabel
for c in lab_cols:
    df_ps[c] = df_ps[c].astype(float)
    df_ps[c] = df_ps[c].fillna(df_ps[c].median(skipna=True))

# log-transform för njurmarkörer
df_ps["log_bun"] = np.log1p(df_ps["bun_first"])
df_ps["log_creatinine"] = np.log1p(df_ps["creatinine_first"])

print("NaN kvar i labb (ska vara 0):", int(df_ps[lab_cols].isna().sum().sum()))


## Skattning av propensity score och stratifiering (10 strata)

PS skattas med logistisk regression där behandlingsindikatorn är high_cci (1=CCI>5, 0=CCI≤5).
Individer delas in i 10 strata via kvantilindelning av PS.

In [None]:
use_covars_log = ["anchor_age","is_male","log_creatinine","glucose_first","sodium_first","log_bun"]

missing_cov = [c for c in use_covars_log if c not in df_ps.columns]
if missing_cov:
    raise ValueError(f"Saknar kovariater i df_ps: {missing_cov}")

X = df_ps[use_covars_log].astype(float).copy()
X = X.replace([np.inf, -np.inf], np.nan)
for c in X.columns:
    if X[c].isna().any():
        X[c] = X[c].fillna(X[c].median())

y = df_ps["high_cci"].astype(int).values

clf = LogisticRegression(penalty="l2", solver="lbfgs", max_iter=5000)
clf.fit(X.values, y)

df_ps["ps_log"] = clf.predict_proba(X.values)[:, 1]
df_ps["ps_stratum10_log"] = pd.qcut(df_ps["ps_log"], q=10, labels=False, duplicates="drop")

print("ps_log min/max:", float(df_ps["ps_log"].min()), float(df_ps["ps_log"].max()))
display(pd.crosstab(df_ps["ps_stratum10_log"], df_ps["high_cci"]))


### Baslinje före balansering (N, medel/andel, SMD)

In [None]:
# Baslinje före balansering
vars_base = ["anchor_age","is_male","bun_first","creatinine_first","glucose_first","sodium_first"]

# säkerställ att df_ps finns (efter merge) och att is_male finns
need = ["high_cci"] + vars_base
miss = [c for c in need if c not in df_ps.columns]
if miss:
    raise ValueError(f"Saknar kolumner i df_ps: {miss}")

def smd_cont(a, b):
    a = np.asarray(a); b = np.asarray(b)
    sp = np.sqrt((a.var(ddof=1) + b.var(ddof=1)) / 2)
    return (a.mean() - b.mean()) / sp if sp > 0 else np.nan

def smd_bin(a, b):
    a = np.asarray(a); b = np.asarray(b)
    p = (a.mean() + b.mean()) / 2
    sp = np.sqrt(p * (1 - p))
    return (a.mean() - b.mean()) / sp if sp > 0 else np.nan

t = df_ps[df_ps["high_cci"]==1]
c = df_ps[df_ps["high_cci"]==0]

rows = []
for v in vars_base:
    mean_c = float(c[v].mean())
    mean_t = float(t[v].mean())
    smd = smd_bin(t[v], c[v]) if v=="is_male" else smd_cont(t[v], c[v])
    rows.append([v, mean_c, mean_t, float(smd)])

tab_baseline = pd.DataFrame(rows, columns=["Variabel","Kontroll (CCI≤5)","Behandling (CCI>5)","SMD_pre"])
display(tab_baseline)


### Mortalitet före (antal + andel)

In [None]:
# Mortalitet före balansering
outcomes = ["died_30d","died_90d","died_1y"]

rows = []
for oc in outcomes:
    d = df_ps.dropna(subset=[oc, "high_cci"]).copy()
    n0 = int((d["high_cci"]==0).sum())
    n1 = int((d["high_cci"]==1).sum())
    dead0 = int(d.loc[d["high_cci"]==0, oc].sum())
    dead1 = int(d.loc[d["high_cci"]==1, oc].sum())
    rows.append([oc, n0, dead0, 100*dead0/n0, n1, dead1, 100*dead1/n1])

tab_mort_pre = pd.DataFrame(
    rows,
    columns=["Outcome","N_kontroll","Döda_kontroll","%_kontroll","N_behandling","Döda_behandling","%_behandling"]
)
tab_mort_pre[["%_kontroll","%_behandling"]] = tab_mort_pre[["%_kontroll","%_behandling"]].round(1)
display(tab_mort_pre)


## Kovariatbalans: SMD före och efter PS-stratifiering

SMD beräknas före justering samt efter stratifiering genom att summera stratum-specifika SMD
viktat med stratumstorlek.


In [None]:
def smd_cont(a, b):
    a = np.asarray(a); b = np.asarray(b)
    sp = np.sqrt((a.var(ddof=1) + b.var(ddof=1)) / 2)
    return (a.mean() - b.mean()) / sp if sp > 0 else np.nan

def smd_bin(a, b):
    a = np.asarray(a); b = np.asarray(b)
    p = (a.mean() + b.mean()) / 2
    sp = np.sqrt(p * (1 - p))
    return (a.mean() - b.mean()) / sp if sp > 0 else np.nan

balance_vars = ["anchor_age","is_male","creatinine_first","glucose_first","sodium_first","bun_first"]

t = df_ps[df_ps["high_cci"]==1]
c = df_ps[df_ps["high_cci"]==0]

pre = {v: (smd_bin(t[v], c[v]) if v=="is_male" else smd_cont(t[v], c[v])) for v in balance_vars}

post = {}
strata = sorted(df_ps["ps_stratum10_log"].dropna().unique())
for v in balance_vars:
    smds=[]; w=[]
    for s in strata:
        dss = df_ps[df_ps["ps_stratum10_log"]==s]
        t_s = dss[dss["high_cci"]==1]
        c_s = dss[dss["high_cci"]==0]
        if len(t_s)==0 or len(c_s)==0:
            continue
        smd_s = smd_bin(t_s[v], c_s[v]) if v=="is_male" else smd_cont(t_s[v], c_s[v])
        smds.append(smd_s); w.append(len(dss))
    post[v] = np.average(smds, weights=w) if len(smds)>0 else np.nan

df_smd_log = pd.DataFrame({
    "SMD_pre": pd.Series(pre),
    "SMD_post_10strata_logPS": pd.Series(post)
})
df_smd_log["|SMD_pre|"] = df_smd_log["SMD_pre"].abs()
df_smd_log["|SMD_post_10strata_logPS|"] = df_smd_log["SMD_post_10strata_logPS"].abs()

display(df_smd_log.sort_values("|SMD_pre|", ascending=False))
print("Max |SMD_post_10strata_logPS|:", float(df_smd_log["|SMD_post_10strata_logPS|"].max()))

## Utfall: mortalitet före och efter PS-stratifiering (strata-standardiserad risk)

Efter PS-stratifiering beräknas strata-standardiserad risk genom att vikta stratum-specifika risker
med den totala stratafördelningen i kohorten. Detta ger en jämförbar mortalitetsandel för grupperna.


In [None]:
strata_col = "ps_stratum10_log"
outcomes = ["died_30d", "died_90d", "died_1y"]

def standardized_risk(d, outcome_col, group_value, strata_col):
    strata_weights = d[strata_col].value_counts(normalize=True).sort_index()
    parts = []
    for s, w in strata_weights.items():
        ds = d[d[strata_col] == s]
        r = ds.loc[ds["high_cci"] == group_value, outcome_col].mean()
        if pd.isna(r):
            continue
        parts.append(w * float(r))
    return float(np.sum(parts))

rows = []
for oc in outcomes:
    d = df_ps.dropna(subset=[oc, "high_cci", strata_col]).copy()
    d["high_cci"] = d["high_cci"].astype(int)
    d[oc] = d[oc].astype(int)
    d[strata_col] = d[strata_col].astype(int)

    # Före (rå risk)
    pre = d.groupby("high_cci")[oc].mean()
    pre0, pre1 = float(pre.loc[0]), float(pre.loc[1])

    # Efter (strata-standardiserad risk)
    post0 = standardized_risk(d, oc, 0, strata_col)
    post1 = standardized_risk(d, oc, 1, strata_col)

    rows.append([oc, "Före (rå)", pre0, pre1])
    rows.append([oc, "Efter (PS-strata std)", post0, post1])

df_plot = pd.DataFrame(rows, columns=["Outcome", "Typ", "Risk_low", "Risk_high"])
display(df_plot)

# --- Plot: grupperade staplar ---
x = np.arange(len(outcomes))
width = 0.18

pre_low  = df_plot[df_plot["Typ"]=="Före (rå)"]["Risk_low"].values
pre_high = df_plot[df_plot["Typ"]=="Före (rå)"]["Risk_high"].values
post_low = df_plot[df_plot["Typ"]=="Efter (PS-strata std)"]["Risk_low"].values
post_high= df_plot[df_plot["Typ"]=="Efter (PS-strata std)"]["Risk_high"].values

plt.figure(figsize=(9,4))
plt.bar(x - 1.5*width, pre_low,  width=width, label="Före low_cci (0)")
plt.bar(x - 0.5*width, pre_high, width=width, label="Före high_cci (1)")
plt.bar(x + 0.5*width, post_low, width=width, label="Efter low_cci (0)")
plt.bar(x + 1.5*width, post_high,width=width, label="Efter high_cci (1)")
plt.xticks(x, outcomes)
plt.ylabel("Mortalitetsandel")
plt.title("Mortalitet före vs efter PS-stratifiering (strata-standardiserad risk)")
plt.legend(ncol=2)
plt.tight_layout()
plt.show()

# --- Sammanfattning: riskdifferens före/efter ---
summary = []
for i, oc in enumerate(outcomes):
    rd_pre  = pre_high[i] - pre_low[i]
    rd_post = post_high[i] - post_low[i]
    summary.append([oc, rd_pre, rd_post])

df_summary = pd.DataFrame(summary, columns=["Outcome", "RD_pre", "RD_post"])
display(df_summary)


## Appendix: Diagnostikfigurer

Detta avsnitt innehåller kompletterande visualiseringar av kovariatbalans.


### SMD-plot före/efter

In [None]:
tmp = df_smd_log.copy().sort_values("|SMD_pre|", ascending=True)

plt.figure(figsize=(7,4))
plt.scatter(tmp["SMD_pre"], tmp.index, label="Före")
plt.scatter(tmp["SMD_post_10strata_logPS"], tmp.index, label="Efter (log-PS + 10 strata)")
plt.axvline(0, linewidth=1)
plt.axvline(0.1, linestyle="--", linewidth=1)
plt.axvline(-0.1, linestyle="--", linewidth=1)
plt.xlabel("Standardized Mean Difference (SMD)")
plt.ylabel("Kovariat")
plt.title("Kovariatbalans före och efter PS-stratifiering")
plt.legend()
plt.tight_layout()
plt.show()


In [None]:
# Efter PS-stratifiering (strata-standardiserade medel + SMD_post)
strata_col = "ps_stratum10_log"
group_col = "high_cci"
vars_show = ["anchor_age","is_male","bun_first","creatinine_first","glucose_first","sodium_first"]

need = [strata_col, group_col] + vars_show
miss = [c for c in need if c not in df_ps.columns]
if miss:
    raise ValueError(f"Saknar kolumner i df_ps: {miss}")

d = df_ps.dropna(subset=[strata_col, group_col]).copy()
d[strata_col] = d[strata_col].astype(int)
d[group_col] = d[group_col].astype(int)

# strata-vikter = stratafördelning i HELA kohorten (standardiseringsmål)
w_strata = d[strata_col].value_counts(normalize=True).sort_index()

def std_mean(var, gval):
    parts = []
    for s, w in w_strata.items():
        ds = d[d[strata_col]==s]
        m = ds.loc[ds[group_col]==gval, var].mean()
        if pd.isna(m):
            continue
        parts.append(w * float(m))
    return float(np.sum(parts))

# SMD_post från din tidigare df_smd_log
# df_smd_log måste finnas från din SMD-cell
if "df_smd_log" not in globals():
    raise ValueError("df_smd_log saknas. Kör cellen där du beräknar SMD före/efter först.")

rows = []
for v in vars_show:
    m0 = std_mean(v, 0)
    m1 = std_mean(v, 1)
    smd_post = float(df_smd_log.loc[v, "SMD_post_10strata_logPS"]) if v in df_smd_log.index else np.nan
    rows.append([v, m0, m1, smd_post])

tab_post = pd.DataFrame(rows, columns=["Variabel","Kontroll (std)","Behandling (std)","SMD_efter"])
display(tab_post)


In [None]:
# === Tabell C: Mortalitet före vs efter (PS-strata standardiserad risk) ===

strata_col = "ps_stratum10_log"
outcomes = ["died_30d","died_90d","died_1y"]

def standardized_risk(d, outcome_col, group_value, strata_col):
    strata_weights = d[strata_col].value_counts(normalize=True).sort_index()
    parts = []
    for s, w in strata_weights.items():
        ds = d[d[strata_col]==s]
        r = ds.loc[ds["high_cci"]==group_value, outcome_col].mean()
        if pd.isna(r):
            continue
        parts.append(w * float(r))
    return float(np.sum(parts))

rows = []
for oc in outcomes:
    dd = df_ps.dropna(subset=[oc, "high_cci", strata_col]).copy()
    dd["high_cci"] = dd["high_cci"].astype(int)
    dd[oc] = dd[oc].astype(int)
    dd[strata_col] = dd[strata_col].astype(int)

    # Gruppstorlekar
    n0 = int((dd["high_cci"]==0).sum())
    n1 = int((dd["high_cci"]==1).sum())

    # Före (rå risk)
    pre = dd.groupby("high_cci")[oc].mean()
    pre0, pre1 = float(pre.loc[0]), float(pre.loc[1])

    # Efter (strata-standardiserad risk)
    post0 = standardized_risk(dd, oc, 0, strata_col)
    post1 = standardized_risk(dd, oc, 1, strata_col)

    # (valfritt) expected deaths efter standardisering
    exp_dead0 = post0 * n0
    exp_dead1 = post1 * n1

    rows.append([
        oc,
        n0, 100*pre0, 100*post0, exp_dead0,
        n1, 100*pre1, 100*post1, exp_dead1
    ])

tab_mort_post = pd.DataFrame(
    rows,
    columns=[
        "Outcome",
        "N_kontroll", "%_före_kontroll", "%_efter_std_kontroll", "Döda_exp_efter_kontroll",
        "N_behandling", "%_före_behandling", "%_efter_std_behandling", "Döda_exp_efter_behandling"
    ]
)

tab_mort_post[["%_före_kontroll","%_efter_std_kontroll","%_före_behandling","%_efter_std_behandling"]] = \
    tab_mort_post[["%_före_kontroll","%_efter_std_kontroll","%_före_behandling","%_efter_std_behandling"]].round(1)

tab_mort_post[["Döda_exp_efter_kontroll","Döda_exp_efter_behandling"]] = \
    tab_mort_post[["Döda_exp_efter_kontroll","Döda_exp_efter_behandling"]].round(1)

display(tab_mort_post)
