In [None]:
# %% [markdown]
# MH-MoE | 100 k samples | 50 epochs | combined loss/accuracy plot

# %% imports & setup
%matplotlib inline
import torch, torch.nn as nn, torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset, random_split
from sklearn.metrics import roc_auc_score
import numpy as np, random, math
import matplotlib.pyplot as plt
from tqdm.auto import tqdm, trange

# reproducibility
torch.manual_seed(0); np.random.seed(0); random.seed(0)

# ---------- synthetic dataset ----------
def make_data(n=100_000, d=20, n_classes=4):
    X = torch.randn(n, d)
    mtype = torch.randint(0, n_classes, (n, 1)).float()
    logit = (
        1.2*X[:,0] - .8*X[:,5] + .6*X[:,11] - 1.4*X[:,17]
        + .9*mtype.squeeze()
    )
    y = torch.bernoulli(torch.sigmoid(logit)).unsqueeze(1)
    X = torch.cat([X, mtype], dim=1)
    return X, y

# ---------- MH-MoE ----------
class MHMoE(nn.Module):
    def __init__(self, in_dim, n_experts=5, hidden=128):
        super().__init__()
        self.gate = nn.Sequential(
            nn.Linear(in_dim, 64), nn.ReLU(), nn.Linear(64, n_experts)
        )
        self.experts = nn.ModuleList([
            nn.Sequential(
                nn.Linear(in_dim, hidden), nn.ReLU(),
                nn.Linear(hidden, hidden//2), nn.ReLU(),
                nn.Linear(hidden//2, 1)
            ) for _ in range(n_experts)
        ])

    def forward(self, x):
        w = torch.softmax(self.gate(x), dim=1)              # [B,E]
        logits = torch.cat([exp(x) for exp in self.experts], dim=1)
        return (w * logits).sum(1, keepdim=True)

# ---------- helpers ----------
@torch.no_grad()
def val_step(model, dl):
    model.eval()
    tot, tot_loss, correct = 0, 0.0, 0
    for xb, yb in dl:
        logits = model(xb)
        loss = F.binary_cross_entropy_with_logits(logits, yb, reduction="sum")
        probs = torch.sigmoid(logits)
        tot_loss += loss.item()
        correct  += ((probs > .5) == yb).sum().item()
        tot      += yb.numel()
    return tot_loss / tot, correct / tot

# ---------- training loop ----------
def run_demo(epochs=50, batch=4096, lr=1e-3, gamma=.3, step=25):
    X, y = make_data()
    ds = TensorDataset(X, y)
    n_train = int(.8 * len(ds))
    train_ds, val_ds = random_split(ds, [n_train, len(ds)-n_train])
    train_dl = DataLoader(train_ds, batch_size=batch, shuffle=True)
    val_dl   = DataLoader(val_ds,   batch_size=batch)

    model = MHMoE(in_dim=X.shape[1])
    opt   = torch.optim.Adam(model.parameters(), lr=lr)
    sched = torch.optim.lr_scheduler.StepLR(opt, step_size=step, gamma=gamma)

    train_losses, val_losses, val_accs = [], [], []

    for epoch in trange(epochs, desc="Epoch"):
        model.train()
        running, seen = 0.0, 0
        for xb, yb in tqdm(train_dl, leave=False):
            opt.zero_grad()
            loss = F.binary_cross_entropy_with_logits(model(xb), yb)
            loss.backward(); opt.step()

            running += loss.item() * yb.size(0)
            seen    += yb.size(0)
        train_losses.append(running / seen)

        v_loss, v_acc = val_step(model, val_dl)
        val_losses.append(v_loss); val_accs.append(v_acc)
        sched.step()

    # ---------- plot ----------
    fig, ax1 = plt.subplots(figsize=(7,4))
    ax1.plot(train_losses, label="train loss", color="tab:blue")
    ax1.plot(val_losses,   label="val loss",   color="tab:orange")
    ax1.set_xlabel("epoch"); ax1.set_ylabel("BCE loss")
    ax1.legend(loc="upper right")

    ax2 = ax1.twinx()
    ax2.plot(val_accs, label="val accuracy", color="tab:green")
    ax2.set_ylabel("accuracy")
    ax2.set_ylim(0,1)
    ax2.legend(loc="lower right")
    plt.title("MH-MoE | loss & accuracy")
    plt.tight_layout(); plt.show()
    print(f"Validation — acc {val_accs[len(val_accs)-1]:.3f} ")

run_demo()
