In [3]:
import torch

In [19]:
import torch.nn as nn

class SelfAttention1(nn.Module):
    def __init__(self , d_in , d_out):
        super().__init__()
        self.W_query = nn.Parameter(torch.rand(d_in , d_out))
        self.W_key = nn.Parameter(torch.rand(d_in , d_out))
        self.W_value = nn.Parameter(torch.rand(d_in , d_out))

    def forward(self , input):
        query = input@self.W_query   
        key = input@self.W_key
        values = input@self.W_value

        attention_score = query@key
        attention_weights = torch.softmax(attention_score/key.shape[-1]**0.5 , dim=-1) 

        context_vector = attention_weights@values

        return context_vector 



       


In [20]:
d_in = 6
d_out =3

torch.manual_seed(123)
W_query = torch.nn.Parameter(torch.rand(d_in,d_out) , requires_grad=False)
W_key = torch.nn.Parameter(torch.rand(d_in,d_out) , requires_grad=False)
W_value = torch.nn.Parameter(torch.rand(d_in,d_out) , requires_grad=False)

In [21]:
inputs = torch.tensor(
    [[0.43 , 0.15 , 0.89], #Your
    [0.55 , 0.87 , 0.66], #journey
    [0.57 , 0.85 , 0.64], #starts
    [0.22 , 0.58 , 0.33], #with
    [0.77 , 0.25, 0.10], #one
    [0.05 , 0.80, 0.55]]#step
)

In [24]:
sa = SelfAttention1(3 , 3)
queries = W_query@inputs.T
keys = W_key@inputs.T

attention_score = queries@keys

attention_weights = torch.softmax(attention_score/keys.shape[-1]**0.5 , dim=-1)

In [27]:
context_length = attention_score.shape[0]
mask_sample = torch.tril(torch.ones(context_length , context_length))
print(mask_sample)

tensor([[1., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1.]])


In [28]:
mask_sample = attention_score*mask_sample
print(mask_sample)

tensor([[2.8975, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [4.2748, 6.4082, 0.0000, 0.0000, 0.0000, 0.0000],
        [1.1340, 1.7285, 1.7092, 0.0000, 0.0000, 0.0000],
        [4.5643, 6.9088, 6.8318, 3.8341, 0.0000, 0.0000],
        [1.6228, 2.5020, 2.4739, 1.3939, 1.2721, 0.0000],
        [3.6002, 5.4342, 5.3715, 3.0204, 2.7298, 3.8237]])


In [29]:
row_sums = mask_sample.sum(dim=1 , keepdim=True)
mask_sample_norm = mask_sample / row_sums
print(mask_sample_norm)

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.4001, 0.5998, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2480, 0.3781, 0.3739, 0.0000, 0.0000, 0.0000],
        [0.2062, 0.3121, 0.3086, 0.1732, 0.0000, 0.0000],
        [0.1752, 0.2701, 0.2670, 0.1505, 0.1373, 0.0000],
        [0.1501, 0.2266, 0.2240, 0.1260, 0.1138, 0.1595]])


In [30]:
print(attention_score)

tensor([[2.8975, 4.4936, 4.4457, 2.4991, 2.3309, 3.1250],
        [4.2748, 6.4082, 6.3347, 3.5557, 3.2284, 4.4984],
        [1.1340, 1.7285, 1.7092, 0.9605, 0.8811, 1.2094],
        [4.5643, 6.9088, 6.8318, 3.8341, 3.5225, 4.8288],
        [1.6228, 2.5020, 2.4739, 1.3939, 1.2721, 1.7556],
        [3.6002, 5.4342, 5.3715, 3.0204, 2.7298, 3.8237]])


In [39]:
mask = torch.triu(torch.ones(context_length , context_length),diagonal=1)
print(mask)
masked = attention_score.masked_fill(mask.bool(),-torch.inf)
print(masked)

tensor([[0., 1., 1., 1., 1., 1.],
        [0., 0., 1., 1., 1., 1.],
        [0., 0., 0., 1., 1., 1.],
        [0., 0., 0., 0., 1., 1.],
        [0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 0., 0.]])
tensor([[2.8975,   -inf,   -inf,   -inf,   -inf,   -inf],
        [4.2748, 6.4082,   -inf,   -inf,   -inf,   -inf],
        [1.1340, 1.7285, 1.7092,   -inf,   -inf,   -inf],
        [4.5643, 6.9088, 6.8318, 3.8341,   -inf,   -inf],
        [1.6228, 2.5020, 2.4739, 1.3939, 1.2721,   -inf],
        [3.6002, 5.4342, 5.3715, 3.0204, 2.7298, 3.8237]])


In [40]:
attention_weight = torch.softmax(masked / keys.shape[-1]**0.5 , dim=1)
print(attention_weight)

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2951, 0.7049, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2825, 0.3601, 0.3573, 0.0000, 0.0000, 0.0000],
        [0.1456, 0.3791, 0.3673, 0.1080, 0.0000, 0.0000],
        [0.1778, 0.2546, 0.2517, 0.1619, 0.1541, 0.0000],
        [0.1288, 0.2724, 0.2655, 0.1017, 0.0903, 0.1412]])


In [41]:
torch.manual_seed(123)
dropout = torch.nn.Dropout(0.5)


In [42]:
ex = torch.ones(6,6)
print(ex)

tensor([[1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1.]])


In [44]:
print(dropout(ex))

tensor([[2., 2., 0., 2., 2., 0.],
        [0., 0., 0., 2., 0., 2.],
        [2., 2., 2., 2., 0., 2.],
        [0., 2., 2., 0., 0., 2.],
        [0., 2., 0., 2., 0., 2.],
        [0., 2., 2., 2., 2., 0.]])
