In [4]:
import torch
from torch import nn
import torch.nn.functional as F

import math
from inspect import isfunction

# constants

MIN_EXPERT_CAPACITY = 4

# helper functions

def default(val, default_val):
    default_val = default_val() if isfunction(default_val) else default_val
    return val if val is not None else default_val

def cast_tuple(el):
    return el if isinstance(el, tuple) else (el,)

# tensor related helper functions

def top1(t):
    # 最后一维的第一大
    values, index = t.topk(k=1, dim=-1)
    # map(function,iterable)map() 会根据提供的函数对指定序列做映射。
    # 第一个参数 function 以参数序列中的每一个元素调用 function 函数，返回包含每次 function 函数返回值的新列表。
    # squeeze去掉dim维，这一维的维度必须为1，相当于降维，去掉维度为1的没用的维度
    values, index = map(lambda x: x.squeeze(dim=-1), (values, index))
    return values, index

def cumsum_exclusive(t, dim=-1):
    num_dims = len(t.shape)
    num_pad_dims = - dim - 1
    pre_padding = (0, 0) * num_pad_dims
    pre_slice   = (slice(None),) * num_pad_dims
    padded_t = F.pad(t, (*pre_padding, 1, 0)).cumsum(dim=dim)
    return padded_t[(..., slice(None, -1), *pre_slice)]

# pytorch one hot throws an error if there are out of bound indices.
# tensorflow, in contrast, does not throw an error
def safe_one_hot(indexes, max_length):
    max_index = indexes.max() + 1
    return F.one_hot(indexes, max(max_index + 1, max_length))[..., :max_length]

def init_(t):
    dim = t.shape[-1]
    std = 1 / math.sqrt(dim)
    return t.uniform_(-std, std)

# activations

class GELU_(nn.Module):
    def forward(self, x):
        return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))

GELU = nn.GELU if hasattr(nn, 'GELU') else GELU_

# expert class

class Experts(nn.Module):
    def __init__(self,
        dim,
        num_experts = 16,
        hidden_dim = None,
        activation = GELU):
        super().__init__()

        hidden_dim = default(hidden_dim, dim * 4)
        num_experts = cast_tuple(num_experts)

        w1 = torch.zeros(*num_experts, dim, hidden_dim)
        w2 = torch.zeros(*num_experts, hidden_dim, dim)

        w1 = init_(w1)
        w2 = init_(w2)

        self.w1 = nn.Parameter(w1)
        self.w2 = nn.Parameter(w2)
        self.act = activation()

    def forward(self, x):
        hidden = torch.einsum('...nd,...dh->...nh', x, self.w1)
        hidden = self.act(hidden)
        out    = torch.einsum('...nh,...hd->...nd', hidden, self.w2)
        return out

# the below code is almost all transcribed from the official tensorflow version, from which the papers are written
# https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/models/research/moe.py

# gating network

