In [2]:
# ============================================================
# Pipeline: "Delayed hemodynamic stabilization" -> RRT after 24h
# MIMIC-III (Postgres) - copy/paste into a Jupyter Notebook
#
# Core idea (Target-trial light):
# - Cohort: ICU stays with ICD9 584* (AKI by code)
# - Landmark: 24h after ICU intime
# - Exposure: vasopressor start timing category:
#     0 = none in first 24h
#     1 = early (<=6h)
#     2 = delayed (6-24h)
# - Outcome: new RRT initiated AFTER 24h landmark (exclude early RRT <=24h)
# - Baseline covariates: age, sex, ICU type, admission_type, Charlson,
#   SOFA approx WITHOUT cardio component (no vasopressor in score),
#   baseline creat, peak creat 0-24h, early vent, early fluids, platelets, bili.
#
# Output:
# - Descriptives
# - Multinomial IPTW (stabilized + trimming)
# - Weighted risks and RD/RR vs reference group (none)
# - Sensitivity: logistic regression (sklearn) for outcome with exposure dummies
# - Subgroups by AKI stage (24h)
# ============================================================

import os
import numpy as np
import pandas as pd
from pathlib import Path
from dotenv import load_dotenv
from sqlalchemy import create_engine, text

# ----------------------------
# 0) DB connection helpers
# ----------------------------
load_dotenv(Path("..") / ".env")  # adjust if your notebook is elsewhere

DB_HOST = os.getenv("DB_HOST")
DB_PORT = os.getenv("DB_PORT", "5432")
DB_NAME = os.getenv("DB_NAME")
DB_USER = os.getenv("DB_USER")
DB_PASS = os.getenv("DB_PASSWORD")

assert DB_HOST and DB_NAME and DB_USER, "Missing DB vars (DB_HOST/DB_NAME/DB_USER)"

engine = create_engine(f"postgresql+psycopg2://{DB_USER}:{DB_PASS}@{DB_HOST}:{DB_PORT}/{DB_NAME}")

def q(sql: str) -> pd.DataFrame:
    with engine.connect() as conn:
        return pd.read_sql(text(sql), conn)

print("Connected OK")

# ----------------------------
# 1) Create/refresh Charlson view (HADM-level)
# ----------------------------
sql_create_charlson = """
CREATE OR REPLACE VIEW charlson_hadm AS
WITH dx AS (
  SELECT DISTINCT hadm_id, icd9_code
  FROM diagnoses_icd
),
flags AS (
  SELECT
    hadm_id,
    MAX(CASE WHEN icd9_code LIKE '410%' OR icd9_code LIKE '412%' THEN 1 ELSE 0 END) AS mi,
    MAX(CASE WHEN icd9_code LIKE '428%' THEN 1 ELSE 0 END) AS chf,
    MAX(CASE WHEN icd9_code LIKE '440%' OR icd9_code LIKE '441%' OR icd9_code LIKE '4439%' OR icd9_code LIKE '7854%' OR icd9_code LIKE 'V434%' THEN 1 ELSE 0 END) AS pvd,
    MAX(CASE WHEN icd9_code LIKE '430%' OR icd9_code LIKE '431%' OR icd9_code LIKE '432%' OR icd9_code LIKE '433%' OR icd9_code LIKE '434%' OR icd9_code LIKE '435%' OR icd9_code LIKE '436%' OR icd9_code LIKE '437%' OR icd9_code LIKE '438%' THEN 1 ELSE 0 END) AS cva,
    MAX(CASE WHEN icd9_code LIKE '290%' OR icd9_code LIKE '2941%' OR icd9_code LIKE '3312%' THEN 1 ELSE 0 END) AS dementia,
    MAX(CASE WHEN icd9_code LIKE '490%' OR icd9_code LIKE '491%' OR icd9_code LIKE '492%' OR icd9_code LIKE '493%' OR icd9_code LIKE '494%' OR icd9_code LIKE '495%' OR icd9_code LIKE '496%' THEN 1 ELSE 0 END) AS copd,
    MAX(CASE WHEN icd9_code LIKE '7100%' OR icd9_code LIKE '7101%' OR icd9_code LIKE '7104%' OR icd9_code LIKE '714%' OR icd9_code LIKE '725%' THEN 1 ELSE 0 END) AS rheum,
    MAX(CASE WHEN icd9_code LIKE '531%' OR icd9_code LIKE '532%' OR icd9_code LIKE '533%' OR icd9_code LIKE '534%' THEN 1 ELSE 0 END) AS pud,
    MAX(CASE WHEN icd9_code LIKE '5712%' OR icd9_code LIKE '5714%' OR icd9_code LIKE '5715%' OR icd9_code LIKE '5716%' THEN 1 ELSE 0 END) AS mild_liver,
    MAX(CASE WHEN icd9_code LIKE '2500%' OR icd9_code LIKE '2501%' OR icd9_code LIKE '2502%' OR icd9_code LIKE '2503%' THEN 1 ELSE 0 END) AS dm_uncomp,
    MAX(CASE WHEN icd9_code LIKE '2504%' OR icd9_code LIKE '2505%' OR icd9_code LIKE '2506%' OR icd9_code LIKE '2507%' OR icd9_code LIKE '2508%' OR icd9_code LIKE '2509%' THEN 1 ELSE 0 END) AS dm_comp,
    MAX(CASE WHEN icd9_code LIKE '342%' OR icd9_code LIKE '343%' OR icd9_code LIKE '3441%' THEN 1 ELSE 0 END) AS hemi,
    MAX(CASE WHEN icd9_code LIKE '582%' OR icd9_code LIKE '583%' OR icd9_code LIKE '585%' OR icd9_code LIKE '586%' OR icd9_code LIKE '588%' THEN 1 ELSE 0 END) AS renal,
    MAX(CASE WHEN (icd9_code >= '140' AND icd9_code <= '172')
              OR (icd9_code >= '174' AND icd9_code <= '195')
              OR (icd9_code >= '200' AND icd9_code <= '208')
             THEN 1 ELSE 0 END) AS malignancy,
    MAX(CASE WHEN icd9_code LIKE '5722%' OR icd9_code LIKE '5723%' OR icd9_code LIKE '5724%' OR icd9_code LIKE '5728%' THEN 1 ELSE 0 END) AS severe_liver,
    MAX(CASE WHEN icd9_code LIKE '196%' OR icd9_code LIKE '197%' OR icd9_code LIKE '198%' OR icd9_code LIKE '199%' THEN 1 ELSE 0 END) AS mets,
    MAX(CASE WHEN icd9_code LIKE '042%' OR icd9_code LIKE '043%' OR icd9_code LIKE '044%' THEN 1 ELSE 0 END) AS aids
  FROM dx
  GROUP BY hadm_id
)
SELECT
  hadm_id,
  (
    1*mi + 1*chf + 1*pvd + 1*cva + 1*dementia + 1*copd + 1*rheum + 1*pud + 1*mild_liver
    + 1*dm_uncomp + 2*dm_comp + 2*hemi + 2*renal + 2*malignancy + 3*severe_liver + 6*mets + 6*aids
  ) AS charlson
FROM flags;
"""
with engine.begin() as conn:
    conn.execute(text(sql_create_charlson))
