# 05 - Association Rules

In [11]:
import os, re, json, warnings, itertools
from pathlib import Path
import numpy as np
import pandas as pd

warnings.filterwarnings("ignore")
np.set_printoptions(suppress=True)
pd.set_option("display.max_columns", 120)

In [None]:
# -----------------------------
# 0) Config / paths
# -----------------------------
DATA_DIR = r"D:/HealthAI Project/data"   # folder already extracted
ART = Path("./Models/Assoc_Models"); ART.mkdir(parents=True, exist_ok=True)

# Targets for the high-confidence search
MIN_CONFIDENCE_TARGET = 0.70
MIN_LIFT_TARGET = 1.50
MAX_LEN = 3

# Support grid to try (stop at the first that yields usable high-confidence rules)
MIN_SUPPORT_GRID = (0.08, 0.06, 0.05, 0.04, 0.03, 0.02, 0.015, 0.01, 0.008, 0.006, 0.005)

In [4]:
# -----------------------------
# 1) Find the diagnoses CSV (already extracted)
# -----------------------------
def find_diagnoses_file(base_dir: str) -> str:
    patterns = [
        r"diagnos.*icd.*\.csv(?:\.gz)?$",  # e.g., diagnoses_icd.csv(.gz)
        r"diagnos.*\.csv(?:\.gz)?$",       # e.g., diagnoses.csv(.gz)
    ]
    hits = []
    for root, _, files in os.walk(base_dir):
        for f in files:
            low = f.lower()
            if any(re.search(p, low) for p in patterns):
                hits.append(os.path.join(root, f))
    if not hits:
        raise FileNotFoundError(
            f"Could not locate a diagnoses CSV (e.g., diagnoses_icd.csv) under: {base_dir}"
        )
    hits.sort(key=lambda p: len(p))  # prefer shorter (often the main one)
    return hits[0]

DIAG_PATH = find_diagnoses_file(DATA_DIR)
print(f"[INFO] Using diagnoses CSV: {DIAG_PATH}")

diagnoses = pd.read_csv(DIAG_PATH, low_memory=False)
print("[INFO] diagnoses shape:", diagnoses.shape)
print("[INFO] diagnoses columns:", diagnoses.columns.tolist()[:12], "...")

[INFO] Using diagnoses CSV: D:/HealthAI Project/data\MIMIC IV\diagnoses_icd.csv
[INFO] diagnoses shape: (6364490, 5)
[INFO] diagnoses columns: ['subject_id', 'hadm_id', 'seq_num', 'icd_code', 'icd_version'] ...


In [5]:
# -----------------------------
# 2) Normalize columns
# -----------------------------
cols = {c.lower(): c for c in diagnoses.columns}
SUBJ_COL = cols.get("subject_id", None)
CODE_COL = cols.get("icd_code", cols.get("icd9_code", cols.get("icd10_code", None)))
VER_COL  = cols.get("icd_version", None)

if SUBJ_COL is None or CODE_COL is None:
    raise RuntimeError(f"Missing required columns. Found: {diagnoses.columns.tolist()}")

if VER_COL is None:
    sample_codes = diagnoses[CODE_COL].astype(str).head(200).tolist()
    guess_10 = sum(1 for c in sample_codes if re.match(r"^[A-Z]", str(c).strip(), re.I))
    icd_version_default = 10 if guess_10 > (len(sample_codes) / 2.0) else 9
    diagnoses["icd_version"] = icd_version_default
    VER_COL = "icd_version"

diagnoses["icd_code"] = diagnoses[CODE_COL].astype(str)
diagnoses["icd_version"] = pd.to_numeric(diagnoses[VER_COL], errors="coerce").fillna(10).astype(int)

