In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets
from tqdm import tqdm

import ERMoE as ermoe

device = ermoe.device

  from .autonotebook import tqdm as notebook_tqdm
Some weights of ViTModel were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [2]:
def collate_pil(batch):
    """Return (list_of_PIL_images, labels_tensor) so ermoe.processor can preprocess."""
    images = [img for (img, y) in batch]
    labels = torch.tensor([y for (img, y) in batch], dtype=torch.long)
    return images, labels

batch_size = 200
num_workers = 4
train_ds = datasets.CIFAR10("./data", train=True,  download=True, transform=None)
test_ds  = datasets.CIFAR10("./data", train=False, download=True, transform=None)
train_ld = DataLoader(train_ds, batch_size=batch_size, shuffle=True,
                      num_workers=num_workers, pin_memory=True, collate_fn=collate_pil)
test_ld  = DataLoader(test_ds,  batch_size=batch_size, shuffle=False,
                      num_workers=num_workers, pin_memory=True, collate_fn=collate_pil)

Files already downloaded and verified
Files already downloaded and verified


In [3]:
def build_heads(hidden_dim: int, num_classes: int, num_experts: int, device):
    heads = nn.ModuleList([nn.Linear(hidden_dim, num_classes) for _ in range(num_experts)])
    for h in heads:
        nn.init.xavier_uniform_(h.weight)
        nn.init.zeros_(h.bias)
        h.to(device)
    return heads

@torch.no_grad()
def extract_cls(expert: nn.Module, images):
    """CLS features (B, D) from a ViT expert using ermoe.processor."""
    inputs = ermoe.processor(images=images, return_tensors="pt").to(device)
    out = expert(**inputs)                    
    return out.last_hidden_state[:, 0, :]      

def combine_logits_from_selected(
    logits_per_expert, scores, selected_mask, temperature: float = 0.7
):
    B, E = scores.shape
    C = logits_per_expert[0].shape[-1]
    mixed = logits_per_expert[0].new_zeros(B, C)
    for b in range(B):
        idx = selected_mask[b].nonzero(as_tuple=False).squeeze(-1)
        if idx.numel() == 0:
            idx = scores[b].argmax().unsqueeze(0)  # fallback
        w = F.softmax(scores[b, idx] / temperature, dim=0)  # (m,)
        out = 0.0
        for j, ei in enumerate(idx.tolist()):
            out = out + w[j] * logits_per_expert[ei][b]
        mixed[b] = out
    return mixed

def count_trainable(m: nn.Module) -> int:
    return sum(p.numel() for p in m.parameters() if p.requires_grad)

In [4]:
hidden_dim = ermoe.vit_base.config.hidden_size  
num_experts = ermoe.num_experts
num_classes = 10
heads = build_heads(hidden_dim, num_classes, num_experts, device)

for p in heads.parameters():
    p.requires_grad = True

optimizer = torch.optim.AdamW(heads.parameters(), lr=1e-3, weight_decay=0.0)
temperature = 0.7

print(f"trainable parameters (heads only): {count_trainable(heads):,}")

trainable parameters (heads only): 61,520


In [5]:
def train(experts, heads, loader, optimizer, temperature=0.7):
    ce = nn.CrossEntropyLoss()
    running_loss, correct, total = 0.0, 0, 0

    for images, labels in tqdm(loader):
        scores, sel_mask = ermoe.moE_router_forward(images)  

        logits_per_expert = []
        for i, expert in enumerate(experts):
            cls = extract_cls(expert, images)       
            logits_per_expert.append(heads[i](cls)) 

        mixed_logits = combine_logits_from_selected(
            logits_per_expert, scores, sel_mask, temperature=temperature
        )

        labels = labels.to(device, non_blocking=True)
        loss = ce(mixed_logits, labels)

        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        nn.utils.clip_grad_norm_(heads.parameters(), 1.0)
        optimizer.step()

        running_loss += loss.item() * labels.size(0)
        pred = mixed_logits.argmax(dim=-1)
        correct += (pred == labels).sum().item()
        total += labels.numel()

    return running_loss / total, correct / total

@torch.no_grad()
def eval(experts, heads, loader, temperature=0.7):
    ce = nn.CrossEntropyLoss()
    running_loss, correct, total = 0.0, 0, 0

    for images, labels in loader:
        scores, sel_mask = ermoe.moE_router_forward(images)

        logits_per_expert = []
        for i, expert in enumerate(experts):
            cls = extract_cls(expert, images)
            logits_per_expert.append(heads[i](cls))

        mixed_logits = combine_logits_from_selected(
            logits_per_expert, scores, sel_mask, temperature=temperature
        )

        labels = labels.to(device, non_blocking=True)
        loss = ce(mixed_logits, labels)

        running_loss += loss.item() * labels.size(0)
        pred = mixed_logits.argmax(dim=-1)
        correct += (pred == labels).sum().item()
        total += labels.numel()

    return running_loss / total, correct / total

In [None]:
epochs = 100
best = 0.0
for ep in range(1, epochs + 1):
    tr_loss, tr_acc = train(ermoe.experts, heads, train_ld, optimizer, temperature)
    te_loss, te_acc = eval(ermoe.experts, heads, test_ld, temperature)
    best = max(best, te_acc)
    print(f"epoch {ep:02d} | train {tr_loss:.4f}/{tr_acc*100:.2f}% "
          f"| test {te_loss:.4f}/{te_acc*100:.2f}% | best {best*100:.2f}%")

 21%|██        | 53/250 [42:42<2:36:48, 47.76s/it]

In [None]:
images, labels = next(iter(test_ld))
scores, sel_mask = ermoe.moE_router_forward(images)
print("scores[0]:", scores[0].tolist())
print("selected experts for sample 0:", sel_mask[0].nonzero(as_tuple=False).squeeze(-1).tolist())
