In [22]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [43]:
class MaskedSelfAttention(nn.Module):
    def __init__(self, d_model=2, row_dim=0, col_dim=1):
        super(MaskedSelfAttention, self).__init__()
        self.W_q = nn.Linear(d_model, d_model, bias=False)
        self.W_k = nn.Linear(d_model, d_model, bias=False)
        self.W_v = nn.Linear(d_model, d_model, bias=False)
        self.row_dim = row_dim
        self.col_dim = col_dim
        
    def forward(self, token_embeddings, mask=None):
        Q = self.W_q(token_embeddings)
        K = self.W_k(token_embeddings)
        V = self.W_v(token_embeddings)

        sims = torch.matmul(Q, K.transpose(dim0=self.row_dim, dim1=self.col_dim))
        scaled_sims = sims / torch.tensor(K.size(self.col_dim) ** 0.5)
        if mask is not None:
            scaled_sims = scaled_sims.masked_fill(mask=mask, value=-1e9)
        attention_percent = F.softmax(scaled_sims, dim=self.col_dim)
        attention_scores = torch.matmul(attention_percent, V)
        
        return attention_scores


In [44]:
ecodings_matrix = torch.tensor([[1.16, 0.23],
                            [0.57, 1.36],
                            [4.41, -2.16]])
torch.manual_seed(42)
maskedSelfAttention = MaskedSelfAttention(d_model=2, row_dim=0, col_dim=1)
mask = torch.tril(torch.ones(3, 3))
mask = mask == 0
mask

tensor([[False,  True,  True],
        [False, False,  True],
        [False, False, False]])

In [45]:
maskedSelfAttention(ecodings_matrix, mask)

tensor([[ 0.6038,  0.7434],
        [-0.0062,  0.6072],
        [ 3.4989,  2.2427]], grad_fn=<MmBackward0>)

In [None]:
# 打印W_q矩阵
maskedSelfAttention.W_q.weight.transpose(0,1)

tensor([[ 0.5406, -0.1657],
        [ 0.5869,  0.6496]], grad_fn=<TransposeBackward0>)

In [None]:
maskedSelfAttention.W_k.weight.transpose(0,1)

tensor([[-0.1549, -0.3443],
        [ 0.1427,  0.4153]], grad_fn=<TransposeBackward0>)

In [48]:
maskedSelfAttention.W_v.weight.transpose(0,1)

tensor([[ 0.6233,  0.6146],
        [-0.5188,  0.1323]], grad_fn=<TransposeBackward0>)

In [None]:
# 打印Q矩阵
maskedSelfAttention.W_q(ecodings_matrix)

tensor([[ 0.7621, -0.0428],
        [ 1.1063,  0.7890],
        [ 1.1164, -2.1336]], grad_fn=<MmBackward0>)

In [50]:
maskedSelfAttention.W_k(ecodings_matrix)

tensor([[-0.1469, -0.3038],
        [ 0.1057,  0.3685],
        [-0.9914, -2.4152]], grad_fn=<MmBackward0>)

In [None]:
# 验证计算 Q = ecodings_matrix @ W_q
torch.matmul(ecodings_matrix, maskedSelfAttention.W_q.weight.transpose(0,1))

tensor([[ 0.7621, -0.0428],
        [ 1.1063,  0.7890],
        [ 1.1164, -2.1336]], grad_fn=<MmBackward0>)