In [1]:
# imports
import torch
import torch.nn as nn
import torch.nn.functional as F


## I. Basic MOE

In [5]:
class BasicExpert(nn.Module): 
    def __init__(self, feature_in, feature_out): 
        super().__init__()
        self.fc = nn.Linear(feature_in, feature_out)

    def forward(self, x): 
        return self.fc(x)

        

In [6]:
class BasicMOE(nn.Module): 
    def __init__(self, feature_in, feature_out, num_experts): 
        super().__init__()
        self.gate = nn.Linear(feature_in,  num_experts)
        # output shape: (batch_size, num_experts)
        self.experts = nn.ModuleList(
            BasicExpert(
                feature_in, feature_out
            ) for _ in range(num_experts)
        )

    def forward(self, x): 
        # x shape is (batch, feature_in)
        # feature_in / hidden_size hidden_dim
        expert_weights = self.gate(x)
        expert_out_list = [
            expert(x) for expert in self.experts
        ] # each expert will output a shape of (batch_size, feature_out)

        expert_outputs = [
            expert_out.unsqueeze(1)
            for expert_out in expert_out_list
        ] 

        # expert out is (b, 1, feature_out)
        expert_output = torch.concat(
            expert_outputs, 
            dim=1
        )


        # expert_weights
        expert_weights = F.softmax(expert_weights, dim=1)
        # expert_weights shape (batch_size, num_experts)

        expert_weights = expert_weights.unsqueeze(1)

        # expected output shape (batch_size, feature_out)

        output = expert_weights @ expert_output
        return output.squeeze(1)
    

def test_basic_moe(): 
    x = torch.rand(4, 512)
    basic_moe = BasicMOE(512, 128, 4)
    output = basic_moe(x)
    print(output.shape)

test_basic_moe()







torch.Size([4, 128])


## II. Sparse MOE

#### Sparse MOE would pick top k experts, and weights the outputs from these k experts. Tokens will be processed by ALL of the k experts.

In [32]:
# mistral MOE

class MOEConfig: 
    def __init__(
            self, 
            hidden_dim, 
            expert_number, 
            top_k, 
            shared_experts_number=2): 
        self.hidden_dim = hidden_dim
        self.expert_number = expert_number
        self.top_k = top_k
        self.shared_experts_number = shared_experts_number

    
class MOERouter(nn.Module): 
    def __init__(self, config): 
        super().__init__()
        self.gate = nn.Linear(config.hidden_dim, config.expert_number)
        # but only choose top k experts

        self.expert_number = config.expert_number
        self.top_k = config.top_k

    def forward(self, x): 
        # E.g., expert_num = 8; k = 2
        router_logits = self.gate(x) # (batch * seq_len, expert_number)

        # calculate the possibility of each expert
        router_probs = F.softmax(router_logits, dim=1, dtype=torch.float)

        # calcualte top_k experts outputs
        # weight and the index/position of the expert
        # topk can be backpropagated
        router_weights, selected_experts_indices = torch.topk(
            input=router_probs,
            k=self.top_k,
            dim=-1
        )

        # shape of router_weights and selected_experts_indices: 
        # (batch * seq_len, top_k)

        # normalization
        router_weights = router_weights / router_weights.sum(
            dim=-1, 
            keepdim=True
        )
        router_weights = router_weights.to(x.dtype)

        expert_masks = F.one_hot(
            selected_experts_indices,
            num_classes=self.expert_number
        ) # (batch * seq_len, top_k, expert_num)

        expert_masks = expert_masks.permute(2, 1, 0)
        # (expert_num, top_k, batch*seq_len)

        return router_logits, router_weights, selected_experts_indices, expert_masks
        # router_logits: (batch*seq_len, expert_num)
        # router_weights: (batch*seq_len, top_k)
        # selected_expert_indices: (batch*seq_len, top_k)
        # expert_mask: (expert_number, top_k, batch * seq_len)



