In [11]:
import math
from typing import Optional, Tuple

import torch
from colossalai.moe.layers import SparseMLP
from torch import nn
from torch.nn import functional as F
from transformers import LlamaConfig

In [12]:
class OpenMoeConfig(LlamaConfig):
    def __init__(
            self,
            num_experts: int,
            moe_layer_interval: int,
            router_topk: int = 2,
            router_capacity_factor_train: float = 1.25,
            router_capacity_factor_eval: float = 2.0,
            router_min_capacity: int = 4,
            router_noisy_policy: str = None,
            router_drop_tks: bool = True,
            router_aux_loss_factor: float = 0.01,
            router_z_loss_factor: float = 0.0001,
            mlp_gated: bool = True,
            label_smoothing: float = 0.001,
            z_loss_factor: float = 0.01,
            enable_load_balance: bool = False,
            load_balance_tolerance: float = 0.1,
            load_balance_beam_width: int = 8,
            load_balance_group_swap_factor: float = 0.4,
            enable_kernel: bool = False,
            enable_comm_overlap: bool = False,
            enable_hierarchical_alltoall: bool = False,
            **kwargs
    ):
        self.num_experts = num_experts
        self.moe_layer_interval = moe_layer_interval
        self.router_topk = router_topk
        self.router_capacity_factor_train = router_capacity_factor_train
        self.router_capacity_factor_eval = router_capacity_factor_eval
        self.router_min_capacity = router_min_capacity
        self.router_noisy_policy = router_noisy_policy
        self.router_drop_tks = router_drop_tks
        self.router_aux_loss_factor = router_aux_loss_factor
        self.router_z_loss_factor = router_z_loss_factor
        self.mlp_gated = mlp_gated
        self.label_smoothing = label_smoothing
        self.z_loss_factor = z_loss_factor
        self.enable_load_balance = enable_load_balance
        self.load_balance_tolerance = load_balance_tolerance
        self.load_balance_beam_width = load_balance_beam_width
        self.load_balance_group_swap_factor = load_balance_group_swap_factor
        self.enable_kernel = enable_kernel
        self.enable_comm_overlap = enable_comm_overlap
        self.enable_hierarchical_alltoall = enable_hierarchical_alltoall

        super().__init__(**kwargs)

In [25]:
def swiglu_act_fn(x):
    """Gated linear unit activation function.
    Args:
        x : input array
        axis: the axis along which the split should be computed (default: -1)
    """
    size = x.shape[-1]
    assert size % 2 == 0, "axis size must be divisible by 2"
    x1, x2 = torch.split(x, size // 2, -1)
    return x1 * (x2 * torch.sigmoid(x2))


class OpenMoeMLP(torch.nn.Module):
    def __init__(self, config: OpenMoeConfig):
        super().__init__()
        assert config.hidden_act=="swiglu"
        self.ffn_dim = config.intermediate_size
        self.hidden_dim = config.hidden_size

        self.gate_proj = nn.Linear(self.hidden_dim, self.ffn_dim * 2, bias=False)
        self.up_proj = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
        self.down_proj = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False)

    def forward(self, hidden_states):
        return self.down_proj(swiglu_act_fn(self.gate_proj(hidden_states)) * self.up_proj(hidden_states))



In [70]:
def moe_cumsum(inputs: torch.Tensor):
    return torch.cumsum(inputs, dim=0) - 1

class OpenMoeTop2Router(torch.nn.Module):
    def __init__(self, config: OpenMoeConfig):
        super().__init__()
        assert config.router_topk == 2
        self.k_value = 2
        self.capacity_factor_train = config.router_capacity_factor_train
        self.capacity_factor_eval = config.router_capacity_factor_eval
        self.min_capacity = config.router_min_capacity
        self.drop_tks = config.router_drop_tks

    def get_capacity(self, logits_shape):
        capacity_factor = self.capacity_factor_train if self.training else self.capacity_factor_eval
        capacity = math.floor(self.k_value * capacity_factor * logits_shape[-2] / logits_shape[-1])
        capacity += capacity % 2
        capacity = max(capacity, self.min_capacity)
        assert capacity > 0
        return int(capacity)

    def forward(self, inputs: torch.Tensor) -> Tuple:
        assert inputs.dtype == torch.float, "Router input should be FP32"

        probs = F.softmax(inputs, dim=-1)
        num_experts = probs.size(-1)
        capacity = self.get_capacity(inputs.shape)

        top1_idx = torch.argmax(probs, dim=-1)
        mask1 = F.one_hot(top1_idx, num_classes=num_experts).to(torch.int32)
        logits_except1 = probs.masked_fill(mask1.bool(), float("-inf"))
        top2_idx = torch.argmax(logits_except1, dim=-1)
        mask2 = F.one_hot(top2_idx, num_classes=num_experts).to(torch.int32)

        rank1 = moe_cumsum(mask1)  # rank1: [s, e]
        rank2 = moe_cumsum(mask2)
        rank2 += torch.sum(mask1, dim=-2, keepdim=True)

        mask1 *= torch.lt(rank1, capacity)
        mask2 *= torch.lt(rank2, capacity)
        used_capacity = mask1.sum(dim=0) + mask2.sum(dim=0)

        rank1 = torch.sum(mask1 * rank1, dim=-1)
        rank2 = torch.sum(mask2 * rank2, dim=-1)

        weight1 = mask1 * probs.type_as(inputs)
        weight2 = mask2 * probs.type_as(inputs)

        cb_weight = torch.zeros(inputs.shape + (capacity,), device=inputs.device)
        sec_mask = torch.zeros_like(cb_weight, dtype=torch.bool)
        indices = torch.arange(0, inputs.shape[0], device=inputs.device)
        cb_weight[indices, top1_idx[indices], rank1[indices]] += weight1[indices, top1_idx[indices]]
        cb_weight[indices, top2_idx[indices], rank2[indices]] += weight2[indices, top2_idx[indices]]
        sec_mask[indices, top1_idx[indices], rank1[indices]] |= mask1.bool()[indices, top1_idx[indices]]
        sec_mask[indices, top2_idx[indices], rank2[indices]] |= mask2.bool()[indices, top2_idx[indices]]

        return used_capacity, cb_weight, sec_mask

