Masking additional weights with dropout

*we'll add dropout in attention weights*

In [1]:
import torch
inputs = torch.tensor([
    [0.43, 0.15, 0.89],  # your        (x^1)
    [0.55, 0.87, 0.66],  # journey     (x^2)
    [0.57, 0.85, 0.64],  # starts     (x^3)
    [0.22, 0.58, 0.33],  # with       (x^4)
    [0.77, 0.25, 0.10],  # one        (x^5)
    [0.05, 0.80, 0.55]  # step        (X^6)
])


In [2]:

import torch.nn as nn


class SelfAttention_v2(nn.Module):
    def __init__(self, d_in, d_out, qkv_bias=False):
        super().__init__()
        self.d_out = d_out
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)

    def forward(self, x):
        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)
        attn_scores = queries@keys.T
        attn_weights = torch.softmax(attn_scores/keys.shape[-1]**0.5, dim=-1)
        context_vec = attn_weights@values

        return context_vec
    
    

In [3]:
d_in = inputs.shape[1]  # input embedding size (3)
d_out = 2  # output embedding size


In [4]:
torch.manual_seed(789)
sa_v2=SelfAttention_v2(d_in=d_in, d_out=d_out)


In [5]:
queries=sa_v2.W_query(inputs)
keys=sa_v2.W_key(inputs)

attn_score=queries@keys.T

In [6]:
print(attn_score)
print(attn_score.shape)

tensor([[ 0.2899,  0.0716,  0.0760, -0.0138,  0.1344, -0.0511],
        [ 0.4656,  0.1723,  0.1751,  0.0259,  0.1771,  0.0085],
        [ 0.4594,  0.1703,  0.1731,  0.0259,  0.1745,  0.0090],
        [ 0.2642,  0.1024,  0.1036,  0.0186,  0.0973,  0.0122],
        [ 0.2183,  0.0874,  0.0882,  0.0177,  0.0786,  0.0144],
        [ 0.3408,  0.1270,  0.1290,  0.0198,  0.1290,  0.0078]],
       grad_fn=<MmBackward0>)
torch.Size([6, 6])


*we'll apply the concept of masking, attention socre - mask apply - normalized (attention weights)*


In [7]:
context_length=attn_score.shape[0]
mask=torch.triu(torch.ones(context_length, context_length), diagonal=1)
print(mask)


tensor([[0., 1., 1., 1., 1., 1.],
        [0., 0., 1., 1., 1., 1.],
        [0., 0., 0., 1., 1., 1.],
        [0., 0., 0., 0., 1., 1.],
        [0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 0., 0.]])


In [8]:
masked=attn_score.masked_fill(mask.bool(), -torch.inf)
print(masked)


tensor([[0.2899,   -inf,   -inf,   -inf,   -inf,   -inf],
        [0.4656, 0.1723,   -inf,   -inf,   -inf,   -inf],
        [0.4594, 0.1703, 0.1731,   -inf,   -inf,   -inf],
        [0.2642, 0.1024, 0.1036, 0.0186,   -inf,   -inf],
        [0.2183, 0.0874, 0.0882, 0.0177, 0.0786,   -inf],
        [0.3408, 0.1270, 0.1290, 0.0198, 0.1290, 0.0078]],
       grad_fn=<MaskedFillBackward0>)


In [9]:
attn_weights=torch.softmax(masked/keys.shape[-1]**0.5, dim=1)
print(attn_weights)


tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5517, 0.4483, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3800, 0.3097, 0.3103, 0.0000, 0.0000, 0.0000],
        [0.2758, 0.2460, 0.2462, 0.2319, 0.0000, 0.0000],
        [0.2175, 0.1983, 0.1984, 0.1888, 0.1971, 0.0000],
        [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],
       grad_fn=<SoftmaxBackward0>)


Masking additional weights with dropout

In [10]:
torch.manual_seed(123)
dropout=torch.nn.Dropout(0.5) # dropout 50%
example=torch.ones(6,6)

print(dropout(example))

tensor([[2., 2., 0., 2., 2., 0.],
        [0., 0., 0., 2., 0., 2.],
        [2., 2., 2., 2., 0., 2.],
        [0., 2., 2., 0., 0., 2.],
        [0., 2., 0., 2., 0., 2.],
        [0., 2., 2., 2., 2., 0.]])


In [11]:
# similarly
torch.manual_seed(123)
print(dropout(attn_weights))


tensor([[2.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.7599, 0.6194, 0.6206, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.4921, 0.4925, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.3966, 0.0000, 0.3775, 0.0000, 0.0000],
        [0.0000, 0.3327, 0.3331, 0.3084, 0.3331, 0.0000]],
       grad_fn=<MulBackward0>)


In [12]:
# causal attention supports the batch outputs produced by the dataloader
batch=torch.stack((inputs, inputs), dim=0)
print(batch.shape)


torch.Size([2, 6, 3])
