<a href="https://colab.research.google.com/github/weedge/doraemon-nb/blob/main/makeMoA_MoE_from_Scratch_with_Expert_Capacity_Aux_Loss_Balance.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#### 从头开始的稀疏专家混合语言模型，灵感来源于（并在很大程度上基于）[Andrej Karpathy的makemore项目](https://github.com/karpathy/makemore) :)

这是一个从头开始实现的稀疏专家混合语言模型。这受到了Andrej Karpathy项目'makemore'的启发，并且大部分重用的组件都来自于该实现。就像makemore一样，makeMoE也是一个自回归的字符级语言模型，但是使用了上述稀疏专家的架构。

与makemore体系结构相比，有显着的变化

- 稀疏专家混合而不是孤立的前馈神经网络。
- 使用了top-k门控和嘈杂的top-k门控实现。
- 初始化 - 这里使用了Kaiming He初始化，但这个笔记本的重点是可hack性，所以你可以替换为Xavier Glorot等，并进行尝试。

与makemore不变的部分

- Andrej最初选择的数据集、预处理（tokenizer）和语言建模任务 - 生成类似莎士比亚的文本
- 自注意力因果实现
- 训练循环
- 推断逻辑

在此实现中大量引用的论文：

- Mixtral of Experts：https://arxiv.org/pdf/2401.04088.pdf
- Outrageosly Large Neural Networks: The Sparsely-Gated Mixture-Of-Experts layer：https://arxiv.org/pdf/1701.06538.pdf
- Switch Transformers: Scaling to Trillion Parameter Models with Simple and Efficient Sparsity: https://arxiv.org/pdf/2101.03961.pdf

这个笔记本演示了整个模型架构的直觉以及所有内容是如何相互关联的。


请注意，该实现强调易读性和可hack性而不是性能，因此有许多方法可以改进此实现。请尝试并告诉我。

如果在colab中运行，选择t4 GPU即可。


