Sparse Mixture of Experts ---> Model 1

In [2]:
!pip install pyreadstat torch torchvision torchaudio scikit-learn --quiet

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/666.4 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[90m╺[0m[90m━━━━━━━━━[0m [32m501.8/666.4 kB[0m [31m16.1 MB/s[0m eta [36m0:00:01[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m666.4/666.4 kB[0m [31m12.1 MB/s[0m eta [36m0:00:00[0m
[?25h

In [4]:
from google.colab import drive
drive.mount('/content/drive')


Mounted at /content/drive


In [5]:
import pyreadstat

file_path = "/content/drive/MyDrive/LLCP2020.XPT"  # change path if needed

print("⏳ Loading XPT file...")
df, meta = pyreadstat.read_xport(file_path)

print("✅ Raw dataframe shape:", df.shape)
print("📋 First 50 columns:", df.columns[:50].tolist())


⏳ Loading XPT file...
✅ Raw dataframe shape: (401958, 279)
📋 First 50 columns: ['_STATE', 'FMONTH', 'IDATE', 'IMONTH', 'IDAY', 'IYEAR', 'DISPCODE', 'SEQNO', '_PSU', 'CTELENM1', 'PVTRESD1', 'COLGHOUS', 'STATERE1', 'CELPHONE', 'LADULT1', 'COLGSEX', 'NUMADULT', 'LANDSEX', 'NUMMEN', 'NUMWOMEN', 'RESPSLCT', 'SAFETIME', 'CTELNUM1', 'CELLFON5', 'CADULT1', 'CELLSEX', 'PVTRESD3', 'CCLGHOUS', 'CSTATE1', 'LANDLINE', 'HHADULT', 'SEXVAR', 'GENHLTH', 'PHYSHLTH', 'MENTHLTH', 'POORHLTH', 'HLTHPLN1', 'PERSDOC2', 'MEDCOST', 'CHECKUP1', 'EXERANY2', 'SLEPTIM1', 'CVDINFR4', 'CVDCRHD4', 'CVDSTRK3', 'ASTHMA3', 'ASTHNOW', 'CHCSCNCR', 'CHCOCNCR', 'CHCCOPD2']


In [6]:
selected_cols = [
    "AGEG5YR","SEXVAR","BMI5","SMOKE100","SMOKDAY2","STOPSMK2","ALCDAY5","AVEDRNK2","DRNK3GE5","DIABETE4",
    "TOLDHI2","CVDCRHD4","CVDSTRK3","CHCCOPD1","HAVARTH3","CHCKIDNY","ADDEPEV3","MENTHLTH","PHYSHLTH","HLTHPLN1",
    "PERSDOC2","MEDCOST","CHECKUP1","EXERANY2","SLEPTIM1","FLUSHOT7","PNEUVAC4","HIVTST7","GENHLTH","POORHLTH",
    "QLACTLM2","USEEQUIP","DEAF","BLIND","DECIDE","DIFFWALK","DIFFDRES","DIFFALON","EMPLOY1","EDUCA",
    "INCOME2","MARITAL","RENTHOM1","VETERAN3","CHILDREN","HCVU651"
]

target_col = "CVDINFR4"  # heart attack label

# check which columns exist
missing = [c for c in selected_cols+[target_col] if c not in df.columns]
print("Missing columns:", missing)

# keep only if available
df = df[[c for c in selected_cols if c in df.columns] + [target_col]].copy()
df = df.rename(columns={target_col: "HeartAttack"})

print("✅ Subset shape:", df.shape)


Missing columns: ['AGEG5YR', 'BMI5', 'AVEDRNK2', 'TOLDHI2', 'CHCCOPD1', 'HAVARTH3', 'CHCKIDNY', 'QLACTLM2', 'USEEQUIP', 'HCVU651']
✅ Subset shape: (401958, 37)


In [7]:
import numpy as np
from sklearn.preprocessing import MinMaxScaler

# Replace BRFSS missing codes with NaN
df = df.replace({7: np.nan, 9: np.nan, 77: np.nan, 99: np.nan, 997: np.nan, 999: np.nan})

# Drop rows with NaN
df = df.dropna().reset_index(drop=True)

# Split features and target
X_df = df.drop(columns=['HeartAttack']).astype(np.float32)
y = df['HeartAttack'].astype(int).to_numpy()

# Scale features
scaler = MinMaxScaler()
X = scaler.fit_transform(X_df)

print("✅ Final X shape:", X.shape)
print("📊 Target distribution:", np.bincount(y))


✅ Final X shape: (5082, 36)
📊 Target distribution: [   0  293 4789]


In [11]:
# BRFSS coding: 1 = Yes, 2 = No
# We remap to 1 = HeartAttack, 0 = No HeartAttack
y = np.where(y == 1, 1, 0)

print("✅ Fixed target distribution:", np.bincount(y))


✅ Fixed target distribution: [4789  293]


In [12]:
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import balanced_accuracy_score, recall_score, roc_auc_score, roc_curve

def youdens_j_threshold(y_true, y_prob):
    fpr, tpr, thr = roc_curve(y_true, y_prob)
    j = tpr - fpr
    return thr[np.argmax(j)]

def fit_evaluate(model, X, y, n_splits=5, random_state=42):
    skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=random_state)
    bal_accs, macro_recalls, recall1s, aucs, thrs = [], [], [], [], []
    for tr, te in skf.split(X, y):
        Xtr, Xte = X[tr], X[te]
        ytr, yte = y[tr], y[te]
        model.fit(Xtr, ytr)
        p = model.predict_proba(Xte)[:,1]
        thr = youdens_j_threshold(yte, p)
        thrs.append(thr)
        yhat = (p >= thr).astype(int)
        bal_accs.append(balanced_accuracy_score(yte, yhat))
        macro_recalls.append(recall_score(yte, yhat, average="macro", zero_division=0))
        recall1s.append(recall_score(yte, yhat, pos_label=1))
        aucs.append(roc_auc_score(yte, p))
    return {
        "Balanced Accuracy": np.mean(bal_accs),
        "Macro Recall": np.mean(macro_recalls),
        "Recall Class 1": np.mean(recall1s),
        "AUC": np.mean(aucs),
        "Youden Threshold": np.mean(thrs)
    }


In [13]:
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.base import BaseEstimator, ClassifierMixin

class MoEClassifier(BaseEstimator, ClassifierMixin):
    def __init__(self, input_dim, n_experts=4, hidden=64, lr=1e-3, epochs=8, batch_size=4096, l2=1e-5, device=None):
        self.input_dim = input_dim
        self.n_experts = n_experts
        self.hidden = hidden
        self.lr = lr
        self.epochs = epochs
        self.batch_size = batch_size
        self.l2 = l2
        self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
        self._build()

    def _build(self):
        self.gate = nn.Sequential(
            nn.Linear(self.input_dim, self.hidden),
            nn.ReLU(),
            nn.Linear(self.hidden, self.n_experts)
        )
        self.experts = nn.ModuleList([
            nn.Sequential(
                nn.Linear(self.input_dim, self.hidden),
                nn.ReLU(),
                nn.Linear(self.hidden, self.hidden),
                nn.ReLU(),
                nn.Linear(self.hidden, 1)
            ) for _ in range(self.n_experts)
        ])
        self.crit = nn.BCEWithLogitsLoss()
        self.sigmoid = nn.Sigmoid()
        self.gate.to(self.device)
        self.experts.to(self.device)

    def fit(self, X, y):
        Xt = torch.tensor(X, dtype=torch.float32).to(self.device)
        yt = torch.tensor(y.reshape(-1,1), dtype=torch.float32).to(self.device)
        params = list(self.gate.parameters()) + list(self.experts.parameters())
        opt = optim.Adam(params, lr=self.lr, weight_decay=self.l2)

        n = Xt.shape[0]
        idx = torch.arange(n, device=self.device)

        for epoch in range(self.epochs):
            perm = idx[torch.randperm(n)]
            for i in range(0, n, self.batch_size):
                batch = perm[i:i+self.batch_size]
                xb = Xt[batch]; yb = yt[batch]
                opt.zero_grad()
                g_logits = self.gate(xb)
                g_prob = torch.softmax(g_logits, dim=1)
                exp_logits = torch.cat([e(xb) for e in self.experts], dim=1)
                mix_logits = (g_prob * exp_logits).sum(dim=1, keepdim=True)
                loss = self.crit(mix_logits, yb)
                loss.backward(); opt.step()
        return self

    def predict_proba(self, X):
        Xt = torch.tensor(X, dtype=torch.float32).to(self.device)
        with torch.no_grad():
            g_prob = torch.softmax(self.gate(Xt), dim=1)
            exp_logits = torch.cat([e(Xt) for e in self.experts], dim=1)
            mix_logits = (g_prob * exp_logits).sum(dim=1, keepdim=True)
            p1 = self.sigmoid(mix_logits).cpu().numpy().ravel()
        return np.vstack([1-p1, p1]).T


In [14]:
print("Using device:", "cuda" if torch.cuda.is_available() else "cpu")

moe = MoEClassifier(input_dim=X.shape[1], n_experts=4, hidden=64, lr=1e-3, epochs=8, batch_size=4096)
results = fit_evaluate(moe, X, y)

print("📊 Mixture-of-Experts Results:")
for k,v in results.items():
    print(f"{k}: {v:.4f}")


Using device: cpu
📊 Mixture-of-Experts Results:
Balanced Accuracy: 0.6704
Macro Recall: 0.6704
Recall Class 1: 0.6512
AUC: 0.6981
Youden Threshold: 0.1973


In [20]:
from sklearn.linear_model import LogisticRegression

# Logistic Regression (baseline)
log_reg = LogisticRegression(max_iter=1000, class_weight="balanced", solver="lbfgs")

results_lr = fit_evaluate(log_reg, X, y)

print("📊 Logistic Regression Results:")
for k,v in results_lr.items():
    print(f"{k}: {v:.4f}")


📊 Logistic Regression Results:
Balanced Accuracy: 0.7703
Macro Recall: 0.7703
Recall Class 1: 0.7304
AUC: 0.8244
Youden Threshold: 0.4627
