# 手撕GPT2 loss

In [1]:
import torch
import math
import torch.nn.functional as F

## 构造数据
比如X(1,6,512), Y(1,6)

In [2]:
B=1 # batch
T=6 # length
D=512 # dimension
vocab_len = 32000

X=torch.randn(B, T, D)
Y=torch.randint(low=0, high=vocab_len, size=(B, T), dtype=torch.long)
print(f'X.shape:{X.shape}')
print(f'Y.shape:{Y.shape}')
print(f'Y[:10]:{Y[:10]}')

X.shape:torch.Size([1, 6, 512])
Y.shape:torch.Size([1, 6])
Y[:10]:tensor([[ 8230, 28369,  4677, 26946, 16820, 15115]])


## Attention

In [3]:
w_q = torch.randn(D,D)
w_k = torch.randn(D,D)
w_v = torch.randn(D,D)
w_o = torch.randn(D,D)

# mask是作用在序列T上的
mask = torch.tril(torch.ones(T,T))
print(mask)

# 同w_q = nn.linear(D, D), w_q(X)
Q,K,V = X @ w_q, X @ w_k, X @ w_v
# 除根号d，因为dk增大，点积累计会很大，softmax计算的分布会很不均匀，接近于0-1分布，导致softmax数值不稳定
# 带来梯度消失问题，增加根号dk，好处是softmax输出更平滑了，在不同维度dk下都能保证较好的训练效果
scores = Q @ K.transpose(1,2) / math.sqrt(D)
scores = scores.masked_fill(mask==0, float('-inf'))
weight = F.softmax(scores, dim=2)
attn = weight @ V
attn = attn @ w_o
attn.shape

tensor([[1., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1.]])


torch.Size([1, 6, 512])

## mlp

In [4]:
mlp_up = torch.randn(D, 1024)
mlp_down = torch.randn(1024, D)

mlp = attn @ mlp_up @ mlp_down
mlp.shape

torch.Size([1, 6, 512])

## output

In [5]:
# 获取next token在词表上的概率
lm_head = torch.randn(D, vocab_len)
logits = mlp @ lm_head
logits.shape

torch.Size([1, 6, 32000])

## Loss

In [7]:
# 将logits的值转换到0-1之间
probs = F.softmax(logits, dim=-1)
# 6个token 对应 6个logits[32000]
print(f'probs.shape:{probs.shape}')
print(f'Y:{Y}')

# Loss 函数，交叉熵
# 以下两种计算loss的结构都正确，cross_entropy输入input(N,C), output(N), N为样本量，C为分类数量
loss = F.cross_entropy(probs.view(-1, probs.size(-1)), Y.view(-1))
loss = F.cross_entropy(probs.transpose(1, 2), Y)
print(f'loss:{loss}')

# 找到概率最大的token
pred = torch.argmax(probs, dim=-1)
# 训练时，每个tokne都对预测next token
print(f"预测的next token:{pred}")
print(pred.shape)
print(pred[:,-1])

probs.shape:torch.Size([1, 6, 32000])
Y:tensor([[ 8230, 28369,  4677, 26946, 16820, 15115]])
loss:10.373614311218262
预测的next token:tensor([[21070, 24383, 26998, 24149, 26998, 22766]])
torch.Size([1, 6])
tensor([22766])
