In [2]:
import torch
import math
import torch.nn as nn
import torch.nn.functional as F

# 1.简化版本

In [3]:
class SelfAttentionV1(nn.Module):
    def __init__(self, hidden_dim: int = 728) -> None:
        super().__init__()
        self.hidden_dim = hidden_dim

        #初始化三个线性映射层
        self.query = nn.Linear(hidden_dim, hidden_dim)
        self.key = nn.Linear(hidden_dim, hidden_dim)
        self.value = nn.Linear(hidden_dim, hidden_dim)


    def forward(self, X):
        #X shape is : [batch_size, seq_len, hidden_dim]
        q = self.query(X)
        k = self.key(X)
        v = self.value(X)

        att_weight = F.softmax(q @ k.transpose(-1,-2) / math.sqrt(self.hidden_dim), dim=-1)

        print(f'att_weight is {att_weight}')
        res = att_weight @ v

        return res 

att_layer = SelfAttentionV1(3)
att_layer(torch.rand(3,2,3))






att_weight is tensor([[[0.5508, 0.4492],
         [0.5309, 0.4691]],

        [[0.5566, 0.4434],
         [0.5220, 0.4780]],

        [[0.5153, 0.4847],
         [0.5186, 0.4814]]], grad_fn=<SoftmaxBackward0>)


tensor([[[-0.0848, -0.2373,  0.7479],
         [-0.0890, -0.2324,  0.7399]],

        [[-0.1690, -0.2583,  0.7524],
         [-0.1685, -0.2413,  0.7298]],

        [[-0.1749, -0.2377,  0.6996],
         [-0.1733, -0.2369,  0.6989]]], grad_fn=<UnsafeViewBackward0>)

# 2.效率优化

In [4]:
class SelfAttentionV2(nn.Module):
    def __init__(self, hidden_dim):
        super().__init__()
        self.dim = hidden_dim

        self.proj = nn.Linear(hidden_dim, 3 * hidden_dim)


    def forward(self, X):
        QKV = self.proj(X)
        Q, K, V = torch.split(QKV, self.dim, dim=-1)

        weight = torch.softmax(Q @ K.transpose(-1,-2) / math.sqrt(self.dim), dim=-1)
        output = weight @ V
        print(weight)
        return output

X = torch.rand(3,2,2)

self_att = SelfAttentionV2(2)
self_att(X)






tensor([[[0.5157, 0.4843],
         [0.5253, 0.4747]],

        [[0.4877, 0.5123],
         [0.4911, 0.5089]],

        [[0.4657, 0.5343],
         [0.4710, 0.5290]]], grad_fn=<SoftmaxBackward0>)


tensor([[[-0.8945,  0.4268],
         [-0.8974,  0.4279]],

        [[-0.6319,  0.0989],
         [-0.6314,  0.0993]],

        [[-0.5164,  0.0375],
         [-0.5142,  0.0361]]], grad_fn=<UnsafeViewBackward0>)

# 3.加入一些细节

In [5]:
# 1.dropout的位置
# 2.attention_mask
# 3.output 矩阵映射


class SelfAttentionV3(nn.Module):
    def __init__(self, hidden_dim, dropout=0.1):
        super().__init__()
        self.dim = hidden_dim
        self.proj = nn.Linear(hidden_dim, 3 * hidden_dim)
        self.dropout = nn.Dropout(dropout)

    
    def forward(self, X, att_mask=None):
        QKV = self.proj(X)
        Q, K, V = torch.split(QKV, self.dim, -1)

        att_weight = att_weight = Q @ K.transpose(-1,-2) / math.sqrt(self.dim)
        if att_mask != None:
            
            att_weight.masked_fill(
                att_mask == 0,
                float('-inf')
            )
        att_weight = torch.softmax(att_weight, dim=-1)
        att_weight = self.dropout(att_weight)
        output = att_weight @ V

        return output

X = torch.rand(3, 4, 2)
b = torch.tensor(
    [
        [1, 1, 1, 0],
        [1, 1, 0, 0],
        [1, 0, 0, 0],
    ]
)
mask = b.unsqueeze(dim=1).repeat(1, 4, 1)
net = SelfAttentionV3(2)
net(X, mask).shape
        


torch.Size([3, 4, 2])

# 4.面试写法

In [11]:
class SelfAttentionV4(nn.Module):
    def __init__(self, dim:int = 2, dropout:float = 0.1):
        super().__init__()
        self.dim = dim
        self.query =nn.Linear(dim, dim)
        self.key = nn.Linear(dim, dim)
        self.value = nn.Linear(dim, dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, X, mask=None):

        q = self.query(X)
        k = self.key(X)
        v = self.value(X)

        att_weight = q @ k.transpose(-1, -2) / math.sqrt(self.dim)
        if mask is not None:
            att_weight = att_weight.masked_fill(
                mask == 0,
                float('-inf')
            )
        att_weight = torch.softmax(
            att_weight,
            dim = -1
        )
        print(f'att_weight is {att_weight}')

        att_weight = self.dropout(att_weight)
        output = att_weight @ v

        return output
    
X = torch.rand(3, 4, 2)
b = torch.tensor(
    [
        [1, 1, 1, 0],
        [1, 1, 0, 0],
        [1, 0, 0, 0],
    ]
)
mask = b.unsqueeze(dim=1).repeat(1, 4, 1)
net = SelfAttentionV4(2)
net(X, mask)


att_weight is tensor([[[0.2955, 0.3588, 0.3457, 0.0000],
         [0.2950, 0.3591, 0.3459, 0.0000],
         [0.3020, 0.3550, 0.3430, 0.0000],
         [0.2946, 0.3593, 0.3461, 0.0000]],

        [[0.4963, 0.5037, 0.0000, 0.0000],
         [0.5144, 0.4856, 0.0000, 0.0000],
         [0.5009, 0.4991, 0.0000, 0.0000],
         [0.5005, 0.4995, 0.0000, 0.0000]],

        [[1.0000, 0.0000, 0.0000, 0.0000],
         [1.0000, 0.0000, 0.0000, 0.0000],
         [1.0000, 0.0000, 0.0000, 0.0000],
         [1.0000, 0.0000, 0.0000, 0.0000]]], grad_fn=<SoftmaxBackward0>)


tensor([[[0.0780, 0.0666],
         [0.1386, 0.1928],
         [0.3819, 0.4381],
         [0.3831, 0.4403]],

        [[0.5325, 0.5278],
         [0.5404, 0.5322],
         [0.5345, 0.5289],
         [0.5343, 0.5288]],

        [[0.5542, 0.5316],
         [0.5542, 0.5316],
         [0.0000, 0.0000],
         [0.5542, 0.5316]]], grad_fn=<UnsafeViewBackward0>)

# 完结