class SparseMOE(nn.Module): 
    def __init__(self, config): 
        super().__init__()
        self.config = config

        self.hidden_dim = config.hidden_dim
        self.expert_number = config.expert_number
        self.top_k = config.top_k

        # init experts
        self.experts = nn.ModuleList(
            BasicExpert(
                config.hidden_dim, 
                config.hidden_dim, 
            ) for _ in range(config.expert_number)
        )
        self.router = MOERouter(config)

    def forward(self, x): 
        # x shape: (batch, seq_len, hidden_dim)
        batch_size, seq_len, hidden_dim = x.size()

        # token dimension calculation
        # reshape x (batch * seq_len, hidden_dim)
        hidden_states = x.view(-1, hidden_dim)

        # expert calculation
        router_logits, router_weights, selected_experts_indices, expert_masks = self.router(hidden_states)
        # expert_mask shape: (exper_num, top_k, batch*seq_len)
        # selected_experts_indices shape: (batch_size * seq_len, top_k)
        
        # the final hidden_states shape: (batch*seq_len, hidden_dim)
        final_hidden_states = torch.zeros(
            (batch_size * seq_len, hidden_dim), 
            dtype=hidden_states.dtype,
            device=hidden_states.device
        )

        # retrieve each expert
        # add the hidden_states of the expert's token to final_hidden_states

        for expert_idx in range(self.expert_number): 
            expert_layer = self.experts[expert_idx]

            current_expert_mask = expert_masks[expert_idx]

            router_weights_idx, top_x = torch.where(current_expert_mask)
            # router_weights_idx is the index of the expert  
            # top_x is the index of the token in the batch * seq_len

            # router_weights_idx is used to pick the weight
            # top_x is used to pick the hidden_states
            # they are both one-dim values

            current_state = hidden_states.unsqueeze(0)[:, top_x, :].reshape(-1, hidden_dim)

            current_state = expert_layer(current_state)

            # hidden_states shape: (batch * seq_len, hidden_dim)
            # hidden_states.unsequeeze(0) shape: (1, batch * seq_len, hidden_dim)
            # current_state shape: (selected_token_num, hidden_dim)

            current_token_router_weight = router_weights[top_x, router_weights_idx]
            # router_weight shape: (batch_size * seq_len, top_k)
            # current_token_router_weight shape: (selected_token_number)

            current_token_router_weight = current_token_router_weight.unsqueeze(-1)
            # current_token_router_weight shape: (selected_token_number, 1)

            current_hidden_states = current_state * current_token_router_weight
            # curret_state shape:  (selected_token_num, hidden_dim)
            # current_token_router_weight shape: (selected_token_number, 1)

            final_hidden_states.index_add (
                0, 
                top_x, 
                current_hidden_states.to(hidden_states.dtype)
            )

        # revert the final_hidden_states to original shape
        final_hidden_states = final_hidden_states.reshape(batch_size, seq_len, hidden_dim)

        return final_hidden_states


def test_token_level_moe(): 
    x = torch.rand(2, 4, 16)
    config = MOEConfig(16, 2, 2)
    token_level_moe = SparseMOE(config)
    out = token_level_moe(x)
    print(out[0].shape, out[1].shape)

test_token_level_moe()

torch.Size([4, 16]) torch.Size([4, 16])


## III. ShareExpert SparseMOE

#### Compare to the SparseMOE, SESMOE has shared expert models, which all the token would be passed to the shared experts, and each token will pick the top k experts with the calculated Router weights. Finally, combine the output from top k experts and shared experts, weight them and summerize them.

