# DMMoE CIFAR-10 Benchmark Experiments

**Paper:** "One Model, Many Architectures: A Dynamic Modular Mixture of Heterogeneous Experts"
**Author:** Zaryab Rahman

## Description
This notebook contains the code to reproduce the main benchmarking results on the CIFAR-10 dataset. We compare our proposed **Dynamic Modular Mixture of Heterogeneous Experts (DMMoE)** against two key baselines:

1.  **Homogeneous MoE:** A standard Mixture-of-Experts model where all experts are identical MLPs.
2.  **Dense ViT:** A dense Vision Transformer baseline to measure the performance of a non-sparse model.

The experiments are designed to validate the efficacy and efficiency of heterogeneous experts. This notebook also includes the implementation of the **Load Balancing Loss**, which is crucial for stable MoE training.

---

## 1. Setup and Imports
This cell imports necessary libraries and sets up the device (CUDA/CPU).

In [None]:

import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from typing import List

# --- device Setup ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cpu


## 2. Core Utilities
This cell defines two essential utility components used across all models:
- **`initialize_weights`:** A helper function to apply Kaiming Normal initialization.
- **`PositionalEncoding`:** A standard sinusoidal positional encoding module for sequence data (i.e., the sequence of image patches).

In [None]:
# weight Initialization and Positional Encoding

def initialize_weights(module: nn.Module) -> None:
    """Applies Kaiming Normal initialization to Conv and Linear layers."""
    if isinstance(module, (nn.Linear, nn.Conv1d, nn.Conv2d)):
        nn.init.kaiming_normal_(module.weight, mode='fan_in', nonlinearity='relu')
        if module.bias is not None:
            nn.init.constant_(module.bias, 0)