In [6]:
# -----------------------------
# 3) ICD -> comorbidity mapping
# -----------------------------
def icd_to_conditions(code, version):
    """Return a set of comorbidity item labels from an ICD code (9/10)."""
    if not isinstance(code, str):
        code = str(code)
    code = code.strip().upper().replace('.', '')
    out = set()
    if not code:
        return out

    if version == 9 or (code[:1].isdigit() and len(code) >= 3):
        # ---- ICD-9 ----
        try:
            num3 = float(code[:3])
        except Exception:
            num3 = None
        if num3 is not None and 401 <= num3 < 406: out.add("Hypertension")
        if code.startswith("250"): out.add("Diabetes")
        if code.startswith("278"): out.add("Obesity")
        if code.startswith("272"): out.add("Hyperlipidemia")
        if num3 is not None and 410 <= num3 < 415: out.add("Coronary_Artery_Disease")
        if code.startswith("428"): out.add("Heart_Failure")
        if code.startswith("585"): out.add("Chronic_Kidney_Disease")
        if (num3 is not None and (491 <= num3 < 493)) or code.startswith("492") or code.startswith("496"):
            out.add("COPD")
        if code.startswith("493"): out.add("Asthma")
        if code.startswith("2962") or code.startswith("2963") or code.startswith("311"):
            out.add("Depression")
        if num3 is not None and 280 <= num3 < 286: out.add("Anemia")
        if code.startswith("530"): out.add("GERD")
        if code.startswith("3051"): out.add("Tobacco_Use")
    else:
        # ---- ICD-10 ----
        if code.startswith(("I10","I11","I12","I13","I15","I16","I14")): out.add("Hypertension")
        if code.startswith(("E10","E11","E12","E13","E14")): out.add("Diabetes")
        if code.startswith("E66"): out.add("Obesity")
        if code.startswith("E78"): out.add("Hyperlipidemia")
        if code.startswith(("I20","I21","I22","I23","I24","I25")): out.add("Coronary_Artery_Disease")
        if code.startswith("I50"): out.add("Heart_Failure")
        if code.startswith("N18"): out.add("Chronic_Kidney_Disease")
        if code.startswith("J44"): out.add("COPD")
        if code.startswith("J45"): out.add("Asthma")
        if code.startswith(("F32","F33")): out.add("Depression")
        if code[:3] in {f"D{n}" for n in range(50,65)}: out.add("Anemia")
        if code.startswith("K21"): out.add("GERD")
        if code.startswith("F17") or code.startswith("Z720") or code.startswith("Z72"): out.add("Tobacco_Use")
    return out

In [7]:
# -----------------------------
# 4) Build patient-level transactions
# -----------------------------
baskets = {}  # subject_id -> set(items)
for _, r in diagnoses.iterrows():
    sid = r[SUBJ_COL]
    conds = icd_to_conditions(r["icd_code"], r["icd_version"])
    if not conds:
        continue
    s = baskets.get(sid, set())
    s.update(conds)
    baskets[sid] = s

print(f"[INFO] Patients with ≥1 mapped comorbidity: {len(baskets)}")
if not baskets:
    raise RuntimeError("No comorbidity items were mapped from codes. Check mapping / columns.")

# Save transactions for transparency
tx_rows = [{"subject_id": sid, "items": ", ".join(sorted(list(items)))} for sid, items in baskets.items()]
pd.DataFrame(tx_rows).to_csv(ART/"assoc_transactions.csv", index=False)

[INFO] Patients with ≥1 mapped comorbidity: 177633


In [8]:
# -----------------------------
# 5) One-hot encode transactions (boolean DataFrame)
# -----------------------------
from sklearn.preprocessing import MultiLabelBinarizer
sids = list(baskets.keys())
item_lists = [sorted(list(baskets[sid])) for sid in sids]
mlb = MultiLabelBinarizer()
X_bool = pd.DataFrame(mlb.fit_transform(item_lists).astype(bool),
                      index=sids, columns=list(mlb.classes_))
items = list(X_bool.columns)
N = X_bool.shape[0]
print(f"[INFO] Basket matrix: {X_bool.shape} (patients x items)")

[INFO] Basket matrix: (177633, 13) (patients x items)


In [9]:
# -----------------------------
# 6) Mining helpers (mlxtend or fallback)
# -----------------------------
USE_MLXTEND = False
try:
    from mlxtend.frequent_patterns import apriori, association_rules
    USE_MLXTEND = True
    print("[INFO] Using mlxtend.frequent_patterns.apriori / association_rules")
except Exception:
    print("[INFO] mlxtend not found; using fast fallback Apriori (max_len<=3)")
    USE_MLXTEND = False

