<a href="https://colab.research.google.com/github/RCortez25/PhD/blob/main/LLM/4.%20Attention%20mechanism/2_Causal_self_attention.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Causal self-attention

> Causal self-attention is a special form of self-attention. It restricts the model to only consider the previous and current tokens for the analysis. Self-attention calculates, for instance, the attention weights for a given query but in relation to all other tokens, before and after the query. That is, we mask out future tokens and only consider the query and the tokens before it. What we need to modify are the attention weights.

> Let's reuse the code we had for the self-attention mechanism.




In [2]:
import torch
import torch.nn as nn

# Create the input tensors
inputs = torch.tensor(
    [[0.43, 0.15, 0.89], # Your     (x^1)
     [0.55, 0.87, 0.66], # journey  (x^2)
     [0.57, 0.85, 0.64], # starts   (x^3)
     [0.22, 0.58, 0.33], # with     (x^4)
     [0.77, 0.25, 0.10], # one      (x^5)
     [0.05, 0.80, 0.55]] # step     (x^6)
)

class SelfAttentionV2(nn.Module):
    # Initialize the object with the number of input and output dimensions
    # Initialize the bias to False
    def __init__(self, dimension_inputs, dimension_outputs, qkv_bias=False):
        super().__init__()
        # Initialize the matrices using Linear layers
        self.W_q = nn.Linear(dimension_inputs, dimension_outputs, bias=qkv_bias)
        self.W_k = nn.Linear(dimension_inputs, dimension_outputs, bias=qkv_bias)
        self.W_v = nn.Linear(dimension_inputs, dimension_outputs, bias=qkv_bias)

    # Method to calculate the context vector
    def forward(self, input_vectors):
        # Calculate the query, key, and value vectors by calling the Linear layers
        # Note that in this case we don't perform a direct matrix multiplication
        # but rather just call the Linear layer as a function passing in the
        # input vectors. The Linear layer performs the matrix multiplication
        queries = self.W_q(input_vectors)
        keys = self.W_k(input_vectors)
        values = self.W_v(input_vectors)

        # Calculate attention scores, that is, omega
        attention_scores = queries @ keys.T

        # Calculate attention weights
        dimension_keys = keys.shape[-1]
        attention_weights = torch.softmax(attention_scores / (dimension_keys ** 0.5), dim=-1)

        # Calculate and return the context vectors
        context_vectors = attention_weights @ values

        return context_vectors

> Let's use this class to obtain the attention weights.

In [5]:
# Testing the class
torch.manual_seed(123)

# Obtain queries and keys
oSelfAttentionV2 = SelfAttentionV2(3, 2)
queries = oSelfAttentionV2.W_q(inputs)
keys = oSelfAttentionV2.W_k(inputs)

# Calculate attention scores
attention_scores = queries @ keys.T

# Calculate attention weights
dimension_keys = keys.shape[-1]
attention_weights = torch.softmax(attention_scores / (dimension_keys ** 0.5), dim=-1)
attention_weights

tensor([[0.1717, 0.1762, 0.1761, 0.1555, 0.1627, 0.1579],
        [0.1636, 0.1749, 0.1746, 0.1612, 0.1605, 0.1652],
        [0.1637, 0.1749, 0.1746, 0.1611, 0.1606, 0.1651],
        [0.1636, 0.1704, 0.1702, 0.1652, 0.1632, 0.1674],
        [0.1667, 0.1722, 0.1721, 0.1618, 0.1633, 0.1639],
        [0.1624, 0.1709, 0.1706, 0.1654, 0.1625, 0.1682]],
       grad_fn=<SoftmaxBackward0>)

> Now, create a mask to hide future tokens

In [10]:
# Obtain the context length
context_length = inputs.shape[0]

# Create the mask using tril (triangular lower) function of PyTorch
# torch.ones creates a tensor of 1s, and the tril function zeroes out
# the upper elements
mask_simple = torch.tril(torch.ones(context_length, context_length))
mask_simple

tensor([[1., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1.]])

In [13]:
# Apply the simple mask by simply multiplying the two tensors
attention_weights_simple_masked = attention_weights * mask_simple
attention_weights_simple_masked

tensor([[0.1717, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1636, 0.1749, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1637, 0.1749, 0.1746, 0.0000, 0.0000, 0.0000],
        [0.1636, 0.1704, 0.1702, 0.1652, 0.0000, 0.0000],
        [0.1667, 0.1722, 0.1721, 0.1618, 0.1633, 0.0000],
        [0.1624, 0.1709, 0.1706, 0.1654, 0.1625, 0.1682]],
       grad_fn=<MulBackward0>)

> However, the results are not normalized as required, so we will normalize them

In [14]:
# Normalize the masked attention weights
attention_weights_normalized = attention_weights_simple_masked / attention_weights_simple_masked.sum(dim=-1, keepdim=True)
attention_weights_normalized

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.4833, 0.5167, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3190, 0.3408, 0.3402, 0.0000, 0.0000, 0.0000],
        [0.2445, 0.2545, 0.2542, 0.2468, 0.0000, 0.0000],
        [0.1994, 0.2060, 0.2058, 0.1935, 0.1953, 0.0000],
        [0.1624, 0.1709, 0.1706, 0.1654, 0.1625, 0.1682]],
       grad_fn=<DivBackward0>)