class PositionalEncoding(nn.Module):
    """Standard sinusoidal positional encoding."""
    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: Tensor, shape [batch_size, seq_len, d_model]
        """
        # The model expects [seq_len, batch_size, d_model] so we permute
        x = x.permute(1, 0, 2)
        x = x + self.pe[:x.size(0)]
        return self.dropout(x.permute(1, 0, 2))

## 3. DMMoE Architecture Components
This section defines the building blocks of our proposed model.

### 3.1 - Heterogeneous Expert Modules
We define the four distinct expert architectures: `MLPExpert`, `CNNExpert`, `ViTExpert`, and `RecurrentExpert`. Each is designed to capture different types of patterns in the data.

### 3.2 - Sparse Router with Gumbel Noise
The `SparseRouter` is responsible for selecting the top-k experts for each token. We add Gumbel noise during training to encourage exploration and ensure stable, differentiable routing.

### 3.3 - Optimized DMMoE Layer
The `DMMoELayer` is the core of the model. It takes the token sequence, uses the router to get expert assignments, and efficiently dispatches tokens to their assigned experts. It includes an optimization to run the sequence-level `ViTExpert` only once per forward pass if it is selected.

In [None]:
# MLPExpert
class MLPExpert(nn.Module):
    """A standard MLP expert with proper initialization."""
    def __init__(self, d_model: int, d_ff: int = 256):
        super().__init__()
        self.fc1 = nn.Linear(d_model, d_ff)
        self.fc2 = nn.Linear(d_ff, d_model)
        self.activation = nn.GELU()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.fc2(self.activation(self.fc1(x)))

# CNNExpert
class CNNExpert(nn.Module):
    """A 1D-CNN expert with proper initialization and residual connection."""
    def __init__(self, d_model: int, n_filters: int = 32, kernel_size: int = 5):
        super().__init__()
        padding = (kernel_size - 1) // 2
        self.conv1 = nn.Conv1d(1, n_filters, kernel_size=kernel_size, padding=padding)
        self.conv2 = nn.Conv1d(n_filters, 1, kernel_size=kernel_size, padding=padding)
        self.activation = nn.GELU()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        residual = x
        x = x.unsqueeze(1) # (B, D) -> (B, 1, D)
        x = self.activation(self.conv1(x))
        x = self.conv2(x)
        x = x.squeeze(1) # (B, 1, D) -> (B, D)
        return x + residual

# ViTExpert
class ViTExpert(nn.Module):
    """A single Transformer block expert. Expects full sequence for context."""
    def __init__(self, d_model: int, n_heads: int = 4, d_ff: int = 256, dropout: float = 0.1):
        super().__init__()
        self.norm1 = nn.LayerNorm(d_model)
        self.attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout, batch_first=True)
        self.norm2 = nn.LayerNorm(d_model)
        self.ffn = nn.Sequential(nn.Linear(d_model, d_ff), nn.GELU(), nn.Dropout(dropout), nn.Linear(d_ff, d_model))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        attn_input = self.norm1(x)
        attn_output, _ = self.attn(attn_input, attn_input, attn_input)
        x = x + attn_output
        ffn_output = self.ffn(self.norm2(x))
        return x + ffn_output

# RecurrentExpert
class RecurrentExpert(nn.Module):
    """A stateless recurrent expert using LSTMCell. It is now device-aware."""
    def __init__(self, d_model: int):
        super().__init__()
        self.lstm_cell = nn.LSTMCell(input_size=d_model, hidden_size=d_model)
        self.d_model = d_model

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Stateless: hidden state is re-initialized for each call.
        # This is now robust to device changes (e.g., moving the model to a GPU).
        h_0 = torch.zeros(x.size(0), self.d_model, device=x.device)
        c_0 = torch.zeros(x.size(0), self.d_model, device=x.device)
        h_1, _ = self.lstm_cell(x, (h_0, c_0))
        return h_1




# SparseRouter with Gumbel Noise
class SparseRouter(nn.Module):
    """
    Selects top-k experts using Gumbel noise for stable and differentiable routing.

    Args:
        d_model (int): The dimension of the input token embeddings.
        num_experts (int): The total number of experts in the pool.
        k (int): The number of experts to select for each token.
    """
    def __init__(self, d_model: int, num_experts: int, k: int = 2):
        super().__init__()
        self.k = k
        self.router_weights = nn.Parameter(torch.empty(d_model, num_experts))
        nn.init.kaiming_normal_(self.router_weights, a=math.sqrt(5))

    def forward(self, x: torch.Tensor, is_training: bool) -> tuple[torch.Tensor, torch.Tensor]:
        """
        Args:
            x (torch.Tensor): Input tokens, shape [batch_size * seq_len, d_model].
            is_training (bool): Flag to enable/disable noise.

        Returns:
            tuple[torch.Tensor, torch.Tensor]: A tuple containing:
                - top_k_indices (torch.Tensor): Indices of selected experts, shape [batch_size * seq_len, k].
                - gating_weights (torch.Tensor): Softmax weights for selected experts, shape [batch_size * seq_len, k].
        """
        logits = x @ self.router_weights

        if is_training:
            # Add Gumbel noise for exploration.
            gumbel_noise = -torch.empty_like(logits, memory_format=torch.legacy_contiguous_format).exponential_().log()
            logits += gumbel_noise

        top_k_logits, top_k_indices = torch.topk(logits, self.k, dim=-1)
        gating_weights = F.softmax(top_k_logits, dim=-1)

        return top_k_indices, gating_weights




#  The Optimized DMMoELayer with Vectorized Dispatch
class DMMoELayer(nn.Module):
    """
    An optimized DMMoE layer with vectorized expert dispatch and efficient
    handling of sequence-level experts like Vision Transformers.
    """
    def __init__(self, d_model: int, experts: nn.ModuleList, k: int = 2, vit_expert_idx: int = -1):
        super().__init__()
        self.d_model = d_model
        self.experts = experts
        self.num_experts = len(experts)
        self.router = SparseRouter(d_model, self.num_experts, k)
        self.vit_expert_idx = vit_expert_idx # Track the ViT expert if it exists

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x (torch.Tensor): Input tensor, shape [batch_size, seq_len, d_model].

        Returns:
            torch.Tensor: Output tensor with the same shape as the input.
        """
        batch_size, seq_len, d_model = x.shape
        x_flat = x.view(-1, d_model) # Shape: [B*S, D]

        # --- Routing ---
        expert_indices, gating_weights = self.router(x_flat, self.training)

        # --- Efficient ViT Expert Handling ---
        vit_output_cache = None
        if self.vit_expert_idx != -1 and (expert_indices == self.vit_expert_idx).any():
            # If any token is routed to the ViT expert, run it ONCE on the full sequence.
            vit_expert = self.experts[self.vit_expert_idx]
            vit_output_cache = vit_expert(x).view(-1, d_model) # Shape: [B*S, D]

        # --- Vectorized Expert Dispatch ---
        final_output = torch.zeros_like(x_flat)
        flat_expert_indices = expert_indices.flatten()

        for i, expert in enumerate(self.experts):
            # Create a mask for tokens routed to the current expert
            mask = (flat_expert_indices == i)
            if not mask.any():
                continue

            # Get the indices of the tokens and which gate (0 or 1) they correspond to
            token_indices, gate_pos = torch.where(expert_indices == i)

            if i == self.vit_expert_idx and vit_output_cache is not None:
                # Use the pre-computed output for the ViT expert
                expert_output = vit_output_cache[token_indices]
            else:
                # For all other experts, run a single batch forward pass
                selected_tokens = x_flat[token_indices]
                expert_output = expert(selected_tokens)

            # Get the corresponding gating weights and apply them
            selected_gates = gating_weights[token_indices, gate_pos].unsqueeze(1)
            weighted_output = expert_output * selected_gates

            # Scatter the weighted outputs back to their original positions
            final_output.scatter_add_(0, token_indices.unsqueeze(1).expand_as(weighted_output), weighted_output)

        return x + final_output.view(batch_size, seq_len, d_model)

## 4. Architecture Unit Test
This cell performs a quick sanity check on the DMMoE layer to ensure the input and output shapes are correct and that the router's gating weights sum to 1 as expected.

In [None]:

# 1. Configuration
D_MODEL = 128
NUM_EXPERTS = 4
K = 2
BATCH_SIZE = 8
SEQ_LEN = 16
VIT_EXPERT_INDEX = 2 # ViTExpert is the 3rd expert in our list

# 2. Instantiate experts and apply initialization
experts = nn.ModuleList([
    MLPExpert(D_MODEL),
    CNNExpert(D_MODEL),
    ViTExpert(D_MODEL),
    RecurrentExpert(D_MODEL)
])
experts.apply(initialize_weights)

# 3. Instantiate the complete DMMoE Layer
dmmoe_layer = DMMoELayer(
    d_model=D_MODEL,
    experts=experts,
    k=K,
    vit_expert_idx=VIT_EXPERT_INDEX
).to(device)

# 4. Instantiate Positional Encoding
pos_encoder = PositionalEncoding(D_MODEL).to(device)

# 5. Create dummy input data and move to device
input_tensor = torch.randn(BATCH_SIZE, SEQ_LEN, D_MODEL).to(device)

