<a href="https://colab.research.google.com/github/weedge/doraemon-nb/blob/main/makeMoA_MoE_from_Scratch_with_Expert_Capacity_Aux_Loss_Balance.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#### 从头开始的稀疏专家混合语言模型，灵感来源于（并在很大程度上基于）[Andrej Karpathy的makemore项目](https://github.com/karpathy/makemore) :)

这是一个从头开始实现的稀疏专家混合语言模型。这受到了Andrej Karpathy项目'makemore'的启发，并且大部分重用的组件都来自于该实现。就像makemore一样，makeMoE也是一个自回归的字符级语言模型，但是使用了上述稀疏专家的架构。

与makemore体系结构相比，有显着的变化

- 稀疏专家混合而不是孤立的前馈神经网络。
- 使用了top-k门控和嘈杂的top-k门控实现。
- 初始化 - 这里使用了Kaiming He初始化，但这个笔记本的重点是可hack性，所以你可以替换为Xavier Glorot等，并进行尝试。

与makemore不变的部分

- Andrej最初选择的数据集、预处理（tokenizer）和语言建模任务 - 生成类似莎士比亚的文本
- 自注意力因果实现
- 训练循环
- 推断逻辑

在此实现中大量引用的论文：

- Mixtral of Experts：https://arxiv.org/pdf/2401.04088.pdf
- Outrageosly Large Neural Networks: The Sparsely-Gated Mixture-Of-Experts layer：https://arxiv.org/pdf/1701.06538.pdf
- Switch Transformers: Scaling to Trillion Parameter Models with Simple and Efficient Sparsity: https://arxiv.org/pdf/2101.03961.pdf

这个笔记本演示了整个模型架构的直觉以及所有内容是如何相互关联的。


请注意，该实现强调易读性和可hack性而不是性能，因此有许多方法可以改进此实现。请尝试并告诉我。

如果在colab中运行，训练时加速，选择t4 GPU即可。


