<a href="https://colab.research.google.com/github/Justin-Qian/ML_Homework_Collection/blob/main/MoMa_MoE.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [None]:
# Examine PyTorch Version
print("PyTorch version:", torch.__version__)

# Check CUDA Availability
print("CUDA available:", torch.cuda.is_available())

# Choose Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

PyTorch version: 2.6.0+cu124
CUDA available: True
Using device: cuda


Expert

In [None]:
class Expert(nn.Module):
    """FFN_SwiGLU for each expert"""
    def __init__(self, input_dim, output_dim):
        super().__init__()
        self.w1 = nn.Linear(input_dim, output_dim * 2, bias=False)
        self.w2 = nn.Linear(input_dim, output_dim * 2, bias=False)
        self.w3 = nn.Linear(output_dim * 2, output_dim, bias=False)

    def forward(self, x):
        x1 = F.linear(x, self.w1.weight)  # Compute Wx (No bias)
        x2 = F.linear(x, self.w2.weight)  # Compute Wx (No bias)
        hidden = F.silu(x1) * x2  # SwiGLU: SiLU(Gate) ⊗ Linear Transformation
        return F.linear(hidden, self.w3.weight)  # Final transformation

Router

In [None]:
class Router(nn.Module):
    """Expert-Choice (EC) Routing for Mixture of Experts"""
    def __init__(self, input_dim, num_experts, capacity_factor=1.0):
        """
        Args:
            input_dim (int): The input feature dimension.
            num_experts (int): The number of experts.
            capacity_factor (float): The fraction of tokens each expert can process.
        """
        super().__init__()
        self.num_experts = num_experts
        self.capacity_factor = capacity_factor
        self.WMg = nn.Linear(input_dim, num_experts)  # Project input features to expert selection scores

    def forward(self, x):
        """
        Args:
            x (Tensor): Input tensor of shape [batch_size, input_dim].
            modality_mask (Tensor): Boolean mask indicating which tokens belong to the current modality.

        Returns:
            Tensor: Routing weights of shape [batch_size, num_experts], with unselected tokens set to zero.
        """
        batch_size = x.shape[0]

        # Compute token-to-expert affinity scores
        scores = self.WMg(x)  # Shape: [batch_size, num_experts]
        routing_weights = torch.sigmoid(scores)  # Shape: [batch_size, num_experts]

        # Keep only top-ke tokens for each expert
        ke = int(batch_size * (self.capacity_factor / self.num_experts))
        top_ke_routing = self.top_k_expert_selection(routing_weights, ke)

        return top_ke_routing

    def top_k_expert_selection(self, routing_weights, ke):
        """
        Select top-ke tokens for each expert.

        Args:
            routing_weights (Tensor): Shape [batch_size, num_experts].
            ke (int): Number of tokens each expert can process.

        Returns:
            Tensor: Updated routing weights with only top-ke tokens selected per expert.
        """
        batch_size, num_experts = routing_weights.shape
        mask = torch.zeros_like(routing_weights)  # Initialize mask with zeros

        # Calculate topk and create mask
        top_values, top_indices = torch.topk(routing_weights, ke, dim=0)
        mask.scatter_(0, top_indices, 1)

        # Apply mask to routing weights (only top-ke tokens per expert remain)
        return routing_weights * mask


MoE Layer

In [None]:
class MoE_Layer(nn.Module):
    """Mixture of Experts (MoE) layer with expert groups."""
    def __init__(self, input_dim, output_dim, num_experts=4):
        super().__init__()
        self.experts = nn.ModuleList([Expert(input_dim, output_dim) for _ in range(num_experts)])
        self.router = Router(input_dim, num_experts)

    def forward(self, x):
        """
        Args:
            x (Tensor): Input tensor of shape [batch_size, input_dim].
        Returns:
            Tensor: The weighted sum of expert outputs.
        """
        # Compute weights
        routing_weights = self.router(x)  # [batch_size, num_experts]
        mask = routing_weights != 0       # [batch_size, num_experts]

        # Reshape routing_weights to [batch_size, num_experts, 1]
        routing_weights = routing_weights.unsqueeze(-1)

        # Prepare a list to hold expert outputs, initialized with zeros
        batch_size, num_experts = mask.shape
        output_dim = self.experts[0](x).shape[-1]
        device = x.device
        expert_outputs = torch.zeros(batch_size, num_experts, output_dim, device=device)

        # Compute expert(x) **only when** routing_weight != 0 (Core!!!This is why MOE is efficient)
        for i, expert in enumerate(self.experts):
            active = mask[:, i]
            if active.any():
                expert_out = expert(x[active])  # [num_active, output_dim]
                expert_outputs[active, i] = expert_out

        # Compute weighted sum of expert outputs
        output = torch.sum(routing_weights * expert_outputs, dim=1)  # [batch_size, output_dim]
        return output

