In [None]:
import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader
from tqdm import tqdm
from sklearn.metrics import f1_score, accuracy_score

from config import Config
from utils import ensure_dirs, device
from models.moe import MoEClassifier

def entropy_regularization(weights, eps=1e-8):
    # weights: [B,K]
    ent = -torch.sum(weights * torch.log(weights + eps), dim=1).mean()
    return ent

def train():
    cfg = Config()
    dev = device()
    ensure_dirs(cfg.out_root/"checkpoints")

    Xtr = np.load(cfg.features_dir/"X_train.npy")
    ytr = np.load(cfg.features_dir/"y_train.npy")
    Xv  = np.load(cfg.features_dir/"X_val.npy")
    yv  = np.load(cfg.features_dir/"y_val.npy")
    Xt  = np.load(cfg.features_dir/"X_test.npy")
    yt  = np.load(cfg.features_dir/"y_test.npy")

    mask = np.load(cfg.features_dir/"best_mask.npy")
    Xtr = Xtr[:, mask==1]
    Xv  = Xv[:, mask==1]
    Xt  = Xt[:, mask==1]

    num_classes = int(max(ytr.max(), yv.max(), yt.max()) + 1)
    in_dim = Xtr.shape[1]

    tr_ds = TensorDataset(torch.tensor(Xtr, dtype=torch.float32), torch.tensor(ytr, dtype=torch.long))
    v_ds  = TensorDataset(torch.tensor(Xv, dtype=torch.float32), torch.tensor(yv, dtype=torch.long))
    te_ds = TensorDataset(torch.tensor(Xt, dtype=torch.float32), torch.tensor(yt, dtype=torch.long))

    tr_dl = DataLoader(tr_ds, batch_size=cfg.moe_batch_size, shuffle=True)
    v_dl  = DataLoader(v_ds, batch_size=cfg.moe_batch_size, shuffle=False)
    te_dl = DataLoader(te_ds, batch_size=cfg.moe_batch_size, shuffle=False)

    model = MoEClassifier(in_dim=in_dim, num_classes=num_classes, num_experts=cfg.num_experts).to(dev)
    opt = torch.optim.Adam(model.parameters(), lr=cfg.moe_lr)

    best_f1 = -1

    for epoch in range(cfg.moe_epochs):
        model.train()
        pbar = tqdm(tr_dl, desc=f"MoE Epoch {epoch+1}/{cfg.moe_epochs}")
        for xb, yb in pbar:
            xb, yb = xb.to(dev), yb.to(dev)
            logits, weights = model(xb)

            loss_cls = F.cross_entropy(logits, yb)
            # encourage non-collapsing gate via entropy reg
            loss_ent = -entropy_regularization(weights)  # maximize entropy => minimize negative entropy

            loss = loss_cls + cfg.entropy_reg * loss_ent

            opt.zero_grad()
            loss.backward()
            opt.step()

            pbar.set_postfix({"loss": float(loss), "cls": float(loss_cls)})

        # validate
        model.eval()
        all_pred, all_true = [], []
        with torch.no_grad():
            for xb, yb in v_dl:
                xb = xb.to(dev)
                logits, _ = model(xb)
                pred = logits.argmax(dim=1).cpu().numpy()
                all_pred.append(pred)
                all_true.append(yb.numpy())

        all_pred = np.concatenate(all_pred)
        all_true = np.concatenate(all_true)
        f1 = f1_score(all_true, all_pred, average="macro")

        if f1 > best_f1:
            best_f1 = f1
            torch.save(model.state_dict(), cfg.moe_ckpt)

        print(f"Val F1: {f1:.4f} | Best: {best_f1:.4f}")

    # test
    model.load_state_dict(torch.load(cfg.moe_ckpt, map_location=dev))
    model.eval()
    all_pred, all_true = [], []
    with torch.no_grad():
        for xb, yb in te_dl:
            xb = xb.to(dev)
            logits, _ = model(xb)
            pred = logits.argmax(dim=1).cpu().numpy()
            all_pred.append(pred)
            all_true.append(yb.numpy())

    all_pred = np.concatenate(all_pred)
    all_true = np.concatenate(all_true)
    print("Test Accuracy:", accuracy_score(all_true, all_pred))
    print("Test Macro F1:", f1_score(all_true, all_pred, average="macro"))
    print("Saved best MoE:", cfg.moe_ckpt)

if __name__ == "__main__":
    train()