In [81]:
class OpenMoeSparseMLP(torch.nn.Module):
    def __init__(self, config: OpenMoeConfig):
        super().__init__()
        self.hidden_size = config.hidden_size
        self.intermediate_size = config.intermediate_size
        self.num_experts = config.num_experts

        self.gate = torch.nn.Linear(self.hidden_size, config.num_experts, bias=False)

        self.experts = nn.ModuleList([OpenMoeMLP(config) for _ in range(self.num_experts)])
        self.router = OpenMoeTop2Router(config)

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        # reshape the input tokens
        tokens = hidden_states.reshape(-1, self.hidden_size)
        inputs = hidden_states

        # the data type of the inputs in the gating should be fp32
        fp32_input = tokens.to(torch.float)
        self.gate = self.gate.to(torch.float)
        gate_output = self.gate(fp32_input)

        used_capacity, *route_result_list = self.router(inputs=gate_output)

        sec_mask_f = route_result_list[1].type_as(inputs)
        dispatch_data = torch.matmul(sec_mask_f.permute(1, 2, 0), tokens)

        expert_output = self._local_process(dispatch_data)

        combine_weights = route_result_list[0].type_as(inputs)
        combine_weights = combine_weights.view(combine_weights.shape[0], -1)
        expert_output = expert_output.view(-1, expert_output.shape[-1])
        ans = torch.matmul(combine_weights, expert_output)

        ans = ans.reshape(inputs.shape)
        return ans

    def _local_process(self, expert_in: torch.Tensor) -> torch.Tensor:
        expert_in = expert_in.unsqueeze(0)
        x = expert_in

        # Copied from colossalai MLPExperts class
        e = x.size(1)
        h = x.size(-1)

        x = x.transpose(0, 1)
        inshape = x.shape
        x = x.reshape(e, -1, h)

        x = [self.experts[i](x[i]) for i in range(e)]

        x = torch.cat([x[i].unsqueeze(0) for i in range(e)], dim=0)
        x = x.reshape(inshape)
        x = x.transpose(0, 1).contiguous()

        expert_out = x
        return expert_out



In [82]:
hf_config = OpenMoeConfig(
    moe_layer_interval=6,
    num_experts=8,
    hidden_size=16,
    intermediate_size=64,
    router_top_k=2,
    router_capacity_factor_train=1.25,
    router_capacity_factor_eval=2.0,
    router_min_capacity=4,
    router_noisy_policy=None,
    router_drop_tks=True,
    mlp_activation='swiglu',
    mlp_gated=True,
    enable_load_balance=False,
    load_balance_tolerance=0.1,
    load_balance_beam_width=8,
    load_balance_group_swap_factor=0.4,
    enable_kernel=False,
    enable_comm_overlap=False,
)

In [83]:
hf_moe_layer = OpenMoeSparseMLP(hf_config)

In [84]:
moe_layer = SparseMLP(
    num_experts=8,
    hidden_size=16,
    intermediate_size=64,
    router_top_k=2,
    router_capacity_factor_train=1.25,
    router_capacity_factor_eval=2.0,
    router_min_capacity=4,
    router_noisy_policy=None,
    router_drop_tks=True,
    mlp_activation='swiglu',
    mlp_gated=True,
    enable_load_balance=False,
    load_balance_tolerance=0.1,
    load_balance_beam_width=8,
    load_balance_group_swap_factor=0.4,
    enable_kernel=False,
    enable_comm_overlap=False,
)

In [85]:
for i in range(hf_config.num_experts):
    hf_moe_layer.experts[i].gate_proj.weight.data.copy_(moe_layer.experts.wi_gate.data[i].t())
    hf_moe_layer.experts[i].up_proj.weight.data.copy_(moe_layer.experts.wi_up.data[i].t())
    hf_moe_layer.experts[i].down_proj.weight.data.copy_(moe_layer.experts.wo.data[i].t())
    hf_moe_layer.gate.weight.data.copy_(moe_layer.gate_weight.data)

In [86]:
for n, p in moe_layer.named_parameters():
    print(n, p.shape)

gate_weight torch.Size([8, 16])
experts.wi_gate torch.Size([8, 16, 128])
experts.wi_up torch.Size([8, 16, 64])
experts.wo torch.Size([8, 64, 16])


In [87]:
for n, p in moe_layer.named_parameters():
    print(n, p.shape)

gate_weight torch.Size([8, 16])
experts.wi_gate torch.Size([8, 16, 128])
experts.wi_up torch.Size([8, 16, 64])
experts.wo torch.Size([8, 64, 16])


In [88]:
input_hidden_states = torch.randn(4, 12, 16)

In [89]:
moe_layer(input_hidden_states).shape

torch.Size([4, 12, 16])

In [90]:
hf_moe_layer(input_hidden_states).shape

torch.Size([4, 12, 16])

In [93]:
torch.allclose(moe_layer(input_hidden_states), hf_moe_layer(input_hidden_states))

True

In [94]:
for i in range(100):
    input_hidden_states = torch.randn(11, 12, 16)
    assert torch.allclose(moe_layer(input_hidden_states), hf_moe_layer(input_hidden_states))