# 6. Run a full forward pass (with positional encoding)
print("--- Running Full Forward Pass and Tests ---")
dmmoe_layer.train() # Set to training mode
input_with_pos = pos_encoder(input_tensor)
output_tensor = dmmoe_layer(input_with_pos)

# 7. Verify shape and device
print(f"Input shape:  {input_tensor.shape}")
print(f"Output shape: {output_tensor.shape}")
print(f"Output device: {output_tensor.device}")
assert input_tensor.shape == output_tensor.shape, "Output shape must match input!"
assert output_tensor.device == device, "Output must be on the correct device!"

# 8. --- Mini Unit Tests for the Router ---
dmmoe_layer.eval() # Test in eval mode for determinism
with torch.no_grad():
    indices, gates = dmmoe_layer.router(input_tensor.view(-1, D_MODEL), is_training=False)

    # Test 1: Router returns exactly k indices per token
    assert indices.shape[1] == K, f"Router should return k={K} indices, but got {indices.shape[1]}"

    # Test 2: Gating weights for each token sum to 1
    gate_sums = gates.sum(dim=-1)
    assert torch.allclose(gate_sums, torch.ones_like(gate_sums)), "Gating weights must sum to 1"

print("\nAll tests passed! The refined implementation is ready for Week 2.")

--- Running Full Forward Pass and Tests ---
Input shape:  torch.Size([8, 16, 128])
Output shape: torch.Size([8, 16, 128])
Output device: cpu

All tests passed! The refined implementation is ready for Week 2.


In [None]:

from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from tqdm.notebook import tqdm # For progress bars
import torch.optim as optim

## 5. Data Pipeline for CIFAR-10
This cell defines the function `get_cifar10_loaders` which handles downloading, transforming (with normalization), and creating `DataLoader` objects for the CIFAR-10 training and test sets.

In [None]:

def get_cifar10_loaders(batch_size: int = 128):
    """Creates training and validation data loaders for CIFAR-10."""

    # transformations for the images
    # 1. Resize to a consistent size (optional but good practice)
    # 2. Convert to tensor
    # 3. Normalize with mean and std of CIFAR-10
    transform = transforms.Compose([
        transforms.Resize((32, 32)),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
    ])

    # download and create datasets
    train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
    test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

    # create data loaders
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)

    return train_loader, test_loader

# --- Test the loaders ---
train_loader, test_loader = get_cifar10_loaders(batch_size=4)
images, labels = next(iter(train_loader))
print("CIFAR-10 Batch Shape:", images.shape) # (B, Channels, Height, Width)
print("CIFAR-10 Labels Shape:", labels.shape)

100%|██████████| 170M/170M [00:02<00:00, 77.8MB/s]


CIFAR-10 Batch Shape: torch.Size([4, 3, 32, 32])
CIFAR-10 Labels Shape: torch.Size([4])


## 6. Full Classifier Model Definitions
Here we define the three complete models for our benchmark comparison.

- **`DMMoE_Classifier`:** Our proposed model.
- **`HomogeneousMoE_Classifier`:** The baseline using only MLP experts.
- **`DenseViT_Classifier`:** The dense (non-MoE) baseline.

In [None]:

