# Causal Attention with Droupouts

In [38]:
import torch


In [39]:
inputs = torch.tensor(
  [[0.43, 0.15, 0.89], # Your     (x^1)
   [0.55, 0.87, 0.66], # journey  (x^2)
   [0.57, 0.85, 0.64], # starts   (x^3)
   [0.22, 0.58, 0.33], # with     (x^4)
   [0.77, 0.25, 0.10], # one      (x^5)
   [0.05, 0.80, 0.55]] # step     (x^6)
)

In [40]:
import torch.nn as nn
class SelfAttentionV2(nn.Module):

    def __init__(self, d_in, d_out, qkv_bias=False):
        super().__init__()
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key   = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)

    def forward(self, x):
        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)
        
        attn_scores = queries @ keys.T
        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)

        context_vec = attn_weights @ values
        return context_vec

In [41]:
sav2 = SelfAttentionV2(3,2)

In [42]:
keys = sav2.W_key(inputs)
queries = sav2.W_query(inputs)
attention_scores = queries @ keys.T
print(attention_scores)

tensor([[-0.2327,  0.1055,  0.1098,  0.0913,  0.1549,  0.0521],
        [-0.2396,  0.1015,  0.1057,  0.0902,  0.1501,  0.0518],
        [-0.2323,  0.1004,  0.1045,  0.0885,  0.1481,  0.0507],
        [-0.1344,  0.0502,  0.0523,  0.0470,  0.0753,  0.0272],
        [-0.0349,  0.0520,  0.0538,  0.0331,  0.0708,  0.0174],
        [-0.2142,  0.0650,  0.0679,  0.0668,  0.1004,  0.0395]],
       grad_fn=<MmBackward0>)


### Here, we are trying to mask out attention scores above the diagonal such that, our model cannot cheat, that is
### it only has access to tokens till i-1 if it is trying to predict ith token
### we can achive this my simply multiplying our attention scores with a lower-triangular matrix of ones. 

In [43]:
context_len = 6
simple_mask = torch.tril(torch.ones(context_len, context_len))
print(simple_mask)

tensor([[1., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1.]])


In [44]:
simple_masked_scores = attention_scores * simple_mask
print(simple_masked_scores)
#However, the data till the ith token is still influenced by the i+1th token 

tensor([[-0.2327,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
        [-0.2396,  0.1015,  0.0000,  0.0000,  0.0000,  0.0000],
        [-0.2323,  0.1004,  0.1045,  0.0000,  0.0000,  0.0000],
        [-0.1344,  0.0502,  0.0523,  0.0470,  0.0000,  0.0000],
        [-0.0349,  0.0520,  0.0538,  0.0331,  0.0708,  0.0000],
        [-0.2142,  0.0650,  0.0679,  0.0668,  0.1004,  0.0395]],
       grad_fn=<MulBackward0>)


In [45]:
row_sums = simple_masked_scores.sum(dim = 1, keepdim = True)
simple_masked_weights = simple_masked_scores / row_sums
print(simple_masked_weights)

tensor([[ 1.0000, -0.0000, -0.0000, -0.0000, -0.0000, -0.0000],
        [ 1.7352, -0.7352, -0.0000, -0.0000, -0.0000, -0.0000],
        [ 8.4746, -3.6631, -3.8116, -0.0000, -0.0000, -0.0000],
        [-8.9549,  3.3429,  3.4830,  3.1290,  0.0000,  0.0000],
        [-0.1997,  0.2974,  0.3077,  0.1896,  0.4050,  0.0000],
        [-1.7071,  0.5185,  0.5414,  0.5327,  0.7999,  0.3146]],
       grad_fn=<DivBackward0>)


### We successfully masked out the weights below diagonal. However, our current weights still are influenced by our future weights (masked out). This is called data-leakage
### We can take use of the softmax function to handle this. 
### We will create an upper-triangular matrix of ones and replace all the ones with -inf. While applying softmax, -inf will be considered as 0

In [46]:
values = sav2.W_value(inputs)
simple_masked_weights @ values

tensor([[ 0.4772,  0.1063],
        [ 0.3303, -0.1816],
        [-1.0134, -2.8082],
        [ 1.4447,  3.3911],
        [ 0.5725,  0.4566],
        [ 0.5892,  0.8548]], grad_fn=<MmBackward0>)

In [47]:
ones = torch.ones(context_len, context_len)
ones

tensor([[1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1.]])

In [48]:
mask = torch.triu(ones, diagonal=1)
mask

tensor([[0., 1., 1., 1., 1., 1.],
        [0., 0., 1., 1., 1., 1.],
        [0., 0., 0., 1., 1., 1.],
        [0., 0., 0., 0., 1., 1.],
        [0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 0., 0.]])

In [49]:
masked = attention_scores.masked_fill(mask.bool(), -torch.inf)
print(masked)

tensor([[-0.2327,    -inf,    -inf,    -inf,    -inf,    -inf],
        [-0.2396,  0.1015,    -inf,    -inf,    -inf,    -inf],
        [-0.2323,  0.1004,  0.1045,    -inf,    -inf,    -inf],
        [-0.1344,  0.0502,  0.0523,  0.0470,    -inf,    -inf],
        [-0.0349,  0.0520,  0.0538,  0.0331,  0.0708,    -inf],
        [-0.2142,  0.0650,  0.0679,  0.0668,  0.1004,  0.0395]],
       grad_fn=<MaskedFillBackward0>)


In [50]:
attention_weights = torch.softmax(masked / keys.shape[-1] ** 0.5, dim=-1)
print(attention_weights)

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.4400, 0.5600, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2830, 0.3580, 0.3590, 0.0000, 0.0000, 0.0000],
        [0.2264, 0.2579, 0.2583, 0.2574, 0.0000, 0.0000],
        [0.1903, 0.2024, 0.2026, 0.1997, 0.2051, 0.0000],
        [0.1408, 0.1715, 0.1718, 0.1717, 0.1758, 0.1684]],
       grad_fn=<SoftmaxBackward0>)


