In [1]:
#!/usr/bin/env python
# demo_moe.py

import torch, torch.nn as nn, torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset, random_split
from sklearn.metrics import brier_score_loss, roc_auc_score
import numpy as np, random, math, tqdm

# ---------- A. synthetic data -------------------------------------------------
def make_data(n=100_000, d=20, n_classes=4, seed=0):
    torch.manual_seed(seed); np.random.seed(seed); random.seed(seed)
    X = torch.randn(n, d)
    missile_type = torch.randint(0, n_classes, (n, 1)).float()
    # ground-truth logit depends on 4 features + missile_type
    logit = (
        1.2*X[:, 0] - 0.8*X[:, 5] + 0.6*X[:, 11] - 1.4*X[:, 17]
        + 0.9*missile_type.squeeze()
    )
    y = torch.bernoulli(torch.sigmoid(logit)).unsqueeze(1)
    X = torch.cat([X, missile_type], dim=1)  # missile_type becomes feature 21
    return X, y

# ---------- B. MH-MoE module --------------------------------------------------
class MHMoE(nn.Module):
    def __init__(self, in_dim, n_experts=3, hidden=80):
        super().__init__()
        self.gate = nn.Sequential(
            nn.Linear(in_dim, 32), nn.ReLU(), nn.Linear(32, n_experts)
        )
        self.experts = nn.ModuleList([
            nn.Sequential(
                nn.Linear(in_dim, hidden), nn.ReLU(),
                nn.Linear(hidden, 32), nn.ReLU(),
                nn.Linear(32, 1)
            ) for _ in range(n_experts)
        ])

    def forward(self, x):
        gate_logits = self.gate(x)                 # [B, E]
        weights = F.softmax(gate_logits, dim=1)    # mixture weights
        expert_logits = torch.cat(
            [e(x) for e in self.experts], dim=1)   # [B, E]
        y_hat = (weights * expert_logits).sum(dim=1, keepdim=True)
        return y_hat, weights

# ---------- C. training loop --------------------------------------------------
def train(model, dl, opt):
    model.train()
    for xb, yb in dl:
        opt.zero_grad()
        logits, _ = model(xb)
        loss = F.binary_cross_entropy_with_logits(logits, yb)
        loss.backward(); opt.step()

@torch.no_grad()
def eval_metrics(model, dl, temp=1.0):
    model.eval(); ys, ps = [], []
    for xb, yb in dl:
        logits, _ = model(xb)
        ys.append(yb); ps.append(torch.sigmoid(logits / temp))
    y = torch.cat(ys).cpu().numpy().ravel()
    p = torch.cat(ps).cpu().numpy().ravel()
    acc = ( (p > .5) == y ).mean()
    brier = brier_score_loss(y, p)
    auc = roc_auc_score(y, p)
    return acc, brier, auc

def fit_temperature(model, dl):
    logit_list, y_list = [], []
    with torch.no_grad():
        for xb, yb in dl:
            logits, _ = model(xb)
            logit_list.append(logits); y_list.append(yb)
    logits = torch.cat(logit_list); y = torch.cat(y_list)
    T = torch.ones(1, requires_grad=True)
    opt = torch.optim.LBFGS([T], lr=0.1, max_iter=50)

    def closure():
        opt.zero_grad()
        loss = F.binary_cross_entropy_with_logits(logits / T, y)
        loss.backward(); return loss
    opt.step(closure)
    return T.detach().item()

# ---------- D. counterfactual search -----------------------------------------
def counterfactual(model, x, mask, lam=0.1, steps=40, lr=0.1):
    x_cf = x.clone().detach().requires_grad_(True)
    opt = torch.optim.Adam([x_cf], lr=lr)
    for _ in range(steps):
        opt.zero_grad()
        logit, _ = model(x_cf)
        loss = F.binary_cross_entropy_with_logits(logit, torch.ones_like(logit))
        loss += lam * ((x_cf - x) * mask).abs().sum()
        loss.backward(); opt.step()
        x_cf.data = torch.clamp(x_cf.data, -4, 4)  # simple box
    return x_cf.detach()

# ---------- E. main -----------------------------------------------------------
def main():
    X, y = make_data()
    ds = TensorDataset(X, y)
    n_train = int(0.85 * len(ds))
    n_val = len(ds) - n_train
    train_ds, val_ds = random_split(ds, [n_train, n_val])
    train_dl = DataLoader(train_ds, batch_size=2048, shuffle=True)
    val_dl = DataLoader(val_ds, batch_size=4096)

    model = MHMoE(in_dim=X.shape[1])
    opt = torch.optim.Adam(model.parameters(), lr=1e-3)

    for epoch in tqdm.trange(15):
        train(model, train_dl, opt)

    acc, brier, auc = eval_metrics(model, val_dl)
    print(f"raw  : acc={acc:.3f}  Brier={brier:.3f}  AUC={auc:.3f}")

    T = fit_temperature(model, DataLoader(train_ds, batch_size=4096))
    acc, brier, auc = eval_metrics(model, val_dl, temp=T)
    print(f"calib: acc={acc:.3f}  Brier={brier:.3f}  AUC={auc:.3f}  T={T:.2f}")

    # counterfactual demo on first validation sample that failed
    xb, yb = val_ds[0]
    if yb.item() == 0:
        mask = torch.ones_lik_