![mixture of experts overview](https://raw.githubusercontent.com/AviSoori1x/makeMoE/main/images/moe.png)

In [2]:
#Import the necessary packages and set seed for reproducibility. For this notebook, pytorch is all you need
import torch
import torch.nn as nn
from torch.nn import functional as F
torch.manual_seed(42)


<torch._C.Generator at 0x79f8342f66f0>

接下来的几个部分，下载数据、预处理数据和自注意力直接来自makemore。我稍微详细说明了自注意力，并添加了一些可视化辅助，以便更好地理解这个过程。

In [3]:
# Downloading the tiny shakespeare dataset
!wget https://raw.githubusercontent.com/AviSoori1x/makeMoE/main/input.txt

--2024-04-09 15:09:09--  https://raw.githubusercontent.com/AviSoori1x/makeMoE/main/input.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1115394 (1.1M) [text/plain]
Saving to: ‘input.txt’


2024-04-09 15:09:09 (207 MB/s) - ‘input.txt’ saved [1115394/1115394]



In [4]:
# read it in to inspect it
with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

In [5]:
print("length of dataset in characters: ", len(text))

length of dataset in characters:  1115394


In [None]:
# let's look at the first 1000 characters
print(text[:1000])

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you know Caius Marcius is chief enemy to the people.

All:
We know't, we know't.

First Citizen:
Let us kill him, and we'll have corn at our own price.
Is't a verdict?

All:
No more talking on't; let it be done: away, away!

Second Citizen:
One word, good citizens.

First Citizen:
We are accounted poor citizens, the patricians good.
What authority surfeits on would relieve us: if they
would yield us but the superfluity, while it were
wholesome, we might guess they relieved us humanely;
but they think we are too dear: the leanness that
afflicts us, the object of our misery, is as an
inventory to particularise their abundance; our
sufferance is a gain to them Let us revenge this with
our pikes, ere we become rakes: for the gods know I
speak this in hunger for bread, not in thirst for revenge.



In [6]:
# here are all the unique characters that occur in this text
chars = sorted(list(set(text)))
vocab_size = len(chars)
print(''.join(chars))
print(vocab_size)


 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz
65


In [7]:
# create a mapping from characters to integers
stoi = { ch:i for i,ch in enumerate(chars) }
itos = { i:ch for i,ch in enumerate(chars) }
encode = lambda s: [stoi[c] for c in s] # encoder: take a string, output a list of integers
decode = lambda l: ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string

print(encode("hii there"))
print(decode(encode("hii there")))

[46, 47, 47, 1, 58, 46, 43, 56, 43]
hii there


In [8]:
# let's now encode the entire text dataset and store it into a torch.Tensor
data = torch.tensor(encode(text), dtype=torch.long)
print(data.shape, data.dtype)
print(data[:1000]) # the 1000 characters we looked at earier will to the GPT look like this

torch.Size([1115394]) torch.int64
tensor([18, 47, 56, 57, 58,  1, 15, 47, 58, 47, 64, 43, 52, 10,  0, 14, 43, 44,
        53, 56, 43,  1, 61, 43,  1, 54, 56, 53, 41, 43, 43, 42,  1, 39, 52, 63,
         1, 44, 59, 56, 58, 46, 43, 56,  6,  1, 46, 43, 39, 56,  1, 51, 43,  1,
        57, 54, 43, 39, 49,  8,  0,  0, 13, 50, 50, 10,  0, 31, 54, 43, 39, 49,
         6,  1, 57, 54, 43, 39, 49,  8,  0,  0, 18, 47, 56, 57, 58,  1, 15, 47,
        58, 47, 64, 43, 52, 10,  0, 37, 53, 59,  1, 39, 56, 43,  1, 39, 50, 50,
         1, 56, 43, 57, 53, 50, 60, 43, 42,  1, 56, 39, 58, 46, 43, 56,  1, 58,
        53,  1, 42, 47, 43,  1, 58, 46, 39, 52,  1, 58, 53,  1, 44, 39, 51, 47,
        57, 46, 12,  0,  0, 13, 50, 50, 10,  0, 30, 43, 57, 53, 50, 60, 43, 42,
         8,  1, 56, 43, 57, 53, 50, 60, 43, 42,  8,  0,  0, 18, 47, 56, 57, 58,
         1, 15, 47, 58, 47, 64, 43, 52, 10,  0, 18, 47, 56, 57, 58,  6,  1, 63,
        53, 59,  1, 49, 52, 53, 61,  1, 15, 39, 47, 59, 57,  1, 25, 39, 56, 41,
      

In [9]:
# Let's now split up the data into train and validation sets
n = int(0.9*len(data)) # first 90% will be train, rest val
train_data = data[:n]
val_data = data[n:]

In [10]:
block_size = 8
train_data[:block_size+1]

tensor([18, 47, 56, 57, 58,  1, 15, 47, 58])

In [11]:
x = train_data[:block_size]
y = train_data[1:block_size+1]
for t in range(block_size):
    context = x[:t+1]
    target = y[t]
    print(f"when input is {context} the target: {target}")

when input is tensor([18]) the target: 47
when input is tensor([18, 47]) the target: 56
when input is tensor([18, 47, 56]) the target: 57
when input is tensor([18, 47, 56, 57]) the target: 58
when input is tensor([18, 47, 56, 57, 58]) the target: 1
when input is tensor([18, 47, 56, 57, 58,  1]) the target: 15
when input is tensor([18, 47, 56, 57, 58,  1, 15]) the target: 47
when input is tensor([18, 47, 56, 57, 58,  1, 15, 47]) the target: 58


In [12]:
batch_size = 4 # how many independent sequences will we process in parallel?
block_size = 8 # what is the maximum context length for predictions?

In [13]:
ix = torch.randint(len(data) - block_size, (batch_size,))
ix

tensor([250930, 237205, 974116, 383898])

In [14]:
x = torch.stack([data[i:i+block_size] for i in ix])
y = torch.stack([data[i+1:i+block_size+1] for i in ix])
x

tensor([[42,  1, 58, 46, 59, 57,  1, 21],
        [54, 56, 47, 43, 57, 58, 11,  0],
        [49, 47, 52, 45, 12,  1, 58, 46],
        [58, 46, 53, 59, 58,  1, 56, 43]])

In [15]:
y

tensor([[ 1, 58, 46, 59, 57,  1, 21,  1],
        [56, 47, 43, 57, 58, 11,  0, 37],
        [47, 52, 45, 12,  1, 58, 46, 53],
        [46, 53, 59, 58,  1, 56, 43, 42]])

以下代码块清楚地展示了预测的自回归性质，以及上下文是对token（在本例中是字符）的一维排列的滚动窗口。

In [16]:
def get_batch(split):
    # generate a small batch of data of inputs x and targets y
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    return x, y

xb, yb = get_batch('train')
print('inputs:')
print(xb.shape)
print(xb)
print('targets:')
print(yb.shape)
print(yb)

print('----')

for b in range(batch_size): # batch dimension
    for t in range(block_size): # time dimension
        context = xb[b, :t+1]
        target = yb[b,t]
        print(f"when input is {context.tolist()} the target: {target}")

inputs:
torch.Size([4, 8])
tensor([[ 6,  0, 14, 43, 44, 53, 56, 43],
        [39,  1, 42, 59, 43,  1, 39, 52],
        [47, 41, 43,  1, 39, 52, 42,  1],
        [53, 44,  1, 50, 43, 58,  1, 58]])
targets:
torch.Size([4, 8])
tensor([[ 0, 14, 43, 44, 53, 56, 43,  1],
        [ 1, 42, 59, 43,  1, 39, 52, 42],
        [41, 43,  1, 39, 52, 42,  1, 42],
        [44,  1, 50, 43, 58,  1, 58, 46]])
----
when input is [6] the target: 0
when input is [6, 0] the target: 14
when input is [6, 0, 14] the target: 43
when input is [6, 0, 14, 43] the target: 44
when input is [6, 0, 14, 43, 44] the target: 53
when input is [6, 0, 14, 43, 44, 53] the target: 56
when input is [6, 0, 14, 43, 44, 53, 56] the target: 43
when input is [6, 0, 14, 43, 44, 53, 56, 43] the target: 1
when input is [39] the target: 1
when input is [39, 1] the target: 42
when input is [39, 1, 42] the target: 59
when input is [39, 1, 42, 59] the target: 43
when input is [39, 1, 42, 59, 43] the target: 1
when input is [39, 1, 42, 59, 4

### Understanding the intuition of Causal Scaled Dot Product Self Attention

这段代码来自于Andrej Karpathy出色的makemore代码库，链接在仓库中。





![scaled dot product self attention](https://raw.githubusercontent.com/AviSoori1x/makeMoE/main/images/self_attention.png)

提供的代码演示了自注意力的机制和基本概念，特别是关注经典的缩放点积自注意力。在这个变体中，查询、键和值矩阵都来自同一个输入序列。为了确保自回归语言生成过程的完整性，特别是在仅包含解码器的模型中，代码实现了掩码。这种掩码技术至关重要，因为它隐藏了当前标记位置后面的任何信息，从而将模型的注意力引导到序列的前面部分。这样的注意力机制称为因果自注意力。值得注意的是，稀疏专家混合模型并不局限于仅包含解码器的Transformer架构。事实上，在这个领域的许多重要工作，特别是由Shazeer等人完成的工作，都围绕着T5架构展开，该架构包含了Transformer模型中的编码器和解码器组件。

In [17]:
torch.manual_seed(1337)
B,T,C = 4,8,32 # batch, time, channels
x = torch.randn(B,T,C)

# let's see a single Head perform self-attention
head_size = 16
key = nn.Linear(C, head_size, bias=False)
query = nn.Linear(C, head_size, bias=False)
value = nn.Linear(C, head_size, bias=False)
k = key(x)   # (B, T, 16)
q = query(x) # (B, T, 16)
wei =  q @ k.transpose(-2, -1) # (B, T, 16) @ (B, 16, T) ---> (B, T, T)

tril = torch.tril(torch.ones(T, T))
#wei = torch.zeros((T,T))
wei = wei.masked_fill(tril == 0, float('-inf'))
wei = F.softmax(wei, dim=-1) #B,T,T

v = value(x) #B,T,H
out = wei @ v # (B,T,T) @ (B,T,H) -> (B,T,H)
#The output from this final matrix product is subsequently passsed through a linear layer as shown in the diagram above

out.shape

torch.Size([4, 8, 16])

对因果自注意力和多头因果自注意力的代码进行泛化和模块化。多头自注意力将多个注意力头并行应用，每个注意力头专注于通道的不同部分（嵌入维度）。

In [18]:
#Causal scaled dot product self-Attention Head

n_embd = 64
n_head = 4
n_layer = 4
head_size = 16
dropout = 0.1

class Head(nn.Module):
    """ one head of self-attention """

    def __init__(self, head_size):
        super().__init__()
        self.key = nn.Linear(n_embd, head_size, bias=False)
        self.query = nn.Linear(n_embd, head_size, bias=False)
        self.value = nn.Linear(n_embd, head_size, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))

        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        B,T,C = x.shape
        k = self.key(x)   # (B,T,C)
        q = self.query(x) # (B,T,C)
        # compute attention scores ("affinities")
        wei = q @ k.transpose(-2,-1) * C**-0.5 # (B, T, C) @ (B, C, T) -> (B, T, T)
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf')) # (B, T, T)
        wei = F.softmax(wei, dim=-1) # (B, T, T)
        wei = self.dropout(wei)
        # perform the weighted aggregation of the values
        v = self.value(x) # (B,T,C)
        out = wei @ v # (B, T, T) @ (B, T, C) -> (B, T, C)
        return out

In [19]:
#Multi-Headed Self Attention
class MultiHeadAttention(nn.Module):
    """ multiple heads of self-attention in parallel """

    def __init__(self, num_heads, head_size):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
        self.proj = nn.Linear(n_embd, n_embd)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim=-1)
        out = self.dropout(self.proj(out))
        return out


In [20]:
#Confirming that what's output from multi head attention is the original embedding size
B,T,C = 4,8,64 # batch, time, channels
x = torch.randn(B,T,C)
mha = MultiHeadAttention(4,16)
mha(x).shape

torch.Size([4, 8, 64])

### Creating an Expert module i.e. a simple Multi Layer Perceptron

在稀疏专家（MoE）架构中，每个Transformer块内部的自注意力机制保持不变。然而，每个块的结构发生了显著变化：标准的前馈神经网络被替换为几个稀疏激活的前馈网络，称为专家。 "稀疏激活" 指的是序列中的每个标记仅被路由到总池中的有限数量的这些专家之一或两个 - 通常是一个或两个。这种修改允许对输入数据的不同部分进行专门处理，使模型能够有效地处理更广泛的复杂性。

![experts](https://raw.githubusercontent.com/AviSoori1x/makeMoE/main/images/experts.png)

In [21]:
#Expert module
class Expert(nn.Module):
    """ An MLP is a simple linear layer followed by a non-linearity i.e. each Expert """

    def __init__(self, n_embd):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embd, 4 * n_embd),
            nn.ReLU(),
            nn.Linear(4 * n_embd, n_embd),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        return self.net(x)

### Top-k Gating Intuition through an Example

![top k gating](https://raw.githubusercontent.com/AviSoori1x/makeMoE/main/images/topk.png)

门控网络，也称为路由器，确定每个token从多头注意力中由哪个专家网络接收输出。让我们考虑一个简单的例子：假设有4个专家，并且要将标记路由到前2个专家。最初，我们通过一个线性层将token输入到门控网络中。这个层将输入张量从形状为（2，4，32）——表示（批量大小，tokens，n_embed，其中n_embed是输入的通道维度）——投影到一个新形状为（2，4，4）的张量，对应于（批量大小，tokens，num_experts），其中num_experts是专家网络的数量。接下来，我们确定最后一维中前k=2个最高值及其相应的索引。

In [22]:
#Understanding how gating works
num_experts = 4
top_k=2
n_embed=32


#Example multi-head attention output for a simple illustrative example, consider n_embed=32, context_length=4 and batch_size=2
mh_output = torch.randn(2, 4, n_embed)

topkgate_linear = nn.Linear(n_embed, num_experts) # nn.Linear(32, 4)

logits = topkgate_linear(mh_output)
top_k_logits, top_k_indices = logits.topk(top_k, dim=-1)  # Get top-k experts
top_k_logits, top_k_indices

(tensor([[[ 0.9558,  0.1610],
          [ 0.8659, -0.1494],
          [ 0.8765,  0.7202],
          [ 0.9496, -0.6609]],
 
         [[ 0.4419, -0.2500],
          [ 1.2602,  0.8430],
          [ 0.8570,  0.7822],
          [ 0.7376,  0.2561]]], grad_fn=<TopkBackward0>),
 tensor([[[2, 0],
          [3, 2],
          [3, 0],
          [1, 2]],
 
         [[1, 3],
          [1, 2],
          [1, 2],
          [0, 1]]]))

通过仅保留沿着最后一个维度的各自索引处的前k个值，获取稀疏门控输出。用'-inf'填充其余部分，并通过softmax激活函数传递。这将'-inf'值推向零，使前两个值更加突出，并且总和为1。这种总和为1有助于对专家输出进行加权。

In [23]:
zeros = torch.full_like(logits, float('-inf')) #full_like clones a tensor and fills it with a specified value (like infinity) for masking or calculations.
sparse_logits = zeros.scatter(-1, top_k_indices, top_k_logits)
sparse_logits

tensor([[[ 0.1610,    -inf,  0.9558,    -inf],
         [   -inf,    -inf, -0.1494,  0.8659],
         [ 0.7202,    -inf,    -inf,  0.8765],
         [   -inf,  0.9496, -0.6609,    -inf]],

        [[   -inf,  0.4419,    -inf, -0.2500],
         [   -inf,  1.2602,  0.8430,    -inf],
         [   -inf,  0.8570,  0.7822,    -inf],
         [ 0.7376,  0.2561,    -inf,    -inf]]], grad_fn=<ScatterBackward0>)

In [24]:
gating_output= F.softmax(sparse_logits, dim=-1)
gating_output

tensor([[[0.3111, 0.0000, 0.6889, 0.0000],
         [0.0000, 0.0000, 0.2660, 0.7340],
         [0.4610, 0.0000, 0.0000, 0.5390],
         [0.0000, 0.8335, 0.1665, 0.0000]],

        [[0.0000, 0.6664, 0.0000, 0.3336],
         [0.0000, 0.6028, 0.3972, 0.0000],
         [0.0000, 0.5187, 0.4813, 0.0000],
         [0.6181, 0.3819, 0.0000, 0.0000]]], grad_fn=<SoftmaxBackward0>)

### Generalizing and Modularizing above code and adding noisy top-k Gating for load balancing
泛化和模块化上述代码，并添加嘈杂的top-k门控以实现负载平衡。

In [25]:
# First define the top k router module
class TopkRouter(nn.Module):
    def __init__(self, n_embed, num_experts, top_k):
        super(TopkRouter, self).__init__()
        self.top_k = top_k
        self.linear =nn.Linear(n_embed, num_experts)

    def forward(self, mh_ouput):
        # mh_ouput is the output tensor from multihead self attention block
        logits = self.linear(mh_output)
        top_k_logits, indices = logits.topk(self.top_k, dim=-1)
        zeros = torch.full_like(logits, float('-inf'))
        sparse_logits = zeros.scatter(-1, indices, top_k_logits)
        router_output = F.softmax(sparse_logits, dim=-1)
        return router_output, indices



In [26]:
#Testing this out:
num_experts = 4
top_k = 2
n_embd = 32

mh_output = torch.randn(2, 4, n_embd)  # Example input
top_k_gate = TopkRouter(n_embd, num_experts, top_k)
gating_output, indices = top_k_gate(mh_output)
gating_output.shape, gating_output, indices
#And it works!!

(torch.Size([2, 4, 4]),
 tensor([[[0.0000, 0.4249, 0.5751, 0.0000],
          [0.3467, 0.6533, 0.0000, 0.0000],
          [0.3970, 0.0000, 0.6030, 0.0000],
          [0.0000, 0.0000, 0.7713, 0.2287]],
 
         [[0.4043, 0.5957, 0.0000, 0.0000],
          [0.0000, 0.5281, 0.0000, 0.4719],
          [0.7053, 0.0000, 0.2947, 0.0000],
          [0.0000, 0.4602, 0.5398, 0.0000]]], grad_fn=<SoftmaxBackward0>),
 tensor([[[2, 1],
          [1, 0],
          [2, 0],
          [2, 3]],
 
         [[1, 0],
          [1, 3],
          [0, 2],
          [2, 1]]]))

虽然最近发布的Mixtral论文没有提到，但我认为嘈杂的top-k门控是训练MoE模型的重要工具。基本上，您不希望所有的token都被发送到同一组“偏爱”的专家中。您希望在开发和探索之间达到良好的平衡。为此，为门控线性层的logits添加标准正态噪声有助于负载平衡，使训练更加高效。

![noisy top-k gating](https://raw.githubusercontent.com/AviSoori1x/makeMoE/main/images/noisytopkgating.png)

In [82]:
#Changing the above to accomodate noisy top-k gating
class NoisyTopkRouter(nn.Module):
    def __init__(self, n_embed, num_experts, top_k):
        super(NoisyTopkRouter, self).__init__()
        self.top_k = top_k
        #layer for router logits
        self.topkroute_linear = nn.Linear(n_embed, num_experts)
        self.noise_linear =nn.Linear(n_embed, num_experts)


    def forward(self, mh_output):
        # mh_ouput is the output tensor from multihead self attention block
        logits = self.topkroute_linear(mh_output)

        #Noise logits
        noise_logits = self.noise_linear(mh_output)

        #Adding scaled unit gaussian noise to the logits
        noise = torch.randn_like(logits)*F.softplus(noise_logits)
        noisy_logits = logits + noise

        top_k_logits, indices = noisy_logits.topk(self.top_k, dim=-1)
        zeros = torch.full_like(noisy_logits, float('-inf'))
        sparse_logits = zeros.scatter(-1, indices, top_k_logits)
        router_output = F.softmax(sparse_logits, dim=-1)
        return router_output, indices

In [83]:
#Testing this out, again:
num_experts = 8
top_k = 2
n_embd = 16

mh_output = torch.randn(2, 4, n_embd)  # Example input
noisy_top_k_gate = NoisyTopkRouter(n_embd, num_experts, top_k)
gating_output, indices = noisy_top_k_gate(mh_output)
gating_output.shape, gating_output, indices
#It works!!

(torch.Size([2, 4, 8]),
 tensor([[[0.0000, 0.7710, 0.0000, 0.0000, 0.0000, 0.2290, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.3704, 0.6296, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.5412, 0.0000, 0.4588, 0.0000],
          [0.6004, 0.0000, 0.0000, 0.0000, 0.0000, 0.3996, 0.0000, 0.0000]],
 
         [[0.4171, 0.0000, 0.0000, 0.0000, 0.0000, 0.5829, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.5208, 0.0000, 0.0000, 0.4792],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.7725, 0.0000, 0.0000, 0.2275],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.8713, 0.0000, 0.0000, 0.1287]]],
        grad_fn=<SoftmaxBackward0>),
 tensor([[[1, 5],
          [4, 3],
          [4, 6],
          [0, 5]],
 
         [[5, 0],
          [4, 7],
          [4, 7],
          [4, 7]]]))


### Creating a sparse Mixture of Experts module


这个过程的主要方面涉及门控网络的输出。在获得这些结果后，会选择性地将前k个值与相应的前k个专家的输出相乘，以获得给定token的结果。这种选择性的乘法形成了加权求和，构成了SparseMoe块的输出。这个过程中的关键和具有挑战性的部分是避免不必要的乘法。只对前k个专家进行前向传播，然后计算这个加权和是至关重要的。对每个专家都进行前向传播会违背使用稀疏MoE的初衷，因为它将不再是稀疏的。

In [84]:
class SparseMoE(nn.Module):
    def __init__(self, n_embed, num_experts, top_k):
        super(SparseMoE, self).__init__()
        self.router = NoisyTopkRouter(n_embed, num_experts, top_k)
        self.experts = nn.ModuleList([Expert(n_embed) for _ in range(num_experts)])
        self.top_k = top_k

    def forward(self, x):
        gating_output, indices = self.router(x)
        final_output = torch.zeros_like(x)

        # Reshape inputs for batch processing
        flat_x = x.view(-1, x.size(-1))
        flat_gating_output = gating_output.view(-1, gating_output.size(-1))

        # Process each expert in parallel
        for i, expert in enumerate(self.experts):
            # Create a mask for the inputs where the current expert is in top-k
            expert_mask = (indices == i).any(dim=-1)
            flat_mask = expert_mask.view(-1)

            if flat_mask.any():
                expert_input = flat_x[flat_mask]
                expert_output = expert(expert_input)

                # Extract and apply gating scores
                gating_scores = flat_gating_output[flat_mask, i].unsqueeze(1)
                weighted_output = expert_output * gating_scores

                # Update final output
                # We need to scatter_add the weighted outputs to their original positions in the batch
                final_output.masked_scatter_(expert_mask.unsqueeze(-1), weighted_output)

        return final_output.view_as(x)




In [85]:
import torch
import torch.nn as nn

#Let's test this out
num_experts = 8
top_k = 2
n_embd = 16
dropout=0.1

mh_output = torch.randn(4, 8, n_embd)  # Example multi-head attention output
sparse_moe = SparseMoE(n_embd, num_experts, top_k)
final_output = sparse_moe(mh_output)
print("Shape of the final output:", final_output.shape)
print(final_output)

Shape of the final output: torch.Size([4, 8, 16])
tensor([[[-7.2254e-02,  2.5299e-01, -1.2675e-01,  1.3837e-01,  5.4608e-02,
           0.0000e+00, -7.7531e-02, -3.5290e-02, -1.4103e-02, -7.2050e-02,
          -3.9991e-02, -0.0000e+00,  7.5821e-02,  2.1454e-02,  1.2398e-01,
           5.4608e-02],
         [-2.3598e-02, -2.6230e-02, -5.5828e-02, -5.5061e-02, -1.0413e-01,
           8.5451e-02, -7.5070e-02, -0.0000e+00, -0.0000e+00, -4.5918e-02,
           9.0797e-02, -1.0355e-01, -3.5142e-03,  8.1981e-02,  1.5646e-02,
           3.1881e-02],
         [-0.0000e+00,  1.2757e-01,  4.8477e-01,  2.2833e-02, -7.3061e-02,
           6.9625e-02,  1.0888e-01,  8.7231e-04,  3.8779e-01, -6.9699e-02,
          -1.0221e-01,  7.9689e-03, -3.9457e-01, -4.9123e-02, -2.2001e-01,
          -0.0000e+00],
         [-7.1534e-02, -1.1351e-01, -1.2371e-01, -1.3362e-01, -1.5074e-01,
           1.5168e-01, -5.7698e-02, -1.0665e-01, -4.2614e-01, -1.0385e-01,
           2.8249e-01, -3.8319e-01,  2.2413e-01,  1.8

强调一下，需要认识到路由器/门控网络输出的前k个专家的幅值，正如上面的代码所示，也是非常重要的。这些前k个索引确定了被激活的专家，而在这些前k个维度中数值的大小决定了它们各自的权重。这种加权求和的概念在下面的图示中进一步强调了。

![sparse MoE](https://raw.githubusercontent.com/AviSoori1x/makeMoE/main/images/sparseMoEfinal.png)

## 引入专家容量 （Expert Capacity factor）

from: https://huggingface.co/blog/AviSoori1x/makemoe2



在预训练混合专家语言模型或任何大型语言模型时，该过程通常跨越多个GPU，并且通常涉及许多机器。跨这些硬件资源并行训练的方式对于平衡计算负载至关重要。然而，如果某些专家或一组专家过度受到偏爱——反映出对开发的偏好超过探索——它不仅可能导致模型中的性能问题，还可能导致集群中的计算负载不平衡。

[Switch Transformers: Scaling to Trillion Parameter Models with Simple and Efficient Sparsity](https://arxiv.org/abs/2101.03961) 实现使用专家容量来规避这个问题。专家容量确定每个专家在训练或推理过程中负责处理多少个标记。它是基于批次中的标记数和可用专家的数量定义的，通常通过容量因子进行调整。该因子允许在分配中灵活性，提供缓冲区以考虑数据分布的变化，并确保没有单个专家由于过载而成为瓶颈。在训练这些大型模型时，硬件故障是很常见的，可能持续数周甚至数月，因此这一点非常重要。

以下是专家容量通常计算的方式：

专家容量 = （每批标记数 / 专家数量）× 容量因子 其中：
```python
expert_capacity = int((tokens_per_batch / self.num_experts) * self.capacity_factor)
```

- 每批标记数`tokens_per_batch`是需要处理的批次中存在的总标记数。
- 专家数量`num_experts`是MoE层中可用于处理数据的专家总数。
- 容量因子`capacity_factor`是用于调整基础容量（每批标记数除以专家数量）的乘数。大于1的容量因子允许每个专家处理超出均匀分配份额的缓冲区，适应标记分配的不平衡。该值的一般范围为1-1.25。

以下代码块进行了轻微调整，以实现专家容量的简单版本。

In [86]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class SparseMoE(nn.Module):
    def __init__(self, n_embed, num_experts, top_k, capacity_factor=1.0):
        super(SparseMoE, self).__init__()
        self.router = NoisyTopkRouter(n_embed, num_experts, top_k)
        self.experts = nn.ModuleList([Expert(n_embed) for _ in range(num_experts)])
        self.top_k = top_k
        # add capacity_factor
        self.capacity_factor = capacity_factor
        self.num_experts = num_experts

    def forward(self, x):
    # Assuming x has shape [batch_size, seq_len, n_embd]
        batch_size, seq_len, _ = x.shape
        gating_output, indices = self.router(x)
        final_output = torch.zeros_like(x)

        # Flatten the batch and sequence dimensions to treat each token independently
        flat_x = x.view(-1, x.size(-1))  # Now shape [batch_size * seq_len, n_embd]
        flat_gating_output = gating_output.view(-1, gating_output.size(-1))

        tokens_per_batch = batch_size * seq_len * self.top_k
        expert_capacity = int((tokens_per_batch / self.num_experts) * self.capacity_factor)

        updates = torch.zeros_like(flat_x)

        for i, expert in enumerate(self.experts):
            expert_mask = (indices == i).any(dim=-1)
            flat_mask = expert_mask.view(-1)
            selected_indices = torch.nonzero(flat_mask).squeeze(-1)

            limited_indices = selected_indices[:expert_capacity] if selected_indices.numel() > expert_capacity else selected_indices
            if limited_indices.numel() > 0:
                expert_input = flat_x[limited_indices]
                expert_output = expert(expert_input)

                gating_scores = flat_gating_output[limited_indices, i].unsqueeze(1)
                weighted_output = expert_output * gating_scores

                updates.index_add_(0, limited_indices, weighted_output)

        # Reshape updates to match the original dimensions of x
        final_output += updates.view(batch_size, seq_len, -1)

        return final_output


In [87]:
import torch
import torch.nn as nn

#Let's test this out
num_experts = 8
top_k = 2
n_embd = 16
dropout=0.1

mh_output = torch.randn(4, 8, n_embd)  # Example multi-head attention output
sparse_moe = SparseMoE(n_embd, num_experts, top_k)
final_output = sparse_moe(mh_output)
print("Shape of the final output:", final_output.shape)
print(final_output)


Shape of the final output: torch.Size([4, 8, 16])
tensor([[[ 1.1761e-01, -2.3379e-02, -4.7678e-01,  9.5960e-03, -1.1987e-01,
           2.1115e-01,  1.2021e-01, -4.8881e-02,  2.6757e-01, -5.7497e-02,
          -1.7639e-01,  2.4494e-01, -2.7037e-02, -2.1677e-03,  8.5174e-02,
          -3.0657e-02],
         [ 6.2244e-02, -3.8118e-02,  4.6825e-02,  2.3763e-02, -1.3538e-01,
          -7.2525e-03,  3.6083e-01, -2.0759e-01,  2.2128e-01,  8.7651e-02,
           3.9775e-01,  1.0490e-01,  7.7816e-02, -1.6218e-01,  5.7070e-02,
          -1.2370e-01],
         [-5.1299e-02, -3.1918e-01, -3.8807e-02,  1.5957e-01, -2.4989e-02,
           2.3948e-01,  2.1835e-01, -6.3932e-02,  7.1299e-02,  2.5659e-02,
           1.7443e-01, -1.3179e-01,  2.7071e-01,  1.5567e-01, -7.9719e-02,
           7.4685e-02],
         [-8.4326e-02,  1.5591e-01,  3.9071e-02,  2.1825e-03,  3.6338e-02,
          -2.0504e-01,  6.0617e-01, -1.1862e-02, -4.1676e-02,  1.8078e-02,
           1.9232e-02,  3.8714e-02,  2.6343e-01,  2.4

## Load Balancing Loss

from:

switch transformers: https://arxiv.org/pdf/2101.03961.pdf  

A. Differentiable Load Balancing Loss

In [103]:
#@torch.jit.script
def compute_aux_loss(num_experts: int,
                     top_k_gates: torch.Tensor,
                     top_k_indices: torch.Tensor,
                     logits: torch.Tensor):
    """
    Calculate and return the auxiliary loss based on the accumulated statistics.
    switch transformers: https://arxiv.org/pdf/2101.03961.pdf
    A. Differentiable Load Balancing Loss

    Args:
        num_experts (int): The number of experts.
        top_k_gates (tensor): k个最大值的对应logits, 其每个元素表示对应logit概率值。
        top_k_indices (tensor): k个最大值的对应logits索引, 其每个元素表示logit对应索引值。
        logits (tensor): 其每个元素表示对应logit概率值。

    Returns:
        torch.Tensor: The calculated auxiliary loss.
    """
    # 对logits进行softmax操作，得到每个类别的概率分布
    probs = torch.softmax(logits, dim=-1)
    zeros = torch.zeros_like(probs)
    # Convert zeros to match top_k_gates dtype
    zeros = zeros.to(top_k_gates.dtype)
    gates = zeros.scatter(-1, top_k_indices, top_k_gates)

    # 获取 logits 张量的批次大小，即样本数量
    count = logits.size(0)
    # 计算每个专家被选中的概率之和，即将概率沿着批次维度求和。
    probs = probs.sum(0)
    # 计算每个专家被选中的频率，即计算门控值大于0的次数（即专家被选中的次数），
    # 然后将其沿着批次维度求和。
    freq = (gates > 0).float().sum(0)
    # 计算 logits 张量经过 softmax 处理后的平方和的对数。
    # 这里首先使用 softmax 函数将 logits 转换为概率分布，
    # 然后计算概率分布的每个样本的平方和，并取对数，最后将结果沿着批次维度求和。
    lsesq = (torch.log(torch.exp(logits).sum(dim=-1)) ** 2).sum()

    # 计算专家选择损失，其计算方式为对每个专家的概率和频率进行归一化，然后计算它们的点积，最后将结果乘以专家数量。
    switchloss = num_experts * \
        (F.normalize(probs, p=1, dim=0) * F.normalize(freq, p=1, dim=0)).sum()
    # 计算 z 损失，即 logits 的对数平方和除以样本数量
    zloss = lsesq / count
    # 将专家选择损失和 z 损失加权相加得到最终的辅助损失
    loss = switchloss + 0.1 * zloss

    return loss

In [144]:
#noisy top-k gating
class NoisyTopkRouter(nn.Module):
    def __init__(self, n_embed, num_experts, top_k):
        super(NoisyTopkRouter, self).__init__()
        self.num_experts = num_experts
        self.top_k = top_k
        #layer for router logits
        self.topkroute_linear = nn.Linear(n_embed, num_experts)
        self.noise_linear =nn.Linear(n_embed, num_experts)
        self.aux_loss = 0.0


    def forward(self, mh_output):
        # mh_ouput is the output tensor from multihead self attention block
        logits = self.topkroute_linear(mh_output)

        #Noise logits
        noise_logits = self.noise_linear(mh_output)

        #Adding scaled unit gaussian noise to the logits
        noise = torch.randn_like(logits)*F.softplus(noise_logits)
        noisy_logits = logits + noise

        top_k_logits, indices = noisy_logits.topk(self.top_k, dim=-1)
        zeros = torch.full_like(noisy_logits, float('-inf'))
        sparse_logits = zeros.scatter(-1, indices, top_k_logits)
        router_output = F.softmax(sparse_logits, dim=-1)
        # 训练时才计算辅助loss值, 为了专家之间的负载平衡
        if self.training:
          self.aux_loss = compute_aux_loss(self.num_experts, router_output, indices, noisy_logits)

        return router_output, indices


In [145]:
#Testing this out, again:
num_experts = 8
top_k = 2
n_embd = 16

mh_output = torch.randn(2, 4, n_embd)  # Example input
noisy_top_k_gate = NoisyTopkRouter(n_embd, num_experts, top_k)
noisy_top_k_gate.training = True

gating_output, indices = noisy_top_k_gate(mh_output)
print(noisy_top_k_gate.aux_loss.shape, noisy_top_k_gate.aux_loss)
gating_output.shape, gating_output, indices
#It works!!

torch.Size([]) tensor(10.0999, grad_fn=<AddBackward0>)


(torch.Size([2, 4, 8]),
 tensor([[[0.0000, 0.0000, 0.0000, 0.0000, 0.6761, 0.0000, 0.3239, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.4239, 0.0000, 0.5761],
          [0.0000, 0.0000, 0.9223, 0.0000, 0.0000, 0.0000, 0.0777, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.4978, 0.0000, 0.5022, 0.0000, 0.0000]],
 
         [[0.6397, 0.0000, 0.0000, 0.0000, 0.3603, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.7183, 0.0000, 0.0000, 0.0000, 0.2817, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.8339, 0.1661, 0.0000, 0.0000],
          [0.4386, 0.0000, 0.0000, 0.5614, 0.0000, 0.0000, 0.0000, 0.0000]]],
        grad_fn=<SoftmaxBackward0>),
 tensor([[[4, 6],
          [7, 5],
          [2, 6],
          [5, 3]],
 
         [[0, 4],
          [2, 6],
          [4, 5],
          [3, 0]]]))

## Putting it all together

In [146]:
#First defining hyperparameters and boiler plate code. Imports and data preparation code is repeated for convenience
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.nn import init

# hyperparameters
batch_size = 16 # how many independent sequences will we process in parallel?
block_size = 32 # what is the maximum context length for predictions?
max_iters = 5000
eval_interval = 100
learning_rate = 1e-3
device = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_iters = 400
head_size = 16
n_embed = 128
n_head = 8
n_layer = 8
dropout = 0.1
num_experts = 8
top_k = 2
aux_loss_coef=0.01
# ------------

torch.manual_seed(1337)

with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

# here are all the unique characters that occur in this text
chars = sorted(list(set(text)))
vocab_size = len(chars)
# create a mapping from characters to integers
stoi = { ch:i for i,ch in enumerate(chars) }
itos = { i:ch for i,ch in enumerate(chars) }
encode = lambda s: [stoi[c] for c in s] # encoder: take a string, output a list of integers
decode = lambda l: ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string

# Train and test splits
data = torch.tensor(encode(text), dtype=torch.long)
n = int(0.9*len(data)) # first 90% will be train, rest val
train_data = data[:n]
val_data = data[n:]

# data loading
def get_batch(split):
    # generate a small batch of data of inputs x and targets y
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    x, y = x.to(device), y.to(device)
    return x, y

@torch.no_grad()
def estimate_loss():
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split)
            logits, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

In [147]:
class Head(nn.Module):
    """ one head of self-attention """

    def __init__(self, head_size):
        super().__init__()
        self.key = nn.Linear(n_embed, head_size, bias=False)
        self.query = nn.Linear(n_embed, head_size, bias=False)
        self.value = nn.Linear(n_embed, head_size, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))

        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        B,T,C = x.shape
        k = self.key(x)   # (B,T,C)
        q = self.query(x) # (B,T,C)
        # compute attention scores ("affinities")
        wei = q @ k.transpose(-2,-1) * C**-0.5 # (B, T, C) @ (B, C, T) -> (B, T, T)
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf')) # (B, T, T)
        wei = F.softmax(wei, dim=-1) # (B, T, T)
        wei = self.dropout(wei)
        # perform the weighted aggregation of the values
        v = self.value(x) # (B,T,C)
        out = wei @ v # (B, T, T) @ (B, T, C) -> (B, T, C)
        return out

#Multi-Headed Self Attention
class MultiHeadAttention(nn.Module):
    """ multiple heads of self-attention in parallel """

    def __init__(self, num_heads, head_size):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
        self.proj = nn.Linear(n_embed, n_embed)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim=-1)
        out = self.dropout(self.proj(out))
        return out


In [178]:
#Expert module
class Expert(nn.Module):
    """ An MLP is a simple linear layer followed by a non-linearity i.e. each Expert """

    def __init__(self, n_embed):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embed, 4 * n_embed),
            nn.ReLU(),
            nn.Linear(4 * n_embed, n_embed),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        return self.net(x)

#noisy top-k gating
class NoisyTopkRouter(nn.Module):
    def __init__(self, n_embed, num_experts, top_k):
        super(NoisyTopkRouter, self).__init__()
        self.num_experts = num_experts
        self.top_k = top_k
        #layer for router logits
        self.topkroute_linear = nn.Linear(n_embed, num_experts)
        self.noise_linear =nn.Linear(n_embed, num_experts)
        self.aux_loss = 0.0


    def forward(self, mh_output):
        # mh_ouput is the output tensor from multihead self attention block
        logits = self.topkroute_linear(mh_output)

        #Noise logits
        noise_logits = self.noise_linear(mh_output)

        #Adding scaled unit gaussian noise to the logits
        noise = torch.randn_like(logits)*F.softplus(noise_logits)
        noisy_logits = logits + noise

        top_k_logits, indices = noisy_logits.topk(self.top_k, dim=-1)
        zeros = torch.full_like(noisy_logits, float('-inf'))
        sparse_logits = zeros.scatter(-1, indices, top_k_logits)
        router_output = F.softmax(sparse_logits, dim=-1)
        # 训练时才计算辅助loss值, 为了专家之间的负载平衡
        if self.training:
            self.aux_loss = compute_aux_loss(self.num_experts, router_output,
                                             indices, noisy_logits)

        return router_output, indices

#Now create the sparse mixture of experts module
class SparseMoE(nn.Module):
    def __init__(self, n_embed, num_experts, top_k, capacity_factor=1.0):
        super(SparseMoE, self).__init__()
        self.router = NoisyTopkRouter(n_embed, num_experts, top_k)
        self.experts = nn.ModuleList([Expert(n_embed) for _ in range(num_experts)])
        self.top_k = top_k
        # add capacity_factor
        self.capacity_factor = capacity_factor
        self.num_experts = num_experts

    def forward(self, x):
    # Assuming x has shape [batch_size, seq_len, n_embd]
        batch_size, seq_len, _ = x.shape
        gating_output, indices = self.router(x)
        final_output = torch.zeros_like(x)

        # Flatten the batch and sequence dimensions to treat each token independently
        flat_x = x.view(-1, x.size(-1))  # Now shape [batch_size * seq_len, n_embd]
        flat_gating_output = gating_output.view(-1, gating_output.size(-1))

        tokens_per_batch = batch_size * seq_len * self.top_k
        expert_capacity = int((tokens_per_batch / self.num_experts) * self.capacity_factor)

        updates = torch.zeros_like(flat_x)

        for i, expert in enumerate(self.experts):
            expert_mask = (indices == i).any(dim=-1)
            flat_mask = expert_mask.view(-1)
            selected_indices = torch.nonzero(flat_mask).squeeze(-1)

            limited_indices = selected_indices[:expert_capacity] if selected_indices.numel() > expert_capacity else selected_indices
            if limited_indices.numel() > 0:
                expert_input = flat_x[limited_indices]
                expert_output = expert(expert_input)

                gating_scores = flat_gating_output[limited_indices, i].unsqueeze(1)
                weighted_output = expert_output * gating_scores

                updates.index_add_(0, limited_indices, weighted_output)

        # Reshape updates to match the original dimensions of x
        final_output += updates.view(batch_size, seq_len, -1)

        return final_output


In [151]:
#First create a self attention + mixture of experts block, that may be repeated several number of times
#Copy pasting key architecture variables for clarity

class Block(nn.Module):
    """ Mixture of Experts Transformer block: communication followed by computation (multi-head self attention + SparseMoE) """

    def __init__(self, n_embed, n_head, num_experts, top_k):
        # n_embed: embedding dimension, n_head: the number of heads we'd like
        super().__init__()
        head_size = n_embed // n_head
        self.sa = MultiHeadAttention(n_head, head_size)
        self.smoe = SparseMoE(n_embed, num_experts, top_k)
        self.ln1 = nn.LayerNorm(n_embed)
        self.ln2 = nn.LayerNorm(n_embed)

    def forward(self, x):
        x = x + self.sa(self.ln1(x))
        x = x + self.smoe(self.ln2(x))
        return x

In [179]:
#Finally putting it all together to crease a sparse mixture of experts language model
class SparseMoELanguageModel(nn.Module):

    def __init__(self):
        super().__init__()
        # each token directly reads off the logits for the next token from a lookup table
        self.token_embedding_table = nn.Embedding(vocab_size, n_embed)
        self.position_embedding_table = nn.Embedding(block_size, n_embed)
        self.blocks = nn.ModuleList([Block(n_embed, n_head=n_head, num_experts=num_experts,top_k=top_k) for _ in range(n_layer)])
        self.ln_f = nn.LayerNorm(n_embed) # final layer norm
        self.lm_head = nn.Linear(n_embed, vocab_size)

    def forward(self, idx, targets=None):
        B, T = idx.shape

        # idx and targets are both (B,T) tensor of integers
        tok_emb = self.token_embedding_table(idx) # (B,T,C)
        pos_emb = self.position_embedding_table(torch.arange(T, device=device)) # (T,C)
        x = tok_emb + pos_emb # (B,T,C)
        # x = self.blocks(x)  # (B,T,C)
        aux_loss = 0.0
        for block in self.blocks:
            x = block(x)
            if self.training:
              #print(block.smoe.router.aux_loss)
              aux_loss += block.smoe.router.aux_loss

        x = self.ln_f(x) # (B,T,C)
        logits = self.lm_head(x) # (B,T,vocab_size)

        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)

        if targets is not None and self.training:
            loss += aux_loss_coef*aux_loss.to(loss.device)

        return logits, loss

    def generate(self, idx, max_new_tokens):
        # idx is (B, T) array of indices in the current context
        for _ in range(max_new_tokens):
            # crop idx to the last block_size tokens
            idx_cond = idx[:, -block_size:]
            # get the predictions
            logits, loss = self(idx_cond)
            # focus only on the last time step
            logits = logits[:, -1, :] # becomes (B, C)
            # apply softmax to get probabilities
            probs = F.softmax(logits, dim=-1) # (B, C)
            # sample from the distribution
            idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
            # append sampled index to the running sequence
            idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
        return idx

这里使用Kaiming He初始化，因为专家中存在ReLU激活函数。可以随意尝试使用更常用于Transformer的Glorot初始化。Jeremy Howard的Fastai第2部分有一个非常出色的讲座，从零开始实现了这些初始化方法：https://course.fast.ai/Lessons/lesson17.html

In [180]:
def kaiming_init_weights(m):
    if isinstance (m, (nn.Linear)):
        init.kaiming_normal_(m.weight)

In [181]:
model = SparseMoELanguageModel()
model.apply(kaiming_init_weights)


SparseMoELanguageModel(
  (token_embedding_table): Embedding(65, 128)
  (position_embedding_table): Embedding(32, 128)
  (blocks): ModuleList(
    (0-7): 8 x Block(
      (sa): MultiHeadAttention(
        (heads): ModuleList(
          (0-7): 8 x Head(
            (key): Linear(in_features=128, out_features=16, bias=False)
            (query): Linear(in_features=128, out_features=16, bias=False)
            (value): Linear(in_features=128, out_features=16, bias=False)
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
        (proj): Linear(in_features=128, out_features=128, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (smoe): SparseMoE(
        (router): NoisyTopkRouter(
          (topkroute_linear): Linear(in_features=128, out_features=8, bias=True)
          (noise_linear): Linear(in_features=128, out_features=8, bias=True)
        )
        (experts): ModuleList(
          (0-7): 8 x Expert(
            (net): Sequential(
       

In [None]:
m = model.to(device)
# print the number of parameters in the model
print(sum(p.numel() for p in m.parameters())/1e6, 'M parameters')

# create a PyTorch optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

model.train()
for iter in range(max_iters):

    # every once in a while evaluate the loss on train and val sets
    if (iter+1) % eval_interval == 0 or iter == max_iters - 1:
        losses = estimate_loss()
        print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")

    # sample a batch of data
    xb, yb = get_batch('train')

    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

8.996545 M parameters
step 99: train loss 2.8891, val loss 2.8999
step 199: train loss 2.6157, val loss 2.6062
step 299: train loss 2.5084, val loss 2.5145
step 399: train loss 2.4519, val loss 2.4494
step 499: train loss 2.3518, val loss 2.3574
step 599: train loss 2.2881, val loss 2.3115
step 699: train loss 2.2374, val loss 2.2419
step 799: train loss 2.1859, val loss 2.2104
step 899: train loss 2.1419, val loss 2.1815
step 999: train loss 2.0992, val loss 2.1380
step 1099: train loss 2.0773, val loss 2.1293
step 1199: train loss 2.0361, val loss 2.1078
step 1299: train loss 2.0126, val loss 2.0945
step 1399: train loss 1.9767, val loss 2.0741
step 1499: train loss 1.9619, val loss 2.0489
step 1599: train loss 1.9214, val loss 2.0253
step 1699: train loss 1.9138, val loss 2.0122
step 1799: train loss 1.8816, val loss 1.9845
step 1899: train loss 1.8613, val loss 1.9789
step 1999: train loss 1.8488, val loss 1.9656
step 2099: train loss 1.8336, val loss 1.9536
step 2199: train loss 1

In [None]:
# generate from the model. Not great. Not too bad either
context = torch.zeros((1, 1), dtype=torch.long, device=device)
print(decode(m.generate(context, max_new_tokens=2000)[0].tolist()))


ae
old, buem neran dose: bo'ld bua woold nor may siver!

LyYORTOLNSIA:
you, what, recooatusan that thou mody!

BuNhRUTUO SINE:

Frotsef'O Rich Lenger:

reemendsseou:
 groung at yo thes Eave hers how had? Whht blse
the tronighnest with of that but mind son
What a gentle, wo thave rear amore off this the poma thesrixe.

FUY BLOKE VFIY BOLLUCVGge:Dlectst yook; a goad, them.

COMINIALE:
vilny an? with! Priviini modeisus!
ThAsesd uproe--

CORuTHy.
FRray;
Yoo,e sads yoush higher waslif you minhim! he'f over hrins miet.

CAMILLO:
Hase dost coe must nou, and live thoug fel!

JuCeIvg:
I o? no there, arm you. cand his cannow thmake a stopworn fothin aling:
Sive lauge meot an to the sort uae: ne
Ius ast begent agrody thugh. Some shimRs of dele, so now.
Whot, Whut not hase one's presier broberes
Thatt him sent my nainher's preself,
Bouviu. Asfriet me mays server yet bew wall if wragaft its:
But it lover faminzs, havH, livNa rhav,
How shat add nobty quars that gost lamor.

KING RICHARR LIEDWD IRDk