print("Created/updated view: charlson_hadm")

# ----------------------------
# 2) Build main dataset (one row per ICU stay)
#    - exposure timing categories
#    - outcome: new RRT after 24h
#    - baseline covariates <=24h
# ----------------------------
# Lab ITEMIDs (common in MIMIC-III)
CREAT_ITEMID = 50912   # Creatinine
PLT_ITEMID   = 51265   # Platelet Count
BILI_ITEMID  = 50885   # Bilirubin, Total

sql_main = f"""
WITH aki_hadm AS (
  SELECT DISTINCT hadm_id
  FROM diagnoses_icd
  WHERE icd9_code LIKE '584%'
),
cohort AS (
  SELECT
    i.subject_id,
    i.hadm_id,
    i.icustay_id,
    i.intime,
    i.first_careunit,
    (i.intime + interval '6 hour') AS t6,
    (i.intime + interval '24 hour') AS t24
  FROM icustays i
  JOIN aki_hadm a ON a.hadm_id = i.hadm_id
),
demo AS (
  SELECT
    c.icustay_id,
    c.subject_id,
    c.hadm_id,
    c.intime,
    c.first_careunit,
    a.admission_type,
    p.gender,
    CASE WHEN p.gender = 'F' THEN 1 ELSE 0 END AS female,
    a.hospital_expire_flag::int AS hospital_mortality,
    EXTRACT(YEAR FROM a.admittime) - EXTRACT(YEAR FROM p.dob)
      - CASE WHEN (EXTRACT(MONTH FROM a.admittime), EXTRACT(DAY FROM a.admittime))
              < (EXTRACT(MONTH FROM p.dob), EXTRACT(DAY FROM p.dob))
             THEN 1 ELSE 0 END AS age
  FROM cohort c
  JOIN admissions a ON a.hadm_id = c.hadm_id
  JOIN patients p ON p.subject_id = c.subject_id
),
charlson AS (
  SELECT hadm_id, charlson FROM charlson_hadm
),

-- Creatinine in [-24h, +24h] around ICU intime (baseline window)
creat AS (
  SELECT
    c.icustay_id,
    l.charttime,
    l.valuenum AS creat
  FROM cohort c
  JOIN labevents l ON l.hadm_id = c.hadm_id
  WHERE l.itemid = {CREAT_ITEMID}
    AND l.valuenum IS NOT NULL
    AND l.charttime BETWEEN c.intime - interval '24 hour' AND c.t24
),
baseline_creat AS (
  SELECT icustay_id, MIN(creat) AS baseline_creat
  FROM creat
  GROUP BY icustay_id
),
peak_creat_24h AS (
  SELECT icustay_id, MAX(creat) AS peak_creat_24h
  FROM creat
  GROUP BY icustay_id
),

-- Platelets min (0-24h) and bili max (0-24h)
labs_24h AS (
  SELECT
    c.icustay_id,
    MIN(CASE WHEN l.itemid = {PLT_ITEMID} THEN l.valuenum END) AS platelets_min_24h,
    MAX(CASE WHEN l.itemid = {BILI_ITEMID} THEN l.valuenum END) AS bili_max_24h
  FROM cohort c
  JOIN labevents l ON l.hadm_id = c.hadm_id
  WHERE l.valuenum IS NOT NULL
    AND l.charttime BETWEEN c.intime AND c.t24
    AND l.itemid IN ({PLT_ITEMID}, {BILI_ITEMID})
  GROUP BY c.icustay_id
),

-- First vasopressor time within 24h (if any)
vaso_first AS (
  SELECT
    c.icustay_id,
    MIN(ie.starttime) AS first_vaso_time
  FROM cohort c
  JOIN inputevents_mv ie ON ie.icustay_id = c.icustay_id
  LEFT JOIN d_items di ON di.itemid = ie.itemid
  WHERE ie.starttime <= c.t24
    AND (
      LOWER(COALESCE(di.label,'')) LIKE '%norepi%' OR LOWER(COALESCE(di.label,'')) LIKE '%noradren%'
      OR LOWER(COALESCE(di.label,'')) LIKE '%levophed%' OR LOWER(COALESCE(di.label,'')) = 'ne'
      OR LOWER(COALESCE(di.label,'')) LIKE '%vasopress%'
      OR LOWER(COALESCE(di.label,'')) LIKE '%epine%' OR LOWER(COALESCE(di.label,'')) LIKE '%adrenalin%'
      OR LOWER(COALESCE(di.label,'')) LIKE '%phenyleph%'
      OR LOWER(COALESCE(di.label,'')) LIKE '%dopamine%'
      OR LOWER(COALESCE(di.label,'')) LIKE '%dobutamine%'
    )
  GROUP BY c.icustay_id
),

-- Exposure category:
-- 0 none within 24h
-- 1 early <=6h
-- 2 delayed (6-24h)
exposure AS (
  SELECT
    c.icustay_id,
    vf.first_vaso_time,
    CASE
      WHEN vf.first_vaso_time IS NULL THEN 0
      WHEN vf.first_vaso_time <= c.t6 THEN 1
      ELSE 2
    END AS vaso_timing
  FROM cohort c
  LEFT JOIN vaso_first vf ON vf.icustay_id = c.icustay_id
),

-- Early mechanical ventilation within 24h (proxy severity)
early_vent AS (
  SELECT DISTINCT
    c.icustay_id,
    1 AS early_vent
  FROM cohort c
  JOIN procedureevents_mv pe ON pe.icustay_id = c.icustay_id
  LEFT JOIN d_items di ON di.itemid = pe.itemid
  WHERE pe.starttime <= c.t24
    AND (
      LOWER(COALESCE(di.label,'')) LIKE '%vent%'
      OR LOWER(COALESCE(di.label,'')) LIKE '%intubat%'
      OR LOWER(COALESCE(di.label,'')) LIKE '%respirat%'
    )
),

-- Early fluids within 24h (as you used before; excludes nutrition-ish labels)
early_fluids AS (
  SELECT DISTINCT
    c.icustay_id,
    1 AS early_fluids
  FROM cohort c
  JOIN inputevents_mv ie ON ie.icustay_id = c.icustay_id
  LEFT JOIN d_items di ON di.itemid = ie.itemid
  WHERE ie.starttime <= c.t24
    AND (
      LOWER(COALESCE(di.label,'')) LIKE '%nacl%' OR LOWER(COALESCE(di.label,'')) LIKE '%normal saline%' OR LOWER(COALESCE(di.label,'')) = 'ns'
      OR LOWER(COALESCE(di.label,'')) LIKE '%0.9%' OR LOWER(COALESCE(di.label,'')) LIKE '%0.45%' OR LOWER(COALESCE(di.label,'')) LIKE '%1/2ns%'
      OR LOWER(COALESCE(di.label,'')) LIKE '%d5 1/2ns%' OR LOWER(COALESCE(di.label,'')) LIKE '%half ns%'
      OR LOWER(COALESCE(di.label,'')) = 'lr' OR LOWER(COALESCE(di.label,'')) LIKE '%lactated%' OR LOWER(COALESCE(di.label,'')) LIKE '%ringer%'
      OR LOWER(COALESCE(di.label,'')) LIKE '%dextrose%' OR LOWER(COALESCE(di.label,'')) LIKE '%d5w%' OR LOWER(COALESCE(di.label,'')) LIKE '%d10w%'
      OR LOWER(COALESCE(di.label,'')) LIKE '%sterile water%' OR LOWER(COALESCE(di.label,'')) LIKE '%free water%'
      OR LOWER(COALESCE(di.label,'')) LIKE '%plasma-lyte%' OR LOWER(COALESCE(di.label,'')) LIKE '%plasmalyte%'
    )
    AND LOWER(COALESCE(di.label,'')) NOT LIKE '%po intake%'
    AND LOWER(COALESCE(di.label,'')) NOT LIKE '%pre-admission intake%'
    AND LOWER(COALESCE(di.label,'')) NOT LIKE '%solution%'
    AND LOWER(COALESCE(di.label,'')) NOT LIKE '%piggyback%'
),

-- RRT/dialysis events: identify RRT procedures in procedureevents_mv by label
rrt_events AS (
  SELECT
    c.icustay_id,
    pe.starttime
  FROM cohort c
  JOIN procedureevents_mv pe ON pe.icustay_id = c.icustay_id
  LEFT JOIN d_items di ON di.itemid = pe.itemid
  WHERE (
    LOWER(COALESCE(di.label,'')) LIKE '%dialysis%'
    OR LOWER(COALESCE(di.label,'')) LIKE '%crrt%'
    OR LOWER(COALESCE(di.label,'')) LIKE '%cvvh%'
    OR LOWER(COALESCE(di.label,'')) LIKE '%hemofiltration%'
    OR LOWER(COALESCE(di.label,'')) LIKE '%hemodialysis%'
  )
),
rrt_early AS (
  SELECT DISTINCT c.icustay_id, 1 AS rrt_early_24h
  FROM cohort c
  JOIN rrt_events r ON r.icustay_id = c.icustay_id
  WHERE r.starttime <= c.t24
),
rrt_after AS (
  SELECT DISTINCT c.icustay_id, 1 AS rrt_after_24h
  FROM cohort c
  JOIN rrt_events r ON r.icustay_id = c.icustay_id
  WHERE r.starttime > c.t24
)

SELECT
  d.subject_id,
  d.hadm_id,
  d.icustay_id,
  d.hospital_mortality,
  d.gender,
  d.female,
  CASE WHEN d.age > 89 THEN 90 ELSE d.age END AS age,
  d.admission_type,
  d.first_careunit,

  COALESCE(ch.charlson, 0) AS charlson,
  COALESCE(bc.baseline_creat, NULL) AS baseline_creat,
  COALESCE(pc.peak_creat_24h, NULL) AS peak_creat_24h,
  COALESCE(l24.platelets_min_24h, NULL) AS platelets_min_24h,
  COALESCE(l24.bili_max_24h, NULL) AS bili_max_24h,

  COALESCE(ev.early_vent, 0) AS early_vent,
  COALESCE(ef.early_fluids, 0) AS early_fluids,

  ex.vaso_timing,

  COALESCE(re.rrt_early_24h, 0) AS rrt_early_24h,
  COALESCE(ra.rrt_after_24h, 0) AS rrt_after_24h

FROM demo d
LEFT JOIN charlson ch ON ch.hadm_id = d.hadm_id
LEFT JOIN baseline_creat bc ON bc.icustay_id = d.icustay_id
LEFT JOIN peak_creat_24h pc ON pc.icustay_id = d.icustay_id
LEFT JOIN labs_24h l24 ON l24.icustay_id = d.icustay_id
LEFT JOIN early_vent ev ON ev.icustay_id = d.icustay_id
LEFT JOIN early_fluids ef ON ef.icustay_id = d.icustay_id
LEFT JOIN exposure ex ON ex.icustay_id = d.icustay_id
LEFT JOIN rrt_early re ON re.icustay_id = d.icustay_id
LEFT JOIN rrt_after ra ON ra.icustay_id = d.icustay_id;
"""

