In [1]:
import torch

In [2]:
# These are the inputs to the attention mechanism, which represents the respectve vectors of each word in the sequence.
inputs = torch.tensor([[ 0.3374, -0.1778, -0.1690],
        [ 0.9178,  1.5810,  1.3010],
        [ 1.2753, -0.2010, -0.1606],
        [-0.4015,  0.9666, -1.1481],
        [-1.1589,  0.3255, -0.6315],
        [-2.8400, -0.7849, -1.4096]])

In [3]:
atten_weights = torch.tensor([[1.6860e-01, 1.4012e-01, 1.6816e-01, 1.6587e-01, 1.6843e-01, 1.8881e-01],
        [9.2893e-04, 9.9621e-01, 1.2864e-03, 9.9685e-04, 5.7419e-04, 5.8191e-06],
        [1.3251e-01, 4.2602e-01, 1.4244e-01, 1.2845e-01, 1.1736e-01, 5.3208e-02],
        [1.5271e-01, 4.4836e-02, 1.4663e-01, 1.4493e-01, 1.5998e-01, 3.5092e-01],
        [9.9106e-02, 8.7224e-03, 8.8134e-02, 9.7618e-02, 1.1824e-01, 5.8818e-01],
        [1.8241e-03, 3.4638e-07, 1.1996e-03, 1.7521e-03, 3.4420e-03, 9.9178e-01]])

In [4]:
context_lenght = 6
simple_mask = torch.tril(torch.ones(context_lenght, context_lenght))
print(simple_mask)

tensor([[1., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1.]])


In [5]:
atten_weights_masked = atten_weights * simple_mask
print(atten_weights_masked)

tensor([[1.6860e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
        [9.2893e-04, 9.9621e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
        [1.3251e-01, 4.2602e-01, 1.4244e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00],
        [1.5271e-01, 4.4836e-02, 1.4663e-01, 1.4493e-01, 0.0000e+00, 0.0000e+00],
        [9.9106e-02, 8.7224e-03, 8.8134e-02, 9.7618e-02, 1.1824e-01, 0.0000e+00],
        [1.8241e-03, 3.4638e-07, 1.1996e-03, 1.7521e-03, 3.4420e-03, 9.9178e-01]])


In [6]:
row_sums = atten_weights_masked.sum(dim=1)
atten_weights_normalized = atten_weights_masked / row_sums
print(atten_weights_normalized)

tensor([[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
        [5.5097e-03, 9.9907e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
        [7.8594e-01, 4.2724e-01, 2.0320e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00],
        [9.0575e-01, 4.4965e-02, 2.0918e-01, 2.9632e-01, 0.0000e+00, 0.0000e+00],
        [5.8782e-01, 8.7474e-03, 1.2573e-01, 1.9958e-01, 2.8712e-01, 0.0000e+00],
        [1.0819e-02, 3.4737e-07, 1.7113e-03, 3.5822e-03, 8.3580e-03, 9.9178e-01]])


In [7]:
torch.manual_seed(123)
dropout = torch.nn.Dropout(0.5)
example = torch.ones(6, 6)
print(dropout(example))

tensor([[2., 2., 2., 2., 2., 2.],
        [0., 2., 0., 0., 0., 0.],
        [0., 0., 2., 0., 2., 0.],
        [2., 2., 0., 0., 0., 2.],
        [2., 0., 0., 0., 0., 2.],
        [0., 2., 0., 0., 0., 0.]])


In [8]:
torch.manual_seed(123)
print(dropout(atten_weights_normalized))

tensor([[2.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
        [0.0000e+00, 1.9981e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
        [0.0000e+00, 0.0000e+00, 4.0641e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00],
        [1.8115e+00, 8.9929e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
        [1.1756e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
        [0.0000e+00, 6.9475e-07, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00]])


In [9]:
batch = torch.stack((inputs, inputs))
print(batch.shape)

torch.Size([2, 6, 3])


In [25]:
import torch.nn as nn
class CausalAttention(nn.Module):
    def __init__(self, dim_in, dim_out, context_lenght, dropout, bias = False):
        super().__init__()
        self.dim_out = dim_out
        self.w_query = nn.Linear(dim_in, dim_out, bias= bias)
        self.w_key = nn.Linear(dim_in, dim_out, bias= bias)
        self.w_value = nn.Linear(dim_in, dim_out, bias= bias)
        self.dropout = nn.Dropout(dropout)
        self.register_buffer('mask', torch.triu(torch.ones(context_lenght, context_lenght), diagonal= 1))
        
    def forward(self, x):
        batch, num_token, d_in = x.shape
        
        querys = self.w_query(x)
        keys = self.w_key(x)
        values = self.w_value(x)
        
        atten_scores = querys @ keys.transpose(1, 2)
        atten_weights = torch.softmax(atten_scores, dim= -1)
        
        atten_weights = self.dropout(atten_weights)

        context_vecs = atten_weights @ values
        
        return context_vecs

In [26]:
torch.manual_seed(123)
context_lenght = batch.shape[1]
causal_attention = CausalAttention(dim_in=3, dim_out=6, context_lenght=context_lenght,dropout=0.5, bias=0.0)
context_vecs = causal_attention(batch)
print(context_vecs)

tensor([[[ 2.3359e-01,  6.6592e-02,  5.3487e-02, -3.3602e-02, -1.7786e-01,
          -8.6936e-02],
         [ 9.1773e-02,  7.4496e-02,  1.5637e-01, -5.3266e-02,  2.3278e-02,
          -3.8601e-02],
         [ 1.1449e-01,  2.2122e-02,  4.3281e-01, -1.0257e-01,  9.1141e-02,
          -1.8214e-03],
         [ 4.2301e-02,  3.5649e-02,  4.6898e-02, -1.9366e-02,  1.1846e-04,
          -2.0548e-02],
         [ 9.7960e-01,  1.5314e-01,  1.1931e-02, -6.6993e-02, -9.2484e-01,
          -3.4012e-01],
         [ 1.5192e-01,  8.5207e-02, -7.4075e-02, -7.3957e-03, -1.3794e-01,
          -8.1052e-02]],

        [[ 5.2393e-01,  1.1631e-01,  4.2029e-01, -1.3341e-01, -2.8370e-01,
          -1.5607e-01],
         [ 1.4661e-01,  6.3644e-03,  1.0459e-01, -2.8740e-02, -1.0239e-01,
          -3.5758e-02],
         [-2.2798e-02,  2.8899e-02, -3.6739e-02,  2.3480e-03,  2.6085e-02,
          -6.7021e-03],
         [ 3.1075e+00,  6.4244e-01, -1.8788e-01, -1.9788e-01, -2.9344e+00,
          -1.1540e+00],
        