### Masking in Transformers sets scores for future tokens to a large negative value, making their influence in the softmax calculation effectively zero.
### The softmax function then recalculates attention weights only among the unmasked tokens.
### This process ensures no information leakage from masked tokens, focusing the model solely on the intended data.

### MASKING ADDITIONAL ATTENTION WEIGHTS WITH DROPOUT
- When applying dropout to an attention weight matrix with a rate of 50%, half of the elements in the matrix are randomly set to zero.
- To compensate for the reduction in active elements, the values of the remaining elements in the matrix are scaled up by a factor of 1/0.5 =2.
- This scaling is crucial to maintain the overall balance of the attention weights, ensuring that the average influence of the attention mechanism remains consistent during both the training and inference phases.

In [51]:
torch.manual_seed(123)
dropout = torch.nn.Dropout(0.5)
print(dropout(attention_weights))

tensor([[2.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5659, 0.7160, 0.7181, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.5159, 0.5167, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.4047, 0.0000, 0.3993, 0.0000, 0.0000],
        [0.0000, 0.3430, 0.3437, 0.3434, 0.3516, 0.0000]],
       grad_fn=<MulBackward0>)


### Creating Causal Attention Class with Dropouts class

In [52]:
batch = torch.stack((inputs, inputs), dim=0)
batch.shape

torch.Size([2, 6, 3])

In [53]:
print(batch)

tensor([[[0.4300, 0.1500, 0.8900],
         [0.5500, 0.8700, 0.6600],
         [0.5700, 0.8500, 0.6400],
         [0.2200, 0.5800, 0.3300],
         [0.7700, 0.2500, 0.1000],
         [0.0500, 0.8000, 0.5500]],

        [[0.4300, 0.1500, 0.8900],
         [0.5500, 0.8700, 0.6600],
         [0.5700, 0.8500, 0.6400],
         [0.2200, 0.5800, 0.3300],
         [0.7700, 0.2500, 0.1000],
         [0.0500, 0.8000, 0.5500]]])


In [54]:
class CausalAttention(nn.Module):
    def __init__(self, dim_in, dim_out, context_length, dropout, qkv_bias=False):
        super().__init__()
        self.W_query = nn.Linear(dim_in, dim_out, bias=qkv_bias)
        self.W_key = nn.Linear(dim_in, dim_out, bias=qkv_bias)
        self.W_value = nn.Linear(dim_in, dim_out, bias=qkv_bias)
        self.dropout = torch.nn.Dropout(dropout)
        self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1))

    def forward(self, x):
        b, num_tokens, _ = x.shape
        queries = self.W_query(x)
        keys = self.W_key(x)
        values = self.W_value(x)

        attention_scores = queries @ keys.transpose(1,2) 
        masked_scores = attention_scores.masked_fill_(self.mask.bool()[:num_tokens, :num_tokens], -torch.inf)
        attention_weights = torch.softmax(masked_scores/keys.shape[-1] ** 0.5, dim=-1)
        attention_weights = self.dropout(attention_weights)

        context_vec = attention_weights @ values
        return context_vec
        

In [55]:
torch.manual_seed(123)
context_length = batch.shape[1]
dim_in = 3
dim_out = 2
ca = CausalAttention(dim_in, dim_out, context_length, 0.0)
context_vecs = ca(batch)
print("context_vecs.shape:", context_vecs.shape)

context_vecs.shape: torch.Size([2, 6, 2])


In [56]:
context_vecs

tensor([[[-0.4519,  0.2216],
         [-0.5874,  0.0058],
         [-0.6300, -0.0632],
         [-0.5675, -0.0843],
         [-0.5526, -0.0981],
         [-0.5299, -0.1081]],

        [[-0.4519,  0.2216],
         [-0.5874,  0.0058],
         [-0.6300, -0.0632],
         [-0.5675, -0.0843],
         [-0.5526, -0.0981],
         [-0.5299, -0.1081]]], grad_fn=<UnsafeViewBackward0>)