> Now, it happens that we have a data leakage problem since before all this masking, the softmax function was applied to the attention weights, and this caused future tokens to influence the tokens not masked out. In this case, one solves the problem by using an upper triangular infinity matrix. This way, the tokens of interest don't change their value, only those in the upper diagonal. Then, after applying the softmax those negative infinities are treated as zero not affecting the tokens of interest.

In [19]:
# Create the upper triangular (triu) mask with 1s in the upper part of it
mask = torch.triu(torch.ones(context_length, context_length), diagonal=1)
mask

tensor([[0., 1., 1., 1., 1., 1.],
        [0., 0., 1., 1., 1., 1.],
        [0., 0., 0., 1., 1., 1.],
        [0., 0., 0., 0., 1., 1.],
        [0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 0., 0.]])

In [21]:
# Mask out the attention scores using negative infinity
# mask.bool() returns True where there's no 0, then masked_fill replaces those
# values with -infinity
attention_scores_masked = attention_scores.masked_fill(mask.bool(), -torch.inf)
attention_scores_masked

tensor([[0.3111,   -inf,   -inf,   -inf,   -inf,   -inf],
        [0.1655, 0.2602,   -inf,   -inf,   -inf,   -inf],
        [0.1667, 0.2602, 0.2577,   -inf,   -inf,   -inf],
        [0.0510, 0.1080, 0.1064, 0.0643,   -inf,   -inf],
        [0.1415, 0.1875, 0.1863, 0.0987, 0.1121,   -inf],
        [0.0476, 0.1192, 0.1171, 0.0731, 0.0477, 0.0966]],
       grad_fn=<MaskedFillBackward0>)

In [23]:
# Print the original attention scores to compare that the numbers of interest
# were not affected
attention_scores

tensor([[0.3111, 0.3479, 0.3471, 0.1714, 0.2350, 0.1928],
        [0.1655, 0.2602, 0.2576, 0.1445, 0.1384, 0.1790],
        [0.1667, 0.2602, 0.2577, 0.1443, 0.1391, 0.1784],
        [0.0510, 0.1080, 0.1064, 0.0643, 0.0476, 0.0835],
        [0.1415, 0.1875, 0.1863, 0.0987, 0.1121, 0.1174],
        [0.0476, 0.1192, 0.1171, 0.0731, 0.0477, 0.0966]],
       grad_fn=<MmBackward0>)

In [27]:
# Applyt the softmaxk to the masked attention weights
attention_weights_masked = torch.softmax(attention_scores_masked/(dimension_keys ** 0.5), dim=-1)
attention_weights_masked
#

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.4833, 0.5167, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3190, 0.3408, 0.3402, 0.0000, 0.0000, 0.0000],
        [0.2445, 0.2545, 0.2542, 0.2468, 0.0000, 0.0000],
        [0.1994, 0.2060, 0.2058, 0.1935, 0.1953, 0.0000],
        [0.1624, 0.1709, 0.1706, 0.1654, 0.1625, 0.1682]],
       grad_fn=<SoftmaxBackward0>)

In [28]:
# Checking that all rows sum up to 1
attention_weights_masked.sum(dim=-1, keepdim=True)

tensor([[1.0000],
        [1.0000],
        [1.0000],
        [1.0000],
        [1.0000],
        [1.0000]], grad_fn=<SumBackward1>)

# Dropout

> We will use this technique from Deep Learning to mask additional attention weights in order to prevent overfitting and improve generalization. In transformers architectures, the dropout can be applied after calculating attention weights and after applying the attention weights to value vectors. In this case, we will apply it after after calculating attention scores.

In [35]:
torch.manual_seed(42)
# Create a dropout layer with a dropout rate of 50%, that is, to mak out
# 50% of the attention weights
dropout = torch.nn.Dropout(0.5)

# Create an example for demonstration purposes
example = torch.ones(6, 6)

# Apply dropout
dropout(example)

tensor([[0., 0., 2., 2., 2., 2.],
        [2., 0., 2., 0., 2., 0.],
        [0., 0., 2., 2., 2., 0.],
        [2., 2., 0., 2., 0., 2.],
        [2., 0., 2., 2., 2., 2.],
        [2., 2., 2., 0., 2., 0.]])

> Note that in order to account for the dropout, the values remaining that are not zeroed-out are 1s, but since 50% were zeroed-out, those 1s are divided by 0.5, producing the 2s we see in the resulting tensor. Now, let us apply it to the attention weights.

In [43]:
# This seed doesn't zero-out the first element of the first row, which is 1
# And this shows that the resulting value is scaled and results in a 2 after
# applying the dropout
torch.manual_seed(123)
print(attention_weights_masked)
print(dropout(attention_weights_masked))

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.4833, 0.5167, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3190, 0.3408, 0.3402, 0.0000, 0.0000, 0.0000],
        [0.2445, 0.2545, 0.2542, 0.2468, 0.0000, 0.0000],
        [0.1994, 0.2060, 0.2058, 0.1935, 0.1953, 0.0000],
        [0.1624, 0.1709, 0.1706, 0.1654, 0.1625, 0.1682]],
       grad_fn=<SoftmaxBackward0>)
tensor([[2.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 1.0335, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.6804, 0.0000, 0.0000, 0.0000],
        [0.4889, 0.5090, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3988, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.3418, 0.0000, 0.0000, 0.0000, 0.0000]],
       grad_fn=<MulBackward0>)