In [34]:
class SharedExpertMOE(nn.Module): 
    def __init__(self, config) -> None:
        super().__init__()
        self.config = config
        self.router_experts_moe = SparseMOE(config)
        self.shared_experts = nn.ModuleList(
            [
                BasicExpert(self.config.hidden_dim, self.config.hidden_dim)
            ]
        )

    def forward(self, x): 
        # x shape: (batch_size, seq_len, hidden_dim)
        batch_size, seq_len, hidden_dim = x.size()

        shared_experts_output_list = [
            expert(x) for expert in self.shared_experts
        ]

        shared_expert_output = torch.stack(
            shared_experts_output_list, 
            dim=0
        )
        # shape (shared_experts_num, batch_size, seq_len, hidden_dim)

        shared_expert_output = shared_expert_output.sum(dim=0, keepdim=False)
        # shape (batch_size, seq_len, hidden_dim)

        sparse_moe_out, router_logits = self.router_experts_moe(
            x
        )

        output = shared_expert_output + sparse_moe_out

        return output, router_logits


def test_share_expert_moe():
    x = torch.rand(2, 4, 16)
    config = MOEConfig(16, 2, 2)
    share_expert_moe = SharedExpertMOE(config)
    out = share_expert_moe(x)
    print(out[0].shape, out[1].shape)


test_share_expert_moe()

torch.Size([2, 4, 16]) torch.Size([4, 16])


In [None]:
def switch_load_balancing_loss(router_logits: torch.Tensor, num_experts: int) -> torch.Tensor:
    """
    
    Args:
        router_logits: shape [batch_size * sequence_length, num_experts]
        num_experts
    
    Returns:
        total_loss: auxiliary_loss + z_loss
    """
    
    router_probs = torch.softmax(router_logits, dim=-1)  # [b*s, num_experts]
    
   
    _, selected_experts = torch.topk(router_probs, k=2, dim=-1)  # [b*s]
    

    mask = torch.nn.functional.one_hot(selected_experts, num_experts).float()  # [b*s, num_experts]
    
    
    expected_load = torch.ones_like(router_probs) / num_experts
    
   
    actual_load = mask.mean(dim=0)  # [num_experts]
    
 
    aux_loss = torch.sum(actual_load * router_probs.mean(dim=0)) * num_experts
    
  
    z_loss = torch.mean(torch.square(router_logits))
    z_loss_weight = 0.001  
    
    
    total_loss = aux_loss + z_loss * z_loss_weight
    
    return total_loss

def test_moe_training():
    # Create a simple dataset
    batch_size = 32
    seq_len = 16
    hidden_dim = 32
    num_batches = 100
    
    # Initialize model and optimizer
    config = MOEConfig(hidden_dim=hidden_dim, 
                      expert_number=4,
                      top_k=2,
                      shared_experts_number=2)
    model = SharedExpertMOE(config)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    
    # Training loop
    model.train()
    for batch in range(num_batches):
        # Generate random input data
        x = torch.randn(batch_size, seq_len, hidden_dim)
        target = torch.randn(batch_size, seq_len, hidden_dim)
        
        # Forward pass
        output, router_logits = model(x)

        # Compute losses
        # MSE loss for prediction
        mse_loss = F.mse_loss(output, target)
        
        aux_loss = switch_load_balancing_loss(router_logits, config.expert_number)
        # Combined loss
        total_loss = mse_loss + 0.01 * aux_loss
        
        # Backward pass and optimize
        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()
        
        if batch % 10 == 0:
            print(f"Batch {batch}, Loss: {total_loss.item():.4f} "
                  f"(MSE: {mse_loss.item():.4f}, Aux: {aux_loss.item():.4f})")

# Run the training test
test_moe_training()

In [20]:
a = torch.rand(2, 2, 3)

In [21]:
batch_size, seq_len, hidden_dim = a.size()
batch_size, seq_len, hidden_dim

(2, 2, 3)

In [25]:
a

tensor([[[0.8919, 0.8194, 0.3395],
         [0.8881, 0.0511, 0.7392]],

        [[0.1639, 0.6840, 0.3140],
         [0.5633, 0.6653, 0.5879]]])

In [24]:
a_sum = a.sum(dim=-1, keepdim=False)
a_sum

tensor([[2.0508, 1.6783],
        [1.1619, 1.8164]])