# Light Model: Sparse Mixture-of-Experts (k-active tiny experts)

**Mục tiêu:** Giảm **FLOPs/latency** trong khi giữ chất lượng bằng cách chỉ kích hoạt **k** chuyên gia cho mỗi token/mẫu. Ở đây ta khảo sát **3 giải thuật**:

### 1) Switch-MoE (Top-1 hard routing)
- **Ý tưởng:** với mỗi mẫu, router chọn **1 expert duy nhất** (Top-1) → cực nhẹ khi suy luận.
- **Cơ chế:** `g = Linear(GAP(x))` → `argmax` → one-hot → chỉ tính 1 expert.  
- **Cân bằng tải:** auxiliary loss khuyến khích phân phối chọn expert gần **đồng đều**.

### 2) Noisy Top-k MoE (Shazeer-style)
- **Ý tưởng:** router **thêm nhiễu** vào logits để tăng khám phá, rồi **Top-k** (vd. k=2), **chuẩn hoá softmax trên top-k**.
- **Ưu điểm:** mềm dẻo hơn Top-1, thường **ổn định** hơn với dữ liệu đa dạng.
- **Cân bằng tải:** loss cân bằng (KL/entropy) lên **trung bình gate**.

### 3) Hash-Routed MoE (Fixed random router)
- **Ý tưởng:** **không học router**; dùng **phép chiếu ngẫu nhiên** cố định để ánh xạ mẫu → expert (gần với “hash”).  
- **Ưu điểm:** **siêu nhẹ** (không thêm params), latency thấp, dễ triển khai ở **edge**.
- **Nhược:** router không thích nghi → độ chính xác có thể thấp hơn router học được.

**Các so sánh sẽ báo cáo:** #Params, FLOPs, latency, Val@1/Val@5.  
**Checkpoint:** `checkpoints/lm_smoe_switch.pth`, `checkpoints/lm_smoe_noisy_k2.pth`, `checkpoints/lm_smoe_hash.pth`.


In [6]:
import os, math, time, random, itertools, json
from pathlib import Path
from typing import Tuple, List, Dict
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset, Subset, Sampler
from torchvision import datasets, transforms, models
from tqdm import tqdm

SEED = 1337
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Control quick demo vs fuller training
QUICKRUN = True  # set False for better accuracy (longer runs)

DATA_ROOT = "./data"
CKPT_DIR = Path("checkpoints")
CKPT_DIR.mkdir(parents=True, exist_ok=True)

# Simple helpers
def set_seed(seed=SEED):
    random.seed(seed); np.random.seed(seed); torch.manual_seed(seed); 
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

def count_params(model: nn.Module):
    return sum(p.numel() for p in model.parameters())

def try_flops(model, input_size=(1,3,32,32)):
    # Attempt FLOPs via thop; return 'N/A' if unavailable.
    try:
        from thop import profile
        inp = torch.randn(*input_size).to(next(model.parameters()).device)
        macs, params = profile(model, inputs=(inp,), verbose=False)
        return int(macs*2)  # MACs*2 ~ FLOPs
    except Exception as e:
        return "N/A"

@torch.no_grad()
def benchmark_latency(model, input_size=(1,3,32,32), nwarm=20, niter=50, use_jit=False):
    model.eval()
    x = torch.randn(*input_size).to(next(model.parameters()).device)
    if use_jit:
        try:
            ts = torch.jit.trace(model, x)
            model = ts
        except Exception as e:
            pass
    # Warmup
    for _ in range(nwarm):
        _ = model(x)
    # Measure
    times = []
    for _ in range(niter):
        t0 = time.time()
        _ = model(x)
        times.append(time.time()-t0)
    return 1000*np.mean(times)  # ms

# Data transforms
MEAN = (0.4914, 0.4822, 0.4465)
STD  = (0.2023, 0.1994, 0.2010)
train_tf = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(MEAN, STD),
])
val_tf = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(MEAN, STD),
])

