In [2]:
import torch
import torch.nn as nn
import math
import numpy as np

In [3]:
#定义self attention类，涉及参数dk，dv，dmodel
class selfattention(nn.Module):
    def __init__(self, dk, dv, dmodel):
        super(selfattention, self).__init__()
        self.wk = nn.Parameter(torch.randn((dmodel, dk)), requires_grad=True)
        self.wq = nn.Parameter(torch.randn((dmodel, dk)), requires_grad=True)
        self.wv = nn.Parameter(torch.randn((dmodel, dv)), requires_grad=True)
        self.scale = math.sqrt(dk)
        
    def forward(self, x):
        q = torch.matmul(x, self.wk)
        k = torch.matmul(x, self.wk)
        v = torch.matmul(x, self.wv)
        
        k_trans = torch.transpose(k, -2, -1)
        
        scores = torch.matmul(q, k_trans) / self.scale
        
        weights = torch.softmax(scores, dim=-1)
    
        output = torch.matmul(weights, v)
        
        return output

In [6]:
#测试self attention类是否正常工作
dmodel = 10
dk = 10
dv = 10

sa = selfattention(dk, dv, dmodel)
x = torch.randn(6, dmodel)  # 假设输入序列长度为 10
output = sa(x)
print(output, output.shape)

tensor([[ 0.1506, -3.4778,  3.5662,  0.2548,  2.0066, -5.2641, -2.3709, -1.4959,
         -2.6784,  2.5706],
        [-2.3301, -3.6530,  1.0091, -0.9456,  0.0206, -0.5263, -2.5090,  3.6106,
          1.2469, -0.0290],
        [ 2.5769,  3.6930,  0.7355,  2.1224,  2.1392,  2.2069,  4.6420, -0.5704,
          3.7105,  6.4722],
        [ 2.0587, -0.8596,  0.0837,  4.1209, -0.8531,  2.6658,  4.3049, -2.0848,
         -4.6817, -2.2809],
        [-0.1594, -0.3190,  4.1963,  2.9322,  0.2026,  0.0194, -0.4789, -4.2381,
         -1.3737,  2.5007],
        [ 2.1397,  4.5830,  0.3462,  2.3870,  1.8430, -5.0613,  0.2383,  4.5549,
         -4.4616, -4.4314]], grad_fn=<MmBackward0>) torch.Size([6, 10])


In [8]:
#定义训练函数
model = selfattention(dk, dv, dmodel)

loss_function = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

x = torch.randn(6, dmodel)
y_true = torch.randn(6, dv)

def train(x, y_true, num_epochs, model, optimizer, loss_function):
    time = 0
    for epoch in range(num_epochs):
        model.train()  # 设定模型为训练模式
        optimizer.zero_grad()  # 清除之前的梯度
        output = model(x)  # 前向传播
        loss = loss_function(output, y_true)  # 计算损失
        loss.backward()  # 反向传播
        optimizer.step()  # 更新参数
        time += 1
        # 打印损失
        if (epoch + 1) % 10 == 0:  # 每 10 个 epoch 打印一次损失
            print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item():.4f}')
    
    return loss, time

In [10]:
#测试训练函数

# 假设训练数据
x = torch.randn(6, dmodel)  # 输入数据
y_true = torch.randn(6, dv)  # 目标数据

# 训练模型
train(x, y_true, num_epochs=100, model=model, optimizer=optimizer, loss_function=loss_function)

Epoch [10/100], Loss: 10.5026
Epoch [20/100], Loss: 9.3150
Epoch [30/100], Loss: 8.1951
Epoch [40/100], Loss: 7.4871
Epoch [50/100], Loss: 7.1009
Epoch [60/100], Loss: 6.8334
Epoch [70/100], Loss: 6.6070
Epoch [80/100], Loss: 6.3987
Epoch [90/100], Loss: 6.2013
Epoch [100/100], Loss: 6.0118


(tensor(6.0118, grad_fn=<MseLossBackward0>), 100)

In [12]:
#带有batch的self attention，基本和普通SA一样

class batchselfattention(nn.Module):
    def __init__(self, dk, dv, dmodel):
        super(batchselfattention, self).__init__()
        self.wq = nn.Parameter(torch.randn(dmodel, dk), requires_grad=True)
        self.wk = nn.Parameter(torch.randn(dmodel, dk), requires_grad=True)
        self.wv = nn.Parameter(torch.randn(dmodel, dv), requires_grad=True)
        self.scales = math.sqrt(dk)
        
    def forward(self, x):
        batch_size, seq_len, dmodel = x.size()
        
        q = torch.matmul(x, self.wq)
        k = torch.matmul(x, self.wk)
        v = torch.matmul(x, self.wv)
        
        k_trans = k.transpose(-2, -1)
        
        scores = torch.matmul(q, k_trans)/self.scales
        
        weights = torch.softmax(scores, dim=1)
        
        output = torch.matmul(weights, v)
        
        return output

In [14]:
#测试带batch的SA

dk = 768
dv = 768
dmodel = 768
batch_size = 10
N = 20

model = batchselfattention(dk, dv, dmodel)
x = torch.randn(batch_size, N, dmodel)
output = model(x)
print(np.shape(output))

torch.Size([10, 20, 768])