df = q(sql_main)
print("Rows:", len(df))
df.head()

# ----------------------------
# 3) Define outcome and exclusions
# ----------------------------
# Outcome: new RRT after 24h
df["outcome_rrt_after_24h"] = df["rrt_after_24h"].astype(int)

# Exclude those already on RRT in first 24h (prevalent cases at time zero)
df = df[df["rrt_early_24h"] == 0].copy()
print("After excluding early RRT (<=24h):", len(df))
print("Outcome rate (RRT after 24h):", df["outcome_rrt_after_24h"].mean())

# Exposure labels for readability
map_expo = {0: "none_24h", 1: "early_<=6h", 2: "delayed_6-24h"}
df["vaso_timing_label"] = df["vaso_timing"].map(map_expo)

df["vaso_timing_label"].value_counts(dropna=False)

# ----------------------------
# 4) Build AKI stage within 24h (KDIGO creatinine-only)
# ----------------------------
# Impute creatinine minimally (low missingness expected)
for col in ["baseline_creat", "peak_creat_24h"]:
    df[col + "_missing"] = df[col].isna().astype(int)
    df[col] = df[col].fillna(df[col].median())

def stage_kdigo_24h(baseline, peak):
    if pd.isna(baseline) or pd.isna(peak):
        return np.nan
    if peak >= 4.0 or peak >= 3.0 * baseline:
        return 3
    if peak >= 2.0 * baseline:
        return 2
    if peak >= 1.5 * baseline or (peak - baseline) >= 0.3:
        return 1
    return 0

