In [42]:
from typing import Optional

import torch
from colossalai.moe.layers import SparseMLP
from torch import nn
from transformers import LlamaConfig

In [None]:
class OpenMoeConfig(LlamaConfig):
    def __init__(
            self,
            num_experts: int,
            moe_layer_interval: int,
            router_topk: int = 2,
            router_capacity_factor_train: float = 1.25,
            router_capacity_factor_eval: float = 2.0,
            router_min_capacity: int = 4,
            router_noisy_policy: str = None,
            router_drop_tks: bool = True,
            router_aux_loss_factor: float = 0.01,
            router_z_loss_factor: float = 0.0001,
            mlp_gated: bool = True,
            label_smoothing: float = 0.001,
            z_loss_factor: float = 0.01,
            enable_load_balance: bool = False,
            load_balance_tolerance: float = 0.1,
            load_balance_beam_width: int = 8,
            load_balance_group_swap_factor: float = 0.4,
            enable_kernel: bool = False,
            enable_comm_overlap: bool = False,
            enable_hierarchical_alltoall: bool = False,
            **kwargs
    ):
        self.num_experts = num_experts
        self.moe_layer_interval = moe_layer_interval
        self.router_topk = router_topk
        self.router_capacity_factor_train = router_capacity_factor_train
        self.router_capacity_factor_eval = router_capacity_factor_eval
        self.router_min_capacity = router_min_capacity
        self.router_noisy_policy = router_noisy_policy
        self.router_drop_tks = router_drop_tks
        self.router_aux_loss_factor = router_aux_loss_factor
        self.router_z_loss_factor = router_z_loss_factor
        self.mlp_gated = mlp_gated
        self.label_smoothing = label_smoothing
        self.z_loss_factor = z_loss_factor
        self.enable_load_balance = enable_load_balance
        self.load_balance_tolerance = load_balance_tolerance
        self.load_balance_beam_width = load_balance_beam_width
        self.load_balance_group_swap_factor = load_balance_group_swap_factor
        self.enable_kernel = enable_kernel
        self.enable_comm_overlap = enable_comm_overlap
        self.enable_hierarchical_alltoall = enable_hierarchical_alltoall

        super().__init__(**kwargs)

In [41]:
def swiglu_act_fn(x):
    """Gated linear unit activation function.
    Args:
        x : input array
        axis: the axis along which the split should be computed (default: -1)
    """
    size = x.shape[-1]
    assert size % 2 == 0, "axis size must be divisible by 2"
    x1, x2 = torch.split(x, size // 2, -1)
    return x1 * (x2 * torch.sigmoid(x2))


class OpenMoeMLP(torch.nn.Module):
    def __init__(self, config: OpenMoeConfig):
        super().__init__()
        self.ffn_dim = config.intermediate_size
        self.hidden_dim = config.hidden_size

        self.gate_proj = nn.Linear(self.hidden_size, self.ffn_dim * 2)
        self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size)
        self.down_proj = nn.Linear(self.ffn_dim, self.hidden_size)

    def forward(self, hidden_states):
        return self.down_proj(swiglu_act_fn(self.gate_proj(hidden_states)) * self.up_proj(hidden_states))



In [39]:
class OpenMoeSparseMLP(torch.nn.Module):
    def __init__(
            self,
            config,
            num_experts: int,
            hidden_size: int,
            intermediate_size: int,
            router_top_k: int = 1,
            router_loss: bool = True,
            router_norm: bool = False,
            router_capacity_factor_train: float = 1.25,
            router_capacity_factor_eval: float = 2.0,
            router_min_capacity: int = 4,
            router_drop_tks: bool = True,
            mlp_activation: Optional[str] = None,
            mlp_gated: bool = False,
            enable_load_balance: bool = False,
            load_balance_tolerance: float = 0.1,
            load_balance_beam_width: int = 8,
            load_balance_group_swap_factor: float = 0.4,
            enable_kernel: bool = False,
            enable_comm_overlap: bool = False,
            enable_hierarchical_comm: bool = True,
            return_gate_logits: bool = False
    ):
        super().__init__()
        self.hidden_size = hidden_size
        self.intermediate_size = intermediate_size
        self.num_experts = num_experts
        self.gated = mlp_gated
        self.return_gate_logits = return_gate_logits
        self.router_loss = router_loss
        self.router_norm = router_norm

        # moe router
        router_cls = get_router_cls(router_top_k)
        self.topk = router_top_k

        self.gate_weight = torch.nn.Linear(self.hidden_size, num_experts)

        self.experts = nn.ModuleList([OpenMoeMLP(config) for _ in range(self.num_experts)])

        # self.experts = MLPExperts(
        #     num_experts=self.num_experts,
        #     expert_parallel=self.expert_parallel,
        #     hidden_size=self.hidden_size,
        #     intermediate_size=self.intermediate_size,
        #     activation=mlp_activation,
        #     gated=mlp_gated,
        #     use_kernel=self.enable_kernel,
        # )

NameError: name 'Optional' is not defined

In [None]:
moe_layer = SparseMLP(
    num_experts=8,
    hidden_size=16,
    intermediate_size=64,
    router_top_k=2,
    router_capacity_factor_train=1.25,
    router_capacity_factor_eval=2.0,
    router_min_capacity=4,
    router_noisy_policy=None,
    router_drop_tks=True,
    mlp_activation='swiglu',
    mlp_gated=True,
    enable_load_balance=False,
    load_balance_tolerance=0.1,
    load_balance_beam_width=8,
    load_balance_group_swap_factor=0.4,
    enable_kernel=False,
    enable_comm_overlap=False,
)

In [None]:
for n, p in moe_layer.named_parameters():
    print(n, p.shape)

In [None]:
input_hidden_states = torch.randn(4, 12, 16)

In [None]:
moe_layer(input_hidden_states).shape