![](https://raw.githubusercontent.com/weedge/baby-llm/main/docs/simple-moa-moe.drawio.png)

In [2]:
#Import the necessary packages and set seed for reproducibility. For this notebook, pytorch is all you need
import torch
import torch.nn as nn
from torch.nn import functional as F
import math # SDPA  q@k / sqrt(C); u can use q@k * C**-0.5 stead it
torch.manual_seed(42)


<torch._C.Generator at 0x7dd4e4519c90>

### 数据集处理

接下来的几个部分，下载数据、预处理数据和自注意力直接来自makemore。我稍微详细说明了自注意力，并添加了一些可视化辅助，以便更好地理解这个过程。

In [1]:
# Downloading the tiny shakespeare dataset
!wget https://raw.githubusercontent.com/AviSoori1x/makeMoE/main/input.txt

--2024-04-11 07:01:06--  https://raw.githubusercontent.com/AviSoori1x/makeMoE/main/input.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1115394 (1.1M) [text/plain]
Saving to: ‘input.txt’


2024-04-11 07:01:06 (20.1 MB/s) - ‘input.txt’ saved [1115394/1115394]



In [3]:
# read it in to inspect it
with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

In [4]:
print("length of dataset in characters: ", len(text))

length of dataset in characters:  1115394


In [5]:
# let's look at the first 1000 characters
print(text[:1000])

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you know Caius Marcius is chief enemy to the people.

All:
We know't, we know't.

First Citizen:
Let us kill him, and we'll have corn at our own price.
Is't a verdict?

All:
No more talking on't; let it be done: away, away!

Second Citizen:
One word, good citizens.

First Citizen:
We are accounted poor citizens, the patricians good.
What authority surfeits on would relieve us: if they
would yield us but the superfluity, while it were
wholesome, we might guess they relieved us humanely;
but they think we are too dear: the leanness that
afflicts us, the object of our misery, is as an
inventory to particularise their abundance; our
sufferance is a gain to them Let us revenge this with
our pikes, ere we become rakes: for the gods know I
speak this in hunger for bread, not in thirst for revenge.



In [6]:
# here are all the unique characters that occur in this text
chars = sorted(list(set(text)))
vocab_size = len(chars)
print(''.join(chars))
print(vocab_size)


 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz
65


In [7]:
# create a mapping from characters to integers
stoi = { ch:i for i,ch in enumerate(chars) }
itos = { i:ch for i,ch in enumerate(chars) }
encode = lambda s: [stoi[c] for c in s] # encoder: take a string, output a list of integers
decode = lambda l: ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string

print(encode("hii there"))
print(decode(encode("hii there")))

[46, 47, 47, 1, 58, 46, 43, 56, 43]
hii there


In [8]:
# let's now encode the entire text dataset and store it into a torch.Tensor
data = torch.tensor(encode(text), dtype=torch.long)
print(data.shape, data.dtype)
print(data[:1000]) # the 1000 characters we looked at earier will to the GPT look like this

torch.Size([1115394]) torch.int64
tensor([18, 47, 56, 57, 58,  1, 15, 47, 58, 47, 64, 43, 52, 10,  0, 14, 43, 44,
        53, 56, 43,  1, 61, 43,  1, 54, 56, 53, 41, 43, 43, 42,  1, 39, 52, 63,
         1, 44, 59, 56, 58, 46, 43, 56,  6,  1, 46, 43, 39, 56,  1, 51, 43,  1,
        57, 54, 43, 39, 49,  8,  0,  0, 13, 50, 50, 10,  0, 31, 54, 43, 39, 49,
         6,  1, 57, 54, 43, 39, 49,  8,  0,  0, 18, 47, 56, 57, 58,  1, 15, 47,
        58, 47, 64, 43, 52, 10,  0, 37, 53, 59,  1, 39, 56, 43,  1, 39, 50, 50,
         1, 56, 43, 57, 53, 50, 60, 43, 42,  1, 56, 39, 58, 46, 43, 56,  1, 58,
        53,  1, 42, 47, 43,  1, 58, 46, 39, 52,  1, 58, 53,  1, 44, 39, 51, 47,
        57, 46, 12,  0,  0, 13, 50, 50, 10,  0, 30, 43, 57, 53, 50, 60, 43, 42,
         8,  1, 56, 43, 57, 53, 50, 60, 43, 42,  8,  0,  0, 18, 47, 56, 57, 58,
         1, 15, 47, 58, 47, 64, 43, 52, 10,  0, 18, 47, 56, 57, 58,  6,  1, 63,
        53, 59,  1, 49, 52, 53, 61,  1, 15, 39, 47, 59, 57,  1, 25, 39, 56, 41,
      

In [9]:
# Let's now split up the data into train and validation sets
n = int(0.9*len(data)) # first 90% will be train, rest val
train_data = data[:n]
val_data = data[n:]

In [10]:
block_size = 8
train_data[:block_size+1]

tensor([18, 47, 56, 57, 58,  1, 15, 47, 58])

In [11]:
x = train_data[:block_size]
y = train_data[1:block_size+1]
for t in range(block_size):
    context = x[:t+1]
    target = y[t]
    print(f"when input is {context} the target: {target}")

when input is tensor([18]) the target: 47
when input is tensor([18, 47]) the target: 56
when input is tensor([18, 47, 56]) the target: 57
when input is tensor([18, 47, 56, 57]) the target: 58
when input is tensor([18, 47, 56, 57, 58]) the target: 1
when input is tensor([18, 47, 56, 57, 58,  1]) the target: 15
when input is tensor([18, 47, 56, 57, 58,  1, 15]) the target: 47
when input is tensor([18, 47, 56, 57, 58,  1, 15, 47]) the target: 58


In [12]:
batch_size = 4 # how many independent sequences will we process in parallel?
block_size = 8 # what is the maximum context length for predictions?

In [13]:
ix = torch.randint(len(data) - block_size, (batch_size,))
ix

tensor([250930, 237205, 974116, 383898])

In [14]:
x = torch.stack([data[i:i+block_size] for i in ix])
y = torch.stack([data[i+1:i+block_size+1] for i in ix])
x

tensor([[42,  1, 58, 46, 59, 57,  1, 21],
        [54, 56, 47, 43, 57, 58, 11,  0],
        [49, 47, 52, 45, 12,  1, 58, 46],
        [58, 46, 53, 59, 58,  1, 56, 43]])

In [15]:
y

tensor([[ 1, 58, 46, 59, 57,  1, 21,  1],
        [56, 47, 43, 57, 58, 11,  0, 37],
        [47, 52, 45, 12,  1, 58, 46, 53],
        [46, 53, 59, 58,  1, 56, 43, 42]])

以下代码块清楚地展示了预测的自回归性质，以及上下文是对token（在本例中是字符）的一维排列的滚动窗口。

In [16]:
def get_batch(split):
    # generate a small batch of data of inputs x and targets y
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    return x, y

xb, yb = get_batch('train')
print('inputs:')
print(xb.shape)
print(xb)
print('targets:')
print(yb.shape)
print(yb)

print('----')

for b in range(batch_size): # batch dimension
    for t in range(block_size): # time dimension
        context = xb[b, :t+1]
        target = yb[b,t]
        print(f"when input is {context.tolist()} the target: {target}")

inputs:
torch.Size([4, 8])
tensor([[ 6,  0, 14, 43, 44, 53, 56, 43],
        [39,  1, 42, 59, 43,  1, 39, 52],
        [47, 41, 43,  1, 39, 52, 42,  1],
        [53, 44,  1, 50, 43, 58,  1, 58]])
targets:
torch.Size([4, 8])
tensor([[ 0, 14, 43, 44, 53, 56, 43,  1],
        [ 1, 42, 59, 43,  1, 39, 52, 42],
        [41, 43,  1, 39, 52, 42,  1, 42],
        [44,  1, 50, 43, 58,  1, 58, 46]])
----
when input is [6] the target: 0
when input is [6, 0] the target: 14
when input is [6, 0, 14] the target: 43
when input is [6, 0, 14, 43] the target: 44
when input is [6, 0, 14, 43, 44] the target: 53
when input is [6, 0, 14, 43, 44, 53] the target: 56
when input is [6, 0, 14, 43, 44, 53, 56] the target: 43
when input is [6, 0, 14, 43, 44, 53, 56, 43] the target: 1
when input is [39] the target: 1
when input is [39, 1] the target: 42
when input is [39, 1, 42] the target: 59
when input is [39, 1, 42, 59] the target: 43
when input is [39, 1, 42, 59, 43] the target: 1
when input is [39, 1, 42, 59, 4

如果想直接进行训练，直接进入  Putting it all together to train and generate 这节进行训练

### 理解 Causal Scaled Dot Product Self Attention (SDPA)

这段代码来自于Andrej Karpathy出色的makemore代码库，链接在仓库中。





![scaled dot product self attention](https://raw.githubusercontent.com/AviSoori1x/makeMoE/main/images/self_attention.png)

提供的代码演示了自注意力的机制和基本概念，特别是关注经典的缩放点积自注意力。在这个变体中，查询、键和值矩阵都来自同一个输入序列。为了确保自回归语言生成过程的完整性，特别是在仅包含解码器的模型中，代码实现了掩码。这种掩码技术至关重要，因为它隐藏了当前标记位置后面的任何信息，从而将模型的注意力引导到序列的前面部分。这样的注意力机制称为因果自注意力。值得注意的是，稀疏专家混合模型并不局限于仅包含解码器的Transformer架构。事实上，在这个领域的许多重要工作，特别是由Shazeer等人完成的工作，都围绕着T5架构展开，该架构包含了Transformer模型中的编码器和解码器组件。

In [28]:
torch.manual_seed(1337)
B,T,C = 2,4,16 # batch(batch_size), time(block_size seq_len), channels(n_embd)
x = torch.randn(B,T,C)

# let's see a single Head perform self-attention
head_size = 8 # H
key = nn.Linear(C, head_size, bias=False)
query = nn.Linear(C, head_size, bias=False)
value = nn.Linear(C, head_size, bias=False)

k = key(x)   # (B, T, H)
q = query(x) # (B, T, H)
wei =  q @ k.transpose(-2, -1) # (B, T, H) @ (B, H, T) ---> (B, T, T)
print(wei)
wei = q @ k.transpose(-2,-1) * C**-0.5 # (B, T, H) @ (B, H, T) -> (B, T, T)
print(wei)

tril = torch.tril(torch.ones(T, T))
print(tril)
#wei = torch.zeros((T,T))
wei = wei.masked_fill(tril == 0, float('-inf'))
print(wei)

wei = F.softmax(wei, dim=-1) #B,T,T
print(wei)

v = value(x) #B,T,H
out = wei @ v # (B,T,T) @ (B,T,H) -> (B,T,H)
print(out)

#The output from this final matrix product is subsequently passsed through a linear layer as shown in the diagram above

out.shape

tensor([[[ 0.3811,  1.2460, -2.0277, -0.7762],
         [ 0.9262, -0.0302,  0.7229, -0.4251],
         [-0.1234, -0.8850, -0.8566, -0.6796],
         [ 0.2638, -1.8139,  1.2615, -2.6516]],

        [[-0.1472, -0.9865,  0.2428, -0.4005],
         [-0.9285,  0.3154,  0.4412,  2.0914],
         [ 0.3901,  0.0709,  0.4188,  0.5025],
         [-0.1971, -0.4844,  0.2090, -1.5026]]], grad_fn=<UnsafeViewBackward0>)
tensor([[[ 0.0953,  0.3115, -0.5069, -0.1941],
         [ 0.2315, -0.0075,  0.1807, -0.1063],
         [-0.0308, -0.2213, -0.2141, -0.1699],
         [ 0.0659, -0.4535,  0.3154, -0.6629]],

        [[-0.0368, -0.2466,  0.0607, -0.1001],
         [-0.2321,  0.0788,  0.1103,  0.5228],
         [ 0.0975,  0.0177,  0.1047,  0.1256],
         [-0.0493, -0.1211,  0.0523, -0.3756]]], grad_fn=<MulBackward0>)
tensor([[1., 0., 0., 0.],
        [1., 1., 0., 0.],
        [1., 1., 1., 0.],
        [1., 1., 1., 1.]])
tensor([[[ 0.0953,    -inf,    -inf,    -inf],
         [ 0.2315, -0.0075,    -i

torch.Size([2, 4, 8])

对因果自注意力和多头因果自注意力的代码进行泛化和模块化。多头自注意力将多个注意力头并行应用，每个注意力头专注于通道的不同部分（嵌入维度）。

In [19]:
#Causal scaled dot product self-Attention Head

n_embd = 64 # hidden_size
n_head = 4
n_layer = 4
head_size:int = 16 # n_embd/n_head
dropout = 0.1

class Head(nn.Module):
    """ one head of self-attention """

    def __init__(self, head_size):
        super().__init__()
        self.key = nn.Linear(n_embd, head_size, bias=False)
        self.query = nn.Linear(n_embd, head_size, bias=False)
        self.value = nn.Linear(n_embd, head_size, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))

        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        B,T,C = x.shape
        k = self.key(x)   # (B,T,C)
        q = self.query(x) # (B,T,C)
        # compute attention scores ("affinities")
        wei = q @ k.transpose(-2,-1) * C**-0.5 # (B, T, C) @ (B, C, T) -> (B, T, T)
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf')) # (B, T, T)
        wei = F.softmax(wei, dim=-1) # (B, T, T)
        wei = self.dropout(wei)
        # perform the weighted aggregation of the values
        v = self.value(x) # (B,T,C)
        out = wei @ v # (B, T, T) @ (B, T, C) -> (B, T, C)
        return out

In [20]:
#Multi-Headed Self Attention
class MultiHeadAttention(nn.Module):
    """ multiple heads of self-attention in parallel """

    def __init__(self, num_heads, head_size):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
        self.proj = nn.Linear(n_embd, n_embd)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim=-1)
        out = self.dropout(self.proj(out))
        return out


In [21]:
#Confirming that what's output from multi head attention is the original embedding size
B,T,C = 4,block_size,n_embd # batch(batch_size), time(block_size seq_len), channels(n_embd)
x = torch.randn(B,T,C)
print(x.shape)
mha = MultiHeadAttention(n_head,head_size)
mha(x).shape

torch.Size([4, 4, 64])


torch.Size([4, 4, 64])

### 创建专家模型： 简单的多层感知器（Multi Layer Perceptron - MLP）

在稀疏专家（MoE）架构中，每个Transformer块内部的自注意力机制保持不变。然而，每个块的结构发生了显著变化：标准的前馈神经网络被替换为几个稀疏激活的前馈网络，称为专家。 "稀疏激活" 指的是序列中的每个标记仅被路由到总池中的有限数量的这些专家之一或两个 - 通常是一个或两个。这种修改允许对输入数据的不同部分进行专门处理，使模型能够有效地处理更广泛的复杂性。

![experts](https://raw.githubusercontent.com/AviSoori1x/makeMoE/main/images/experts.png)

In [None]:
#Expert module
class Expert(nn.Module):
    """ An MLP is a simple linear layer followed by a non-linearity i.e. each Expert """

    def __init__(self, n_embd):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embd, 4 * n_embd),
            nn.ReLU(),
            nn.Linear(4 * n_embd, n_embd),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        return self.net(x)

### Top-k Gating Router

![top k gating](https://raw.githubusercontent.com/AviSoori1x/makeMoE/main/images/topk.png)

门控网络，也称为路由器，确定每个token从多头注意力中由哪个专家网络接收输出。让我们考虑一个简单的例子：假设有4个专家，并且要将标记路由到前2个专家。最初，我们通过一个线性层将token输入到门控网络中。这个层将输入张量从形状为（2，4，32）——表示（批量大小，tokens，n_embed，其中n_embed是输入的通道维度）——投影到一个新形状为（2，4，4）的张量，对应于（批量大小，tokens，num_experts），其中num_experts是专家网络的数量。接下来，我们确定最后一维中前k=2个最高值及其相应的索引。

In [None]:
#Understanding how gating works
num_experts = 4
top_k=2
n_embed=32


#Example multi-head attention output for a simple illustrative example, consider n_embed=32, context_length=4 and batch_size=2
mh_output = torch.randn(2, 4, n_embed)

topkgate_linear = nn.Linear(n_embed, num_experts) # nn.Linear(32, 4)

logits = topkgate_linear(mh_output)
top_k_logits, top_k_indices = logits.topk(top_k, dim=-1)  # Get top-k experts
top_k_logits, top_k_indices

(tensor([[[ 0.9558,  0.1610],
          [ 0.8659, -0.1494],
          [ 0.8765,  0.7202],
          [ 0.9496, -0.6609]],
 
         [[ 0.4419, -0.2500],
          [ 1.2602,  0.8430],
          [ 0.8570,  0.7822],
          [ 0.7376,  0.2561]]], grad_fn=<TopkBackward0>),
 tensor([[[2, 0],
          [3, 2],
          [3, 0],
          [1, 2]],
 
         [[1, 3],
          [1, 2],
          [1, 2],
          [0, 1]]]))

通过仅保留沿着最后一个维度的各自索引处的前k个值，获取稀疏门控输出。用'-inf'填充其余部分，并通过softmax激活函数传递。这将'-inf'值推向零，使前两个值更加突出，并且总和为1。这种总和为1有助于对专家输出进行加权。

In [None]:
zeros = torch.full_like(logits, float('-inf')) #full_like clones a tensor and fills it with a specified value (like infinity) for masking or calculations.
sparse_logits = zeros.scatter(-1, top_k_indices, top_k_logits)
sparse_logits

tensor([[[ 0.1610,    -inf,  0.9558,    -inf],
         [   -inf,    -inf, -0.1494,  0.8659],
         [ 0.7202,    -inf,    -inf,  0.8765],
         [   -inf,  0.9496, -0.6609,    -inf]],

        [[   -inf,  0.4419,    -inf, -0.2500],
         [   -inf,  1.2602,  0.8430,    -inf],
         [   -inf,  0.8570,  0.7822,    -inf],
         [ 0.7376,  0.2561,    -inf,    -inf]]], grad_fn=<ScatterBackward0>)

In [None]:
gating_output= F.softmax(sparse_logits, dim=-1)
gating_output

tensor([[[0.3111, 0.0000, 0.6889, 0.0000],
         [0.0000, 0.0000, 0.2660, 0.7340],
         [0.4610, 0.0000, 0.0000, 0.5390],
         [0.0000, 0.8335, 0.1665, 0.0000]],

        [[0.0000, 0.6664, 0.0000, 0.3336],
         [0.0000, 0.6028, 0.3972, 0.0000],
         [0.0000, 0.5187, 0.4813, 0.0000],
         [0.6181, 0.3819, 0.0000, 0.0000]]], grad_fn=<SoftmaxBackward0>)

### Noisy Top-k Gating Router (adding noisy top-k Gating for load balancing)

泛化和模块化上述代码，并添加嘈杂的top-k门控以实现负载平衡。

In [None]:
# First define the top k router module
class TopkRouter(nn.Module):
    def __init__(self, n_embed, num_experts, top_k):
        super(TopkRouter, self).__init__()
        self.top_k = top_k
        self.linear =nn.Linear(n_embed, num_experts)

    def forward(self, mh_ouput):
        # mh_ouput is the output tensor from multihead self attention block
        logits = self.linear(mh_output)
        top_k_logits, indices = logits.topk(self.top_k, dim=-1)
        zeros = torch.full_like(logits, float('-inf'))
        sparse_logits = zeros.scatter(-1, indices, top_k_logits)
        router_output = F.softmax(sparse_logits, dim=-1)
        return router_output, indices



In [None]:
#Testing this out:
num_experts = 4
top_k = 2
n_embd = 32

mh_output = torch.randn(2, 4, n_embd)  # Example input
top_k_gate = TopkRouter(n_embd, num_experts, top_k)
gating_output, indices = top_k_gate(mh_output)
gating_output.shape, gating_output, indices
#And it works!!

(torch.Size([2, 4, 4]),
 tensor([[[0.0000, 0.0000, 0.6237, 0.3763],
          [0.8167, 0.0000, 0.1833, 0.0000],
          [0.2440, 0.7560, 0.0000, 0.0000],
          [0.4934, 0.0000, 0.0000, 0.5066]],
 
         [[0.0000, 0.0000, 0.5009, 0.4991],
          [0.0000, 0.0000, 0.4645, 0.5355],
          [0.0000, 0.7588, 0.2412, 0.0000],
          [0.6103, 0.0000, 0.3897, 0.0000]]], grad_fn=<SoftmaxBackward0>),
 tensor([[[2, 3],
          [0, 2],
          [1, 0],
          [3, 0]],
 
         [[2, 3],
          [3, 2],
          [1, 2],
          [0, 2]]]))

虽然最近发布的Mixtral论文没有提到，但我认为嘈杂的top-k门控是训练MoE模型的重要工具。基本上，您不希望所有的token都被发送到同一组“偏爱”的专家中。您希望在开发和探索之间达到良好的平衡。为此，为门控线性层的logits添加标准正态噪声有助于负载平衡，使训练更加高效。

![noisy top-k gating](https://raw.githubusercontent.com/AviSoori1x/makeMoE/main/images/noisytopkgating.png)

In [None]:
#Changing the above to accomodate noisy top-k gating
class NoisyTopkRouter(nn.Module):
    def __init__(self, n_embed, num_experts, top_k):
        super(NoisyTopkRouter, self).__init__()
        self.top_k = top_k
        #layer for router logits
        self.topkroute_linear = nn.Linear(n_embed, num_experts)
        self.noise_linear =nn.Linear(n_embed, num_experts)


    def forward(self, mh_output):
        # mh_ouput is the output tensor from multihead self attention block
        logits = self.topkroute_linear(mh_output)

        #Noise logits
        noise_logits = self.noise_linear(mh_output)

        #Adding scaled unit gaussian noise to the logits
        noise = torch.randn_like(logits)*F.softplus(noise_logits)
        noisy_logits = logits + noise

        top_k_logits, indices = noisy_logits.topk(self.top_k, dim=-1)
        zeros = torch.full_like(noisy_logits, float('-inf'))
        sparse_logits = zeros.scatter(-1, indices, top_k_logits)
        router_output = F.softmax(sparse_logits, dim=-1)
        return router_output, indices

In [None]:
#Testing this out, again:
num_experts = 8
top_k = 2
n_embd = 16

mh_output = torch.randn(2, 4, n_embd)  # Example input
noisy_top_k_gate = NoisyTopkRouter(n_embd, num_experts, top_k)
gating_output, indices = noisy_top_k_gate(mh_output)
gating_output.shape, gating_output, indices
#It works!!

(torch.Size([2, 4, 8]),
 tensor([[[0.8108, 0.0000, 0.0000, 0.0000, 0.1892, 0.0000, 0.0000, 0.0000],
          [0.5673, 0.0000, 0.0000, 0.0000, 0.4327, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.3452, 0.6548, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.6709, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.3291]],
 
         [[0.5440, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.4560],
          [0.9737, 0.0000, 0.0000, 0.0000, 0.0000, 0.0263, 0.0000, 0.0000],
          [0.4904, 0.5096, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
          [0.5366, 0.0000, 0.0000, 0.0000, 0.0000, 0.4634, 0.0000, 0.0000]]],
        grad_fn=<SoftmaxBackward0>),
 tensor([[[0, 4],
          [0, 4],
          [4, 3],
          [1, 7]],
 
         [[0, 7],
          [0, 5],
          [1, 0],
          [0, 5]]]))


### 创建一个稀疏专家模型 (Sparse MoE)


这个过程的主要方面涉及门控网络的输出。在获得这些结果后，会选择性地将前k个值与相应的前k个专家的输出相乘，以获得给定token的结果。这种选择性的乘法形成了加权求和，构成了SparseMoe块的输出。这个过程中的关键和具有挑战性的部分是避免不必要的乘法。只对前k个专家进行前向传播，然后计算这个加权和是至关重要的。对每个专家都进行前向传播会违背使用稀疏MoE的初衷，因为它将不再是稀疏的。

In [None]:
class SparseMoE(nn.Module):
    def __init__(self, n_embed, num_experts, top_k):
        super(SparseMoE, self).__init__()
        self.router = NoisyTopkRouter(n_embed, num_experts, top_k)
        self.experts = nn.ModuleList([Expert(n_embed) for _ in range(num_experts)])
        self.top_k = top_k

    def forward(self, x):
        gating_output, indices = self.router(x)
        final_output = torch.zeros_like(x)

        # Reshape inputs for batch processing
        flat_x = x.view(-1, x.size(-1))
        flat_gating_output = gating_output.view(-1, gating_output.size(-1))

        # Process each expert in parallel
        for i, expert in enumerate(self.experts):
            # Create a mask for the inputs where the current expert is in top-k
            expert_mask = (indices == i).any(dim=-1)
            flat_mask = expert_mask.view(-1)

            if flat_mask.any():
                expert_input = flat_x[flat_mask]
                expert_output = expert(expert_input)

                # Extract and apply gating scores
                gating_scores = flat_gating_output[flat_mask, i].unsqueeze(1)
                weighted_output = expert_output * gating_scores

                # Update final output
                # We need to scatter_add the weighted outputs to their original positions in the batch
                final_output.masked_scatter_(expert_mask.unsqueeze(-1), weighted_output)

        return final_output.view_as(x)




In [None]:
import torch
import torch.nn as nn

#Let's test this out
num_experts = 8
top_k = 2
n_embd = 16
dropout=0.1

mh_output = torch.randn(4, 8, n_embd)  # Example multi-head attention output
sparse_moe = SparseMoE(n_embd, num_experts, top_k)
final_output = sparse_moe(mh_output)
print("Shape of the final output:", final_output.shape)
print(final_output)

Shape of the final output: torch.Size([4, 8, 16])
tensor([[[ 7.0902e-03, -3.9785e-01,  2.5950e-01, -2.7855e-01,  1.0532e-01,
           2.1916e-02, -3.9044e-01, -2.9938e-01, -1.6038e-01,  8.4919e-03,
          -5.1904e-01, -2.6361e-01, -8.4711e-02,  3.5468e-02,  2.9799e-02,
          -6.0143e-03],
         [ 1.0294e-01,  4.1642e-01, -4.5188e-01,  3.9192e-03, -1.1848e-02,
           5.1312e-02,  1.7714e-01, -1.9039e-01, -0.0000e+00,  1.9526e-01,
           3.2151e-01, -3.0537e-01,  2.8723e-02,  2.9380e-01, -2.0152e-02,
          -3.2795e-01],
         [ 4.4361e-03,  0.0000e+00,  1.0039e-01, -2.2677e-02,  8.2390e-02,
           5.1152e-02, -0.0000e+00, -1.0229e-01, -6.3003e-02, -7.9471e-03,
          -2.7593e-01, -1.4275e-01,  1.6774e-02,  1.4232e-02,  1.9370e-02,
          -3.2314e-02],
         [-1.8278e-01, -3.1073e-01, -2.5140e-02,  1.5505e-01,  6.3831e-02,
           7.3816e-02, -6.7672e-02, -8.8567e-02,  1.9844e-03,  9.7871e-02,
          -3.4875e-01, -3.8098e-01, -1.6249e-01, -1.6

强调一下，需要认识到路由器/门控网络输出的前k个专家的幅值，正如上面的代码所示，也是非常重要的。这些前k个索引确定了被激活的专家，而在这些前k个维度中数值的大小决定了它们各自的权重。这种加权求和的概念在下面的图示中进一步强调了。

![sparse MoE](https://raw.githubusercontent.com/AviSoori1x/makeMoE/main/images/sparseMoEfinal.png)

## Load Balancing Auxiliary Loss

from:

switch transformers: https://arxiv.org/pdf/2101.03961.pdf  

A. Differentiable Load Balancing Loss

In [None]:
#@torch.jit.script
def compute_aux_loss(num_experts: int,
                     top_k_gates: torch.Tensor,
                     top_k_indices: torch.Tensor,
                     logits: torch.Tensor):
    """
    Calculate and return the auxiliary loss based on the accumulated statistics.
    switch transformers: https://arxiv.org/pdf/2101.03961.pdf
    A. Differentiable Load Balancing Loss

    Args:
        num_experts (int): The number of experts.
        top_k_gates (tensor): k个最大值的对应logits, 其每个元素表示对应logit概率值。
        top_k_indices (tensor): k个最大值的对应logits索引, 其每个元素表示logit对应索引值。
        logits (tensor): 其每个元素表示对应logit概率值。

    Returns:
        torch.Tensor: The calculated auxiliary loss.
    """
    # 对logits进行softmax操作，得到每个类别的概率分布
    probs = torch.softmax(logits, dim=-1)
    zeros = torch.zeros_like(probs)
    # Convert zeros to match top_k_gates dtype
    zeros = zeros.to(top_k_gates.dtype)
    gates = zeros.scatter(-1, top_k_indices, top_k_gates)

    # 获取 logits 张量的批次大小，即样本数量
    count = logits.size(0)
    # 计算每个专家被选中的概率之和，即将概率沿着批次维度求和。
    probs = probs.sum(0)
    # 计算每个专家被选中的频率，即计算门控值大于0的次数（即专家被选中的次数），
    # 然后将其沿着批次维度求和。
    freq = (gates > 0).float().sum(0)
    # 计算 logits 张量经过 softmax 处理后的平方和的对数。
    # 这里首先使用 softmax 函数将 logits 转换为概率分布，
    # 然后计算概率分布的每个样本的平方和，并取对数，最后将结果沿着批次维度求和。
    lsesq = (torch.log(torch.exp(logits).sum(dim=-1)) ** 2).sum()

    # 计算专家选择损失，其计算方式为对每个专家的概率和频率进行归一化，然后计算它们的点积，最后将结果乘以专家数量。
    switchloss = num_experts * \
        (F.normalize(probs, p=1, dim=0) * F.normalize(freq, p=1, dim=0)).sum()
    # 计算 z 损失，即 logits 的对数平方和除以样本数量
    zloss = lsesq / count
    # 将专家选择损失和 z 损失加权相加得到最终的辅助损失
    loss = switchloss + 0.1 * zloss

    return loss

In [None]:
#noisy top-k gating
class NoisyTopkRouter(nn.Module):
    def __init__(self, n_embed, num_experts, top_k):
        super(NoisyTopkRouter, self).__init__()
        self.num_experts = num_experts
        self.top_k = top_k
        #layer for router logits
        print(n_embed,num_experts,top_k)
        self.topkroute_linear = nn.Linear(n_embed, num_experts)
        self.noise_linear = nn.Linear(n_embed, num_experts)
        self.aux_loss = 0.0


    def forward(self, mh_output):
        print(f"mh_output.shape:{mh_output.shape}")

        # mh_ouput is the output tensor from multihead self attention block
        logits = self.topkroute_linear(mh_output)
        print(f"logits.shape:{logits.shape}")


        #Noise logits
        noise_logits = self.noise_linear(mh_output)


        #Adding scaled unit gaussian noise to the logits
        noise = torch.randn_like(logits)*F.softplus(noise_logits)
        noisy_logits = logits + noise

        top_k_logits, indices = noisy_logits.topk(self.top_k, dim=-1)
        zeros = torch.full_like(noisy_logits, float('-inf'))
        sparse_logits = zeros.scatter(-1, indices, top_k_logits)
        router_output = F.softmax(sparse_logits, dim=-1)
        # 训练时才计算辅助loss值, 为了专家之间的负载平衡
        if self.training:
          self.aux_loss = compute_aux_loss(self.num_experts, router_output, indices, noisy_logits)

        return router_output, indices


In [None]:
#Testing this out, again:
num_experts = 8
top_k = 2
n_embd = 16

input = torch.randn(2, 4, n_embd)  # Example input
input = input.reshape(-1,n_embd)
print(f"input.shape:{input.shape}")
noisy_top_k_gate = NoisyTopkRouter(n_embd, num_experts, top_k)
noisy_top_k_gate.training = True

gating_output, indices = noisy_top_k_gate(input)
print(noisy_top_k_gate.aux_loss.shape, noisy_top_k_gate.aux_loss)
gating_output.shape, gating_output, indices
#It works!!

input.shape:torch.Size([8, 16])
noise_logits.shape:torch.Size([8, 8])
torch.Size([]) tensor(1.7808, grad_fn=<AddBackward0>)


(torch.Size([8, 8]),
 tensor([[0.0000, 0.5588, 0.0000, 0.0000, 0.0000, 0.4412, 0.0000, 0.0000],
         [0.0000, 0.3876, 0.0000, 0.6124, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.3627, 0.0000, 0.6373, 0.0000, 0.0000],
         [0.4443, 0.0000, 0.0000, 0.0000, 0.0000, 0.5557, 0.0000, 0.0000],
         [0.5004, 0.0000, 0.0000, 0.0000, 0.0000, 0.4996, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.3144, 0.0000, 0.0000, 0.0000, 0.6856, 0.0000],
         [0.0000, 0.0000, 0.2767, 0.0000, 0.7233, 0.0000, 0.0000, 0.0000],
         [0.3451, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.6549, 0.0000]],
        grad_fn=<SoftmaxBackward0>),
 tensor([[1, 5],
         [3, 1],
         [5, 3],
         [5, 0],
         [0, 5],
         [6, 2],
         [4, 2],
         [6, 0]]))

## 引入专家容量 （Expert Capacity factor）

from: https://huggingface.co/blog/AviSoori1x/makemoe2



在预训练混合专家语言模型或任何大型语言模型时，该过程通常跨越多个GPU，并且通常涉及许多机器。跨这些硬件资源并行训练的方式对于平衡计算负载至关重要。然而，如果某些专家或一组专家过度受到偏爱——反映出对开发的偏好超过探索——它不仅可能导致模型中的性能问题，还可能导致集群中的计算负载不平衡。

[Switch Transformers: Scaling to Trillion Parameter Models with Simple and Efficient Sparsity](https://arxiv.org/abs/2101.03961) 实现使用专家容量来规避这个问题。专家容量确定每个专家在训练或推理过程中负责处理多少个标记。它是基于批次中的标记数和可用专家的数量定义的，通常通过容量因子进行调整。该因子允许在分配中灵活性，提供缓冲区以考虑数据分布的变化，并确保没有单个专家由于过载而成为瓶颈。在训练这些大型模型时，硬件故障是很常见的，可能持续数周甚至数月，因此这一点非常重要。

以下是专家容量通常计算的方式：

专家容量 = （每批标记数 / 专家数量）× 容量因子 其中：
```python
expert_capacity = int((tokens_per_batch / self.num_experts) * self.capacity_factor)
```

- 每批标记数`tokens_per_batch`是需要处理的批次中存在的总标记数。
- 专家数量`num_experts`是MoE层中可用于处理数据的专家总数。
- 容量因子`capacity_factor`是用于调整基础容量（每批标记数除以专家数量）的乘数。大于1的容量因子允许每个专家处理超出均匀分配份额的缓冲区，适应标记分配的不平衡。该值的一般范围为1-1.25。

以下代码块进行了轻微调整，以实现专家容量的简单版本。

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class SparseMoE(nn.Module):
    def __init__(self, n_embed, num_experts, top_k, capacity_factor=1.0):
        super(SparseMoE, self).__init__()
        self.router = NoisyTopkRouter(n_embed, num_experts, top_k)
        self.experts = nn.ModuleList([Expert(n_embed) for _ in range(num_experts)])
        self.top_k = top_k
        # add capacity_factor
        self.capacity_factor = capacity_factor
        self.num_experts = num_experts

    def forward(self, x):
    # Assuming x has shape [batch_size, seq_len, n_embd]
        batch_size, seq_len, _ = x.shape
        gating_output, indices = self.router(x)
        final_output = torch.zeros_like(x)

        # Flatten the batch and sequence dimensions to treat each token independently
        flat_x = x.view(-1, x.size(-1))  # Now shape [batch_size * seq_len, n_embd]
        flat_gating_output = gating_output.view(-1, gating_output.size(-1))

        tokens_per_batch = batch_size * seq_len * self.top_k
        expert_capacity = int((tokens_per_batch / self.num_experts) * self.capacity_factor)

        updates = torch.zeros_like(flat_x)

        for i, expert in enumerate(self.experts):
            expert_mask = (indices == i).any(dim=-1)
            flat_mask = expert_mask.view(-1)
            selected_indices = torch.nonzero(flat_mask).squeeze(-1)

            limited_indices = selected_indices[:expert_capacity] if selected_indices.numel() > expert_capacity else selected_indices
            if limited_indices.numel() > 0:
                expert_input = flat_x[limited_indices]
                expert_output = expert(expert_input)

                gating_scores = flat_gating_output[limited_indices, i].unsqueeze(1)
                weighted_output = expert_output * gating_scores

                updates.index_add_(0, limited_indices, weighted_output)

        # Reshape updates to match the original dimensions of x
        final_output += updates.view(batch_size, seq_len, -1)

        return final_output


In [None]:
import torch
import torch.nn as nn

#Let's test this out
num_experts = 8
top_k = 2
n_embd = 16
dropout=0.1

mh_output = torch.randn(4, 8, n_embd)  # Example multi-head attention output
sparse_moe = SparseMoE(n_embd, num_experts, top_k)
final_output = sparse_moe(mh_output)
print("Shape of the final output:", final_output.shape)
print(final_output)


Shape of the final output: torch.Size([4, 8, 16])
tensor([[[-1.3410e-01,  2.8907e-01, -1.8396e-01,  2.1652e-01,  3.1087e-01,
           1.5474e-01,  1.2759e-01,  2.1254e-01,  6.7634e-02, -1.8771e-01,
           2.2336e-02, -2.4315e-01, -2.3298e-01, -2.9945e-02,  9.7308e-02,
          -3.2861e-02],
         [-3.1417e-01,  1.1285e-01,  1.9766e-02,  5.2405e-02,  1.3877e-01,
           4.7633e-01,  1.2400e-01,  8.7759e-02,  1.1468e-01,  2.7183e-01,
           1.4489e-01, -2.7205e-02,  3.2547e-01, -2.2932e-01, -5.1435e-01,
           2.6505e-01],
         [-2.4992e-01,  3.7111e-02,  2.6726e-01,  1.9991e-01,  1.3635e-01,
           9.2536e-02,  2.1399e-01,  2.9728e-03, -1.0015e-01, -1.3600e-02,
          -1.4013e-01, -4.1275e-02,  1.7718e-01, -1.4034e-01,  2.3375e-02,
          -2.1945e-01],
         [-1.5340e-01, -6.3935e-02, -1.4276e-01, -5.8719e-02,  2.7454e-01,
          -5.0834e-02, -1.9131e-01,  1.7582e-01, -1.0842e-01, -1.0077e-01,
           3.6268e-01,  3.8537e-02, -1.8663e-02,  2.1

## SMoE+MultiHeadAttention

![](https://raw.githubusercontent.com/weedge/baby-llm/main/docs/moe-self-attention.jpg)




根据 router 返回的 top-k 门控值（gating values），计算批量级别的门控值(batch_gates)

In [17]:
batch_size = 2
block_size = 4
n_embd = 16 # hidden_size

# fake hidden states randn tensor B(batch_size),T(block_size),C(n_embd)
fake_hidden_states = torch.randn(batch_size, block_size, n_embd)
print(fake_hidden_states.shape)

#Causal scaled dot product self-Attention Head
n_head = 4
n_layer = 2
head_size:int = 4 # n_embd/n_head
dropout = 0.1

# Sparse Top-K gating router + Experts (SMoE)
num_experts = 8
top_k = 2


torch.Size([2, 4, 16])


In [None]:
@torch.jit.script
def compute_gating(k: int, num_experts: int, top_k_gates: torch.Tensor, top_k_indices: torch.Tensor):
    """
    Compute gating values for the mixture of experts based on probabilities and top-k indices.

    Args:
        k (int): Number of experts to select.
        num_experts (int): Total number of experts.
        top_k_gates (torch.Tensor): Gating values for top-k experts (batch_size x k).
        top_k_indices (torch.Tensor): Indices of top-k experts (batch_size x k).

    Returns:
        torch.Tensor: Batch-level gating values.
        torch.Tensor: Batch-level expert indices.
        torch.Tensor: Expert size for each expert.
        torch.Tensor: Sorted indices of top-k experts.
    """
    zeros = torch.zeros([top_k_gates.size(0), num_experts],
                        dtype=top_k_gates.dtype, device=top_k_gates.device)
    gates = zeros.scatter(-1, top_k_indices, 1)
    print(gates)
    # 计算每个专家被选择的次数，即每列中值为 1 的数量，得到专家大小（expert_size）。
    expert_size = gates.long().sum(0)
    print(expert_size)
    # 将顶部 k 个专家的门控值和索引展平为一维张量，并对专家索引进行排序。
    top_k_gates = top_k_gates.flatten()
    #print(top_k_gates)
    top_k_experts = top_k_indices.flatten()
    _, index_sorted_experts = top_k_experts.sort(0)

    # 根据专家索引的排序结果，确定每个样本所属的批次索引（batch_index）。
    # 将排序后的索引张量 index_sorted_experts 中的每个元素除以一个标量 k，
    # 并指定舍入模式为“截断”（truncation）。这意味着将索引除以 k 后取整数部分，舍去小数部分
    batch_index = index_sorted_experts.div(k, rounding_mode="trunc")
    # 提取排序后的专家门控值，得到批次级别的门控值（batch_gates）。
    batch_gates = top_k_gates[index_sorted_experts]

    return batch_gates, batch_index, expert_size, index_sorted_experts

In [None]:
# Test noisy top-k gate router compute_gating method

input=fake_hidden_states.reshape(-1, n_embd)# B*T, C
print(f"input.shape:{input.shape}")

noisy_top_k_gate = NoisyTopkRouter(n_embd, num_experts, top_k)
#noisy_top_k_gate.eval()
noisy_top_k_gate.training = True #default True

top_k_gates, top_k_indices = noisy_top_k_gate(input)
print(f"aux_loss:{noisy_top_k_gate.aux_loss}")
print(top_k_gates.shape, top_k_gates)
print(top_k_indices.shape, top_k_indices)


batch_gates, batch_index, expert_size, index_sorted_experts = compute_gating(
  top_k, num_experts, top_k_gates, top_k_indices
)
print(f"batch_gates:{batch_gates}")
print(f"batch_index:{batch_index}")
print(f"expert_size:{expert_size}")
print(f"index_sorted_experts:{index_sorted_experts}")
expert_size = expert_size.tolist()
print(f"expert_size:{expert_size}")


input.shape:torch.Size([8, 16])
16 8 2
mh_output.shape:torch.Size([8, 16])
logits.shape:torch.Size([8, 8])
aux_loss:1.65873122215271
torch.Size([8, 8]) tensor([[0.6291, 0.0000, 0.0000, 0.0000, 0.3709, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.5793, 0.0000, 0.4207],
        [0.0000, 0.0000, 0.0000, 0.4442, 0.0000, 0.0000, 0.5558, 0.0000],
        [0.0000, 0.0000, 0.4570, 0.5430, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.2763, 0.0000, 0.7237, 0.0000],
        [0.0000, 0.0000, 0.7778, 0.0000, 0.0000, 0.2222, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.6454, 0.0000, 0.3546, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.5659, 0.0000, 0.0000, 0.4341, 0.0000]],
       grad_fn=<SoftmaxBackward0>)
torch.Size([8, 2]) tensor([[0, 4],
        [5, 7],
        [6, 3],
        [3, 2],
        [6, 4],
        [2, 5],
        [4, 6],
        [3, 6]])
 1  0  0  0  1  0  0  0
 0  0  0  0  0  1  0  1
 0  0  0  1  0  0  1  0


使用并行专家处理，分为输入并行专家(专家数目num_experts)和输出并行专家(专家数目num_experts)，

初始化专家的参数权重W B,T,C（num_experts, output_size, input_size）； 并初始化为在 [-1/output_size, 1/output_size] 范围内的均匀分布的随机值

In [None]:
# init weight
assert n_head % top_k == 0
num_key_value_heads = int(n_head/top_k)
print(f"num_key_value_heads:{num_key_value_heads}")
kv_proj_size=num_key_value_heads*head_size
print(f"kv_proj_size:{kv_proj_size}")
weight = nn.Parameter(torch.empty(num_experts, kv_proj_size, n_embd))
print(f"weight.shape:{weight.shape}")
#print(weight)
nn.init.uniform_(weight, -1.0 / weight.size(1), 1.0 / weight.size(1))
print(weight)

num_key_value_heads:2
kv_proj_size:8
weight.shape:torch.Size([8, 8, 16])
Parameter containing:
tensor([[[ 0.1162, -0.0343,  0.0675,  ..., -0.0964,  0.0922,  0.0681],
         [ 0.0016, -0.0472, -0.0787,  ...,  0.0549, -0.0476, -0.0060],
         [-0.0668,  0.1190,  0.0771,  ..., -0.0331,  0.1091, -0.1133],
         ...,
         [-0.0979,  0.1120,  0.0739,  ..., -0.1052, -0.0721,  0.1052],
         [-0.0765, -0.0952,  0.0303,  ..., -0.1188, -0.1188,  0.0708],
         [ 0.0260,  0.0751, -0.1059,  ...,  0.1188,  0.0343, -0.0411]],

        [[-0.0151, -0.1161, -0.1086,  ...,  0.0219,  0.0250, -0.0181],
         [-0.0239, -0.1244,  0.0143,  ...,  0.0865, -0.0011,  0.1112],
         [ 0.0210, -0.0955, -0.0176,  ...,  0.1119, -0.1043,  0.0331],
         ...,
         [-0.1228, -0.1043,  0.0418,  ..., -0.0954, -0.0621,  0.1240],
         [-0.0826, -0.0012,  0.0303,  ...,  0.0408, -0.0891, -0.0392],
         [-0.0761, -0.0749,  0.1163,  ...,  0.0658,  0.0944, -0.0178]],

        [[ 0.0954, -0

In [None]:
expert_inputs = input[batch_index]
print(expert_inputs.shape)
print(expert_inputs)

print(f"expert_size:{expert_size}")
#在指定维度 dim=0 上将输入张量 expert_inputs 按照给定的尺寸 expert_size 进行分割，然后返回分割后的子张量列表
input_list = expert_inputs.split(expert_size, dim=0)# return tuple
print(input_list)

output_list = []
for i in range(num_experts):
    print(f"input_list[{i}].shape:{input_list[i].shape}")
    print(f"weight[{i}].shape:{weight[i].shape}")
    # B = A*W^T A:(N, in_features) W:(out_features, in_features) => B:(N,out_features)
    output=F.linear(input_list[i], weight[i])
    print(f"output.shape:{output.shape}")
    output_list.append(output)

# 张量的形状在除了连接维度 dim=0 外都是一致的，那么这个操作会将这些张量沿着第一个维度 进行连接，形成一个新的张量。
# 如果连接的张量在其他维度的大小不同，那么连接操作会失败
results = torch.cat(output_list, dim=0)
print(results.shape)
print(results)


torch.Size([16, 16])
tensor([[ 0.6275, -0.2494, -0.9234, -0.1392, -0.6819, -1.2787,  0.4516, -1.2711,
         -1.0155, -1.0718,  1.7033,  0.3927, -0.4351, -0.7506,  0.6580,  0.1792],
        [-1.0995,  0.4704,  0.3262, -0.4887, -0.5590,  1.0151, -1.4383,  1.2668,
         -0.5727, -0.5563, -0.0041,  0.6063, -0.3828, -0.6839,  0.1887, -2.0636],
        [-0.2277, -0.2768, -0.9280,  0.1068, -1.1972, -0.8575, -0.4312,  0.1373,
         -1.0293, -0.1143, -0.7472,  1.0169,  1.0057,  1.7110,  0.5539, -0.2593],
        [-2.7094,  1.2737, -0.4098,  0.8823, -1.4135,  1.4640,  0.1716, -0.9132,
         -0.5465,  0.8197,  0.2410,  1.5406, -1.2586, -0.4886, -0.7066,  0.4511],
        [-1.5670,  1.0781, -1.2979, -1.1760, -0.3268, -0.5293, -0.7467,  0.5153,
         -1.6209,  0.9633, -1.0558, -0.6734, -0.1359, -0.2297, -1.8250,  0.0945],
        [-0.2609,  0.5300, -0.3991,  1.2225, -1.1739, -0.7944,  2.2792, -0.0253,
          0.4196, -1.0681,  2.0502, -1.3337,  0.2271, -1.4689,  0.7892,  0.3918],
 

封装成ParallelExperts

In [None]:
class ParallelExperts(nn.Module):
    def __init__(self, num_experts, input_size, output_size) -> None:
        """
        Initialize the ParallelExperts module.
        like a Expert pool
        maybe manager diff export pool for feature to load :)

        Args:
            num_experts (int): Number of experts.
            input_size (int): Size of the input.
            output_size (int): Size of the output.
            bias (bool): Whether to include bias terms.
        """
        super().__init__()
        self.weight = nn.Parameter(torch.empty(
            num_experts, output_size, input_size))
        self.reset_parameters()
        self.num_experts = num_experts
        self.input_size = input_size
        self.output_size = output_size

    def extra_repr(self):
        return "num_experts={}, input_size={}, output_size={}".format(
            self.num_experts, self.input_size, self.output_size
        )

    def reset_parameters(self) -> None:
        """
        Reset the parameters of the model.
        """
        nn.init.uniform_(self.weight, -1.0 / self.weight.size(1),
                         1.0 / self.weight.size(1))

    def forward(self, inputs, expert_size):
        """
        Forward pass of the ParallelExperts module.

        Args:
            inputs (Tensor): Input tensor.
            expert_size: Expert size information.

        Returns:
            Tensor: Output tensor.
        """
        input_list = inputs.split(expert_size, dim=0)# return tuple
        output_list = []
        for i in range(self.num_experts):
            output_list.append(F.linear(input_list[i], self.weight[i]))
        results = torch.cat(output_list, dim=0)
        return results

In [None]:
assert n_head % top_k == 0
#num_key_value_heads = int(n_head/top_k)
print(f"num_key_value_heads:{num_key_value_heads}")
#kv_proj_size=num_key_value_heads*head_size
print(f"kv_proj_size:{kv_proj_size}")

parallel_experts = ParallelExperts(num_experts, input_size=n_embd, output_size=kv_proj_size)
#expert_inputs = input[batch_index]
#print(expert_inputs.shape)
#print(expert_inputs)

results = parallel_experts(expert_inputs, expert_size)
print(results.shape)
print(results)


num_key_value_heads:2
kv_proj_size:8
torch.Size([16, 8])
tensor([[ 5.7369e-02, -2.7929e-01, -6.2102e-01,  1.4588e-01,  1.5471e-01,
         -2.1183e-01, -6.1586e-01,  2.2009e-01],
        [ 3.6758e-02,  4.7301e-02,  2.0505e-01, -1.5489e-01, -9.4189e-02,
         -4.6538e-02,  2.6035e-02,  1.4479e-01],
        [-4.8403e-01, -1.3046e-01, -2.6872e-01, -2.0892e-01, -2.1421e-01,
         -1.2554e-04, -5.1007e-02,  4.3072e-01],
        [-1.0765e-01,  2.7371e-01,  5.0974e-02, -1.0747e-01, -4.5881e-01,
          1.7182e-01,  5.8526e-01, -1.9386e-01],
        [-2.7319e-01, -4.8780e-01,  1.0511e-01, -8.3513e-01, -4.7752e-01,
          1.5330e-01, -6.3043e-02,  2.7391e-01],
        [ 6.4076e-01,  3.2031e-01,  4.3063e-01,  1.6182e-01,  1.4516e-01,
         -2.2927e-01,  2.9131e-01, -7.2021e-03],
        [-3.1115e-01, -5.2544e-02, -9.6139e-02, -2.6856e-01,  3.3494e-01,
         -1.9269e-01, -1.9188e-01, -2.4091e-01],
        [-2.4421e-01,  7.2130e-02,  7.0269e-02,  1.5036e-02, -4.0587e-01,
        

输入 hidden states 对其 map操作 输出 Q

QKV 进行 SDPA 输出 对其 reduce操作 输出 attetion_output

整体封装成 SparseMoEMultiHeadAttention， 相关map/reduce 操作见代码注释:

In [None]:
class SparseMoEMultiHeadAttention(nn.Module):
    """ spare moe + multiple heads of self-attention in parallel """

    def __init__(self, num_heads, head_size, n_embed, block_size, dropout, num_experts=8, top_k=2, reduce_bias=True):
        super(SparseMoEMultiHeadAttention, self).__init__()

        # 偏置是可学习的参数，通常用于线性层（如全连接层）和卷积层中: a = Wx + Bias
        # 模型中引入偏置项，有助于模型更好地拟合训练数据和提高模型的表达能力
        # 在训练过程中，模型会通过梯度下降等优化算法自动学习到合适的偏置值，从而使模型的预测更准确。
        self.p_reduce_bias = None
        if reduce_bias:
            self.p_reduce_bias = torch.nn.Parameter(torch.empty(n_embed))
            torch.nn.init.zeros_(self.p_reduce_bias)

        self.n_embed = n_embed
        self.num_heads = num_heads
        self.head_size = head_size
        self.num_experts = num_experts
        self.top_k = min(top_k, self.num_experts)

        assert self.top_k > 0, f"topk must > 0"
        assert self.num_heads > 0, f"num_heads must > 0"
        assert num_heads % \
            self.top_k == 0, f"need num_heads:{num_heads}%top_k:{self.top_k} == 0"

        # num_heads = topk * num_key_val_heads
        # kv_proj_size = num_key_val_heads * head_size
        # num_heads * head_size = topk * kv_proj_size
        self.num_key_val_heads = int(num_heads/top_k)
        self.kv_proj_size = self.num_key_val_heads*head_size

        self.input_linear = ParallelExperts(
            num_experts, n_embed, self.kv_proj_size)
        self.output_linear = ParallelExperts(
            num_experts, self.kv_proj_size, n_embed)

        self.router = NoisyTopkRouter(n_embed, num_experts, self.top_k)

        self.k_proj = torch.nn.Linear(
            n_embed, self.kv_proj_size, bias=False)
        self.v_proj = torch.nn.Linear(
            n_embed, self.kv_proj_size, bias=False)

        self.register_buffer('tril', torch.tril(
            torch.ones(block_size, block_size)))

        self.dropout = nn.Dropout(dropout)

    def forward(self, x: torch.Tensor):
        # B:bsz, S:seq_len=block_size, C:feat_dim=n_embed
        bsz, seq_len, feat_dim = x.size()

        # H:num_heads, kvH:num_key_val_heads, D:head_size
        query_states = self.map(x)  # B S H*D
        key_states = self.k_proj(x)  # B S kvH*D
        value_states = self.v_proj(x)  # B S kvH*D

        query_states = query_states.view(
            bsz, seq_len, self.num_heads, self.head_size
        ).transpose(1, 2)  # B H S D
        key_states = key_states.view(
            bsz, seq_len, self.num_key_val_heads, self.head_size
        ).transpose(1, 2)  # B kvH S D
        value_states = value_states.view(
            bsz, seq_len, self.num_key_val_heads, self.head_size
        ).transpose(1, 2)  # B kvH S D

        # repeat k/v heads if num_key_val_heads < num_heads, it's true
        key_states = key_states.repeat(1, self.top_k, 1, 1)  # B H S D
        value_states = value_states.repeat(1, self.top_k, 1, 1)  # B H S D

        # (B H S D) @ (B H D S) * D**-0.5 -> (B H S S)
        attn_weights = query_states@key_states.transpose(2, 3) * self.head_size**-0.5
        #attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_size)

        # check attention weights shape
        if attn_weights.size() != (bsz, self.num_heads, seq_len, seq_len):
            raise ValueError(
                f"Attention weights should be of size {(bsz, self.num_heads, seq_len, seq_len)}, but is"
                f" {attn_weights.size()}"
            )

        # cuasle sequence masked fill with -inf
        attn_weights = attn_weights.masked_fill(
            self.tril[:seq_len, :seq_len] == 0, float('-inf'))  # (B H S S)

        # upcast attention to fp32
        attn_weights = F.softmax(
            attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)

        # dropout, if trainning loss have some overfit happen, open it
        attn_weights = self.dropout(attn_weights)
        # attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)

        # (B H S S) @ (B H S D) -> (B H S D)
        attn_output = attn_weights@value_states
        #attn_output = torch.matmul(attn_weights, value_states)

        # check attention output shape
        if attn_output.size() != (bsz, self.num_heads, seq_len, self.head_size):
            raise ValueError(
                f"`attn_output` should be of size {(bsz, self.num_heads, seq_len, self.head_size)}, but is"
                f" {attn_output.size()}"
            )

        # 内存连续的张量意味着张量的元素在内存中是按照其在张量中的顺序连续存储的，没有间隔
        # 调用一些需要连续张量作为输入的函数(reshape)时可能会引发错误。
        # 因此，在执行一些操作之前，需要确保张量是连续的
        attn_output = attn_output.transpose(1, 2).contiguous()  # B S H D
        # num_heads(H) * head_size(D) = topk * kv_proj_size
        attn_output = attn_output.reshape(
            bsz, seq_len, self.top_k, self.kv_proj_size)  # B S topk kv_proj_size

        attn_output = self.reduce(attn_output)
        attn_output = attn_output.view(bsz, seq_len, -1)

        attn_output = self.dropout(attn_output)

        return attn_output
        # return attn_output, attn_weights

    def map(self, x):
        # 解析输入张量的形状，获取批次大小（bsz）、序列长度（length）和输入特征维度（emb_size）。
        bsz, length, emb_size = x.size()
        # 将输入张量 x 重新整形为二维张量，形状为 (bsz * length, emb_size)，以便进行批次级别的处理。
        x = x.reshape(-1, emb_size)
        # 调用 compute_gate 方法计算门控损失。
        self.compute_gate(x)

        # 根据 batch_index 提取每个样本所属的专家输入，形状为 (num_experts, expert_size)。
        expert_inputs = x[self.batch_index]
        # 将专家输入传递给 input_linear 层，使用专家大小信息进行线性变换，得到专家输出。
        expert_outputs = self.input_linear(expert_inputs, self.expert_size)

        # 创建一个全零张量 zeros，形状为 (bsz * length * top_k, kv_proj_size)，数据类型和设备与 expert_outputs 相同。
        zeros = torch.zeros(
            (bsz * length * self.top_k, self.kv_proj_size), dtype=expert_outputs.dtype, device=expert_outputs.device
        )
        # 使用 index_add 方法将专家输出根据 index_sorted_experts 分散到全零张量 zeros 中，得到混合输出张量 y。
        y = zeros.index_add(0, self.index_sorted_experts, expert_outputs)
        # 将混合输出张量 y 重新整形为四维张量，形状为(bsz, length, top_k, kv_proj_size)。
        y = y.view(bsz, length, self.top_k, -1)
        return y

    def reduce(self, x: torch.Tensor):
        # 解析输入张量的形状，获取批次大小（bsz）、序列长度（length）、专家数量（k）和嵌入维度（emb_size=kv_proj_size）。
        bsz, length, k, emb_size = x.size()
        # 将输入张量 x 重新整形为二维张量，形状为 (bsz * length * k, emb_size)。
        x = x.reshape(-1, emb_size)

        # 根据 index_sorted_experts 提取每个样本所属的专家输入，形状为 (num_experts, expert_size)。
        expert_inputs = x[self.index_sorted_experts]
        # 将专家输入传递给 output_linear 层，使用专家大小信息进行线性变换，得到专家输出。
        expert_outputs = self.output_linear(expert_inputs, self.expert_size)

        # 将专家输出乘以对应的门控值。
        expert_outputs = expert_outputs * self.batch_gates[:, None]

        # 创建一个全零张量 zeros，形状为 (bsz * length, n_embed)，数据类型和设备与 expert_outputs 相同。
        zeros = torch.zeros((bsz * length, self.n_embed),
                            dtype=expert_outputs.dtype, device=expert_outputs.device)
        # 使用 index_add 方法将乘以门控值的专家输出张量根据 batch_index 分散到全零张量 zeros 中，得到降维后的输出张量 y。
        y = zeros.index_add(0, self.batch_index, expert_outputs)
        # 将降维后的输出张量 y 重新整形为三维张量，形状为 (bsz, length, n_embed)。
        y = y.view(bsz, length, self.n_embed)
        # 如果设置了偏置项，则将偏置项添加到输出张量 y 中。
        if self.p_reduce_bias is not None:
            y = y + self.p_reduce_bias
        return y

    def compute_gate(self, x):
        self.top_k_gates, top_k_indices = self.router(x)

        self.batch_gates, self.batch_index, expert_size, self.index_sorted_experts = compute_gating(
            self.top_k, self.num_experts, self.top_k_gates, top_k_indices
        )
        self.expert_size = expert_size.tolist()

In [None]:
sparse_moe_mha = SparseMoEMultiHeadAttention(n_head, head_size, n_embd, block_size, dropout, num_experts, top_k)
print(sparse_moe_mha)
sparse_moe_mha_final_output = sparse_moe_mha(fake_hidden_states)
print("sparse_moe_mha Shape of the final output:", sparse_moe_mha_final_output.shape)
print(sparse_moe_mha_final_output)


16 8 2
SparseMoEMultiHeadAttention(
  (input_linear): ParallelExperts(num_experts=8, input_size=16, output_size=8)
  (output_linear): ParallelExperts(num_experts=8, input_size=8, output_size=16)
  (router): NoisyTopkRouter(
    (topkroute_linear): Linear(in_features=16, out_features=8, bias=True)
    (noise_linear): Linear(in_features=16, out_features=8, bias=True)
  )
  (k_proj): Linear(in_features=16, out_features=8, bias=False)
  (v_proj): Linear(in_features=16, out_features=8, bias=False)
  (dropout): Dropout(p=0.1, inplace=False)
)
mh_output.shape:torch.Size([8, 16])
logits.shape:torch.Size([8, 8])
 0  0  0  0  0  0  1  1
 0  0  0  1  0  0  0  1
 0  0  1  1  0  0  0  0
 0  0  0  0  0  1  0  1
 0  1  0  0  1  0  0  0
 0  0  1  1  0  0  0  0
 0  0  1  1  0  0  0  0
 0  0  0  0  0  1  1  0
[ CPUFloatType{8,8} ]
 0
 1
 3
 4
 1
 2
 2
 3
[ CPULongType{8} ]
expert_outputs.shape:torch.Size([16, 8])
zeros.shape:torch.Size([16, 8])
index_sorted_experts.shape:torch.Size([16])
sparse_moe_mha 

## Putting it all together to train and generate

In [29]:
#First defining hyperparameters and boiler plate code. Imports and data preparation code is repeated for convenience
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.nn import init
import math

# hyperparameters
batch_size = 16 # how many independent sequences will we process in parallel?
block_size = 32 # what is the maximum context length for predictions?
max_iters = 5000
eval_interval = 100
learning_rate = 1e-3
device = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_iters = 400
head_size = 16
n_embed = 128
n_head = 8
n_layer = 8
dropout = 0.1
num_experts = 8
top_k = 2
aux_loss_coef=0.01
moe_self_attention=False
# ------------

torch.manual_seed(1337)

with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

# here are all the unique characters that occur in this text
chars = sorted(list(set(text)))
vocab_size = len(chars)
# create a mapping from characters to integers
stoi = { ch:i for i,ch in enumerate(chars) }
itos = { i:ch for i,ch in enumerate(chars) }
encode = lambda s: [stoi[c] for c in s] # encoder: take a string, output a list of integers
decode = lambda l: ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string

# Train and test splits
data = torch.tensor(encode(text), dtype=torch.long)
n = int(0.9*len(data)) # first 90% will be train, rest val
train_data = data[:n]
val_data = data[n:]

# data loading
def get_batch(split):
    # generate a small batch of data of inputs x and targets y
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    x, y = x.to(device), y.to(device)
    return x, y

@torch.no_grad()
def estimate_loss():
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split)
            logits, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

In [30]:
#@torch.jit.script
def compute_aux_loss(num_experts: int,
                     top_k_gates: torch.Tensor,
                     top_k_indices: torch.Tensor,
                     logits: torch.Tensor):
    """
    Calculate and return the auxiliary loss based on the accumulated statistics.
    switch transformers: https://arxiv.org/pdf/2101.03961.pdf
    A. Differentiable Load Balancing Loss

    Args:
        num_experts (int): The number of experts.
        top_k_gates (tensor): k个最大值的对应logits, 其每个元素表示对应logit概率值。
        top_k_indices (tensor): k个最大值的对应logits索引, 其每个元素表示logit对应索引值。
        logits (tensor): 其每个元素表示对应logit概率值。

    Returns:
        torch.Tensor: The calculated auxiliary loss.
    """
    # 对logits进行softmax操作，得到每个类别的概率分布
    probs = torch.softmax(logits, dim=-1)
    zeros = torch.zeros_like(probs)
    # Convert zeros to match top_k_gates dtype
    zeros = zeros.to(top_k_gates.dtype)
    gates = zeros.scatter(-1, top_k_indices, top_k_gates)

    # 获取 logits 张量的批次大小，即样本数量
    count = logits.size(0)
    # 计算每个专家被选中的概率之和，即将概率沿着批次维度求和。
    probs = probs.sum(0)
    # 计算每个专家被选中的频率，即计算门控值大于0的次数（即专家被选中的次数），
    # 然后将其沿着批次维度求和。
    freq = (gates > 0).float().sum(0)
    # 计算 logits 张量经过 softmax 处理后的平方和的对数。
    # 这里首先使用 softmax 函数将 logits 转换为概率分布，
    # 然后计算概率分布的每个样本的平方和，并取对数，最后将结果沿着批次维度求和。
    lsesq = (torch.log(torch.exp(logits).sum(dim=-1)) ** 2).sum()

    # 计算专家选择损失，其计算方式为对每个专家的概率和频率进行归一化，然后计算它们的点积，最后将结果乘以专家数量。
    switchloss = num_experts * \
        (F.normalize(probs, p=1, dim=0) * F.normalize(freq, p=1, dim=0)).sum()
    # 计算 z 损失，即 logits 的对数平方和除以样本数量
    zloss = lsesq / count
    # 将专家选择损失和 z 损失加权相加得到最终的辅助损失
    loss = switchloss + 0.1 * zloss

    return loss

In [31]:
@torch.jit.script
def compute_gating(k: int, num_experts: int, top_k_gates: torch.Tensor, top_k_indices: torch.Tensor):
    """
    Compute gating values for the mixture of experts based on probabilities and top-k indices.

    Args:
        k (int): Number of experts to select.
        num_experts (int): Total number of experts.
        top_k_gates (torch.Tensor): Gating values for top-k experts (batch_size x k).
        top_k_indices (torch.Tensor): Indices of top-k experts (batch_size x k).

    Returns:
        torch.Tensor: Batch-level gating values.
        torch.Tensor: Batch-level expert indices.
        torch.Tensor: Expert size for each expert.
        torch.Tensor: Sorted indices of top-k experts.
    """
    zeros = torch.zeros([top_k_gates.size(0), num_experts],
                        dtype=top_k_gates.dtype, device=top_k_gates.device)
    gates = zeros.scatter(-1, top_k_indices, 1)
    #print(gates)
    # 计算每个专家被选择的次数，即每列中值为 1 的数量，得到专家大小（expert_size）。
    expert_size = gates.long().sum(0)
    #print(expert_size)
    # 将顶部 k 个专家的门控值和索引展平为一维张量，并对专家索引进行排序。
    top_k_gates = top_k_gates.flatten()
    #print(top_k_gates)
    top_k_experts = top_k_indices.flatten()
    _, index_sorted_experts = top_k_experts.sort(0)

    # 根据专家索引的排序结果，确定每个样本所属的批次索引（batch_index）。
    # 将排序后的索引张量 index_sorted_experts 中的每个元素除以一个标量 k，
    # 并指定舍入模式为“截断”（truncation）。这意味着将索引除以 k 后取整数部分，舍去小数部分
    batch_index = index_sorted_experts.div(k, rounding_mode="trunc")
    # 提取排序后的专家门控值，得到批次级别的门控值（batch_gates）。
    batch_gates = top_k_gates[index_sorted_experts]

    return batch_gates, batch_index, expert_size, index_sorted_experts

In [32]:
class Head(nn.Module):
    """ one head of self-attention """

    def __init__(self, head_size):
        super().__init__()
        self.key = nn.Linear(n_embed, head_size, bias=False)
        self.query = nn.Linear(n_embed, head_size, bias=False)
        self.value = nn.Linear(n_embed, head_size, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))

        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        B,T,C = x.shape
        k = self.key(x)   # (B,T,C)
        q = self.query(x) # (B,T,C)
        # compute attention scores ("affinities")
        wei = q @ k.transpose(-2,-1) * C**-0.5 # (B, T, C) @ (B, C, T) -> (B, T, T)
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf')) # (B, T, T)
        wei = F.softmax(wei, dim=-1) # (B, T, T)
        wei = self.dropout(wei)
        # perform the weighted aggregation of the values
        v = self.value(x) # (B,T,C)
        out = wei @ v # (B, T, T) @ (B, T, C) -> (B, T, C)
        return out

#Multi-Headed Self Attention
class MultiHeadAttention(nn.Module):
    """ multiple heads of self-attention in parallel """

    def __init__(self, num_heads, head_size):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
        self.proj = nn.Linear(n_embed, n_embed)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim=-1)
        out = self.dropout(self.proj(out))
        return out


In [33]:
#Expert module
class Expert(nn.Module):
    """ An MLP is a simple linear layer followed by a non-linearity i.e. each Expert """

    def __init__(self, n_embed):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embed, 4 * n_embed),
            nn.ReLU(),
            nn.Linear(4 * n_embed, n_embed),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        return self.net(x)

#noisy top-k gating
class NoisyTopkRouter(nn.Module):
    def __init__(self, n_embed, num_experts, top_k):
        super(NoisyTopkRouter, self).__init__()
        self.num_experts = num_experts
        self.top_k = top_k
        #layer for router logits
        self.topkroute_linear = nn.Linear(n_embed, num_experts)
        self.noise_linear =nn.Linear(n_embed, num_experts)
        self.aux_loss = 0.0


    def forward(self, mh_output):
        # mh_ouput is the output tensor from multihead self attention block
        logits = self.topkroute_linear(mh_output)

        #Noise logits
        noise_logits = self.noise_linear(mh_output)

        #Adding scaled unit gaussian noise to the logits
        noise = torch.randn_like(logits)*F.softplus(noise_logits)
        noisy_logits = logits + noise

        top_k_logits, indices = noisy_logits.topk(self.top_k, dim=-1)
        zeros = torch.full_like(noisy_logits, float('-inf'))
        sparse_logits = zeros.scatter(-1, indices, top_k_logits)
        router_output = F.softmax(sparse_logits, dim=-1)
        # 训练时才计算辅助loss值, 为了专家之间的负载平衡
        if self.training:
            self.aux_loss = compute_aux_loss(self.num_experts, router_output,
                                             indices, noisy_logits)

        return router_output, indices

#Now create the sparse mixture of experts module
class SparseMoE(nn.Module):
    def __init__(self, n_embed, num_experts, top_k, capacity_factor=1.0):
        super(SparseMoE, self).__init__()
        self.router = NoisyTopkRouter(n_embed, num_experts, top_k)
        self.experts = nn.ModuleList([Expert(n_embed) for _ in range(num_experts)])
        self.top_k = top_k
        # add capacity_factor
        self.capacity_factor = capacity_factor
        self.num_experts = num_experts

    def forward(self, x):
    # Assuming x has shape [batch_size, seq_len, n_embd]
        batch_size, seq_len, _ = x.shape
        gating_output, indices = self.router(x)
        final_output = torch.zeros_like(x)

        # Flatten the batch and sequence dimensions to treat each token independently
        flat_x = x.view(-1, x.size(-1))  # Now shape [batch_size * seq_len, n_embd]
        flat_gating_output = gating_output.view(-1, gating_output.size(-1))

        tokens_per_batch = batch_size * seq_len * self.top_k
        expert_capacity = int((tokens_per_batch / self.num_experts) * self.capacity_factor)

        updates = torch.zeros_like(flat_x)

        for i, expert in enumerate(self.experts):
            expert_mask = (indices == i).any(dim=-1)
            flat_mask = expert_mask.view(-1)
            selected_indices = torch.nonzero(flat_mask).squeeze(-1)

            limited_indices = selected_indices[:expert_capacity] if selected_indices.numel() > expert_capacity else selected_indices
            if limited_indices.numel() > 0:
                expert_input = flat_x[limited_indices]
                expert_output = expert(expert_input)

                gating_scores = flat_gating_output[limited_indices, i].unsqueeze(1)
                weighted_output = expert_output * gating_scores

                updates.index_add_(0, limited_indices, weighted_output)

        # Reshape updates to match the original dimensions of x
        final_output += updates.view(batch_size, seq_len, -1)

        return final_output


In [34]:
class ParallelExperts(nn.Module):
    def __init__(self, num_experts, input_size, output_size) -> None:
        """
        Initialize the ParallelExperts module.
        like a Expert pool
        maybe manager diff export pool for feature to load :)

        Args:
            num_experts (int): Number of experts.
            input_size (int): Size of the input.
            output_size (int): Size of the output.
            bias (bool): Whether to include bias terms.
        """
        super().__init__()
        self.weight = nn.Parameter(torch.empty(
            num_experts, output_size, input_size))
        self.reset_parameters()
        self.num_experts = num_experts
        self.input_size = input_size
        self.output_size = output_size

    def extra_repr(self):
        return "num_experts={}, input_size={}, output_size={}".format(
            self.num_experts, self.input_size, self.output_size
        )

    def reset_parameters(self) -> None:
        """
        Reset the parameters of the model.
        """
        nn.init.uniform_(self.weight, -1.0 / self.weight.size(1),
                         1.0 / self.weight.size(1))

    def forward(self, inputs, expert_size):
        """
        Forward pass of the ParallelExperts module.

        Args:
            inputs (Tensor): Input tensor.
            expert_size: Expert size information.

        Returns:
            Tensor: Output tensor.
        """
        input_list = inputs.split(expert_size, dim=0)# return tuple
        output_list = []
        for i in range(self.num_experts):
            output_list.append(F.linear(input_list[i], self.weight[i]))
        results = torch.cat(output_list, dim=0)
        return results

In [43]:
class SparseMoEMultiHeadAttention(nn.Module):
    """ spare moe + multiple heads of self-attention in parallel """

    def __init__(self, num_heads, head_size, n_embed, block_size, dropout, num_experts=8, top_k=2, reduce_bias=True):
        super(SparseMoEMultiHeadAttention, self).__init__()

        # 偏置是可学习的参数，通常用于线性层（如全连接层）和卷积层中: a = Wx + Bias
        # 模型中引入偏置项，有助于模型更好地拟合训练数据和提高模型的表达能力
        # 在训练过程中，模型会通过梯度下降等优化算法自动学习到合适的偏置值，从而使模型的预测更准确。
        self.p_reduce_bias = None
        if reduce_bias:
            self.p_reduce_bias = torch.nn.Parameter(torch.empty(n_embed))
            torch.nn.init.zeros_(self.p_reduce_bias)

        self.n_embed = n_embed
        self.num_heads = num_heads
        self.head_size = head_size
        self.num_experts = num_experts
        self.top_k = min(top_k, self.num_experts)

        assert self.top_k > 0, f"topk must > 0"
        assert self.num_heads > 0, f"num_heads must > 0"
        assert num_heads % \
            self.top_k == 0, f"need num_heads:{num_heads}%top_k:{self.top_k} == 0"

        # num_heads = topk * num_key_val_heads
        # kv_proj_size = num_key_val_heads * head_size
        # num_heads * head_size = topk * kv_proj_size
        self.num_key_val_heads = int(num_heads/top_k)
        self.kv_proj_size = self.num_key_val_heads*head_size

        self.input_linear = ParallelExperts(
            num_experts, n_embed, self.kv_proj_size)
        self.output_linear = ParallelExperts(
            num_experts, self.kv_proj_size, n_embed)

        self.router = NoisyTopkRouter(n_embed, num_experts, self.top_k)

        self.k_proj = torch.nn.Linear(
            n_embed, self.kv_proj_size, bias=False)
        self.v_proj = torch.nn.Linear(
            n_embed, self.kv_proj_size, bias=False)

        self.register_buffer('tril', torch.tril(
            torch.ones(block_size, block_size)))

        self.dropout = nn.Dropout(dropout)

    def forward(self, x: torch.Tensor):
        # B:bsz, S:seq_len=block_size, C:feat_dim=n_embed
        bsz, seq_len, feat_dim = x.size()

        # H:num_heads, kvH:num_key_val_heads, D:head_size
        query_states = self.map(x)  # B S H*D
        key_states = self.k_proj(x)  # B S kvH*D
        value_states = self.v_proj(x)  # B S kvH*D

        query_states = query_states.view(
            bsz, seq_len, self.num_heads, self.head_size
        ).transpose(1, 2)  # B H S D
        key_states = key_states.view(
            bsz, seq_len, self.num_key_val_heads, self.head_size
        ).transpose(1, 2)  # B kvH S D
        value_states = value_states.view(
            bsz, seq_len, self.num_key_val_heads, self.head_size
        ).transpose(1, 2)  # B kvH S D

        # repeat k/v heads if num_key_val_heads < num_heads, it's true
        key_states = key_states.repeat(1, self.top_k, 1, 1)  # B H S D
        value_states = value_states.repeat(1, self.top_k, 1, 1)  # B H S D

        # (B H S D) @ (B H D S) * D**-0.5 -> (B H S S)
        attn_weights = query_states@key_states.transpose(2, 3) * self.head_size**-0.5
        #attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_size)

        # check attention weights shape
        if attn_weights.size() != (bsz, self.num_heads, seq_len, seq_len):
            raise ValueError(
                f"Attention weights should be of size {(bsz, self.num_heads, seq_len, seq_len)}, but is"
                f" {attn_weights.size()}"
            )

        # cuasle sequence masked fill with -inf
        attn_weights = attn_weights.masked_fill(
            self.tril[:seq_len, :seq_len] == 0, float('-inf'))  # (B H S S)

        # upcast attention to fp32
        attn_weights = F.softmax(
            attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)

        # dropout, if trainning loss have some overfit happen, open it
        attn_weights = self.dropout(attn_weights)
        # attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)

        # (B H S S) @ (B H S D) -> (B H S D)
        attn_output = attn_weights@value_states
        #attn_output = torch.matmul(attn_weights, value_states)

        # check attention output shape
        if attn_output.size() != (bsz, self.num_heads, seq_len, self.head_size):
            raise ValueError(
                f"`attn_output` should be of size {(bsz, self.num_heads, seq_len, self.head_size)}, but is"
                f" {attn_output.size()}"
            )

        # 内存连续的张量意味着张量的元素在内存中是按照其在张量中的顺序连续存储的，没有间隔
        # 调用一些需要连续张量作为输入的函数(reshape)时可能会引发错误。
        # 因此，在执行一些操作之前，需要确保张量是连续的
        attn_output = attn_output.transpose(1, 2).contiguous()  # B S H D
        # num_heads(H) * head_size(D) = topk * kv_proj_size
        attn_output = attn_output.reshape(
            bsz, seq_len, self.top_k, self.kv_proj_size)  # B S topk kv_proj_size

        attn_output = self.reduce(attn_output)
        attn_output = attn_output.view(bsz, seq_len, -1)

        attn_output = self.dropout(attn_output)

        return attn_output
        # return attn_output, attn_weights

    def map(self, x):
        # 解析输入张量的形状，获取批次大小（bsz）、序列长度（length）和输入特征维度（emb_size）。
        bsz, length, emb_size = x.size()
        # 将输入张量 x 重新整形为二维张量，形状为 (bsz * length, emb_size)，以便进行批次级别的处理。
        x = x.reshape(-1, emb_size)
        # 调用 compute_gate 方法计算门控损失。
        self.compute_gate(x)

        # 根据 batch_index 提取每个样本所属的专家输入，形状为 (num_experts, expert_size)。
        expert_inputs = x[self.batch_index]
        # 将专家输入传递给 input_linear 层，使用专家大小信息进行线性变换，得到专家输出。
        expert_outputs = self.input_linear(expert_inputs, self.expert_size)

        # 创建一个全零张量 zeros，形状为 (bsz * length * top_k, kv_proj_size)，数据类型和设备与 expert_outputs 相同。
        zeros = torch.zeros(
            (bsz * length * self.top_k, self.kv_proj_size), dtype=expert_outputs.dtype, device=expert_outputs.device
        )
        # 使用 index_add 方法将专家输出根据 index_sorted_experts 分散到全零张量 zeros 中，得到混合输出张量 y。
        y = zeros.index_add(0, self.index_sorted_experts, expert_outputs)
        # 将混合输出张量 y 重新整形为四维张量，形状为(bsz, length, top_k, kv_proj_size)。
        y = y.view(bsz, length, self.top_k, -1)
        return y

    def reduce(self, x: torch.Tensor):
        # 解析输入张量的形状，获取批次大小（bsz）、序列长度（length）、专家数量（k）和嵌入维度（emb_size=kv_proj_size）。
        bsz, length, k, emb_size = x.size()
        # 将输入张量 x 重新整形为二维张量，形状为 (bsz * length * k, emb_size)。
        x = x.reshape(-1, emb_size)

        # 根据 index_sorted_experts 提取每个样本所属的专家输入，形状为 (num_experts, expert_size)。
        expert_inputs = x[self.index_sorted_experts]
        # 将专家输入传递给 output_linear 层，使用专家大小信息进行线性变换，得到专家输出。
        expert_outputs = self.output_linear(expert_inputs, self.expert_size)

        # 将专家输出乘以对应的门控值。
        expert_outputs = expert_outputs * self.batch_gates[:, None]

        # 创建一个全零张量 zeros，形状为 (bsz * length, n_embed)，数据类型和设备与 expert_outputs 相同。
        zeros = torch.zeros((bsz * length, self.n_embed),
                            dtype=expert_outputs.dtype, device=expert_outputs.device)
        # 使用 index_add 方法将乘以门控值的专家输出张量根据 batch_index 分散到全零张量 zeros 中，得到降维后的输出张量 y。
        y = zeros.index_add(0, self.batch_index, expert_outputs)
        # 将降维后的输出张量 y 重新整形为三维张量，形状为 (bsz, length, n_embed)。
        y = y.view(bsz, length, self.n_embed)
        # 如果设置了偏置项，则将偏置项添加到输出张量 y 中。
        if self.p_reduce_bias is not None:
            y = y + self.p_reduce_bias
        return y

    def compute_gate(self, x):
        self.top_k_gates, top_k_indices = self.router(x)

        self.batch_gates, self.batch_index, expert_size, self.index_sorted_experts = compute_gating(
            self.top_k, self.num_experts, self.top_k_gates, top_k_indices
        )
        self.expert_size = expert_size.tolist()

In [44]:
#First create a self attention + mixture of experts block, that may be repeated several number of times
#Copy pasting key architecture variables for clarity

class Block(nn.Module):
    """ Mixture of Experts Transformer block: communication followed by computation (multi-head self attention + SparseMoE) """

    def __init__(self, n_embed, n_head, num_experts, top_k):
        # n_embed: embedding dimension, n_head: the number of heads we'd like
        super().__init__()
        head_size = n_embed // n_head
        self.sa = MultiHeadAttention(n_head, head_size)
        if moe_self_attention:
            # moe_self_attention, block_size, dropout is global var
            self.sa = SparseMoEMultiHeadAttention(n_head, head_size, n_embed, block_size, dropout, num_experts, top_k)

        self.smoe = SparseMoE(n_embed, num_experts, top_k)
        self.ln1 = nn.LayerNorm(n_embed)
        self.ln2 = nn.LayerNorm(n_embed)

    def forward(self, x):
        x = x + self.sa(self.ln1(x))
        x = x + self.smoe(self.ln2(x))
        return x

In [45]:
#Finally putting it all together to crease a sparse mixture of experts language model
class SparseMoELanguageModel(nn.Module):

    def __init__(self):
        super().__init__()
        # each token directly reads off the logits for the next token from a lookup table
        self.token_embedding_table = nn.Embedding(vocab_size, n_embed)
        self.position_embedding_table = nn.Embedding(block_size, n_embed)
        self.blocks = nn.ModuleList([Block(n_embed, n_head=n_head, num_experts=num_experts,top_k=top_k) for _ in range(n_layer)])
        self.ln_f = nn.LayerNorm(n_embed) # final layer norm
        self.lm_head = nn.Linear(n_embed, vocab_size)

    def forward(self, idx, targets=None):
        B, T = idx.shape

        # idx and targets are both (B,T) tensor of integers
        tok_emb = self.token_embedding_table(idx) # (B,T,C)
        pos_emb = self.position_embedding_table(torch.arange(T, device=device)) # (T,C)
        x = tok_emb + pos_emb # (B,T,C)
        # x = self.blocks(x)  # (B,T,C)
        aux_loss = 0.0
        for block in self.blocks:
            x = block(x)
            if self.training:
              #print(block.smoe.router.aux_loss)
              aux_loss += block.smoe.router.aux_loss
              if moe_self_attention:
                aux_loss += block.sa.router.aux_loss


        x = self.ln_f(x) # (B,T,C)
        logits = self.lm_head(x) # (B,T,vocab_size)

        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)

        if targets is not None and self.training:
            loss += aux_loss_coef*aux_loss.to(loss.device)

        return logits, loss

    def generate(self, idx, max_new_tokens):
        # idx is (B, T) array of indices in the current context
        for _ in range(max_new_tokens):
            # crop idx to the last block_size tokens
            idx_cond = idx[:, -block_size:]
            # get the predictions
            logits, loss = self(idx_cond)
            # focus only on the last time step
            logits = logits[:, -1, :] # becomes (B, C)
            # apply softmax to get probabilities
            probs = F.softmax(logits, dim=-1) # (B, C)
            # sample from the distribution
            idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
            # append sampled index to the running sequence
            idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
        return idx

这里使用Kaiming He初始化，因为专家中存在ReLU激活函数。可以随意尝试使用更常用于Transformer的Glorot初始化。Jeremy Howard的Fastai第2部分有一个非常出色的讲座，从零开始实现了这些初始化方法：https://course.fast.ai/Lessons/lesson17.html

In [46]:
def kaiming_init_weights(m):
    if isinstance (m, (nn.Linear)):
        init.kaiming_normal_(m.weight)

In [39]:
moe_self_attention=False
model = SparseMoELanguageModel()
model.apply(kaiming_init_weights)


SparseMoELanguageModel(
  (token_embedding_table): Embedding(65, 128)
  (position_embedding_table): Embedding(32, 128)
  (blocks): ModuleList(
    (0-7): 8 x Block(
      (sa): MultiHeadAttention(
        (heads): ModuleList(
          (0-7): 8 x Head(
            (key): Linear(in_features=128, out_features=16, bias=False)
            (query): Linear(in_features=128, out_features=16, bias=False)
            (value): Linear(in_features=128, out_features=16, bias=False)
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
        (proj): Linear(in_features=128, out_features=128, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (smoe): SparseMoE(
        (router): NoisyTopkRouter(
          (topkroute_linear): Linear(in_features=128, out_features=8, bias=True)
          (noise_linear): Linear(in_features=128, out_features=8, bias=True)
        )
        (experts): ModuleList(
          (0-7): 8 x Expert(
            (net): Sequential(
       

In [40]:
m = model.to(device)
# print the number of parameters in the model
print(sum(p.numel() for p in m.parameters())/1e6, 'M parameters')


8.996545 M parameters


In [None]:
m = model.to(device)
# print the number of parameters in the model
print(sum(p.numel() for p in m.parameters())/1e6, 'M parameters')

# create a PyTorch optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

model.train()
for iter in range(max_iters):

    # every once in a while evaluate the loss on train and val sets
    if (iter+1) % eval_interval == 0 or iter == max_iters - 1:
        losses = estimate_loss()
        print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")

    # sample a batch of data
    xb, yb = get_batch('train')

    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

8.996545 M parameters
step 99: train loss 2.8684, val loss 2.8967
step 199: train loss 2.6146, val loss 2.6185
step 299: train loss 2.4957, val loss 2.5016
step 399: train loss 2.4058, val loss 2.4177
step 499: train loss 2.3337, val loss 2.3613
step 599: train loss 2.2537, val loss 2.2832
step 699: train loss 2.2103, val loss 2.2446
step 799: train loss 2.1495, val loss 2.1818
step 899: train loss 2.1208, val loss 2.1598
step 999: train loss 2.0753, val loss 2.1309
step 1099: train loss 2.0560, val loss 2.1215
step 1199: train loss 2.0330, val loss 2.1079
step 1299: train loss 1.9918, val loss 2.0804
step 1399: train loss 1.9599, val loss 2.0594
step 1499: train loss 1.9431, val loss 2.0496
step 1599: train loss 1.9134, val loss 2.0195
step 1699: train loss 1.8929, val loss 2.0026
step 1799: train loss 1.8740, val loss 1.9935
step 1899: train loss 1.8562, val loss 1.9703
step 1999: train loss 1.8378, val loss 1.9672
step 2099: train loss 1.8274, val loss 1.9422
step 2199: train loss 1

In [47]:
moe_self_attention=True
model = SparseMoELanguageModel()
model.apply(kaiming_init_weights)


SparseMoELanguageModel(
  (token_embedding_table): Embedding(65, 128)
  (position_embedding_table): Embedding(32, 128)
  (blocks): ModuleList(
    (0-7): 8 x Block(
      (sa): SparseMoEMultiHeadAttention(
        (input_linear): ParallelExperts(num_experts=8, input_size=128, output_size=64)
        (output_linear): ParallelExperts(num_experts=8, input_size=64, output_size=128)
        (router): NoisyTopkRouter(
          (topkroute_linear): Linear(in_features=128, out_features=8, bias=True)
          (noise_linear): Linear(in_features=128, out_features=8, bias=True)
        )
        (k_proj): Linear(in_features=128, out_features=64, bias=False)
        (v_proj): Linear(in_features=128, out_features=64, bias=False)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (smoe): SparseMoE(
        (router): NoisyTopkRouter(
          (topkroute_linear): Linear(in_features=128, out_features=8, bias=True)
          (noise_linear): Linear(in_features=128, out_features=8, bias=True)

In [48]:
m = model.to(device)
# print the number of parameters in the model
print(sum(p.numel() for p in m.parameters())/1e6, 'M parameters')


9.668417 M parameters


In [49]:
m = model.to(device)
# print the number of parameters in the model
print(sum(p.numel() for p in m.parameters())/1e6, 'M parameters')

# create a PyTorch optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

model.train()
for iter in range(max_iters):

    # every once in a while evaluate the loss on train and val sets
    if (iter+1) % eval_interval == 0 or iter == max_iters - 1:
        losses = estimate_loss()
        print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")

    # sample a batch of data
    xb, yb = get_batch('train')

    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

9.668417 M parameters
step 99: train loss 2.7353, val loss 2.7500
step 199: train loss 2.6042, val loss 2.5989
step 299: train loss 2.5329, val loss 2.5403
step 399: train loss 2.5097, val loss 2.5125
step 499: train loss 2.4708, val loss 2.4865
step 599: train loss 2.4513, val loss 2.4713
step 699: train loss 2.4285, val loss 2.4464
step 799: train loss 2.3995, val loss 2.4206
step 899: train loss 2.4006, val loss 2.4150
step 999: train loss 2.3821, val loss 2.4041
step 1099: train loss 2.3690, val loss 2.3878
step 1199: train loss 2.3629, val loss 2.3894
step 1299: train loss 2.3496, val loss 2.3668
step 1399: train loss 2.3359, val loss 2.3636
step 1499: train loss 2.3296, val loss 2.3492
step 1599: train loss 2.3085, val loss 2.3423
step 1699: train loss 2.2993, val loss 2.3355
step 1799: train loss 2.2820, val loss 2.3084
step 1899: train loss 2.2791, val loss 2.3135
step 1999: train loss 2.2668, val loss 2.2964
step 2099: train loss 2.2583, val loss 2.2986
step 2199: train loss 2

In [50]:
# generate from the model. Not great. Not too bad either
context = torch.zeros((1, 1), dtype=torch.long, device=device)
print(decode(m.generate(context, max_new_tokens=2000)[0].tolist()))


oi ts if, hart oke be sytervesta vaken,
This and furson ho mas, this Rall OraPpatoo he with bady,
Ahicqechle our thing live;O Vhat.

TYCANILLARY CIZAR:
Whree, not werrad Jitlice,ersuns:
Toothen st that but gyouities anle
pisid, acteryf tor gatse aregs the rock
Luw fittes.
And, I manind my fairsties li's:'ht baid,-
Ws astiour tree anwt:ere ston ash your hease nevers;
An fivur aight:
Athar, soad my case sheak not this joet you hanh sted hach,ry bcope want.

QULIET:
This as timoth amploil aflest you:
Withe my sure is comiat by, and this dalet,
AygWard geast!uit,
Merebae with cuse was beht!
Do pervorge; as purslesan plie of tplaid, and thabboy,
Nuram as you yo alard I soun sharld
Amramy woul frawe to of thon woBothat balmaniss in ut re'd ou.

LUUUEENID By here BAWTBY:
Dais nobly themonestrad you vid gall baddoa gatt ere midins to soim an ill cash! Here on your oraid on sustilt? INC'
And those!y my amance, grey hario thyow orm, of 'twer
Waok'WAh agail aich wYCENTER:
TaWelll now
Ur it these

## 工程优化考虑

以上是简单实现梳理训练过程。

如果考虑硬件资源训练成本，以及加速训练和推理过程，需要对模型中的tensor操作进行优化， 比如：

1. 稀疏tensor 尽量降维，变成紧密tensor, 然后进行+、*等算子操作，减低内存空间，和计算成本，利用SIMD指令集提高计算效率；

2. 计算尽量批量处理，并且尽量并行化处理。