class Top2Gating(nn.Module):
    def __init__(
        self,
        dim,
        num_gates,
        eps = 1e-9,
        outer_expert_dims = tuple(),
        second_policy_train = 'random',
        second_policy_eval = 'random',
        second_threshold_train = 0.2,
        second_threshold_eval = 0.2,
        capacity_factor_train = 1.25,
        capacity_factor_eval = 2.):
        super().__init__()

        self.eps = eps
        self.num_gates = num_gates
        # w_gating维度是 outer_expert_dims * dim * num_gates
        # 输出是门数，所以最后一维是num_gates
        # dim是特征维数，作为中间维度与输入相乘
        # outer_experts_dims是啥？
        self.w_gating = nn.Parameter(torch.randn(*outer_expert_dims, dim, num_gates))

        self.second_policy_train = second_policy_train
        self.second_policy_eval = second_policy_eval
        self.second_threshold_train = second_threshold_train
        self.second_threshold_eval = second_threshold_eval
        self.capacity_factor_train = capacity_factor_train
        self.capacity_factor_eval = capacity_factor_eval

    def forward(self, x, importance = None):
        *_, b, group_size, dim = x.shape
        num_gates = self.num_gates

        # 这应该是module自带参数，train（）的时候为True，eval（）的时候为负
        if self.training:
            policy = self.second_policy_train
            threshold = self.second_threshold_train
            capacity_factor = self.capacity_factor_train
        else:
            policy = self.second_policy_eval
            threshold = self.second_threshold_eval
            capacity_factor = self.capacity_factor_eval

        raw_gates = torch.einsum('...bnd,...de->...bne', x, self.w_gating)
        # 最后一维是gate，所以按照最后一维softmax
        raw_gates = raw_gates.softmax(dim=-1)

        # FIND TOP 2 EXPERTS PER POSITON
        # Find the top expert for each position. shape=[batch, group]

        gate_1, index_1 = top1(raw_gates)
        mask_1 = F.one_hot(index_1, num_gates).float()
        density_1_proxy = raw_gates
        
        # 默认不用L_importance
        if importance is not None:
            equals_one_mask = (importance == 1.).float()
            mask_1 *= equals_one_mask[..., None]
            gate_1 *= equals_one_mask
            density_1_proxy *= equals_one_mask[..., None]
            del equals_one_mask

        # mask_1是top1下标的one-hot编码，1-mask_1也就意味着top1对应的位置*0，其他的位置*1
        # 变成0了就是最小的了，再一次top1就可以提取出top2了    
        gates_without_top_1 = raw_gates * (1. - mask_1)

        gate_2, index_2 = top1(gates_without_top_1)
        mask_2 = F.one_hot(index_2, num_gates).float()

        if importance is not None:
            greater_zero_mask = (importance > 0.).float()
            mask_2 *= greater_zero_mask[..., None]
            del greater_zero_mask

        # normalize top2 gate scores
        #只保留最大的2个，这里相当于对最大的两个又进行一次softmax
        denom = gate_1 + gate_2 + self.eps
        gate_1 /= denom
        gate_2 /= denom

        # BALANCING LOSSES
        # shape = [batch, experts]
        # We want to equalize the fraction of the batch assigned to each expert
        density_1 = mask_1.mean(dim=-2)
        # Something continuous that is correlated with what we want to equalize.
        # 倒数第二维的维是outer_experts_dim
        # 没看懂怎么算的
        density_1_proxy = density_1_proxy.mean(dim=-2)
        loss = (density_1_proxy * density_1).mean() * float(num_gates ** 2)

        # Depending on the policy in the hparams, we may drop out some of the
        # second-place experts.
        # 对第二expert进行类似dropout的操作，可能删去低于阈值的，可能随机扔一些，可能都不要，可能都要
        if policy == "all":
            pass
        elif policy == "none":
            mask_2 = torch.zeros_like(mask_2)
        elif policy == "threshold":
            mask_2 *= (gate_2 > threshold).float()
        elif policy == "random":
            # uniform_在均匀分布中随机采样，左闭右开
            probs = torch.zeros_like(gate_2).uniform_(0., 1.) #[inputs[0], outer_expert_dim]
            # 是mask_2*=，不是=，unsqueeze加一维
            mask_2 *= (probs < (gate_2 / max(threshold, self.eps))).float().unsqueeze(-1)
        else:
            raise ValueError(f"Unknown policy {policy}")

        # Each sequence sends (at most?) expert_capacity positions to each expert.
        # Static expert_capacity dimension is needed for expert batch sizes
        expert_capacity = min(group_size, int((group_size * capacity_factor) / num_gates))
        expert_capacity = max(expert_capacity, MIN_EXPERT_CAPACITY)
        expert_capacity_f = float(expert_capacity)

        # COMPUTE ASSIGNMENT TO EXPERTS
        # [batch, group, experts]
        # This is the position within the expert's mini-batch for this sequence
        position_in_expert_1 = cumsum_exclusive(mask_1, dim=-2) * mask_1
        # Remove the elements that don't fit. [batch, group, experts]
        mask_1 *= (position_in_expert_1 < expert_capacity_f).float()
        # [batch, experts]
        # How many examples in this sequence go to this expert
        mask_1_count = mask_1.sum(dim=-2, keepdim=True)
        # [batch, group] - mostly ones, but zeros where something didn't fit
        mask_1_flat = mask_1.sum(dim=-1)
        # [batch, group]
        position_in_expert_1 = position_in_expert_1.sum(dim=-1)
        # Weight assigned to first expert.  [batch, group]
        gate_1 *= mask_1_flat

        position_in_expert_2 = cumsum_exclusive(mask_2, dim=-2) + mask_1_count
        position_in_expert_2 *= mask_2
        mask_2 *= (position_in_expert_2 < expert_capacity_f).float()
        mask_2_flat = mask_2.sum(dim=-1)

        position_in_expert_2 = position_in_expert_2.sum(dim=-1)
        gate_2 *= mask_2_flat
        
        # [batch, group, experts, expert_capacity]
        combine_tensor = (
            gate_1[..., None, None]
            * mask_1_flat[..., None, None]
            * F.one_hot(index_1, num_gates)[..., None]
            * safe_one_hot(position_in_expert_1.long(), expert_capacity)[..., None, :] +
            gate_2[..., None, None]
            * mask_2_flat[..., None, None]
            * F.one_hot(index_2, num_gates)[..., None]
            * safe_one_hot(position_in_expert_2.long(), expert_capacity)[..., None, :]
        )

        dispatch_tensor = combine_tensor.bool().to(combine_tensor)
        return dispatch_tensor, combine_tensor, loss