# CIFAR-10 dataset and splits
full_train = datasets.CIFAR10(DATA_ROOT, train=True, download=True, transform=train_tf)
test_set   = datasets.CIFAR10(DATA_ROOT, train=False, download=True, transform=val_tf)

# create a small validation split from train
indices = list(range(len(full_train)))
random.shuffle(indices)
val_ratio = 0.1
val_count = int(len(indices)*val_ratio)
val_idx   = indices[:val_count]
train_idx = indices[val_count:]
train_set = Subset(full_train, train_idx)
val_set   = Subset(datasets.CIFAR10(DATA_ROOT, train=True, download=False, transform=val_tf), val_idx)

BATCH = 128 if not QUICKRUN else 64
train_loader = DataLoader(train_set, batch_size=BATCH, shuffle=True, num_workers=2, pin_memory=True)
val_loader   = DataLoader(val_set, batch_size=256, shuffle=False, num_workers=2, pin_memory=True)
test_loader  = DataLoader(test_set, batch_size=256, shuffle=False, num_workers=2, pin_memory=True)

NUM_CLASSES = 10

def accuracy(logits, y, topk=(1,5)):
    maxk = max(topk)
    _, pred = logits.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(y.view(1,-1).expand_as(pred))
    res = []
    for k in topk:
        correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
        res.append((correct_k.item()/y.size(0))*100.0)
    return res  # top1, top5

class AverageMeter:
    def __init__(self):
        self.reset()
    def reset(self):
        self.sum = 0.0; self.n = 0
    def update(self, val, n=1):
        self.sum += val*n; self.n += n
    @property
    def avg(self):
        return self.sum/max(1,self.n)


In [7]:
# ========= SHARED FOR MoE =========
import math, random, time, numpy as np, torch
import torch.nn as nn
import torch.nn.functional as F
from pathlib import Path

# ---- Env & folders ----
if 'device' not in globals():
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if 'CKPT_DIR' not in globals():
    CKPT_DIR = Path("checkpoints"); CKPT_DIR.mkdir(exist_ok=True, parents=True)

# ---- Check loaders ----
needed = ['train_loader','val_loader','test_loader','NUM_CLASSES']
missing = [n for n in needed if n not in globals()]
if missing:
    raise RuntimeError(f"Thiếu biến: {missing}. Hãy chạy phần setup/dataloader trước.")

# ---- Utils ----
if 'AverageMeter' not in globals():
    class AverageMeter:
        def __init__(self): self.sum=0.0; self.n=0
        def update(self,v,n=1): self.sum+=v*n; self.n+=n
        @property
        def avg(self): return self.sum/max(1,self.n)

if 'accuracy' not in globals():
    def accuracy(logits, y, topk=(1,5)):
        maxk = max(topk)
        _, pred = logits.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(y.view(1,-1).expand_as(pred))
        res = []
        for k in topk:
            correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
            res.append((correct_k.item()/y.size(0))*100.0)
        return res

if 'try_flops' not in globals():
    def try_flops(model, input_size=(1,3,32,32)):
        try:
            from thop import profile
            dummy = torch.randn(*input_size).to(next(model.parameters()).device)
            macs,_ = profile(model, inputs=(dummy,), verbose=False)
            return int(macs*2)
        except Exception:
            return "N/A"

if 'benchmark_latency' not in globals():
    @torch.no_grad()
    def benchmark_latency(model, input_size=(1,3,32,32), nwarm=10, niter=30):
        model.eval()
        x = torch.randn(*input_size).to(next(model.parameters()).device)
        for _ in range(nwarm): _ = model(x)
        ts=[]
        for _ in range(niter):
            t0=time.time(); _=model(x); ts.append(time.time()-t0)
        return 1000*np.mean(ts)

@torch.no_grad()
def eval_topk(model, loader, topk=(1,5)):
    model.eval()
    n, t1_sum, t5_sum = 0, 0, 0
    for x,y in loader:
        x,y = x.to(device), y.to(device)
        out = model(x)
        logits = out[0] if isinstance(out, tuple) else out
        _, pred = logits.topk(max(topk), 1, True, True)
        correct = pred.eq(y.view(-1,1))
        n += y.size(0)
        t1_sum += correct[:, :1].sum().item()
        t5_sum += correct[:, :5].sum().item()
    return 100.0*t1_sum/n, 100.0*t5_sum/n