df["aki_stage_24h"] = [stage_kdigo_24h(b,p) for b,p in zip(df["baseline_creat"], df["peak_creat_24h"])]
print(df["aki_stage_24h"].value_counts(dropna=False).sort_index())

# ----------------------------
# 5) SOFA approx WITHOUT cardio component (no pressor info inside)
#    - coag: platelets min 24h
#    - liver: bilirubin max 24h
#    - renal: peak creat 24h
#    - resp: early_vent (proxy; 2 points if ventilated)
# ----------------------------
for col in ["platelets_min_24h", "bili_max_24h"]:
    df[col + "_missing"] = df[col].isna().astype(int)
    df[col] = df[col].fillna(df[col].median())

def sofa_coag_platelets(plt):
    if plt < 20: return 4
    if plt < 50: return 3
    if plt < 100: return 2
    if plt < 150: return 1
    return 0

def sofa_liver_bili(bili):
    if bili >= 12: return 4
    if bili >= 6:  return 3
    if bili >= 2:  return 2
    if bili >= 1.2:return 1
    return 0

def sofa_renal_creat(creat):
    if creat >= 5.0: return 4
    if creat >= 3.5: return 3
    if creat >= 2.0: return 2
    if creat >= 1.2: return 1
    return 0

df["sofa_coag_24h"] = df["platelets_min_24h"].apply(sofa_coag_platelets)
df["sofa_liver_24h"] = df["bili_max_24h"].apply(sofa_liver_bili)
df["sofa_renal_24h"] = df["peak_creat_24h"].apply(sofa_renal_creat)
df["sofa_resp_24h"]  = df["early_vent"].astype(int) * 2  # proxy

df["sofa_nocv_24h"] = df["sofa_coag_24h"] + df["sofa_liver_24h"] + df["sofa_renal_24h"] + df["sofa_resp_24h"]
df["sofa_nocv_24h"].describe()

# ----------------------------
# 6) Descriptives by exposure group
# ----------------------------
grp = df.groupby("vaso_timing_label").agg(
    n=("icustay_id","count"),
    rrt_after_rate=("outcome_rrt_after_24h","mean"),
    mort_rate=("hospital_mortality","mean"),
    age_mean=("age","mean"),
    charlson_mean=("charlson","mean"),
    sofa_mean=("sofa_nocv_24h","mean"),
    aki_stage_mean=("aki_stage_24h","mean")
).sort_index()

grp

# ----------------------------
# 7) Multinomial propensity score + stabilized IPTW
# ----------------------------
from sklearn.compose import ColumnTransformer
from sklearn.preprocessing import OneHotEncoder
from sklearn.pipeline import Pipeline
from sklearn.linear_model import LogisticRegression

# Treatment: 3 classes (0/1/2)
T = df["vaso_timing"].astype(int)

# Baseline covariates (strictly pre/within 24h; avoid post-treatment)
num_cols = [
    "age",
    "charlson",
    "baseline_creat", "peak_creat_24h",
    "platelets_min_24h", "bili_max_24h",
    "sofa_nocv_24h",
    "early_vent", "early_fluids",
    "baseline_creat_missing", "peak_creat_24h_missing",
    "platelets_min_24h_missing", "bili_max_24h_missing"
]
cat_cols = ["female", "admission_type", "first_careunit"]

# Ensure exist
num_cols = [c for c in num_cols if c in df.columns]
cat_cols = [c for c in cat_cols if c in df.columns]

X = df[num_cols + cat_cols].copy()

preprocess = ColumnTransformer(
    transformers=[
        ("num", "passthrough", num_cols),
        ("cat", OneHotEncoder(handle_unknown="ignore"), cat_cols),
    ],
    remainder="drop"
)

# Multinomial logistic regression for propensity
ps_model = Pipeline(steps=[
    ("prep", preprocess),
    ("mnl", LogisticRegression(
    solver="lbfgs",
    max_iter=4000
))
])

ps_model.fit(X, T)
ps = ps_model.predict_proba(X)  # shape (n, 3)

# Stabilized weights: P(T=t)/P(T=t|X)
p_marg = np.bincount(T, minlength=3) / len(T)
w = np.zeros(len(df), dtype=float)
for t in [0,1,2]:
    w[T.values == t] = p_marg[t] / ps[T.values == t, t]

# Trim extreme weights
w_trunc = np.clip(w, np.quantile(w, 0.01), np.quantile(w, 0.99))

df["w"] = w_trunc

pd.Series(df["w"]).describe()

# ----------------------------
# 8) Weighted risks + RD/RR vs reference (none_24h)
# ----------------------------
def weighted_mean(a, w):
    a = np.asarray(a, dtype=float)
    w = np.asarray(w, dtype=float)
    return np.sum(a*w)/np.sum(w)

def weighted_risk_by_group(df_in, outcome_col="outcome_rrt_after_24h"):
    out = []
    for t, label in sorted(map_expo.items()):
        d = df_in[df_in["vaso_timing"] == t]
        if len(d) == 0:
            continue
        r = weighted_mean(d[outcome_col].values, d["w"].values)
        out.append((label, len(d), r))
    return pd.DataFrame(out, columns=["group","n","weighted_risk"])

risks = weighted_risk_by_group(df)
risks
ref = risks.loc[risks["group"]=="none_24h","weighted_risk"].values[0]

risks["RD_vs_none"] = risks["weighted_risk"] - ref
risks["RR_vs_none"] = risks["weighted_risk"] / ref if ref > 0 else np.nan
risks

# ----------------------------
# 9) Sensitivity: logistic regression for outcome with exposure dummies (sklearn)
# ----------------------------
from sklearn.linear_model import LogisticRegression

y = df["outcome_rrt_after_24h"].astype(int)

# exposure dummies (reference none_24h)
df["expo_early"] = (df["vaso_timing"] == 1).astype(int)
df["expo_delayed"] = (df["vaso_timing"] == 2).astype(int)

features = ["expo_early", "expo_delayed"] + num_cols + cat_cols
Xr = df[features].copy()
Xr = Xr.fillna(Xr.median(numeric_only=True))