def apriori_fallback_bool(B: pd.DataFrame, min_support=0.1, max_len=3):
    """Compact Apriori for boolean DataFrame; returns DataFrame with columns: itemset(tuple), support."""
    Bv = B.values.astype(bool)
    names = list(B.columns)
    n, d = Bv.shape

    # L1
    L = []
    freq = []
    for j in range(d):
        s = Bv[:, j].mean()
        if s >= min_support:
            L.append((frozenset([names[j]]), (j,)))
            freq.append({"itemset": frozenset([names[j]]), "support": float(s)})

    k = 2
    prev = L
    while k <= max_len and prev:
        cand = {}
        for i in range(len(prev)):
            for j in range(i+1, len(prev)):
                iset = prev[i][0] | prev[j][0]
                if len(iset) == k:
                    idxs = tuple(sorted(set(prev[i][1]) | set(prev[j][1])))
                    cand[iset] = idxs
        new_prev = []
        for iset, idxs in cand.items():
            s = np.all(Bv[:, idxs], axis=1).mean()
            if s >= min_support:
                new_prev.append((iset, idxs))
                freq.append({"itemset": iset, "support": float(s)})
        prev = new_prev
        k += 1

    out = []
    for rec in freq:
        out.append({"itemset": tuple(sorted(list(rec["itemset"]))), "support": rec["support"]})
    return pd.DataFrame(out).sort_values("support", ascending=False).reset_index(drop=True)

def rules_from_itemsets_fallback(freq_df: pd.DataFrame, min_conf=0.3):
    """Generate rules with columns: antecedents, consequents, support, confidence, lift."""
    sup = {frozenset(t): s for t, s in zip(freq_df["itemset"].map(frozenset), freq_df["support"])}
    rules = []
    for iset_tup, sXY in zip(freq_df["itemset"], freq_df["support"]):
        XU = frozenset(iset_tup)
        if len(XU) < 2:
            continue
        items = list(XU)
        for r in range(1, len(items)):
            for A in itertools.combinations(items, r):
                A = frozenset(A)
                B = XU - A
                sX = sup.get(A)
                sY = sup.get(B)
                if sX is None or sY is None:
                    continue
                conf = sXY / (sX + 1e-12)
                if conf >= min_conf:
                    lift = conf / (sY + 1e-12)
                    rules.append({
                        "antecedents": tuple(sorted(list(A))),
                        "consequents": tuple(sorted(list(B))),
                        "support": float(sXY),
                        "confidence": float(conf),
                        "lift": float(lift),
                    })
    return pd.DataFrame(rules).sort_values(["lift","confidence","support"], ascending=False).reset_index(drop=True)

def mine_rules_once(X_bool: pd.DataFrame, min_support: float, max_len: int):
    """Mine itemsets and rules *one time* for a given support."""
    if USE_MLXTEND:
        df_bin = X_bool.astype(int)
        freq = apriori(df_bin, min_support=min_support, use_colnames=True, max_len=max_len)
        # Standardize to 'itemset' column (tuple)
        freq_std = freq.copy()
        freq_std["itemset"] = freq_std["itemsets"].apply(lambda s: tuple(sorted(list(s))))
        freq_std = freq_std.drop(columns=["itemsets"]).rename(columns={"support": "support"})
        # Rules (we'll filter later)
        rules = association_rules(freq, metric="confidence", min_threshold=0.0)
        rules = rules[["antecedents","consequents","support","confidence","lift"]].copy()
        rules["antecedents"] = rules["antecedents"].apply(lambda s: tuple(sorted(list(s))))
        rules["consequents"] = rules["consequents"].apply(lambda s: tuple(sorted(list(s))))
        return freq_std, rules
    else:
        freq_std = apriori_fallback_bool(X_bool, min_support=min_support, max_len=max_len)
        rules = rules_from_itemsets_fallback(freq_std, min_conf=0.0)
        return freq_std, rules

[INFO] Using mlxtend.frequent_patterns.apriori / association_rules


