In [2]:
import torch
import torch.nn as nn
import math

In [3]:
#x.shape = (batch_size, N, dmodel)

class MHselfattention(nn.Module):
    def __init__(self, dk, dv, num_heads, dmodel):
        super(MHselfattention, self).__init__()
        self.wq = nn.Parameter(torch.randn((num_heads, dmodel, dk)), requires_grad=True)
        self.wk = nn.Parameter(torch.randn((num_heads, dmodel, dk)), requires_grad=True)
        self.wv = nn.Parameter(torch.randn((num_heads, dmodel, dv)), requires_grad=True)
        self.output_linear = nn.Linear(num_heads * dv, dmodel)
        self.scale = math.sqrt(dk)
        self.num_heads = num_heads
        
    def forward(self, x):
        batch_size, seq_len, dmodel = x.size()
        
        q_splits = torch.stack([torch.matmul(x, self.wq[i]) for i in range(self.num_heads)], dim=0)# wq[i]形状为(dmodel, dk),因此matmul(x, wq)形状为(batch_size, N, dk)
        k_splits = torch.stack([torch.matmul(x, self.wk[i]) for i in range(self.num_heads)], dim=0)# 再在第一维度stack起来，因此最后的形状是(num_heads, batch_size, N, dk)
        v_splits = torch.stack([torch.matmul(x, self.wv[i]) for i in range(self.num_heads)], dim=0)

        outputs = []
        
        for i in range(self.num_heads):
            q = q_splits[i] # 每个q的shape是(batch_size, N, dk)
            k = k_splits[i]# 每个k的shape是(batch_size, N, dk)
            v = v_splits[i]# 每个v的shape是(batch_size, N, dv)

            k_trans = k.transpose(-2, -1) # 变成(batch_size, dk, N)
            scores = torch.matmul(q, k_trans) / self.scale # shape是(batch_size, N, N)
            weights = torch.softmax(scores, dim=1) # shape是(batch_size, N, N)
            output = torch.matmul(weights, v) # shape是(batch_size, N, dv)
            outputs.append(output) # outputs在经历num_heads个循环之后，本质上是一列长度为num_heads的向量列表，其中的每个元素是output，大小为(batch_size, N, dv)
            
        concat_outputs = torch.cat(outputs, dim=-1) # 按最后一个维度cat起outputs，因此shape=(batch_size, N, num_heads*dv)
        final_output = self.output_linear(concat_outputs) # 定义了nn.Linear(num_heads * dv, dmodel)，因此(batch_size, N, num_heads*dv)的最后一维变成了dmodel
        
        return final_output # 输出，即shape为(batch_size, N, dmodel)

In [6]:
#测试

def test_mhselfattention():
    # 定义超参数
    batch_size = 10
    N = 6
    dmodel = 512
    dk = 512
    dv = 512
    num_heads = 8
    
    # 创建输入张量
    x = torch.randn(batch_size, N, dmodel)
    
    # 初始化模型
    model = MHselfattention(dk=dk, dv=dv, num_heads=num_heads, dmodel=dmodel)
    
    # 前向传播
    output = model(x)
    
    # 检查输出形状
    expected_shape = (batch_size, N, dmodel)
    
    assert output.shape == expected_shape, f"Expected output shape {expected_shape}, but got {output.shape}"
    
    print("Test passed!")

# 运行测试
test_mhselfattention()


Test passed!


In [8]:
# 训练
#定义超参
batch_size = 10
N = 6
dmodel = 512
dk = 512
dv = 512
num_heads = 8
num_epoths = 100

#定义模型、损失函数、优化器
model = MHselfattention(dk=dk, dv=dv, num_heads=num_heads, dmodel=dmodel)
loss_function = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

#定义数据
x = torch.randn(batch_size, N, dmodel)
y_true = torch.randn(batch_size, N, dmodel)

def train(x, y_true, num_epochs, model, optimizer, loss_function):
    time = 0
    for epoch in range(num_epochs):
        model.train()  # 设定模型为训练模式
        optimizer.zero_grad()  # 清除之前的梯度
        output = model(x) # 前向传播
        # 确保输出的维度与 y_true 匹配
        if output.shape != y_true.shape:
            raise ValueError(f"Output shape {output.shape} does not match target shape {y_true.shape}")

        loss = loss_function(output, y_true)# 计算损失
        loss.backward() # 反向传播
        optimizer.step() # 更新参数
        time += 1 # 增加时间计数
        # 打印损失
        if (epoch + 1) % 10 == 0:  # 每 10 个 epoch 打印一次损失
            print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item():.4f}')
    
    return loss, time

final_loss, final_time = train(x, y_true, num_epoths, model=model, optimizer=optimizer, loss_function=loss_function)

# 打印最终结果
print(f'Final Loss: {final_loss.item():.4f}')
print(f'Total Time (steps): {final_time}')

            

Epoch [10/100], Loss: 19.1212
Epoch [20/100], Loss: 8.1841
Epoch [30/100], Loss: 1.9286
Epoch [40/100], Loss: 0.9113
Epoch [50/100], Loss: 0.3228
Epoch [60/100], Loss: 0.1145
Epoch [70/100], Loss: 0.0410
Epoch [80/100], Loss: 0.0151
Epoch [90/100], Loss: 0.0055
Epoch [100/100], Loss: 0.0020
Final Loss: 0.0020
Total Time (steps): 100