# plain mixture of experts

class MoE(nn.Module):
    def __init__(self,
        dim,
        num_experts = 16,
        hidden_dim = None,
        activation = nn.ReLU,
        second_policy_train = 'random',
        second_policy_eval = 'random',
        second_threshold_train = 0.2,
        second_threshold_eval = 0.2,
        capacity_factor_train = 1.25,
        capacity_factor_eval = 2.,
        loss_coef = 1e-2,
        experts = None):
        super().__init__()

        self.num_experts = num_experts

        gating_kwargs = {'second_policy_train': second_policy_train, 'second_policy_eval': second_policy_eval, 'second_threshold_train': second_threshold_train, 'second_threshold_eval': second_threshold_eval, 'capacity_factor_train': capacity_factor_train, 'capacity_factor_eval': capacity_factor_eval}
        # 初始化gate，传入门的数量，和超参，没有传入outer_expert_dims
        self.gate = Top2Gating(dim, num_gates = num_experts, **gating_kwargs)
        self.experts = default(experts, lambda: Experts(dim, num_experts = num_experts, hidden_dim = hidden_dim, activation = activation))
        self.loss_coef = loss_coef

    def forward(self, inputs, **kwargs):
        b, n, d, e = *inputs.shape, self.num_experts
        dispatch_tensor, combine_tensor, loss = self.gate(inputs)
        expert_inputs = torch.einsum('bnd,bnec->ebcd', inputs, dispatch_tensor)

        # Now feed the expert inputs through the experts.
        orig_shape = expert_inputs.shape
        expert_inputs = expert_inputs.reshape(e, -1, d)
        expert_outputs = self.experts(expert_inputs)
        expert_outputs = expert_outputs.reshape(*orig_shape)

        output = torch.einsum('ebcd,bnec->bnd', expert_outputs, combine_tensor)
        return output, loss * self.loss_coef


In [1]:
import torch
from torch import nn
# from mixture_of_experts import MoE

moe = MoE(
    dim = 512,
    num_experts = 16,               # increase the experts (# parameters) of your model without increasing computation
    hidden_dim = 512 * 4,           # size of hidden dimension in each expert, defaults to 4 * dimension
    activation = nn.LeakyReLU,      # use your preferred activation, will default to GELU
    second_policy_train = 'random', # in top_2 gating, policy for whether to use a second-place expert
    second_policy_eval = 'random',  # all (always) | none (never) | threshold (if gate value > the given threshold) | random (if gate value > threshold * random_uniform(0, 1))
    second_threshold_train = 0.2,
    second_threshold_eval = 0.2,
    capacity_factor_train = 1.25,   # experts have fixed capacity per batch. we need some extra capacity in case gating is not perfectly balanced.
    capacity_factor_eval = 2.,      # capacity_factor_* should be set to a value >=1
    loss_coef = 1e-2                # multiplier on the auxiliary expert balancing auxiliary loss
)

inputs = torch.randn(4, 1024, 512)
out, aux_loss = moe(inputs)

In [5]:
outer_expert_dims=tuple([3])
dim=4
num_gates=5
w_gating = nn.Parameter(torch.randn(*outer_expert_dims, dim, num_gates))
print(w_gating)
# 将一个不可训练的类型Tensor转换成可以训练的类型parameter并将这个parameter绑定到这个module里面
# (net.parameter()中就有这个绑定的parameter，所以在参数优化的时候可以进行优化的)
# 所以经过类型转换变成了模型的一部分，成为了模型中根据训练可以改动的参数了
# 使用这个函数的目的也是想让某些变量在学习的过程中不断的修改其值以达到最优化。

