In [None]:
import os
# os.environ["CUDA_VISIBLE_DEVICES"] = "2"
# 打印当前选择的GPU
print(f"Using GPU:{os.environ['CUDA_VISIBLE_DEVICES']}")

Using GPU:2


In [3]:
import torch
from torch import nn
from torch.functional import F
import math
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Running on device: {device}")

Running on device: cuda


In [None]:
X = torch.randn(128, 64, 512).to(device) # Batch, Token, Dimension
# 128个句子，64个词元，512维的词向量
print(X.shape)

torch.Size([128, 64, 512])


In [5]:
d_model = 512
n_heads = 8

In [None]:
class multi_head_attention(nn.Module):
    def __init__(self, d_model, n_head):
        super().__init__()
        
        self.n_head = n_head
        self.d_model = d_model
        self.w_q = nn.Linear(d_model, d_model)
        self.w_k = nn.Linear(d_model, d_model)
        self.w_v = nn.Linear(d_model, d_model)
        self.w_combine = nn.Linear(d_model, d_model)
        self.softmax = nn.Softmax(dim=-1)
        
    def forward(self, q, k, v):
        # q, k, v: [batch_size, seq_len, d_model]
        batch, token, dimensiton = q.shape # 128, 64, 512
        n_d = self.d_model // self.n_head # 512 // 8 = 64
        q, k, v = self.w_q(q), self.w_k(k), self.w_v(v)
        
        q = q.view(batch, token, self.n_head, n_d).permute(0,2,1,3) #[batch, token, head, d] to [batch, head, token, d]
        k = k.view(batch, token, self.n_head, n_d).permute(0,2,1,3)
        v = v.view(batch, token, self.n_head, n_d).permute(0,2,1,3)
        
        #[a, b] @ [b, c] → [a, c]
        # score = q @ k.transpose(2,3) / math.sqrt(n_d)
        score = q @ k.transpose(-2,-1) / math.sqrt(n_d) # [batch, head, token, d_k] * [batch, head, d_k, token]
        mask = torch.tril(torch.ones(token,token,dtype=torch.bool, device=q.device))
        # score = score.masked_fill(mask == 0, -1e9)
        score = score.masked_fill(mask == 0, float('-inf'))
        score = self.softmax(score) @ v
        
        score = score.permute(0,2,1,3).contiguous().view(batch, token, dimensiton)
        
        output = self.w_combine(score)
        return output
        
attention = multi_head_attention(d_model, n_heads).to(device)
output = attention(X, X, X)
print(output, output.shape)

tensor([[[ 8.6273e-02,  4.7680e-02,  1.2444e-01,  ...,  1.6692e-01,
          -3.4143e-01,  1.9715e-01],
         [ 4.3156e-01, -2.4046e-01, -7.2708e-02,  ..., -1.4109e-01,
          -2.5770e-01,  1.9147e-01],
         [ 3.9320e-01, -1.0397e-01, -1.0461e-01,  ...,  7.1226e-02,
          -2.4283e-01,  1.0542e-01],
         ...,
         [-2.3892e-02, -6.2497e-02,  3.2211e-03,  ..., -5.3380e-03,
          -4.2989e-02,  2.6861e-02],
         [-1.1776e-02, -8.4258e-02,  2.9447e-02,  ...,  1.0789e-02,
          -4.2125e-02, -3.4765e-04],
         [-7.7546e-03, -8.0037e-02, -3.0760e-02,  ...,  1.8437e-02,
          -6.0307e-02,  1.4884e-02]],

        [[ 2.4255e-01,  1.8704e-01, -3.8051e-01,  ...,  3.9135e-01,
          -1.1913e-01,  2.1020e-01],
         [ 1.0736e-01,  1.4980e-01, -3.5683e-01,  ..., -1.0606e-01,
          -2.5445e-01,  3.0958e-01],
         [ 5.3697e-02, -2.9903e-02, -2.4500e-01,  ...,  7.5165e-02,
          -1.2057e-01,  3.3893e-01],
         ...,
         [ 2.1646e-02, -9

self-attention 就是 input 一排Vector 然后 output 一排 Vector