In [24]:
import torch
from torch import tensor
from math import sqrt
from torch.nn.functional import softmax

def par_attention(queries: tensor, keys: tensor, values: tensor, dim: int) -> tensor:
    # a batch of 2d matrixes is produced: 3d
    raw_weights = torch.bmm(queries, keys.transpose(1, 2))

    mask = torch.tril(torch.ones_like(raw_weights), diagonal=0)
    raw_weights = raw_weights.masked_fill(mask == 0, float('-inf'))
    # ^ still a batch of matrices
    # print(f"raw_weights.shape:{raw_weights.shape}\nraw_weights: {raw_weights}")

    scale_factor = sqrt(float(dim))
    scaled_weights = softmax(raw_weights / scale_factor, dim=2) 
    # ^ still same structure

    # I add a bonus dimension. why tf did I do this...?
    reshaped_scaled_weights = scaled_weights.view(scaled_weights.shape[0], scaled_weights.shape[1], scaled_weights.shape[2], 1)
    # I do the same thing to the original values 3d tensor, making that 4d too. But I add the extra dimension at the start
    reshaped_values = values.view(values.shape[0], values.shape[1], 1, values.shape[2])

    # The goal now is: for each value in each row of weights in scaled_weights, I multiply a value row by that scalar.
    print(f"reshaped_scaled_weights:{reshaped_scaled_weights.shape}\nreshaped_values:{reshaped_values.shape}")
    scaled_values = reshaped_scaled_weights * reshaped_values

    contextualized_values = torch.sum(scaled_values, 2)
    return contextualized_values

In [27]:
Q = tensor([[1, 3], [1, 1]]).float().unsqueeze(0)
K = tensor([[2, 1], [3, 5]]).float().unsqueeze(0)
V = tensor([[0, 5], [3, 2]]).float().unsqueeze(0)
d = 2
output = par_attention(Q, K, V, d)
output[0]

reshaped_scaled_weights:torch.Size([1, 2, 2, 1])
reshaped_values:torch.Size([1, 2, 1, 2])


tensor([[0., 5.],
        [3., 2.]])

In [31]:
Q_batch = torch.zeros([3, 2, 2]).float()
Q_batch[0] = Q
Q_batch[1] = Q
Q_batch[2] = Q
K_batch = torch.zeros([3, 2, 2]).float()
K_batch[0] = K
K_batch[1] = K
K_batch[2] = K
V_batch = torch.zeros([3, 2, 2]).float()
V_batch[0] = V
V_batch[1] = V
V_batch[2] = V
d = 2

par_attention(Q_batch, K_batch, V_batch, d)

reshaped_scaled_weights:torch.Size([3, 2, 2, 1])
reshaped_values:torch.Size([3, 2, 1, 2])


tensor([[[0., 5.],
         [3., 2.]],

        [[0., 5.],
         [3., 2.]],

        [[0., 5.],
         [3., 2.]]])