class PatchEncoder(nn.Module):
    """Converts an image into a sequence of patch embeddings."""
    def __init__(self, img_size=32, patch_size=8, in_channels=3, d_model=128):
        super().__init__()
        self.num_patches = (img_size // patch_size) ** 2
        self.proj = nn.Conv2d(in_channels, d_model, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        x = self.proj(x)  # (B, D, H/P, W/P)
        x = x.flatten(2)   # (B, D, N_patches)
        x = x.transpose(1, 2) # (B, N_patches, D)
        return x

class DMMoE_Classifier(nn.Module):
    """The full DMMoE model for image classification."""
    def __init__(self, d_model=128, num_classes=10):
        super().__init__()
        self.patch_encoder = PatchEncoder(d_model=d_model)
        self.pos_encoder = PositionalEncoding(d_model)

        # Instantiate the experts
        experts = nn.ModuleList([
            MLPExpert(d_model), CNNExpert(d_model),
            ViTExpert(d_model), RecurrentExpert(d_model)
        ])
        experts.apply(initialize_weights) # Initialize weights

        # The core DMMoE layer
        self.dmmoe_layer = DMMoELayer(d_model, experts, k=2, vit_expert_idx=2)

        # Classification head
        self.norm = nn.LayerNorm(d_model)
        self.head = nn.Linear(d_model, num_classes)

    def forward(self, x):
        x = self.patch_encoder(x)
        x = self.pos_encoder(x)
        x = self.dmmoe_layer(x)
        # Average pooling over the patch dimension
        x = x.mean(dim=1)
        return self.head(self.norm(x))

class HomogeneousMoE_Classifier(nn.Module):
    """Baseline: An MoE model with 4 identical MLP experts."""
    def __init__(self, d_model=128, num_classes=10):
        super().__init__()
        self.patch_encoder = PatchEncoder(d_model=d_model)
        self.pos_encoder = PositionalEncoding(d_model)

        # All experts are the same type
        experts = nn.ModuleList([MLPExpert(d_model) for _ in range(4)])
        experts.apply(initialize_weights)

        self.moe_layer = DMMoELayer(d_model, experts, k=2) # Using the same layer logic

        self.norm = nn.LayerNorm(d_model)
        self.head = nn.Linear(d_model, num_classes)

    def forward(self, x):
        x = self.patch_encoder(x)
        x = self.pos_encoder(x)
        x = self.moe_layer(x)
        x = x.mean(dim=1)
        return self.head(self.norm(x))

class DenseViT_Classifier(nn.Module):
    """Baseline: A standard (dense) Vision Transformer."""
    def __init__(self, d_model=128, num_classes=10, num_layers=2):
        super().__init__()
        self.patch_encoder = PatchEncoder(d_model=d_model)
        self.pos_encoder = PositionalEncoding(d_model)

        self.transformer_layers = nn.ModuleList(
            [ViTExpert(d_model) for _ in range(num_layers)]
        )
        self.transformer_layers.apply(initialize_weights)

        self.norm = nn.LayerNorm(d_model)
        self.head = nn.Linear(d_model, num_classes)

    def forward(self, x):
        x = self.patch_encoder(x)
        x = self.pos_encoder(x)
        for layer in self.transformer_layers:
            x = layer(x)
        x = x.mean(dim=1)
        return self.head(self.norm(x))

In [None]:

# --- Step 1: Modify DMMoELayer to return router logits ---
# We'll create a new class that inherits from our previous DMMoELayer
# to keep the original logic clean.

class DMMoELayerWithLoadBalancing(DMMoELayer):
    def forward(self, x: torch.Tensor):
        batch_size, seq_len, d_model = x.shape


        # Use .reshape() instead of .view() to handle non-contiguous tensors
        x_flat = x.reshape(-1, d_model)

        # Get routing decisions AND the raw logits from the router
        expert_indices, gating_weights, full_logits = self.router(x_flat, self.training)

        # --- The rest of the forward pass is the same as the parent class ---
        vit_output_cache = None
        if self.vit_expert_idx != -1 and (expert_indices == self.vit_expert_idx).any():
            vit_expert = self.experts[self.vit_expert_idx]
            vit_output_cache = vit_expert(x).reshape(-1, d_model)

        final_output = torch.zeros_like(x_flat)
        flat_expert_indices = expert_indices.flatten()

        for i, expert in enumerate(self.experts):
            mask = (flat_expert_indices == i)
            if not mask.any(): continue
            token_indices, gate_pos = torch.where(expert_indices == i)
            if i == self.vit_expert_idx and vit_output_cache is not None:
                expert_output = vit_output_cache[token_indices]
            else:
                expert_output = expert(x_flat[token_indices])
            selected_gates = gating_weights[token_indices, gate_pos].unsqueeze(1)
            weighted_output = expert_output * selected_gates
            final_output.scatter_add_(0, token_indices.unsqueeze(1).expand_as(weighted_output), weighted_output)


        # Also use .reshape() for the final output for safety
        output = x + final_output.reshape(batch_size, seq_len, d_model)
        return output, full_logits, expert_indices

# --- Step 2: Update the Router to also return full logits ---
class SparseRouterWithLoadBalancing(SparseRouter):
    def forward(self, x: torch.Tensor, is_training: bool):
        logits = x @ self.router_weights
        if is_training:
            gumbel_noise = -torch.empty_like(logits, memory_format=torch.legacy_contiguous_format).exponential_().log()
            logits += gumbel_noise
        top_k_logits, top_k_indices = torch.topk(logits, self.k, dim=-1)
        gating_weights = F.softmax(top_k_logits, dim=-1)
        return top_k_indices, gating_weights, logits # Return the full logits

# --- Step 3: Implement the loss calculation and a model wrapper ---
class DMMoE_Model_With_Loss(nn.Module):
    """A wrapper for any MoE model that computes the total loss."""
    def __init__(self, model, lb_weight=0.02):
        super().__init__()
        self.model = model
        self.lb_weight = lb_weight

        # IMPORTANT: Replace the router and layer with the new versions
        # This part remains the same, as it correctly instantiates our new classes
        self.model.dmmoe_layer = DMMoELayerWithLoadBalancing(
            self.model.dmmoe_layer.d_model,
            self.model.dmmoe_layer.experts,
            self.model.dmmoe_layer.router.k,
            self.model.dmmoe_layer.vit_expert_idx
        )
        self.model.dmmoe_layer.router = SparseRouterWithLoadBalancing(
            self.model.dmmoe_layer.d_model,
            self.model.dmmoe_layer.num_experts,
            self.model.dmmoe_layer.router.k
        )

    def compute_load_balancing_loss(self, router_logits, expert_indices):
        num_experts = router_logits.size(-1)

        # P_j: Fraction of tokens routed to expert j
        expert_counts = torch.zeros(num_experts, device=router_logits.device)
        expert_counts.index_add_(0, expert_indices.flatten(), torch.ones_like(expert_indices.flatten(), dtype=torch.float))
        P = expert_counts / len(expert_indices)

        # f_j: Average router probability for expert j
        f = F.softmax(router_logits, dim=-1).mean(dim=0)

        # Loss: alpha * E * dot_product(P, f)
        # Using a fixed alpha=1 for simplicity; the lb_weight is the main knob
        loss = num_experts * torch.sum(P * f)
        return loss

    def forward(self, images, labels=None):
        # The underlying model needs a different forward pass now
        x = self.model.patch_encoder(images)
        x = self.model.pos_encoder(x)
        x, router_logits, expert_indices = self.model.dmmoe_layer(x)
        x = x.mean(dim=1)
        output = self.model.head(self.model.norm(x))

        if labels is not None:
            task_loss = F.cross_entropy(output, labels)
            lb_loss = self.compute_load_balancing_loss(router_logits, expert_indices)
            total_loss = task_loss + self.lb_weight * lb_loss
            return output, total_loss

        return output, None

In [None]:

def train_one_epoch(model, loader, optimizer, scheduler, device):
    model.train()
    total_loss = 0.0
    progress_bar = tqdm(loader, desc="Training", leave=False)

    is_moe = isinstance(model, DMMoE_Model_With_Loss)

    for images, labels in progress_bar:
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()

        if is_moe:
            outputs, loss = model(images, labels)
        else:
            outputs = model(images)
            loss = F.cross_entropy(outputs, labels)

        loss.backward()
        optimizer.step()
        scheduler.step()

        total_loss += loss.item()
        progress_bar.set_postfix(loss=loss.item())

    return total_loss / len(loader)

def evaluate(model, loader, device):
    model.eval()
    correct = 0
    total = 0

    is_moe = isinstance(model, DMMoE_Model_With_Loss)

    with torch.no_grad():
        for images, labels in loader:
            images, labels = images.to(device), labels.to(device)

            if is_moe:
                outputs, _ = model(images, labels)
            else:
                outputs = model(images)

            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = 100 * correct / total
    return accuracy

In [None]:

\
from torch.utils.tensorboard import SummaryWriter
import os
import shutil

# --- configuration ---
NUM_EPOCHS = 50 # The full run as per the plan
BATCH_SIZE = 128
LEARNING_RATE = 1e-3
LB_WEIGHT = 0.02
D_MODEL = 128
NUM_CLASSES = 10
MODEL_SAVE_PATH = "./models"

# --- clean up previous runs for a fresh start ---
if os.path.exists('runs'):
    print("Deleting old 'runs' directory.")
    shutil.rmtree('runs')
if not os.path.exists(MODEL_SAVE_PATH):
    os.makedirs(MODEL_SAVE_PATH)

print("\nSetup complete. Ready for full experiments.")


Setup complete. Ready for full experiments.


In [None]:

def evaluate(model, loader, device):
    """
    Evaluates the model and returns accuracy and expert utilization statistics.
    """
    model.eval()
    correct = 0
    total = 0

    is_moe = isinstance(model, DMMoE_Model_With_Loss)
    all_expert_indices = []

    with torch.no_grad():
        for images, labels in loader:
            images, labels = images.to(device), labels.to(device)

            if is_moe:
                # The model now returns expert_indices
                outputs, _, expert_indices = model.model.dmmoe_layer(model.model.pos_encoder(model.model.patch_encoder(images)))
                x = outputs.mean(dim=1)
                outputs = model.model.head(model.model.norm(x))
                all_expert_indices.append(expert_indices.cpu())
            else:
                outputs = model(images)

            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = 100 * correct / total

    # --- Expert Utilization Calculation ---
    expert_utilization = None
    if is_moe and len(all_expert_indices) > 0:
        # Concatenate indices from all batches
        all_expert_indices = torch.cat(all_expert_indices, dim=0)
        num_experts = model.model.dmmoe_layer.num_experts

        # Count occurrences of each expert index
        counts = torch.bincount(all_expert_indices.flatten(), minlength=num_experts)
        expert_utilization = {f"expert_{i}": counts[i].item() for i in range(num_experts)}

    return accuracy, expert_utilization

In [None]:

def run_experiment(model, model_name, train_loader, test_loader, device, num_epochs, lr, lb_weight=0.0):
    """
    Runs a full training and evaluation experiment for a given model,
    with logging to TensorBoard.
    """
    print(f"--- Starting Experiment: {model_name} ---")

    # Setup TensorBoard writer
    writer = SummaryWriter(f'runs/{model_name}')

    # Wrap MoE models to handle loss calculation
    is_moe = "DMMoE" in model_name or "Homogeneous" in model_name
    if is_moe:
        model_with_loss = DMMoE_Model_With_Loss(model, lb_weight=lb_weight).to(device)
    else:
        model_with_loss = model.to(device)

    optimizer = optim.AdamW(model_with_loss.parameters(), lr=lr, weight_decay=1e-4)
    scheduler = optim.lr_scheduler.OneCycleLR(optimizer, max_lr=lr,
                                              steps_per_epoch=len(train_loader), epochs=num_epochs)

    best_accuracy = 0.0

    for epoch in range(num_epochs):
        # Use the original training loop which is already fine
        train_loss = train_one_epoch(model_with_loss, train_loader, optimizer, scheduler, device)

        # Use the new evaluation loop to get stats
        val_accuracy, expert_utilization = evaluate(model_with_loss, test_loader, device)

        # --- Logging to TensorBoard ---
        writer.add_scalar('Loss/train', train_loss, epoch)
        writer.add_scalar('Accuracy/validation', val_accuracy, epoch)

        if expert_utilization:
            # This logs expert counts as a multi-line chart
            writer.add_scalars('Expert_Utilization/validation', expert_utilization, epoch)

        print(f"Epoch {epoch+1}/{num_epochs} | Train Loss: {train_loss:.4f} | Val Accuracy: {val_accuracy:.2f}%")
        if expert_utilization:
            print(f"  Expert Utilization: {expert_utilization}")

        # --- Save the best model ---
        if val_accuracy > best_accuracy:
            best_accuracy = val_accuracy
            save_path = os.path.join(MODEL_SAVE_PATH, f"{model_name}_best.pth")
            torch.save(model.state_dict(), save_path)
            print(f"  New best model saved to {save_path} with accuracy: {best_accuracy:.2f}%")

    writer.close()
    print(f"--- Experiment Finished: {model_name} | Best Accuracy: {best_accuracy:.2f}% ---")
    return best_accuracy

In [None]:

# --- Get Data Loaders ---
train_loader, test_loader = get_cifar10_loaders(batch_size=BATCH_SIZE)

# --- Experiment 1: DMMoE Classifier ---
dmm_classifier = DMMoE_Classifier(d_model=D_MODEL, num_classes=NUM_CLASSES)
run_experiment(
    model=dmm_classifier,
    model_name="DMMoE_CIFAR10",
    train_loader=train_loader,
    test_loader=test_loader,
    device=device,
    num_epochs=NUM_EPOCHS,
    lr=LEARNING_RATE,
    lb_weight=LB_WEIGHT
)

'''# --- Experiment 2: Homogeneous MoE Baseline ---
# Note: You can comment out the first experiment while running this one.
homo_moe_classifier = HomogeneousMoE_Classifier(d_model=D_MODEL, num_classes=NUM_CLASSES)
run_experiment(
    model=homo_moe_classifier,
    model_name="HomogeneousMoE_CIFAR10",
    train_loader=train_loader,
    test_loader=test_loader,
    device=device,
    num_epochs=NUM_EPOCHS,
    lr=LEARNING_RATE,
    lb_weight=LB_WEIGHT
)

# --- Experiment 3: Dense ViT Baseline ---
# Note: You can comment out the other experiments while running this one.
dense_vit_classifier = DenseViT_Classifier(d_model=D_MODEL, num_classes=NUM_CLASSES, num_layers=2)
run_experiment(
    model=dense_vit_classifier,
    model_name="DenseViT_CIFAR10",
    train_loader=train_loader,
    test_loader=test_loader,
    device=device,
    num_epochs=NUM_EPOCHS,
    lr=LEARNING_RATE
)'''

--- Starting Experiment: DMMoE_CIFAR10 ---


Training:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 1/50 | Train Loss: 2.1838 | Val Accuracy: 28.06%
  Expert Utilization: {'expert_0': 66747, 'expert_1': 153824, 'expert_2': 19589, 'expert_3': 79840}
  New best model saved to ./models/DMMoE_CIFAR10_best.pth with accuracy: 28.06%


Training:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 2/50 | Train Loss: 1.9491 | Val Accuracy: 36.02%
  Expert Utilization: {'expert_0': 69768, 'expert_1': 141176, 'expert_2': 43453, 'expert_3': 65603}
  New best model saved to ./models/DMMoE_CIFAR10_best.pth with accuracy: 36.02%


Training:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 3/50 | Train Loss: 1.7693 | Val Accuracy: 40.39%
  Expert Utilization: {'expert_0': 70602, 'expert_1': 125152, 'expert_2': 60370, 'expert_3': 63876}
  New best model saved to ./models/DMMoE_CIFAR10_best.pth with accuracy: 40.39%


Training:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 4/50 | Train Loss: 1.6612 | Val Accuracy: 44.05%
  Expert Utilization: {'expert_0': 75587, 'expert_1': 106948, 'expert_2': 72595, 'expert_3': 64870}
  New best model saved to ./models/DMMoE_CIFAR10_best.pth with accuracy: 44.05%


Training:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 5/50 | Train Loss: 1.5926 | Val Accuracy: 46.22%
  Expert Utilization: {'expert_0': 81845, 'expert_1': 92221, 'expert_2': 82069, 'expert_3': 63865}
  New best model saved to ./models/DMMoE_CIFAR10_best.pth with accuracy: 46.22%


Training:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 6/50 | Train Loss: 1.5348 | Val Accuracy: 47.52%
  Expert Utilization: {'expert_0': 88076, 'expert_1': 80356, 'expert_2': 83478, 'expert_3': 68090}
  New best model saved to ./models/DMMoE_CIFAR10_best.pth with accuracy: 47.52%


Training:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 7/50 | Train Loss: 1.4905 | Val Accuracy: 48.38%
  Expert Utilization: {'expert_0': 90003, 'expert_1': 77789, 'expert_2': 90657, 'expert_3': 61551}
  New best model saved to ./models/DMMoE_CIFAR10_best.pth with accuracy: 48.38%


Training:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 8/50 | Train Loss: 1.4504 | Val Accuracy: 50.33%
  Expert Utilization: {'expert_0': 92222, 'expert_1': 74331, 'expert_2': 90555, 'expert_3': 62892}
  New best model saved to ./models/DMMoE_CIFAR10_best.pth with accuracy: 50.33%


Training:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 9/50 | Train Loss: 1.4197 | Val Accuracy: 50.42%
  Expert Utilization: {'expert_0': 95960, 'expert_1': 69772, 'expert_2': 92048, 'expert_3': 62220}
  New best model saved to ./models/DMMoE_CIFAR10_best.pth with accuracy: 50.42%


Training:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 10/50 | Train Loss: 1.3907 | Val Accuracy: 51.19%
  Expert Utilization: {'expert_0': 94989, 'expert_1': 70145, 'expert_2': 93050, 'expert_3': 61816}
  New best model saved to ./models/DMMoE_CIFAR10_best.pth with accuracy: 51.19%


Training:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 11/50 | Train Loss: 1.3635 | Val Accuracy: 50.59%
  Expert Utilization: {'expert_0': 100038, 'expert_1': 67560, 'expert_2': 95947, 'expert_3': 56455}


Training:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 12/50 | Train Loss: 1.3446 | Val Accuracy: 51.92%
  Expert Utilization: {'expert_0': 100576, 'expert_1': 61663, 'expert_2': 98525, 'expert_3': 59236}
  New best model saved to ./models/DMMoE_CIFAR10_best.pth with accuracy: 51.92%


Training:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 13/50 | Train Loss: 1.3170 | Val Accuracy: 53.23%
  Expert Utilization: {'expert_0': 102140, 'expert_1': 64597, 'expert_2': 93643, 'expert_3': 59620}
  New best model saved to ./models/DMMoE_CIFAR10_best.pth with accuracy: 53.23%


Training:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 14/50 | Train Loss: 1.2934 | Val Accuracy: 52.89%
  Expert Utilization: {'expert_0': 99774, 'expert_1': 59117, 'expert_2': 93011, 'expert_3': 68098}


Training:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 15/50 | Train Loss: 1.2689 | Val Accuracy: 53.96%
  Expert Utilization: {'expert_0': 95913, 'expert_1': 60401, 'expert_2': 95349, 'expert_3': 68337}
  New best model saved to ./models/DMMoE_CIFAR10_best.pth with accuracy: 53.96%


Training:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 16/50 | Train Loss: 1.2431 | Val Accuracy: 54.91%
  Expert Utilization: {'expert_0': 103404, 'expert_1': 59594, 'expert_2': 97503, 'expert_3': 59499}
  New best model saved to ./models/DMMoE_CIFAR10_best.pth with accuracy: 54.91%


Training:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 17/50 | Train Loss: 1.2219 | Val Accuracy: 54.47%
  Expert Utilization: {'expert_0': 101411, 'expert_1': 61705, 'expert_2': 95850, 'expert_3': 61034}


Training:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 18/50 | Train Loss: 1.1974 | Val Accuracy: 54.73%
  Expert Utilization: {'expert_0': 104634, 'expert_1': 61448, 'expert_2': 96447, 'expert_3': 57471}


Training:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 19/50 | Train Loss: 1.1816 | Val Accuracy: 56.03%
  Expert Utilization: {'expert_0': 107018, 'expert_1': 59095, 'expert_2': 98295, 'expert_3': 55592}
  New best model saved to ./models/DMMoE_CIFAR10_best.pth with accuracy: 56.03%


Training:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 20/50 | Train Loss: 1.1559 | Val Accuracy: 55.58%
  Expert Utilization: {'expert_0': 106330, 'expert_1': 57003, 'expert_2': 95958, 'expert_3': 60709}


Training:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 21/50 | Train Loss: 1.1413 | Val Accuracy: 56.47%
  Expert Utilization: {'expert_0': 105662, 'expert_1': 58364, 'expert_2': 97965, 'expert_3': 58009}
  New best model saved to ./models/DMMoE_CIFAR10_best.pth with accuracy: 56.47%


Training:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 22/50 | Train Loss: 1.1195 | Val Accuracy: 56.44%
  Expert Utilization: {'expert_0': 102974, 'expert_1': 56328, 'expert_2': 98411, 'expert_3': 62287}


Training:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 23/50 | Train Loss: 1.0959 | Val Accuracy: 56.15%
  Expert Utilization: {'expert_0': 102127, 'expert_1': 55400, 'expert_2': 101682, 'expert_3': 60791}


Training:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 24/50 | Train Loss: 1.0809 | Val Accuracy: 56.51%
  Expert Utilization: {'expert_0': 104133, 'expert_1': 56808, 'expert_2': 97641, 'expert_3': 61418}
  New best model saved to ./models/DMMoE_CIFAR10_best.pth with accuracy: 56.51%


Training:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 25/50 | Train Loss: 1.0607 | Val Accuracy: 57.55%
  Expert Utilization: {'expert_0': 104369, 'expert_1': 55581, 'expert_2': 95381, 'expert_3': 64669}
  New best model saved to ./models/DMMoE_CIFAR10_best.pth with accuracy: 57.55%


Training:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 26/50 | Train Loss: 1.0447 | Val Accuracy: 57.14%
  Expert Utilization: {'expert_0': 104923, 'expert_1': 58566, 'expert_2': 98144, 'expert_3': 58367}


Training:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 27/50 | Train Loss: 1.0199 | Val Accuracy: 57.03%
  Expert Utilization: {'expert_0': 103884, 'expert_1': 54535, 'expert_2': 98409, 'expert_3': 63172}


Training:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 28/50 | Train Loss: 1.0071 | Val Accuracy: 58.12%
  Expert Utilization: {'expert_0': 105543, 'expert_1': 55792, 'expert_2': 98997, 'expert_3': 59668}
  New best model saved to ./models/DMMoE_CIFAR10_best.pth with accuracy: 58.12%


Training:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 29/50 | Train Loss: 0.9877 | Val Accuracy: 58.38%
  Expert Utilization: {'expert_0': 106753, 'expert_1': 53353, 'expert_2': 103299, 'expert_3': 56595}
  New best model saved to ./models/DMMoE_CIFAR10_best.pth with accuracy: 58.38%


Training:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 30/50 | Train Loss: 0.9667 | Val Accuracy: 57.26%
  Expert Utilization: {'expert_0': 104216, 'expert_1': 53616, 'expert_2': 100469, 'expert_3': 61699}


Training:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 31/50 | Train Loss: 0.9509 | Val Accuracy: 57.83%
  Expert Utilization: {'expert_0': 103799, 'expert_1': 53139, 'expert_2': 100989, 'expert_3': 62073}


Training:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 32/50 | Train Loss: 0.9355 | Val Accuracy: 58.46%
  Expert Utilization: {'expert_0': 103965, 'expert_1': 54200, 'expert_2': 100027, 'expert_3': 61808}
  New best model saved to ./models/DMMoE_CIFAR10_best.pth with accuracy: 58.46%


Training:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 33/50 | Train Loss: 0.9119 | Val Accuracy: 58.12%
  Expert Utilization: {'expert_0': 107236, 'expert_1': 51522, 'expert_2': 103501, 'expert_3': 57741}


Training:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 34/50 | Train Loss: 0.8951 | Val Accuracy: 58.36%
  Expert Utilization: {'expert_0': 104157, 'expert_1': 52330, 'expert_2': 101025, 'expert_3': 62488}


Training:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 35/50 | Train Loss: 0.8808 | Val Accuracy: 58.62%
  Expert Utilization: {'expert_0': 105265, 'expert_1': 55479, 'expert_2': 99812, 'expert_3': 59444}
  New best model saved to ./models/DMMoE_CIFAR10_best.pth with accuracy: 58.62%


Training:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 36/50 | Train Loss: 0.8617 | Val Accuracy: 58.85%
  Expert Utilization: {'expert_0': 104422, 'expert_1': 53376, 'expert_2': 101978, 'expert_3': 60224}
  New best model saved to ./models/DMMoE_CIFAR10_best.pth with accuracy: 58.85%


Training:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 37/50 | Train Loss: 0.8482 | Val Accuracy: 59.27%
  Expert Utilization: {'expert_0': 103416, 'expert_1': 52974, 'expert_2': 102928, 'expert_3': 60682}
  New best model saved to ./models/DMMoE_CIFAR10_best.pth with accuracy: 59.27%


Training:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 38/50 | Train Loss: 0.8289 | Val Accuracy: 58.84%
  Expert Utilization: {'expert_0': 104342, 'expert_1': 52326, 'expert_2': 103315, 'expert_3': 60017}


Training:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 39/50 | Train Loss: 0.8171 | Val Accuracy: 58.46%
  Expert Utilization: {'expert_0': 104220, 'expert_1': 52234, 'expert_2': 102675, 'expert_3': 60871}


Training:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 40/50 | Train Loss: 0.8017 | Val Accuracy: 59.08%
  Expert Utilization: {'expert_0': 104277, 'expert_1': 52753, 'expert_2': 103108, 'expert_3': 59862}


Training:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 41/50 | Train Loss: 0.7914 | Val Accuracy: 58.93%
  Expert Utilization: {'expert_0': 104625, 'expert_1': 52511, 'expert_2': 103492, 'expert_3': 59372}


Training:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 42/50 | Train Loss: 0.7827 | Val Accuracy: 59.23%
  Expert Utilization: {'expert_0': 103676, 'expert_1': 52329, 'expert_2': 103339, 'expert_3': 60656}


Training:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 43/50 | Train Loss: 0.7668 | Val Accuracy: 58.96%
  Expert Utilization: {'expert_0': 104046, 'expert_1': 51392, 'expert_2': 103247, 'expert_3': 61315}


Training:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 44/50 | Train Loss: 0.7653 | Val Accuracy: 59.25%
  Expert Utilization: {'expert_0': 104171, 'expert_1': 51327, 'expert_2': 103836, 'expert_3': 60666}


Training:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 45/50 | Train Loss: 0.7605 | Val Accuracy: 59.15%
  Expert Utilization: {'expert_0': 104300, 'expert_1': 51529, 'expert_2': 103528, 'expert_3': 60643}


Training:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 46/50 | Train Loss: 0.7501 | Val Accuracy: 59.38%
  Expert Utilization: {'expert_0': 104170, 'expert_1': 51663, 'expert_2': 103364, 'expert_3': 60803}
  New best model saved to ./models/DMMoE_CIFAR10_best.pth with accuracy: 59.38%


Training:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 47/50 | Train Loss: 0.7398 | Val Accuracy: 59.13%
  Expert Utilization: {'expert_0': 104140, 'expert_1': 51509, 'expert_2': 103590, 'expert_3': 60761}


Training:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 48/50 | Train Loss: 0.7370 | Val Accuracy: 59.29%
  Expert Utilization: {'expert_0': 103995, 'expert_1': 51460, 'expert_2': 103604, 'expert_3': 60941}


Training:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 49/50 | Train Loss: 0.7336 | Val Accuracy: 59.24%
  Expert Utilization: {'expert_0': 103965, 'expert_1': 51436, 'expert_2': 103592, 'expert_3': 61007}


Training:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 50/50 | Train Loss: 0.7374 | Val Accuracy: 59.28%
  Expert Utilization: {'expert_0': 103993, 'expert_1': 51438, 'expert_2': 103602, 'expert_3': 60967}
--- Experiment Finished: DMMoE_CIFAR10 | Best Accuracy: 59.38% ---


'# --- Experiment 2: Homogeneous MoE Baseline ---\n# Note: You can comment out the first experiment while running this one.\nhomo_moe_classifier = HomogeneousMoE_Classifier(d_model=D_MODEL, num_classes=NUM_CLASSES)\nrun_experiment(\n    model=homo_moe_classifier,\n    model_name="HomogeneousMoE_CIFAR10",\n    train_loader=train_loader,\n    test_loader=test_loader,\n    device=device,\n    num_epochs=NUM_EPOCHS,\n    lr=LEARNING_RATE,\n    lb_weight=LB_WEIGHT\n)\n\n# --- Experiment 3: Dense ViT Baseline ---\n# Note: You can comment out the other experiments while running this one.\ndense_vit_classifier = DenseViT_Classifier(d_model=D_MODEL, num_classes=NUM_CLASSES, num_layers=2)\nrun_experiment(\n    model=dense_vit_classifier,\n    model_name="DenseViT_CIFAR10",\n    train_loader=train_loader,\n    test_loader=test_loader,\n    device=device,\n    num_epochs=NUM_EPOCHS,\n    lr=LEARNING_RATE\n)'