Parameter containing:
tensor([[[ 1.1544, -0.1697, -1.3483,  0.8665, -1.9129],
         [ 2.5419,  0.4158,  0.5876,  1.1696,  0.8588],
         [-0.2754, -1.1637, -0.7319, -1.6875, -0.9345],
         [ 0.0663,  1.8295,  1.0952,  2.0829,  0.0970]],

        [[-1.9983,  1.3038, -1.6020, -0.1983, -0.1377],
         [-0.8009, -0.8584,  0.4908, -1.7664,  1.3419],
         [-1.6858, -0.4301, -0.0996, -0.5580,  0.3219],
         [ 0.1491, -0.5547,  0.4204, -0.0390, -1.5064]],

        [[ 0.0232,  0.6683,  0.1763, -0.4046,  0.4990],
         [-1.4328,  2.4181, -0.1386, -0.0922, -0.1296],
         [ 0.0056,  0.1546,  0.3910,  0.7146,  1.0534],
         [-1.4351, -1.7781, -2.5079,  1.3089, -0.6859]]], requires_grad=True)


In [6]:
inputs=torch.randn(3,4)
print("inputs=",inputs)
# *_, b, group_size, dim = inputs.shape
# print("_,b,group_size,dim=",_,b,group_size,dim)
raw_gates = torch.einsum('ad,bde->abe', inputs, w_gating)
print("raw_gates=",raw_gates)
raw_gates = raw_gates.softmax(dim=-1)
print("raw_gates=",raw_gates)

inputs= tensor([[ 1.1591,  0.0032, -2.7902,  2.0255],
        [-0.5143, -0.8529, -0.2640,  0.0589],
        [ 0.4107,  2.9308,  0.0175,  0.6375]])
raw_gates= tensor([[[ 2.2488,  6.7571,  2.6995,  9.9354,  0.5894],
         [ 2.6871,  1.5851, -0.7260,  1.2426, -4.1048],
         [-2.9000, -3.2506, -5.9669,  0.1880, -3.7503]],

        [[-2.6850,  0.1476,  0.4500, -0.8750,  0.5039],
         [ 2.1646,  0.1425,  0.4563,  1.7535, -1.2474],
         [ 1.1241, -2.5515, -0.2233,  0.1751, -0.4646]],

        [[ 7.9613,  2.2948,  1.8538,  5.0821,  1.7767],
         [-3.1022, -2.3415,  1.0467, -5.2929,  2.9218],
         [-5.1043,  6.2305, -1.9258,  0.4105, -0.5937]]],
       grad_fn=<ViewBackward0>)