In [12]:
# -----------------------------
# 7) High-confidence search (conf ≥ 0.80, lift ≥ 1.50)
# -----------------------------
def mine_high_confidence_rules(
    X_bool: pd.DataFrame,
    min_support_grid=MIN_SUPPORT_GRID,
    max_len=MAX_LEN,
    min_conf=MIN_CONFIDENCE_TARGET,
    min_lift=MIN_LIFT_TARGET,
    top_k=25
):
    tried = []
    chosen_freq = None
    chosen_rules = None
    chosen_support = None

    for ms in min_support_grid:
        freq, rules = mine_rules_once(X_bool, min_support=ms, max_len=max_len)
        if freq.empty or rules.empty:
            tried.append((ms, len(freq), 0))
            continue

        # Filter to interpretable rules
        rules_f = rules.loc[
            (rules["confidence"] >= min_conf) &
            (rules["lift"]       >= min_lift) &
            (rules["antecedents"].apply(lambda s: len(s) >= 1)) &
            (rules["consequents"].apply(lambda s: len(s) == 1))
        ].copy()

        if rules_f.empty:
            tried.append((ms, len(freq), 0))
            continue

        # Baseline prevalence (support of 1-item consequent)
        one_item_sup = {}
        for it, sup in zip(freq["itemset"], freq["support"]):
            if len(it) == 1:
                one_item_sup[list(it)[0]] = float(sup)

        rules_f["conseq"] = rules_f["consequents"].apply(lambda s: list(s)[0])
        rules_f["baseline_support"] = rules_f["conseq"].map(one_item_sup).astype(float)
        rules_f["abs_risk_increase"] = rules_f["confidence"] - rules_f["baseline_support"]

        rules_f = rules_f.sort_values(["confidence","lift","support"], ascending=False).head(top_k).reset_index(drop=True)

        chosen_freq = freq
        chosen_rules = rules_f
        chosen_support = ms
        print(f"[OK] min_support={ms:.3f} → {len(rules_f)} high-confidence rules")
        break

    if chosen_rules is None:
        print(f"[WARN] No rules at confidence ≥ {min_conf:.2f} & lift ≥ {min_lift:.2f}. "
              f"Consider lowering min_support or raising max_len.")
        print("Tried:", tried)
        # Fall back to the weakest grid point for general rules (no high-conf filter), so you still get outputs
        ms = min_support_grid[-1]
        freq, rules = mine_rules_once(X_bool, min_support=ms, max_len=max_len)
        rules_all = rules.copy()
        rules_all = rules_all.sort_values(["lift","confidence","support"], ascending=False)
        return ms, freq, rules_all, pd.DataFrame()

    return chosen_support, chosen_freq, None, chosen_rules

chosen_ms, freq_used, rules_all_fallback, rules_80 = mine_high_confidence_rules(
    X_bool,
    min_support_grid=MIN_SUPPORT_GRID,
    max_len=MAX_LEN,
    min_conf=MIN_CONFIDENCE_TARGET,
    min_lift=MIN_LIFT_TARGET,
    top_k=25
)

# If we found high-confidence rules, also produce a "general" rules table (for context)
if rules_80 is not None and not rules_80.empty:
    # Mine once at chosen support and keep all (for reference)
    _, _, _, _rules80_tmp = chosen_ms, freq_used, None, rules_80
    # For "all rules", we can re-mine and keep everything, then pretty-print/save separately
    freq_all, rules_all = mine_rules_once(X_bool, min_support=chosen_ms, max_len=MAX_LEN)
    rules_all = rules_all.sort_values(["lift","confidence","support"], ascending=False).reset_index(drop=True)
else:
    # No high-confidence set; keep whatever we had as the "all rules"
    freq_all, rules_all = freq_used, rules_all_fallback

print(f"[INFO] Frequent itemsets: {len(freq_used)} (at min_support={chosen_ms:.3f})")
print(f"[INFO] All-rules table:   {len(rules_all)} rows")
print(f"[INFO] High-conf (≥{MIN_CONFIDENCE_TARGET:.2f}, lift≥{MIN_LIFT_TARGET:.2f}): "
      f"{0 if rules_80 is None else len(rules_80)} rows")

[OK] min_support=0.080 → 13 high-confidence rules
[INFO] Frequent itemsets: 66 (at min_support=0.080)
[INFO] All-rules table:   190 rows
[INFO] High-conf (≥0.70, lift≥1.50): 13 rows