Model Architecture

In [None]:
class MoMa_Module(nn.Module):
    """Complete MoMa module with hierarchical routing and modality-aware experts."""
    def __init__(self, input_dim, output_dim, num_text_experts=4, num_image_experts=4):
        super().__init__()
        # Text MoE
        self.text_moe = MoE_Layer(input_dim, output_dim, num_text_experts)
        # Image MoE
        self.image_moe = MoE_Layer(input_dim, output_dim, num_image_experts)

    def forward(self, x, modality_mask):
        """
        Args:
            x (Tensor): Input tensor of shape [batch_size, input_dim].
            modality_mask (Tensor): Boolean mask of shape [batch_size],
                                    True for text ("T"), False for image ("I").

        Returns:
            Tensor: Output from the MoE layer after expert routing, preserving original order.
        """
        batch_size = x.shape[0]

        # Create masks
        text_mask = modality_mask  # True for text tokens
        image_mask = ~modality_mask  # False for image tokens

        # Process Text Tokens
        text_output = torch.zeros(batch_size, x.shape[-1], device=x.device)  # Placeholder
        if text_mask.any():
            text_x = x[text_mask]
            text_output[text_mask] = self.text_moe(text_x)

        # Process Image Tokens
        image_output = torch.zeros(batch_size, x.shape[-1], device=x.device)  # Placeholder
        if image_mask.any():
            image_x = x[image_mask]
            image_output[image_mask] = self.image_moe(image_x)

        # Merge Outputs (preserving original order)
        output = torch.zeros(batch_size, x.shape[-1], device=x.device)  # Final output tensor
        output[text_mask] = text_output[text_mask]
        output[image_mask] = image_output[image_mask]

        return output


Test

In [None]:
# Define input parameters
input_dim = 128  # Transformer hidden dimension
output_dim = 128  # Output feature dimension
batch_size = 8  # Number of tokens per batch

# Generate random input tokens
x = torch.randn(batch_size, input_dim).to(device)

# Generate a random modality mask (True for text, False for image)
modality_mask = torch.randint(0, 2, (batch_size,)).bool().to(device)

# Initialize MoMa module
moma = MoMa_Module(input_dim, output_dim).to(device)

# Forward pass
output = moma(x, modality_mask)

# Print results
print("Modality mask:", modality_mask)
print("Output shape:", output.shape)

Modality mask: tensor([False, False,  True,  True,  True,  True, False,  True],
       device='cuda:0')
Output shape: torch.Size([8, 128])


In [None]:
print(moma)

MoMa_Module(
  (text_moe): MoE_Layer(
    (experts): ModuleList(
      (0-3): 4 x Expert(
        (w1): Linear(in_features=128, out_features=256, bias=False)
        (w2): Linear(in_features=128, out_features=256, bias=False)
        (w3): Linear(in_features=256, out_features=128, bias=False)
      )
    )
    (router): Router(
      (WMg): Linear(in_features=128, out_features=4, bias=True)
    )
  )
  (image_moe): MoE_Layer(
    (experts): ModuleList(
      (0-3): 4 x Expert(
        (w1): Linear(in_features=128, out_features=256, bias=False)
        (w2): Linear(in_features=128, out_features=256, bias=False)
        (w3): Linear(in_features=256, out_features=128, bias=False)
      )
    )
    (router): Router(
      (WMg): Linear(in_features=128, out_features=4, bias=True)
    )
  )
)
