In [4]:
import torch
import numpy as np

import os
import matplotlib.pyplot

In [119]:
class Attention(torch.nn.Module):

    def __init__(self, input_embedding_size: int, kq_embedding_size: int, v_embedding_size: int, masked: bool = False):
        super().__init__()
        # Input Shape=(T, S) 
        
        # Shape=(S, KQ)
        self.K_w = torch.nn.Linear(input_embedding_size, kq_embedding_size)
        # Shape=(S, KQ)
        self.Q_w = torch.nn.Linear(input_embedding_size, kq_embedding_size)

        # Shape=(S,V)
        self.V_w = torch.nn.Linear(input_embedding_size, v_embedding_size)
        self.V_scale = 1.0 / np.sqrt(v_embedding_size)

        # if true, hides future values (decoder), otherwise allows peeking forward
        # in time (encoder)
        self._masked = masked

    def forward(self, input_sequence):
        # (T, S) x (S, KQ) -> (T, KQ)
        queries = self.Q_w(input_sequence)
        # (T, S) x (S, KQ) -> (T, KQ)
        keys = self.K_w(input_sequence)

        # (T, KQ) X (KQ, T) -> (T, T)
        attention_matrix = self.V_scale * (queries @ torch.transpose(keys, 0, 1))

        # Block out future values for each item in the sequence.
        # -infinity goes to 0 in the softmax.
        if self._masked:
            upper_mask = torch.ones_like(attention_matrix).tril() == 0
            attention_matrix = attention_matrix.masked_fill(upper_mask, -np.inf)
        # Normalizes row-by-row. (T, T)
        normalized_attention_matrix = torch.softmax(attention_matrix, dim=1)
        print('weights', normalized_attention_matrix)

        # (T, S) x (S, V) -> (T, V)
        values =  self.V_w(input_sequence)

        # (T, T) x (T, V) = (T, V)
        return normalized_attention_matrix @ values

In [120]:
class MultiheadedAttention(torch.nn.Module):

    def __init__(self, n_heads: int, input_embedding_size: int, kq_embedding_size: int, v_embedding_size: int, masked:bool=False):
        super().__init__()
        # TODO: only final embedding size should be specified. we should divide this among the heads
        self.head_stack = [Attention(input_embedding_size, kq_embedding_size, v_embedding_size, masked=masked) for _ in range(n_heads)]

    def forward(self, input_sequence):
        return torch.hstack([head(input_sequence) for head in self.head_stack])

In [121]:
test_tensor = torch.tensor([[0.2, 0.8], [0.5, 0.5]])
test_tensor

tensor([[0.2000, 0.8000],
        [0.5000, 0.5000]])

In [122]:
att = Attention(input_embedding_size=2, kq_embedding_size=5, v_embedding_size=10, masked=True)
mha = MultiheadedAttention(n_heads =5, input_embedding_size=2, kq_embedding_size=5, v_embedding_size=10, masked=True)

In [123]:
att = Attention(input_embedding_size=2, kq_embedding_size=5, v_embedding_size=10, masked=False)
att.forward(test_tensor)

weights tensor([[0.5128, 0.4872],
        [0.5151, 0.4849]], grad_fn=<SoftmaxBackward0>)


tensor([[-0.8743,  0.0041,  0.5484,  0.2443,  0.5566, -0.8311, -0.2982, -0.5566,
         -0.7567,  0.0167],
        [-0.8741,  0.0047,  0.5481,  0.2437,  0.5573, -0.8311, -0.2980, -0.5570,
         -0.7573,  0.0169]], grad_fn=<MmBackward0>)

In [124]:
masked_att = Attention(input_embedding_size=2, kq_embedding_size=5, v_embedding_size=10, masked=True)
masked_att.forward(test_tensor)

weights tensor([[1.0000, 0.0000],
        [0.4947, 0.5053]], grad_fn=<SoftmaxBackward0>)


tensor([[ 0.1343,  0.7997, -0.9043,  0.3885,  0.5285,  0.4651,  0.1492,  0.3160,
         -0.3017, -0.0447],
        [ 0.1696,  0.7208, -0.7852,  0.3348,  0.4258,  0.5256,  0.0345,  0.3818,
         -0.2358, -0.0599]], grad_fn=<MmBackward0>)

In [125]:
mha.forward(test_tensor)

weights tensor([[1.0000, 0.0000],
        [0.5130, 0.4870]], grad_fn=<SoftmaxBackward0>)
weights tensor([[1.0000, 0.0000],
        [0.4910, 0.5090]], grad_fn=<SoftmaxBackward0>)
weights tensor([[1.0000, 0.0000],
        [0.4915, 0.5085]], grad_fn=<SoftmaxBackward0>)
weights tensor([[1.0000, 0.0000],
        [0.5191, 0.4809]], grad_fn=<SoftmaxBackward0>)
weights tensor([[1.0000, 0.0000],
        [0.5166, 0.4834]], grad_fn=<SoftmaxBackward0>)


tensor([[ 0.0493, -0.3194, -0.3123,  0.2131, -0.8150,  0.4920,  0.3484,  0.5023,
         -0.7059, -0.4539,  0.7817, -1.1859, -0.5948,  0.2744, -0.2611, -0.7955,
         -0.6004, -0.6219,  0.2608, -0.6604,  0.3387,  0.0361,  0.3525, -0.6725,
         -0.2645, -0.5994, -0.0470, -0.4691,  0.1977,  0.1626, -0.8820, -0.1713,
          0.1643, -0.4443, -0.0565,  0.8806, -0.8181, -0.4663,  0.0050,  0.9629,
         -0.8639, -0.6341,  0.5698,  0.2015,  0.7313,  0.3077,  0.5483, -0.4744,
          0.2041, -0.1124],
        [ 0.2301, -0.2138, -0.2308,  0.1769, -0.6883,  0.5825,  0.3498,  0.4796,
         -0.6806, -0.4117,  0.6675, -1.1405, -0.6076,  0.3759, -0.2764, -0.7715,
         -0.4457, -0.5670,  0.4318, -0.5113,  0.1845, -0.0296,  0.3962, -0.5674,
         -0.1871, -0.7022,  0.0090, -0.3089,  0.2506, -0.0073, -0.8790, -0.1680,
          0.1970, -0.3761, -0.2002,  0.8629, -0.7113, -0.4607,  0.0389,  0.9654,
         -0.7377, -0.4755,  0.4311,  0.1662,  0.8252,  0.3023,  0.4956, -0.4705,