In [None]:
# -----------------------------
# 8) Pretty print & save
# -----------------------------
def tup2str(s): return ", ".join(sorted(list(s)))

# Save frequent itemsets
freq_out = freq_used.copy()
if "itemset" in freq_out.columns:
    freq_out["itemset"] = freq_out["itemset"].apply(tup2str)
freq_out.rename(columns={"support": "support"}, inplace=True)
freq_out.to_csv(ART/"frequent_itemsets.csv", index=False)

# Save all rules (context)
rules_all_out = rules_all.copy()
rules_all_out["antecedents"] = rules_all_out["antecedents"].apply(tup2str)
rules_all_out["consequents"] = rules_all_out["consequents"].apply(tup2str)
rules_all_out.to_csv(ART/"association_rules_all.csv", index=False)

# Save high-confidence rules
if rules_80 is not None and not rules_80.empty:
    hc = rules_80.copy()
    hc["antecedents"] = hc["antecedents"].apply(tup2str)
    hc["consequents"] = hc["consequents"].apply(tup2str)
    # order columns nicely
    cols = ["antecedents","consequents","support","confidence","lift","baseline_support","abs_risk_increase"]
    hc = hc[cols]
    hc.to_csv(ART/"high_confidence_rules.csv", index=False)

    print("\n=== Top high-confidence rules (conf ≥ {:.2f}, lift ≥ {:.2f}) ===".format(
        MIN_CONFIDENCE_TARGET, MIN_LIFT_TARGET))
    print(hc.head(25).to_string(index=False,
                                formatters={
                                    "support": "{:.3f}".format,
                                    "confidence": "{:.3f}".format,
                                    "lift": "{:.2f}".format,
                                    "baseline_support": "{:.3f}".format,
                                    "abs_risk_increase": "{:.3f}".format
                                }))
else:
    print("\n[INFO] No high-confidence rules met the thresholds; "
          "see association_rules_all.csv for the best available rules.")

# Summary JSON
summary = {
    "data_dir_used": str(Path(DATA_DIR).resolve()),
    "diagnoses_csv_used": str(Path(DIAG_PATH).resolve()),
    "n_patients_baskets": int(N),
    "n_items": int(len(items)),
    "min_support_grid": list(MIN_SUPPORT_GRID),
    "chosen_min_support": float(chosen_ms),
    "min_confidence_target": float(MIN_CONFIDENCE_TARGET),
    "min_lift_target": float(MIN_LIFT_TARGET),
    "max_len": int(MAX_LEN),
    "n_frequent_itemsets": int(len(freq_used)),
    "n_rules_all": int(len(rules_all)),
    "n_rules_high_conf": 0 if (rules_80 is None) else int(len(rules_80)),
}
with open(ART/"summary.json","w",encoding="utf-8") as f:
    json.dump(summary, f, indent=2)

print("\n=== Saved ===")
print({k: str((ART/k).resolve()) for k in [
    "assoc_transactions.csv",
    "frequent_itemsets.csv",
    "association_rules_all.csv",
    "Assoc_high_confidence_rules.csv",
    "summary.json"
]})


=== Top high-confidence rules (conf ≥ 0.70, lift ≥ 1.50) ===
                                    antecedents             consequents support confidence lift baseline_support abs_risk_increase
         Chronic_Kidney_Disease, Hyperlipidemia            Hypertension   0.108      0.965 1.54            0.626             0.340
Chronic_Kidney_Disease, Coronary_Artery_Disease            Hypertension   0.083      0.958 1.53            0.626             0.333
                  Heart_Failure, Hyperlipidemia            Hypertension   0.111      0.940 1.50            0.626             0.315
              Coronary_Artery_Disease, Diabetes          Hyperlipidemia   0.087      0.808 1.78            0.454             0.354
          Coronary_Artery_Disease, Hypertension          Hyperlipidemia   0.173      0.781 1.72            0.454             0.327
                Anemia, Coronary_Artery_Disease          Hyperlipidemia   0.104      0.766 1.69            0.454             0.312
         Coronary_Art