def train_moe(model, label, epochs=10, lr=0.1, lb_weight=0.02, print_every=1):
    """Train + log mỗi epoch (Train loss/@k, Val@k, Test@k). Save best theo Val@1."""
    model = model.to(device)
    opt = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=5e-4)
    sch = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs)
    best_val = -1.0
    ckpt = CKPT_DIR/f"{label}.pth"

    print(f"[Start] {label} | epochs={epochs} lr={lr} lb_w={lb_weight}")
    for ep in range(1, epochs+1):
        # ---- Train ----
        model.train()
        loss_meter = AverageMeter()
        tr_t1, tr_t5 = AverageMeter(), AverageMeter()
        for x,y in train_loader:
            x,y = x.to(device), y.to(device)
            opt.zero_grad()
            out = model(x)
            logits, lb = out if isinstance(out, tuple) else (out, None)
            ce = F.cross_entropy(logits, y)
            loss = ce + (lb_weight*lb if lb is not None else 0.0)
            loss.backward(); opt.step()
            t1,t5 = accuracy(logits.detach(), y, topk=(1,5))
            loss_meter.update(loss.item(), x.size(0))
            tr_t1.update(t1, x.size(0)); tr_t5.update(t5, x.size(0))
        sch.step()

        # ---- Val & Test ----
        val_t1, val_t5 = eval_topk(model, val_loader, topk=(1,5))
        test_t1, test_t5 = eval_topk(model, test_loader, topk=(1,5))

        # ---- Save best on Val@1 ----
        is_best = val_t1 > best_val
        if is_best:
            best_val = val_t1
            torch.save({"model_state": model.state_dict(),
                        "meta":{"label":label,"val_top1":best_val}}, ckpt)

        if ep % print_every == 0:
            print(f"[{label}] Ep {ep:03d}/{epochs} | "
                  f"TrainLoss {loss_meter.avg:.4f} | Train@1 {tr_t1.avg:.2f} | Train@5 {tr_t5.avg:.2f} || "
                  f"Val@1 {val_t1:.2f} | Val@5 {val_t5:.2f} || "
                  f"Test@1 {test_t1:.2f} | Test@5 {test_t5:.2f} || "
                  f"Best Val@1 {best_val:.2f} {'*' if is_best else ''}")

    # ---- Summary ----
    params = sum(p.numel() for p in model.parameters())
    flops  = try_flops(model)
    lat    = benchmark_latency(model, niter=20)
    print(f"[Done:{label}] Params:{params:,} | FLOPs:{flops} | Lat(ms):{lat:.2f} | Best Val@1:{best_val:.2f} | ckpt={ckpt}")
    return ckpt

