In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
import numpy as np
import matplotlib.pyplot as plt

$\begin{equation*}\Large
Attention(Q,K,V) = softmax\left(\frac{QK^{T}}{\sqrt{d_k}}\right)V
\end{equation*}$

In [16]:
class Attention(nn.Module):
    def forward(self, q, k, v, mask=None): # q,k[B, N, dv, dim]
        dim = q.size(-1)
        scale = dim**-0.5
        q *= scale
        
        attn = torch.matmul(q, k.transpose(-1,-2)).softmax(dim=-1) # [B, N, dv, dv]
        
        out = torch.matmul(attn, v)    # [B, N, dv, dim]
        return out, attn

In [17]:
atn = Attention()

In [23]:
k = torch.Tensor([[10, 0, 0],
                  [0, 10, 0],
                  [0, 0, 10],
                  [0, 0, 10]])[None,:][None,:]

v = torch.Tensor([[1, 0],
                  [10, 0],
                  [100, 5],
                  [1000, 6]])[None,:][None,:]

q = torch.Tensor([[0, 10, 0]])[None,:][None,:]

In [24]:
q.shape, k.shape, v.shape

(torch.Size([1, 1, 1, 3]), torch.Size([1, 1, 4, 3]), torch.Size([1, 1, 4, 2]))

In [25]:
o, a = atn(q,k,v)

In [26]:
a.shape, o.shape

(torch.Size([1, 1, 1, 4]), torch.Size([1, 1, 1, 2]))

In [27]:
a.round(), o.round()

(tensor([[[[0., 1., 0., 0.]]]]), tensor([[[[10.,  0.]]]]))