raw_gates= tensor([[[4.4005e-04, 3.9944e-02, 6.9059e-04, 9.5884e-01, 8.3720e-05],
         [6.2416e-01, 2.0736e-01, 2.0559e-02, 1.4722e-01, 7.0084e-04],
         [4.1474e-02, 2.9210e-02, 1.9313e-03, 9.0966e-01, 1.7722e-02]],

        [[1.4015e-02, 2.3812e-01, 3.2219e-01, 8.5643e-02, 3.4003e-01],
   

In [7]:
gate_1,index_1=top1(raw_gates)

In [8]:
print(gate_1,index_1)

tensor([[0.9588, 0.6242, 0.9097],
        [0.3400, 0.4977, 0.5329],
        [0.9399, 0.8612, 0.9957]], grad_fn=<SqueezeBackward1>) tensor([[3, 0, 3],
        [4, 0, 0],
        [0, 4, 1]])


In [16]:
values, index = raw_gates.topk(k=1, dim=-1)
print(values, index)
values, index = map(lambda x: x.squeeze(dim=-1), (values, index))
print(values, index)

tensor([[[0.9588],
         [0.6242],
         [0.9097]],

        [[0.3400],
         [0.4977],
         [0.5329]],

        [[0.9399],
         [0.8612],
         [0.9957]]], grad_fn=<TopkBackward0>) tensor([[[3],
         [0],
         [3]],

        [[4],
         [0],
         [0]],

        [[0],
         [4],
         [1]]])
tensor([[0.9588, 0.6242, 0.9097],
        [0.3400, 0.4977, 0.5329],
        [0.9399, 0.8612, 0.9957]], grad_fn=<SqueezeBackward1>) tensor([[3, 0, 3],
        [4, 0, 0],
        [0, 4, 1]])


In [17]:
mask_1 = F.one_hot(index_1, num_gates).float()
density_1_proxy = raw_gates
print(mask_1)
gates_without_top_1 = raw_gates * (1. - mask_1)
print(gates_without_top_1)
gate_2, index_2 = top1(gates_without_top_1)
mask_2 = F.one_hot(index_2, num_gates).float()

tensor([[[0., 0., 0., 1., 0.],
         [1., 0., 0., 0., 0.],
         [0., 0., 0., 1., 0.]],

        [[0., 0., 0., 0., 1.],
         [1., 0., 0., 0., 0.],
         [1., 0., 0., 0., 0.]],

        [[1., 0., 0., 0., 0.],
         [0., 0., 0., 0., 1.],
         [0., 1., 0., 0., 0.]]])
tensor([[[4.4005e-04, 3.9944e-02, 6.9059e-04, 0.0000e+00, 8.3720e-05],
         [0.0000e+00, 2.0736e-01, 2.0559e-02, 1.4722e-01, 7.0084e-04],
         [4.1474e-02, 2.9210e-02, 1.9313e-03, 0.0000e+00, 1.7722e-02]],

        [[1.4015e-02, 2.3812e-01, 3.2219e-01, 8.5643e-02, 0.0000e+00],
         [0.0000e+00, 6.5875e-02, 9.0162e-02, 3.2990e-01, 1.6410e-02],
         [0.0000e+00, 1.3500e-02, 1.3850e-01, 2.0630e-01, 1.0881e-01]],

        [[0.0000e+00, 3.2521e-03, 2.0923e-03, 5.2803e-02, 1.9370e-03],
         [2.0840e-03, 4.4595e-03, 1.3206e-01, 2.3308e-04, 0.0000e+00],
         [1.1898e-05, 0.0000e+00, 2.8567e-04, 2.9546e-03, 1.0824e-03]]],
       grad_fn=<MulBackward0>)


In [18]:
density_1 = mask_1.mean(dim=-2)
print("density_1=", density_1)
# Something continuous that is correlated with what we want to equalize.
density_1_proxy = density_1_proxy.mean(dim=-2)
print("density_1_proxy=", density_1_proxy)
loss = (density_1_proxy * density_1).mean() * float(num_gates ** 2)
print("loss=",loss)

density_1= tensor([[0.3333, 0.0000, 0.0000, 0.6667, 0.0000],
        [0.6667, 0.0000, 0.0000, 0.0000, 0.3333],
        [0.3333, 0.3333, 0.0000, 0.0000, 0.3333]])
density_1_proxy= tensor([[0.2220, 0.0922, 0.0077, 0.6719, 0.0062],
        [0.3482, 0.1058, 0.1836, 0.2073, 0.1551],
        [0.3140, 0.3345, 0.0448, 0.0187, 0.2881]], grad_fn=<MeanBackward1>)
loss= tensor(1.8632, grad_fn=<MulBackward0>)


1. outer_experts_dim 是什么？为什么初始化的时候不传入，初始为 tuple()
2. loss到底是什么原理
3. expert_capacity是干什么的

In [22]:
probs = torch.zeros_like(gate_2).uniform_(0., 1.)
print("probs=",probs)
print("drop=",(probs < (gate_2 / max(0.2, 1e-9))).float().unsqueeze(-1))
mask_2 *= (probs < (gate_2 / max(0.2, 1e-9))).float().unsqueeze(-1)
print("gate_2=", gate_2)
print("mask_2",mask_2)

probs= tensor([[0.2740, 0.5237, 0.1820],
        [0.2794, 0.5914, 0.1781],
        [0.9341, 0.2405, 0.6881]])
drop= tensor([[[0.],
         [1.],
         [1.]],

        [[1.],
         [1.],
         [1.]],

        [[0.],
         [1.],
         [0.]]])
gate_2= tensor([[0.0399, 0.2074, 0.0415],
        [0.3222, 0.3299, 0.2063],
        [0.0528, 0.1321, 0.0030]], grad_fn=<SqueezeBackward1>)
mask_2 tensor([[[0., 0., 0., 0., 0.],
         [0., 1., 0., 0., 0.],
         [0., 0., 0., 0., 0.]],

        [[0., 0., 1., 0., 0.],
         [0., 0., 0., 1., 0.],
         [0., 0., 0., 1., 0.]],

        [[0., 0., 0., 0., 0.],
         [0., 0., 1., 0., 0.],
         [0., 0., 0., 0., 0.]]])
