In [1]:
import torch

In [68]:
import torch.nn as nn
import torch.nn.functional as F

In [106]:
D_MODEL = 4
SEQ_LEN = 2
N = 2  # batch size
N_HEADS = 2

In [107]:
x = torch.rand(N, SEQ_LEN, D_MODEL)  # mock after pos encoding

In [108]:
x

tensor([[[0.5666, 0.6814, 0.1057, 0.7824],
         [0.3482, 0.1262, 0.6898, 0.4915]],

        [[0.7603, 0.7400, 0.3171, 0.7637],
         [0.5961, 0.7009, 0.8971, 0.6194]]])

$$ \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V $$

In [109]:
# precisamos das matrizes Q, K e V
class QKVMatrices(nn.Module):
    def __init__(self, embed_dim: int = D_MODEL):
        super().__init__()
        self.w_q = nn.Linear(embed_dim, embed_dim)
        self.w_k = nn.Linear(embed_dim, embed_dim)
        self.w_v = nn.Linear(embed_dim, embed_dim)

    def forward(self, x):
        q = self.w_q(x)
        k = self.w_k(x)
        v = self.w_v(x)
        return q, k, v

In [110]:
q, k, v = QKVMatrices()(x)

In [111]:
q.shape

torch.Size([2, 2, 4])

In [112]:
k_t = k.transpose(-1, -2)

In [113]:
print('EXISTEM 3 MANEIRAS:')
print('Usando @: \n', q @ k_t)
print('\n Usando bmm: \n', torch.bmm(q, k_t))
print('\n Usando einsum: \n', torch.einsum("bsd, bde -> bse", q, k_t))

EXISTEM 3 MANEIRAS:
Usando @: 
 tensor([[[-0.2340, -0.3028],
         [-0.3030, -0.1154]],

        [[-0.2700, -0.4067],
         [-0.4408, -0.4764]]], grad_fn=<UnsafeViewBackward0>)

 Usando bmm: 
 tensor([[[-0.2340, -0.3028],
         [-0.3030, -0.1154]],

        [[-0.2700, -0.4067],
         [-0.4408, -0.4764]]], grad_fn=<BmmBackward0>)

 Usando einsum: 
 tensor([[[-0.2340, -0.3028],
         [-0.3030, -0.1154]],

        [[-0.2700, -0.4067],
         [-0.4408, -0.4764]]], grad_fn=<ViewBackward0>)


In [118]:
attention = F.softmax((q @ k_t) / math.sqrt(D_MODEL), dim=-1) @ v

In [119]:
attention

tensor([[[ 0.1077,  0.2344, -0.8078,  0.6266],
         [ 0.0956,  0.2222, -0.8025,  0.6164]],

        [[ 0.1622,  0.4760, -1.0948,  0.9112],
         [ 0.1611,  0.4747, -1.0960,  0.9107]]], grad_fn=<UnsafeViewBackward0>)

In [125]:
def attention(q, k, v):
    embed_dim = q.shape[-1]
    k_t = k.transpose(-1, -2)
    return F.softmax((q @ k_t) / math.sqrt(embed_dim), dim=-1) @ v

In [126]:
attention(q, k, v)

tensor([[[ 0.1077,  0.2344, -0.8078,  0.6266],
         [ 0.0956,  0.2222, -0.8025,  0.6164]],

        [[ 0.1622,  0.4760, -1.0948,  0.9112],
         [ 0.1611,  0.4747, -1.0960,  0.9107]]], grad_fn=<UnsafeViewBackward0>)

In [127]:
q_teste = torch.tensor(
    [
        [[-0.9461, -0.8619], [-0.4798, 1.2657], [-1.1975, -0.5603]],
        [[1.8781, -0.1852], [-0.8300, -0.4783], [1.2326, 2.0119]],
        [[0.0388, 0.0122], [0.2351, 1.8343], [1.1487, 0.3338]],
        [[-0.2851, -2.6427], [-0.5660, 0.2364], [-2.6867, 0.7265]],
    ]
)
k_teste = torch.tensor(
    [
        [[-1.4583, 1.1762], [-1.7270, -1.0825], [0.1402, 0.9177]],
        [[0.3978, -1.4284], [0.7356, 1.3742], [-0.7733, -1.6810]],
        [[0.7199, 0.1448], [-1.0729, -0.9027], [0.1009, -0.0685]],
        [[-0.5599, 1.3669], [-0.6013, 2.0099], [-0.8953, -0.8493]],
    ]
)
v_teste = torch.tensor(
    [
        [[0.3650, -0.3458], [0.7181, -0.0432], [1.7011, 0.7965]],
        [[-0.4509, -1.3015], [-1.3350, -0.0976], [-0.0171, 0.1326]],
        [[0.4331, 0.2897], [0.3835, 0.5914], [-0.2202, 0.4520]],
        [[-1.2591, 0.1445], [2.1758, -1.5864], [-0.4733, 0.0633]],
    ]
)

In [131]:
retorno = attention(q_teste, k_teste, v_teste)

In [132]:
gabarito = torch.tensor(
    [
        [[0.7249, -0.0375], [0.7806, 0.0096], [0.7016, -0.0575]],
        [[-0.8263, -0.5985], [-0.2619, -0.3001], [-1.3185, -0.1132]],
        [[0.1979, 0.4415], [0.1897, 0.3796], [0.2141, 0.3741]],
        [[-0.4730, 0.0572], [0.2842, -0.5481], [0.5559, -0.7088]],
    ]
)

In [133]:
assert retorno.shape == gabarito.shape, "o shape da resposta esta errado"

In [134]:
assert torch.all(
    torch.eq(torch.round(retorno, decimals=4), torch.round(gabarito, decimals=4))
), "os valores da resposta estao errados"