# one-hot encode categoricals
preprocess_r = ColumnTransformer(
    transformers=[
        ("num", "passthrough", [c for c in features if c in num_cols or c in ["expo_early","expo_delayed"]]),
        ("cat", OneHotEncoder(handle_unknown="ignore"), [c for c in features if c in cat_cols]),
    ],
    remainder="drop"
)

logit = Pipeline(steps=[
    ("prep", preprocess_r),
    ("lr", LogisticRegression(solver="lbfgs", max_iter=4000))
])

logit.fit(Xr, y)

# Extract ORs for exposure terms from the trained pipeline
# (We need feature names after encoding)
ohe = logit.named_steps["prep"].named_transformers_["cat"]
num_names = [c for c in features if c in num_cols or c in ["expo_early","expo_delayed"]]
cat_names = []
if len(cat_cols) > 0:
    cat_names = list(ohe.get_feature_names_out(cat_cols))
feature_names = num_names + cat_names

coef = logit.named_steps["lr"].coef_[0]
coef_s = pd.Series(coef, index=feature_names)

or_expo = pd.DataFrame({
    "OR": np.exp(coef_s[["expo_early","expo_delayed"]])
})
or_expo
# Interpretation:
# expo_early OR: early vs none
# expo_delayed OR: delayed vs none

# ----------------------------
# 10) Subgroup: AKI stage (24h) – weighted risks by exposure within stage
# ----------------------------
def weighted_risks_by_stage(df_in):
    rows = []
    for stage in sorted(df_in["aki_stage_24h"].dropna().unique()):
        d = df_in[df_in["aki_stage_24h"] == stage]
        r = weighted_risk_by_group(d)
        r["aki_stage_24h"] = stage
        rows.append(r)
    return pd.concat(rows, ignore_index=True) if rows else pd.DataFrame()

stage_risks = weighted_risks_by_stage(df)
stage_risks
# Pivot for readability
stage_pivot = stage_risks.pivot(index="aki_stage_24h", columns="group", values="weighted_risk")
stage_pivot


Connected OK
Created/updated view: charlson_hadm
Rows: 12879
After excluding early RRT (<=24h): 12507
Outcome rate (RRT after 24h): 0.029983209402734468
aki_stage_24h
0    5189
1    5129
2     643
3    1546
Name: count, dtype: int64


STOP: TOTAL NO. OF ITERATIONS REACHED LIMIT