# ---- Tiny expert shared by all MoE variants ----
class DWConvBlock(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.dw = nn.Conv2d(in_ch, in_ch, 3, padding=1, groups=in_ch, bias=False)
        self.pw = nn.Conv2d(in_ch, out_ch, 1, bias=False)
        self.bn = nn.BatchNorm2d(out_ch)
        self.act = nn.ReLU(inplace=True)
    def forward(self, x): return self.act(self.bn(self.pw(self.dw(x))))

class TinyExpert(nn.Module):
    def __init__(self, ch):
        super().__init__()
        self.block = nn.Sequential(DWConvBlock(ch, ch), DWConvBlock(ch, ch))
    def forward(self, x): return self.block(x)




## Ký hiệu

* Số expert: $E$.
* Số expert hoạt động/token: $k \le E$.
* Chiều ẩn đầu vào: $d$.
* Chiều ẩn FFN của **mỗi expert**: $d_{\text{ff}}$.
* Ma trận router: $W_g \in \mathbb{R}^{d \times E}$, bias $b_g \in \mathbb{R}^{E}$.
* Đầu vào một token: $x \in \mathbb{R}^{d}$.
* Logits router: $z = W_g^\top x + b_g \in \mathbb{R}^{E}$.
* Softmax: $\operatorname{softmax}(z)_i = \dfrac{e^{z_i}}{\sum_{j=1}^E e^{z_j}}$.
* Expert $i$: hàm $F_i: \mathbb{R}^{d} \to \mathbb{R}^{d}$ (ví dụ FFN hai lớp).
* Hệ số capacity: $c \ge 1$; **capacity/ expert**:  
  $$
  C = \left\lceil c \cdot \frac{N_{\text{tokens}} \cdot k}{E} \right\rceil.
  $$

**MoE layer (dạng tổng quát).** Gọi $S \in \{0,1\}^{N_{\text{tokens}} \times E}$ là ma trận **dispatch** (token→expert) và $G \in \mathbb{R}_{\ge 0}^{N_{\text{tokens}} \times E}$ là ma trận **kết hợp** (weight sau gate, chỉ khác 0 ở top-$k$). Với token $t$:

$$
\tilde{y}_t = \sum_{i=1}^{E} S_{t,i} \cdot F_i(x_t), \quad 
y_t = x_t + \sum_{i=1}^E G_{t,i} \cdot F_i(x_t).
$$

Ràng buộc capacity: mỗi cột $i$ của $S$ có nhiều nhất $C$ phần tử bằng 1 (token vượt ngưỡng có thể **drop** hoặc đi qua **residual**).

---

## 1) Switch-MoE (Top-1, hard routing)

**Router.** Tính $z = W_g^\top x + b_g$. Với nhiệt độ $\tau>0$:

$$
\pi = \operatorname{softmax}(z/\tau), \quad 
i^* = \operatorname*{argmax}_{i \in [E]} z_i, \quad 
S_{t,i} = \mathbb{1}[i=i^*].
$$

**Kết hợp.** $k=1$ nên $G_{t,i}=\mathbb{1}[i=i^*]$ (hoặc $G_{t,i}=\pi_i \cdot \mathbb{1}[i=i^*]$ nếu muốn scale theo xác suất).

**Đầu ra.** $y_t = x_t + F_{i^*}(x_t)$.

**Auxiliary load-balancing loss.** Gọi:

* $\text{load}_i = \dfrac{1}{N_{\text{tokens}}} \sum_t \mathbb{1}[i=i^*(t)]$ (tỉ lệ **token** định tuyến tới expert $i$),
* $\text{imp}_i = \dfrac{1}{N_{\text{tokens}}} \sum_t \pi_i(t)$ (tỉ lệ **khối xác suất** router dành cho expert $i$).

Hàm phạt cân bằng (dạng Switch) với hệ số $\lambda_{\text{bal}}$:

$$
\mathcal{L}_{\text{bal}} = \lambda_{\text{bal}} \cdot E \cdot \sum_{i=1}^{E} \text{load}_i \, \text{imp}_i.
$$

**Z-loss (ổn định router).**

$$
\mathcal{L}_{z} = \lambda_{z} \cdot \frac{1}{N_{\text{tokens}}} \sum_t \big( \operatorname{logsumexp}(z(t)) \big)^2.
$$

---

## Loss tổng

$$
\mathcal{L} = \mathcal{L}_{\text{task}} + \mathcal{L}_{\text{bal}} + \mathcal{L}_{z}.
$$

---

## Đo đạc & đánh giá

* **#Params**: số tham số.
* **FLOPs**: /token hoặc /chuỗi.
* **Latency**: p50/p95/p99 trên phần cứng.
* **Val@1 / Val@5**:

$$
\text{Top-1} = \frac{1}{N}\sum_{n} \mathbb{1}[y_n^{\text{pred}} = y_n^{\text{true}}], \quad
\text{Top-5} = \frac{1}{N}\sum_{n} \mathbb{1}[y_n^{\text{true}} \in \text{Top5}(p_n)].
$$

* **Expert utilization entropy**:

$$
H(u) = -\sum_{i=1}^E u_i \log u_i.
$$

---

## Checkpoints

* `checkpoints/lm_smoe_switch.pth`
* `checkpoints/lm_smoe_noisy_k2.pth`
* `checkpoints/lm_smoe_hash.pth`


In [8]:
# ========= SWITCH-MoE =========
class SwitchMoE(nn.Module):
    def __init__(self, channels, num_experts=4, use_load_balance=True):
        super().__init__()
        self.E = num_experts
        self.use_lb = use_load_balance
        self.experts = nn.ModuleList([TinyExpert(channels) for _ in range(self.E)])
        self.gate = nn.Linear(channels, self.E)

    def forward(self, x):
        B,C,H,W = x.shape
        gap = F.adaptive_avg_pool2d(x,1).view(B,C)
        probs = F.softmax(self.gate(gap), dim=1)   # [B,E]
        top1 = probs.argmax(dim=1)                 # [B]
        onehot = F.one_hot(top1, num_classes=self.E).float()  # [B,E]
        out = 0.0
        for e_id, expert in enumerate(self.experts):
            out = out + expert(x) * onehot[:, e_id].view(B,1,1,1)
        lb = None
        if self.use_lb and self.training:
            mean_usage = probs.mean(dim=0)
            prior = torch.full_like(mean_usage, 1.0/self.E)
            lb = F.kl_div(mean_usage.log(), prior, reduction="batchmean")
        return out, lb

class TinyCNN_SwitchMoE(nn.Module):
    def __init__(self, num_classes=10, base_ch=32, num_experts=4, use_lb=True):
        super().__init__()
        self.stem = nn.Sequential(nn.Conv2d(3, base_ch, 3, padding=1, bias=False),
                                  nn.BatchNorm2d(base_ch), nn.ReLU(True))
        self.m1 = SwitchMoE(base_ch, num_experts=num_experts, use_load_balance=use_lb)
        self.down = nn.Conv2d(base_ch, base_ch*2, 3, 2, 1)
        self.m2 = SwitchMoE(base_ch*2, num_experts=num_experts, use_load_balance=use_lb)
        self.pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Linear(base_ch*2, num_classes)
    def forward(self, x):
        x, lb_sum = self.stem(x), 0.0
        x, lb1 = self.m1(x); x = F.relu(self.down(x))
        x, lb2 = self.m2(x)
        if lb1 is not None: lb_sum += lb1
        if lb2 is not None: lb_sum += lb2
        x = self.pool(x).flatten(1); logits = self.fc(x)
        return logits, (lb_sum if lb_sum!=0.0 else None)

# ---- Train Switch-MoE ----
EPOCHS = 30
model_sw = TinyCNN_SwitchMoE(num_classes=NUM_CLASSES, base_ch=(24 if 'QUICKRUN' in globals() and QUICKRUN else 32),
                             num_experts=4, use_lb=True)
ckpt_switch = train_moe(model_sw, label="lm_smoe_switch",
                        epochs=EPOCHS, lr=(0.05 if 'QUICKRUN' in globals() and QUICKRUN else 0.1), lb_weight=0.02)


[Start] lm_smoe_switch | epochs=30 lr=0.05 lb_w=0.02
[lm_smoe_switch] Ep 001/30 | TrainLoss 1.8189 | Train@1 30.29 | Train@5 84.26 || Val@1 36.60 | Val@5 89.12 || Test@1 37.27 | Test@5 89.39 || Best Val@1 36.60 *
[lm_smoe_switch] Ep 002/30 | TrainLoss 1.6079 | Train@1 39.44 | Train@5 90.01 || Val@1 39.36 | Val@5 89.92 || Test@1 38.43 | Test@5 90.41 || Best Val@1 39.36 *
[lm_smoe_switch] Ep 003/30 | TrainLoss 1.4892 | Train@1 44.81 | Train@5 91.99 || Val@1 42.62 | Val@5 92.30 || Test@1 43.43 | Test@5 92.14 || Best Val@1 42.62 *
[lm_smoe_switch] Ep 004/30 | TrainLoss 1.3850 | Train@1 49.21 | Train@5 93.24 || Val@1 48.98 | Val@5 93.02 || Test@1 48.12 | Test@5 92.87 || Best Val@1 48.98 *
[lm_smoe_switch] Ep 005/30 | TrainLoss 1.3239 | Train@1 52.10 | Train@5 93.80 || Val@1 48.18 | Val@5 93.54 || Test@1 48.72 | Test@5 93.95 || Best Val@1 48.98 
[lm_smoe_switch] Ep 006/30 | TrainLoss 1.2745 | Train@1 54.01 | Train@5 94.09 || Val@1 43.92 | Val@5 90.04 || Test@1 44.55 | Test@5 89.86 || Best Va

---

## 2) Noisy Top-$k$ MoE (kiểu Shazeer)

**Router có nhiễu.** Với $\epsilon \sim \mathcal{N}(0,\sigma^2 I)$:

$$
\tilde{z} = z + \epsilon, \quad 
\mathcal{I}_k = \operatorname{TopK}(\tilde{z}, k).
$$

**Chuẩn hoá trên top-$k$.** Với nhiệt độ $\tau$:

$$
G_{t,i} = \begin{cases}
\dfrac{\exp(\tilde{z}_i/\tau)}{\sum_{j\in \mathcal{I}_k} \exp(\tilde{z}_j/\tau)} & i\in \mathcal{I}_k, \\
0 & \text{ngược lại.}
\end{cases}
$$

**Dispatch.** $S_{t,i}=\mathbb{1}[i\in \mathcal{I}_k]$ (sau khi áp dụng capacity $C$).

**Đầu ra.**

$$
y_t = x_t + \sum_{i\in\mathcal{I}_k} G_{t,i} \, F_i(x_t).
$$

**Cân bằng tải.** Với

$$
\text{load}_i = \frac{1}{N_{\text{tokens}}}\sum_t \mathbb{1}[i\in\mathcal{I}_k(t)], \quad
\text{imp}_i = \frac{1}{N_{\text{tokens}}}\sum_t G_{t,i},
$$

đặt

$$
\mathcal{L}_{\text{bal}} = \lambda_{\text{bal}} \cdot E \cdot \sum_{i=1}^{E} \text{load}_i \, \text{imp}_i.
$$

---

In [9]:
# ========= NOISY TOP-k MoE =========
class NoisyTopKMoE(nn.Module):
    def __init__(self, channels, num_experts=8, k=2, noise_std=1.0, use_load_balance=True):
        super().__init__()
        assert 1 <= k <= num_experts
        self.E, self.k, self.noise_std = num_experts, k, noise_std
        self.use_lb = use_load_balance
        self.experts = nn.ModuleList([TinyExpert(channels) for _ in range(self.E)])
        self.gate = nn.Linear(channels, self.E)

    def forward(self, x):
        B,C,H,W = x.shape
        gap = F.adaptive_avg_pool2d(x,1).view(B,C)
        logits = self.gate(gap)
        if self.training and self.noise_std>0:
            logits = logits + torch.randn_like(logits)*self.noise_std
        vals, idx = torch.topk(logits, self.k, dim=1)      # [B,k]
        w = F.softmax(vals, dim=1)                         # [B,k]
        out = 0.0
        for i in range(self.k):
            pick = idx[:, i]
            contrib = 0.0
            for e_id, expert in enumerate(self.experts):
                contrib = contrib + expert(x) * (pick==e_id).float().view(B,1,1,1)
            out = out + contrib * w[:, i].view(B,1,1,1)
        lb = None
        if self.use_lb and self.training:
            probs = F.softmax(logits, dim=1)
            mean_usage = probs.mean(dim=0)
            prior = torch.full_like(mean_usage, 1.0/self.E)
            lb = F.kl_div(mean_usage.log(), prior, reduction="batchmean")
        return out, lb

class TinyCNN_NoisyTopK(nn.Module):
    def __init__(self, num_classes=10, base_ch=32, num_experts=8, k=2, noise_std=1.0, use_lb=True):
        super().__init__()
        self.stem = nn.Sequential(nn.Conv2d(3, base_ch, 3, padding=1, bias=False),
                                  nn.BatchNorm2d(base_ch), nn.ReLU(True))
        self.m1 = NoisyTopKMoE(base_ch, num_experts=num_experts, k=k, noise_std=noise_std, use_load_balance=use_lb)
        self.down = nn.Conv2d(base_ch, base_ch*2, 3, 2, 1)
        self.m2 = NoisyTopKMoE(base_ch*2, num_experts=num_experts, k=k, noise_std=noise_std, use_load_balance=use_lb)
        self.pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Linear(base_ch*2, num_classes)
    def forward(self, x):
        x, lb_sum = self.stem(x), 0.0
        x, lb1 = self.m1(x); x = F.relu(self.down(x))
        x, lb2 = self.m2(x)
        if lb1 is not None: lb_sum += lb1
        if lb2 is not None: lb_sum += lb2
        x = self.pool(x).flatten(1); logits = self.fc(x)
        return logits, (lb_sum if lb_sum!=0.0 else None)

# ---- Train Noisy Top-k ----
model_nk = TinyCNN_NoisyTopK(num_classes=NUM_CLASSES, base_ch=(24 if 'QUICKRUN' in globals() and QUICKRUN else 32),
                             num_experts=8, k=2, noise_std=1.0, use_lb=True)
ckpt_noisy = train_moe(model_nk, label="lm_smoe_noisy_k2",
                       epochs=EPOCHS, lr=(0.05 if 'QUICKRUN' in globals() and QUICKRUN else 0.1), lb_weight=0.02)


[Start] lm_smoe_noisy_k2 | epochs=30 lr=0.05 lb_w=0.02
[lm_smoe_noisy_k2] Ep 001/30 | TrainLoss 1.8226 | Train@1 29.93 | Train@5 83.62 || Val@1 31.92 | Val@5 85.42 || Test@1 32.63 | Test@5 86.77 || Best Val@1 31.92 *
[lm_smoe_noisy_k2] Ep 002/30 | TrainLoss 1.4227 | Train@1 48.04 | Train@5 92.99 || Val@1 47.90 | Val@5 94.08 || Test@1 48.02 | Test@5 93.59 || Best Val@1 47.90 *
[lm_smoe_noisy_k2] Ep 003/30 | TrainLoss 1.2050 | Train@1 57.15 | Train@5 95.31 || Val@1 54.64 | Val@5 95.46 || Test@1 54.39 | Test@5 95.38 || Best Val@1 54.64 *
[lm_smoe_noisy_k2] Ep 004/30 | TrainLoss 1.1046 | Train@1 60.89 | Train@5 96.08 || Val@1 55.08 | Val@5 92.10 || Test@1 55.16 | Test@5 92.27 || Best Val@1 55.08 *
[lm_smoe_noisy_k2] Ep 005/30 | TrainLoss 1.0425 | Train@1 63.23 | Train@5 96.67 || Val@1 63.30 | Val@5 96.62 || Test@1 63.99 | Test@5 97.07 || Best Val@1 63.30 *
[lm_smoe_noisy_k2] Ep 006/30 | TrainLoss 0.9905 | Train@1 65.40 | Train@5 97.00 || Val@1 62.36 | Val@5 96.74 || Test@1 62.04 | Test@5 9

---

## 3) Hash-Routed MoE (router cố định/ngẫu nhiên)

**Router cố định.** Lấy phép chiếu ngẫu nhiên cố định $R \in \mathbb{R}^{d \times h}$. Với token $x$:

$$
q = R^\top x \in \mathbb{R}^{h}, \quad 
s = \operatorname{sign}(q) \in \{-1,+1\}^{h}.
$$

Ánh xạ băm tới chỉ số expert:

$$
\phi(s) = \Big(\sum_{m=1}^{h} \mathbb{1}[s_m=+1] \cdot 2^{m-1}\Big) \bmod E, \quad 
i^* = \phi(s).
$$

**Routing/Kết hợp.** $k=1$, $S_{t,i}=\mathbb{1}[i=i^*]$, $G_{t,i}=\mathbb{1}[i=i^*]$.

---


In [10]:
# ========= HASH-ROUTED MoE =========
class HashMoE(nn.Module):
    def __init__(self, channels, num_experts=4, k=1, seed=42):
        super().__init__()
        assert k==1
        self.E = num_experts
        g = torch.Generator(); g.manual_seed(seed)
        self.register_buffer('proj', torch.randn(channels, num_experts, generator=g))  # fixed router
        self.experts = nn.ModuleList([TinyExpert(channels) for _ in range(self.E)])

    def forward(self, x):
        B,C,H,W = x.shape
        gap = F.adaptive_avg_pool2d(x,1).view(B,C)
        top1 = (gap @ self.proj).argmax(dim=1)            # [B]
        onehot = F.one_hot(top1, num_classes=self.E).float()
        out = 0.0
        for e_id, expert in enumerate(self.experts):
            out = out + expert(x) * onehot[:, e_id].view(B,1,1,1)
        return out, None

class TinyCNN_HashMoE(nn.Module):
    def __init__(self, num_classes=10, base_ch=32, num_experts=4, seed=123):
        super().__init__()
        self.stem = nn.Sequential(nn.Conv2d(3, base_ch, 3, padding=1, bias=False),
                                  nn.BatchNorm2d(base_ch), nn.ReLU(True))
        self.m1 = HashMoE(base_ch, num_experts=num_experts, k=1, seed=seed)
        self.down = nn.Conv2d(base_ch, base_ch*2, 3, 2, 1)
        self.m2 = HashMoE(base_ch*2, num_experts=num_experts, k=1, seed=seed+1)
        self.pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Linear(base_ch*2, num_classes)
    def forward(self, x):
        x, _ = self.m1(self.stem(x)); x = F.relu(self.down(x))
        x, _ = self.m2(x)
        x = self.pool(x).flatten(1)
        return self.fc(x), None

# ---- Train Hash-Routed ----
model_hash = TinyCNN_HashMoE(num_classes=NUM_CLASSES, base_ch=(24 if 'QUICKRUN' in globals() and QUICKRUN else 32),
                             num_experts=4, seed=123)
ckpt_hash = train_moe(model_hash, label="lm_smoe_hash",
                      epochs=EPOCHS, lr=(0.05 if 'QUICKRUN' in globals() and QUICKRUN else 0.1), lb_weight=0.0)


[Start] lm_smoe_hash | epochs=30 lr=0.05 lb_w=0.0
[lm_smoe_hash] Ep 001/30 | TrainLoss 1.7471 | Train@1 33.71 | Train@5 86.19 || Val@1 35.20 | Val@5 88.50 || Test@1 35.11 | Test@5 89.12 || Best Val@1 35.20 *
[lm_smoe_hash] Ep 002/30 | TrainLoss 1.4353 | Train@1 47.08 | Train@5 92.51 || Val@1 38.44 | Val@5 84.60 || Test@1 37.69 | Test@5 85.36 || Best Val@1 38.44 *
[lm_smoe_hash] Ep 003/30 | TrainLoss 1.2741 | Train@1 53.55 | Train@5 94.04 || Val@1 47.52 | Val@5 90.96 || Test@1 46.21 | Test@5 90.24 || Best Val@1 47.52 *
[lm_smoe_hash] Ep 004/30 | TrainLoss 1.1783 | Train@1 57.23 | Train@5 95.48 || Val@1 56.34 | Val@5 95.42 || Test@1 56.62 | Test@5 95.68 || Best Val@1 56.34 *
[lm_smoe_hash] Ep 005/30 | TrainLoss 1.1252 | Train@1 59.58 | Train@5 95.94 || Val@1 52.72 | Val@5 94.26 || Test@1 52.07 | Test@5 94.29 || Best Val@1 56.34 
[lm_smoe_hash] Ep 006/30 | TrainLoss 1.0851 | Train@1 61.27 | Train@5 96.25 || Val@1 52.54 | Val@5 93.74 || Test@1 52.56 | Test@5 93.29 || Best Val@1 56.34 
[lm_