Increase the number of iterations to improve the convergence (max_iter=4000).
You might also want to scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(


group,delayed_6-24h,early_<=6h,none_24h
aki_stage_24h,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
0,0.056276,0.047748,0.01875
1,0.07993,0.052186,0.024541
2,0.093389,0.073327,0.027443
3,0.384362,0.185152,0.056054


In [3]:
sql_first_vaso = """
WITH vaso AS (
  SELECT
    ie.icustay_id,
    MIN(ie.starttime) AS first_vaso_time
  FROM inputevents_mv ie
  JOIN d_items di ON di.itemid = ie.itemid
  WHERE
    LOWER(COALESCE(di.label,'')) LIKE '%norepi%'
    OR LOWER(COALESCE(di.label,'')) LIKE '%noradren%'
    OR LOWER(COALESCE(di.label,'')) LIKE '%levophed%'
    OR LOWER(COALESCE(di.label,'')) LIKE '%vasopress%'
    OR LOWER(COALESCE(di.label,'')) LIKE '%epine%'
    OR LOWER(COALESCE(di.label,'')) LIKE '%phenyleph%'
    OR LOWER(COALESCE(di.label,'')) LIKE '%dopamine%'
  GROUP BY ie.icustay_id
)
SELECT * FROM vaso;
"""
df_vaso_time = q(sql_first_vaso)


In [4]:
df = df.merge(df_vaso_time, on="icustay_id", how="left")
df = df[~df["first_vaso_time"].isna()].copy()  # nur Patienten mit Vasopressoren
print("ICU stays with vaso:", len(df))


ICU stays with vaso: 2316


In [5]:
ITEMS_VITALS = {
    "map": [456, 52],        # Mean BP
    "sbp": [51, 455],        # Systolic BP
    "hr":  [211, 220045],    # Heart Rate
}


In [7]:
sql_first_vaso_sub = """
WITH vaso AS (
  SELECT
    ie.icustay_id,
    MIN(ie.starttime) AS first_vaso_time
  FROM inputevents_mv ie
  JOIN d_items di ON di.itemid = ie.itemid
  WHERE
    LOWER(COALESCE(di.label,'')) LIKE '%norepi%'
    OR LOWER(COALESCE(di.label,'')) LIKE '%noradren%'
    OR LOWER(COALESCE(di.label,'')) LIKE '%levophed%'
    OR LOWER(COALESCE(di.label,'')) LIKE '%vasopress%'
    OR LOWER(COALESCE(di.label,'')) LIKE '%epine%'
    OR LOWER(COALESCE(di.label,'')) LIKE '%phenyleph%'
    OR LOWER(COALESCE(di.label,'')) LIKE '%dopamine%'
    OR LOWER(COALESCE(di.label,'')) LIKE '%dobutamine%'
  GROUP BY ie.icustay_id
)
SELECT icustay_id, first_vaso_time
FROM vaso
"""


In [8]:
sql_vitals_around_vaso = f"""
SELECT
  ce.icustay_id,
  ce.charttime,
  ce.valuenum,
  ce.itemid
FROM chartevents ce
JOIN (
  {sql_first_vaso_sub}
) v ON v.icustay_id = ce.icustay_id
WHERE ce.valuenum IS NOT NULL
  AND ce.charttime BETWEEN v.first_vaso_time - interval '6 hour'
                       AND v.first_vaso_time + interval '6 hour'
  AND ce.itemid IN (456,52,51,455,211,220045);
"""

df_vitals = q(sql_vitals_around_vaso)
df_vitals.head()


Unnamed: 0,icustay_id,charttime,valuenum,itemid
0,249202,2144-07-01 11:10:00,127.0,220045
1,249202,2144-07-01 11:16:00,126.0,220045
2,249202,2144-07-01 13:51:00,81.0,220045
3,249202,2144-07-01 13:52:00,81.0,220045
4,249202,2144-07-01 13:53:00,82.0,220045


In [10]:
df_vitals = df_vitals.merge(
    df[["icustay_id", "first_vaso_time", "vaso_timing", "aki_stage_24h"]],
    on="icustay_id",
    how="left"
)

df_vitals["dt_hours"] = (
    pd.to_datetime(df_vitals["charttime"]) -
    pd.to_datetime(df_vitals["first_vaso_time"])
).dt.total_seconds() / 3600


In [11]:
def map_vital(itemid):
    if itemid in ITEMS_VITALS["map"]:
        return "MAP"
    if itemid in ITEMS_VITALS["sbp"]:
        return "SBP"
    if itemid in ITEMS_VITALS["hr"]:
        return "HR"
    return None

df_vitals["vital"] = df_vitals["itemid"].apply(map_vital)
df_vitals = df_vitals.dropna(subset=["vital"])


In [12]:
def summarize_pre_post(df, t_pre=(-2,0), t_post=(0,2)):
    out = []
    for (icu, vital), d in df.groupby(["icustay_id", "vital"]):
        pre = d[(d["dt_hours"]>=t_pre[0]) & (d["dt_hours"]<t_pre[1])]["valuenum"]
        post = d[(d["dt_hours"]>=t_post[0]) & (d["dt_hours"]<=t_post[1])]["valuenum"]
        if len(pre)>=2 and len(post)>=2:
            out.append({
                "icustay_id": icu,
                "vital": vital,
                "pre_mean": pre.mean(),
                "post_mean": post.mean(),
                "delta": post.mean() - pre.mean()
            })
    return pd.DataFrame(out)

df_delta = summarize_pre_post(df_vitals)
df_delta.head()


Unnamed: 0,icustay_id,vital,pre_mean,post_mean,delta
0,200024,HR,119.666667,112.333333,-7.333333
1,200063,HR,106.083333,97.5,-8.583333
2,200095,HR,116.25,63.25,-53.0
3,200116,HR,48.0,60.0,12.0
4,200143,HR,106.0,99.0,-7.0


In [13]:
df_delta = df_delta.merge(
    df[["icustay_id", "vaso_timing", "aki_stage_24h"]],
    on="icustay_id",
    how="left"
)

df_delta["timing_label"] = df_delta["vaso_timing"].map({
    1: "early",
    2: "delayed"
})

summary = (
    df_delta
    .groupby(["vital", "timing_label"])
    .agg(
        n=("delta", "count"),
        delta_mean=("delta", "mean"),
        delta_median=("delta", "median")
    )
)

summary


Unnamed: 0_level_0,Unnamed: 1_level_0,n,delta_mean,delta_median
vital,timing_label,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
HR,delayed,340,-1.410956,-1.0
HR,early,754,-2.448791,-1.816667
MAP,delayed,1,16.166667,16.166667
SBP,delayed,1,27.166667,27.166667


In [9]:
sql_vitals_around_vaso = f"""
SELECT
  ce.icustay_id,
  ce.charttime,
  ce.valuenum,
  ce.itemid
FROM chartevents ce
JOIN (
  {sql_first_vaso_sub}
) v ON v.icustay_id = ce.icustay_id
WHERE ce.valuenum IS NOT NULL
  AND ce.charttime BETWEEN v.first_vaso_time - interval '6 hour'
                       AND v.first_vaso_time + interval '6 hour'
  AND ce.itemid IN (456,52,51,455,211,220045);
"""

df_vitals = q(sql_vitals_around_vaso)
df_vitals.head()


Unnamed: 0,icustay_id,charttime,valuenum,itemid
0,249202,2144-07-01 11:10:00,127.0,220045
1,249202,2144-07-01 11:16:00,126.0,220045
2,249202,2144-07-01 13:51:00,81.0,220045
3,249202,2144-07-01 13:52:00,81.0,220045
4,249202,2144-07-01 13:53:00,82.0,220045


In [14]:
sql_esrd_flags = """
WITH dx AS (
  SELECT DISTINCT hadm_id, icd9_code
  FROM diagnoses_icd
)
SELECT
  hadm_id,
  MAX(CASE
        WHEN icd9_code IN ('5856') THEN 1  -- ESRD
        WHEN icd9_code LIKE 'V451%' THEN 1 -- renal dialysis status
        WHEN icd9_code LIKE 'V56%' THEN 1  -- dialysis encounter/care
        ELSE 0
      END) AS esrd_or_chronic_dialysis_flag
FROM dx
GROUP BY hadm_id;
"""
df_esrd = q(sql_esrd_flags)
df_esrd.head()


Unnamed: 0,hadm_id,esrd_or_chronic_dialysis_flag
0,155703,0
1,130406,0
2,194150,0
3,117336,0
4,151812,0


In [15]:
sql_rrt_late = """
WITH rrt_events AS (
  SELECT
    i.icustay_id,
    i.hadm_id,
    pe.starttime AS rrt_time
  FROM icustays i
  JOIN procedureevents_mv pe ON pe.icustay_id = i.icustay_id
  JOIN d_items di ON di.itemid = pe.itemid
  WHERE
    LOWER(COALESCE(di.label,'')) LIKE '%dialysis%'
    OR LOWER(COALESCE(di.label,'')) LIKE '%crrt%'
    OR LOWER(COALESCE(di.label,'')) LIKE '%cvvh%'
    OR LOWER(COALESCE(di.label,'')) LIKE '%hemofiltration%'
    OR LOWER(COALESCE(di.label,'')) LIKE '%hemodialysis%'
),
rrt_summary AS (
  SELECT
    r.hadm_id,
    MIN(r.rrt_time) AS first_rrt_time,
    MAX(r.rrt_time) AS last_rrt_time,
    COUNT(*) AS n_rrt_events
  FROM rrt_events r
  GROUP BY r.hadm_id
),
disch AS (
  SELECT hadm_id, dischtime
  FROM admissions
)
SELECT
  s.hadm_id,
  s.first_rrt_time,
  s.last_rrt_time,
  s.n_rrt_events,
  d.dischtime,
  CASE
    WHEN s.last_rrt_time IS NOT NULL
     AND s.last_rrt_time >= d.dischtime - interval '48 hour'
    THEN 1 ELSE 0
  END AS late_rrt_within_48h_of_discharge
FROM rrt_summary s
JOIN disch d ON d.hadm_id = s.hadm_id;
"""
df_rrt = q(sql_rrt_late)
df_rrt.head()


Unnamed: 0,hadm_id,first_rrt_time,last_rrt_time,n_rrt_events,dischtime,late_rrt_within_48h_of_discharge
0,182383,2121-11-30 21:58:00,2121-12-01 09:15:00,2,2121-12-05 14:18:00,0
1,174162,2122-05-14 21:00:00,2122-05-16 13:05:00,2,2122-05-18 15:11:00,0
2,131345,2141-09-06 08:00:00,2141-09-06 08:00:00,1,2141-09-08 18:30:00,0
3,172335,2141-09-21 18:00:00,2141-09-22 18:00:00,2,2141-09-24 13:53:00,1
4,126055,2141-10-15 19:00:00,2141-10-24 18:45:00,20,2141-11-03 18:45:00,0


In [16]:
# Merge ESRD flag + RRT timing summary
tmp = df[["hadm_id"]].drop_duplicates().merge(df_esrd, on="hadm_id", how="left")
tmp = tmp.merge(df_rrt, on="hadm_id", how="left")

tmp["esrd_or_chronic_dialysis_flag"] = tmp["esrd_or_chronic_dialysis_flag"].fillna(0).astype(int)
tmp["late_rrt_within_48h_of_discharge"] = tmp["late_rrt_within_48h_of_discharge"].fillna(0).astype(int)

# Any RRT during hospitalization?
tmp["any_rrt"] = tmp["last_rrt_time"].notna().astype(int)

# Outcome: dialysis-dependent at discharge (proxy)
tmp["dialysis_dependent_at_discharge"] = (
    (tmp["any_rrt"] == 1) &
    (
        (tmp["late_rrt_within_48h_of_discharge"] == 1) |
        (tmp["esrd_or_chronic_dialysis_flag"] == 1)
    )
).astype(int)

tmp[["any_rrt","late_rrt_within_48h_of_discharge","esrd_or_chronic_dialysis_flag","dialysis_dependent_at_discharge"]].mean()


any_rrt                             0.138385
late_rrt_within_48h_of_discharge    0.039927
esrd_or_chronic_dialysis_flag       0.024047
dialysis_dependent_at_discharge     0.055354
dtype: float64

In [17]:
df = df.merge(
    tmp[["hadm_id","dialysis_dependent_at_discharge","any_rrt","late_rrt_within_48h_of_discharge","esrd_or_chronic_dialysis_flag"]],
    on="hadm_id",
    how="left"
)

df["dialysis_dependent_at_discharge"] = df["dialysis_dependent_at_discharge"].fillna(0).astype(int)

df["dialysis_dependent_at_discharge"].value_counts(dropna=False), df["dialysis_dependent_at_discharge"].mean()


(dialysis_dependent_at_discharge
 0    2183
 1     133
 Name: count, dtype: int64,
 np.float64(0.057426597582038))

In [19]:
df.loc[df["dialysis_dependent_at_discharge"]==1, [
    "aki_stage_24h", "any_rrt", "late_rrt_within_48h_of_discharge",
    "esrd_or_chronic_dialysis_flag"
]].mean()


aki_stage_24h                       1.165414
any_rrt                             1.000000
late_rrt_within_48h_of_discharge    0.729323
esrd_or_chronic_dialysis_flag       0.375940
dtype: float64

In [20]:
df_dep = df[df["dialysis_dependent_at_discharge"] == 1].copy()
print("N dialysis-dependent at discharge:", len(df_dep))


N dialysis-dependent at discharge: 133


In [22]:
df_dep["vaso_timing_label"] = df_dep["vaso_timing"].map({
    0: "no_vasopressors_24h",
    1: "early_<=6h",
    2: "delayed_6-24h"
})

df_dep["vaso_timing_label"].value_counts(normalize=True).mul(100).round(1)


vaso_timing_label
early_<=6h             46.6
no_vasopressors_24h    34.6
delayed_6-24h          18.8
Name: proportion, dtype: float64

In [24]:
sql_intime = """
SELECT icustay_id, intime
FROM icustays
"""
df_intime = q(sql_intime)

df = df.merge(df_intime, on="icustay_id", how="left")

df[["icustay_id", "intime"]].head()


Unnamed: 0,icustay_id,intime
0,200024,2127-03-03 16:09:07
1,200063,2141-03-09 23:20:49
2,200095,2113-10-27 15:23:21
3,200116,2198-03-19 20:16:11
4,200143,2191-04-01 21:45:49


In [26]:
df["hours_to_vaso"] = (
    pd.to_datetime(df["first_vaso_time"]) -
    pd.to_datetime(df["intime"])
).dt.total_seconds() / 3600


In [27]:
df_dep["vaso_timing_label"] = df_dep["vaso_timing"].map({
    0: "no_vasopressors_24h",
    1: "early_<=6h",
    2: "delayed_6-24h"
})

df_dep["vaso_timing_label"].value_counts()
df_dep["vaso_timing_label"].value_counts(normalize=True).mul(100).round(1)



vaso_timing_label
early_<=6h             46.6
no_vasopressors_24h    34.6
delayed_6-24h          18.8
Name: proportion, dtype: float64

In [28]:
pd.crosstab(
    df["dialysis_dependent_at_discharge"],
    df["vaso_timing_label"],
    normalize="index"
).rename(
    index={0: "not_dialysis_dependent", 1: "dialysis_dependent"}
).round(3)


vaso_timing_label,delayed_6-24h,early_<=6h,none_24h
dialysis_dependent_at_discharge,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
not_dialysis_dependent,0.183,0.655,0.162
dialysis_dependent,0.188,0.466,0.346


In [29]:
baseline_vars = [
    "age",
    "female",
    "charlson",
    "sofa_approx_24h",
    "baseline_creat",
    "aki_stage_24h",
    "early_vent",
    "early_fluids"
]


In [30]:
import numpy as np
import pandas as pd

def smd(x_t, x_c):
    """Standardized Mean Difference"""
    m1, m0 = np.mean(x_t), np.mean(x_c)
    s1, s0 = np.var(x_t, ddof=1), np.var(x_c, ddof=1)
    return (m1 - m0) / np.sqrt((s1 + s0) / 2)

def weighted_mean(x, w):
    return np.sum(x * w) / np.sum(w)

def weighted_var(x, w):
    m = weighted_mean(x, w)
    return np.sum(w * (x - m)**2) / np.sum(w)

def weighted_smd(x, t, w):
    x1, w1 = x[t==1], w[t==1]
    x0, w0 = x[t==0], w[t==0]
    m1, m0 = weighted_mean(x1, w1), weighted_mean(x0, w0)
    v1, v0 = weighted_var(x1, w1), weighted_var(x0, w0)
    return (m1 - m0) / np.sqrt((v1 + v0) / 2)


In [None]:
df_bal = df[df["vaso_timing"].isin([1, 2])].copy()
df_bal["treat"] = (df_bal["vaso_timing"] == 1).astype(int)  # 1 = early


In [31]:
df_vaso = df[df["vaso_timing"].isin([1, 2])].copy()
df_vaso["vaso_timing_label"] = df_vaso["vaso_timing"].map({
    1: "early_<=6h",
    2: "delayed_6-24h"
})

df_vaso["vaso_timing_label"].value_counts()


vaso_timing_label
early_<=6h       1492
delayed_6-24h     425
Name: count, dtype: int64

In [32]:
df_pre = df_vitals[
    (df_vitals["dt_hours"] >= -1) &
    (df_vitals["dt_hours"] < 0)
].copy()


In [33]:
pre_vitals = (
    df_pre
    .groupby(["icustay_id", "vital"])
    .agg(pre_value=("valuenum", "median"))
    .reset_index()
)


In [35]:
pre_vitals_pivot = pre_vitals.pivot(
    index="icustay_id",
    columns="vital",
    values="pre_value"
).reset_index()


In [36]:
pre_state = pre_vitals_pivot.merge(
    df_vaso[[
        "icustay_id",
        "vaso_timing_label",
        "aki_stage_24h",
        "baseline_creat",
        "peak_creat_24h",
        "early_vent",
        "early_fluids"
    ]],
    on="icustay_id",
    how="left"
)


In [37]:
summary_state = (
    pre_state
    .groupby("vaso_timing_label")
    .agg(
        n=("icustay_id", "count"),
        HR=("HR", "median"),
        SBP=("SBP", "median"),
        MAP=("MAP_final", "median"),
        baseline_creat=("baseline_creat", "median"),
        peak_creat_24h=("peak_creat_24h", "median"),
        aki_stage=("aki_stage_24h", "median")
    )
)

summary_state


KeyError: "Column(s) ['MAP_final'] do not exist"

In [38]:
import numpy as np
import pandas as pd

# 1) Pre-window on raw vitals
df_pre = df_vitals[(df_vitals["dt_hours"] >= -1) & (df_vitals["dt_hours"] < 0)].copy()

# Ensure vital labels exist
# (Falls du das schon gemacht hast, schadet es nicht)
ITEM_MAP = {
    "MAP": [456, 52],
    "SBP": [51, 455],
    "DBP": [8368, 8441],
    "HR":  [211, 220045],
}
def map_vital(itemid):
    for k, v in ITEM_MAP.items():
        if itemid in v:
            return k
    return None

if "vital" not in df_pre.columns:
    df_pre["vital"] = df_pre["itemid"].apply(map_vital)

df_pre = df_pre.dropna(subset=["vital"])

# 2) Aggregate within the hour before start: median per vital
pre_vitals = (
    df_pre
    .groupby(["icustay_id", "vital"])
    .agg(pre_value=("valuenum", "median"))
    .reset_index()
)

pre_vitals_pivot = (
    pre_vitals
    .pivot(index="icustay_id", columns="vital", values="pre_value")
    .reset_index()
)

pre_vitals_pivot.head()


vital,icustay_id,HR,MAP,SBP
0,200024,119.0,,
1,200063,104.0,,
2,200095,119.0,,
3,200116,48.0,,
4,200143,104.0,,


In [40]:
# Was ist in df_vitals überhaupt drin?
df_vitals["itemid"].value_counts().head(30)


itemid
220045    117145
211          120
455           72
456           72
52            62
51            62
Name: count, dtype: int64

In [41]:
sql_labels = """
SELECT itemid, label
FROM d_items
WHERE itemid IN (456,52,51,455,8368,8441,211,220045)
ORDER BY itemid;
"""
q(sql_labels)


Unnamed: 0,itemid,label
0,51,Arterial BP [Systolic]
1,52,Arterial BP Mean
2,211,Heart Rate
3,455,NBP [Systolic]
4,456,NBP Mean
5,8368,Arterial BP [Diastolic]
6,8441,NBP [Diastolic]
7,220045,Heart Rate


In [42]:
sql_uo_items = """
SELECT DISTINCT itemid, label
FROM d_items
WHERE LOWER(label) LIKE '%urine%';
"""
q(sql_uo_items)


Unnamed: 0,itemid,label
0,706,Urine [Color]
1,707,Urine Source
2,941,urine culture
3,1011,urine osmolarity
4,1352,urine pH
...,...,...
151,226631,PACU Urine
152,227059,UrineScore_ApacheIV
153,227471,Specific Gravity (urine)
154,227489,GU Irrigant/Urine Volume Out


In [43]:
sql_uo_near_t0 = """
WITH vaso AS (
  SELECT
    ie.icustay_id,
    MIN(ie.starttime) AS first_vaso_time
  FROM inputevents_mv ie
  JOIN d_items di ON di.itemid = ie.itemid
  WHERE
    LOWER(COALESCE(di.label,'')) LIKE '%norepi%'
    OR LOWER(COALESCE(di.label,'')) LIKE '%noradren%'
    OR LOWER(COALESCE(di.label,'')) LIKE '%levophed%'
    OR LOWER(COALESCE(di.label,'')) LIKE '%vasopress%'
    OR LOWER(COALESCE(di.label,'')) LIKE '%epine%'
    OR LOWER(COALESCE(di.label,'')) LIKE '%phenyleph%'
    OR LOWER(COALESCE(di.label,'')) LIKE '%dopamine%'
    OR LOWER(COALESCE(di.label,'')) LIKE '%dobutamine%'
  GROUP BY ie.icustay_id
),
uo AS (
  SELECT
    v.icustay_id,
    SUM(CASE
          WHEN oe.charttime BETWEEN v.first_vaso_time - interval '6 hour'
                               AND v.first_vaso_time
          THEN oe.value ELSE 0 END) AS uo_pre_6h,
    SUM(CASE
          WHEN oe.charttime BETWEEN v.first_vaso_time
                               AND v.first_vaso_time + interval '6 hour'
          THEN oe.value ELSE 0 END) AS uo_post_6h
  FROM vaso v
  JOIN outputevents oe ON oe.icustay_id = v.icustay_id
  WHERE oe.value IS NOT NULL
  GROUP BY v.icustay_id
)
SELECT * FROM uo;
"""
df_uo = q(sql_uo_near_t0)
df_uo.head()


Unnamed: 0,icustay_id,uo_pre_6h,uo_post_6h
0,263381,0.0,2463.0
1,292890,120.0,2290.0
2,247054,0.0,1915.0
3,286920,0.0,1900.0
4,234799,630.0,265.0
