<a href="https://colab.research.google.com/github/Redcoder815/Deep_Learning_PyTorch/blob/main/32AttentionScoringFunctions.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import math
import torch
from torch import nn

[None, :]: This adds a new dimension at the beginning of the tensor, turning it into a row vector. So, [0, 1, 2, 3, 4] becomes [[0, 1, 2, 3, 4]]. This is crucial for broadcasting later.

valid_len[:, None]: This takes your valid_len tensor (which contains the actual lengths of each sequence) and adds a new dimension at the end, turning it into a column vector. For example, if valid_len is torch.tensor([3, 2]) (meaning the first sequence has 3 valid items and the second has 2), this part would become [[3], [2]].

<: This performs an element-wise comparison between the two tensors created in steps 2 and 3. Due to broadcasting rules in PyTorch, the row vector [[0, 1, 2, 3, 4]] and the column vector [[3], [2]] are expanded to match dimensions before the comparison.

Let's walk through an example:

Assume:

maxlen = 5
valid_len = torch.tensor([3, 2]) (meaning you have two sequences; the first has 3 valid elements, the second has 2).
Step-by-step:

The first part torch.arange(...) creates [0, 1, 2, 3, 4]. After [None, :], it becomes index_tensor = [[0, 1, 2, 3, 4]].
The valid_len[:, None] part creates lengths_tensor = [[3], [2]].
Now, we compare index_tensor < lengths_tensor:

[[0, 1, 2, 3, 4]]
<
[[3],
 [2]]
PyTorch broadcasts these. Conceptually, it expands them to:

index_tensor becomes:
[[0, 1, 2, 3, 4],
 [0, 1, 2, 3, 4]]
lengths_tensor becomes:
[[3, 3, 3, 3, 3],
 [2, 2, 2, 2, 2]]
Performing the element-wise comparison (<) then yields:

mask = [[True, True, True, False, False],   (0<3, 1<3, 2<3, 3<3 is False, 4<3 is False)
        [True, True, False, False, False]]  (0<2, 1<2, 2<2 is False, 3<2 is False, 4<2 is False)
This mask tensor tells you exactly which positions in each sequence (row) are valid (True) and which are padding/invalid (False).

------

This line is used to set specific elements within a tensor X to a particular value, based on a boolean mask. The key part here is ~mask.

mask: As we saw, the mask tensor contains True for valid elements and False for invalid (padded) elements. Continuing with our previous example, mask was:

mask = [[True, True, True, False, False],
        [True, True, False, False, False]]
~mask: The tilde ~ operator inverts the boolean values in the mask tensor. So, True becomes False and False becomes True.

Applying ~ to our mask example:

~mask = [[False, False, False, True, True],
         [False, False, True, True, True]]
Now, True indicates the positions that were originally invalid (padded) in our sequence.

X[~mask] = value: This operation uses the inverted mask (~mask) to select elements within the tensor X and then assigns value to only those selected elements. In other words, it sets all the invalid elements (where ~mask is True) in X to the specified value.

Example:

Let's assume you have a tensor X with some values, for instance:

X = [[10, 11, 12, 13, 14],
     [20, 21, 22, 23, 24]]
And let's say value = -1e6 (a very small negative number often used to make masked elements become 0 after a softmax operation).

Using our ~mask:

~mask = [[False, False, False, True, True],
         [False, False, True, True, True]]
The operation X[~mask] = value will look at all the positions where ~mask is True and replace the corresponding elements in X with -1e6.

So, X would become:

X = [[10, 11, 12, -1e6, -1e6],
     [20, 21, -1e6, -1e6, -1e6]]
This effectively 'masks out' the invalid elements by setting them to a specific value, usually one that will have no impact (or a desired specific impact, like becoming 0) in subsequent calculations (like a softmax function).

----------

Now, let's apply valid_lens = torch.repeat_interleave(valid_lens, shape[1]):

input is valid_lens = torch.tensor([5, 4])
repeats is shape[1] = 3
The operation will take the first element of valid_lens (which is 5) and repeat it 3 times: [5, 5, 5]. Then, it takes the second element of valid_lens (which is 4) and repeats it 3 times: [4, 4, 4].

Combining these, the new valid_lens tensor will be:

torch.tensor([5, 5, 5, 4, 4, 4])
So, from an initial valid_lens of [5, 4] for a batch of 2, if each batch item contains 3 sequences (as implied by shape[1]=3), the valid_lens is expanded to [5, 5, 5, 4, 4, 4], now representing the valid lengths for each of the 2 * 3 = 6 sequences after X is potentially reshaped. This prepares the valid_lens to correspond element-wise to the flattened sequences that will be passed to _sequence_mask.

Masked Softmax Operation

In [2]:
def masked_softmax(X, valid_lens):
    """Perform softmax operation by masking elements on the last axis."""
    # X: 3D tensor, valid_lens: 1D or 2D tensor
    def _sequence_mask(X, valid_len, value=0):
        maxlen = X.size(1)
        mask = torch.arange((maxlen), dtype=torch.float32,
                            device=X.device)[None, :] < valid_len[:, None]
        X[~mask] = value
        return X

    if valid_lens is None:
        return nn.functional.softmax(X, dim=-1)
    else:
        shape = X.shape
        if valid_lens.dim() == 1:
            valid_lens = torch.repeat_interleave(valid_lens, shape[1])
        else:
            valid_lens = valid_lens.reshape(-1)
        # On the last axis, replace masked elements with a very large negative
        # value, whose exponentiation outputs 0
        X = _sequence_mask(X.reshape(-1, shape[-1]), valid_lens, value=-1e6)
        return nn.functional.softmax(X.reshape(shape), dim=-1)

In [3]:
masked_softmax(torch.rand(2, 2, 4), torch.tensor([2, 3]))

tensor([[[0.4416, 0.5584, 0.0000, 0.0000],
         [0.3445, 0.6555, 0.0000, 0.0000]],

        [[0.2835, 0.4354, 0.2811, 0.0000],
         [0.3568, 0.3621, 0.2811, 0.0000]]])

In [4]:
masked_softmax(torch.rand(2, 2, 4), torch.tensor([[1, 3], [2, 4]]))

tensor([[[1.0000, 0.0000, 0.0000, 0.0000],
         [0.2025, 0.3446, 0.4530, 0.0000]],

        [[0.5520, 0.4480, 0.0000, 0.0000],
         [0.2307, 0.2830, 0.2738, 0.2126]]])

The code executed successfully! The check_shape function confirmed that the shape of the output from torch.bmm(Q, K) is (2, 3, 6), which matches the expected shape. This means the batch matrix multiplication worked as intended. torch.bmm performs a batched matrix multiplication, where Q has shape (batch, n, m) and K has shape (batch, m, p), resulting in an output of shape (batch, n, p). In this case, Q is (2, 3, 4) and K is (2, 4, 6), so the result is (2, 3, 6).



Batch Matrix Multiplication

In [5]:
def check_shape(tensor, expected_shape):
    assert tensor.shape == expected_shape, f"Expected shape {expected_shape}, but got {tensor.shape}"
    print(f"Shape check passed: {tensor.shape} == {expected_shape}")

Q = torch.ones((2, 3, 4))
K = torch.ones((2, 4, 6))
check_shape(torch.bmm(Q, K), (2, 3, 6))

Shape check passed: torch.Size([2, 3, 6]) == (2, 3, 6)


Let's explain scores = torch.bmm(queries, keys.transpose(1, 2)) / math.sqrt(d) with an example.

This line is a core component of the scaled dot-product attention mechanism, which is fundamental in many modern neural networks, especially transformers. It calculates a raw score for how much each query 'attends' to each key, and then scales these scores.

Here's a breakdown of each part:

queries: This is a tensor representing the queries. Its shape is typically (batch_size, num_queries, d), where d is the feature dimension.

keys: This is a tensor representing the keys. Its shape is typically (batch_size, num_key_value_pairs, d).

keys.transpose(1, 2): This operation swaps the second and third dimensions of the keys tensor. If keys is (batch_size, num_key_value_pairs, d), then keys.transpose(1, 2) becomes (batch_size, d, num_key_value_pairs). This is done to prepare the keys for matrix multiplication with the queries.

torch.bmm(queries, keys.transpose(1, 2)): torch.bmm stands for "batch matrix-matrix product". It performs matrix multiplication for each batch independently. If:

queries has shape (batch, N, D) (where N is num_queries)
keys.transpose(1, 2) has shape (batch, D, M) (where M is num_key_value_pairs)
The result, scores (before scaling), will have shape (batch, N, M). Each element scores[b, i, j] represents the dot product between the i-th query in batch b and the j-th key in batch b.
d = queries.shape[-1]: This gets the last dimension of the queries tensor, which is the feature dimension d.



Scaled Dot Product Attention

In [6]:
class DotProductAttention(nn.Module):
    """Scaled dot product attention."""
    def __init__(self, dropout):
        super().__init__()
        self.dropout = nn.Dropout(dropout)

    # Shape of queries: (batch_size, no. of queries, d)
    # Shape of keys: (batch_size, no. of key-value pairs, d)
    # Shape of values: (batch_size, no. of key-value pairs, value dimension)
    # Shape of valid_lens: (batch_size,) or (batch_size, no. of queries)
    def forward(self, queries, keys, values, valid_lens=None):
        d = queries.shape[-1]
        # Swap the last two dimensions of keys with keys.transpose(1, 2)
        scores = torch.bmm(queries, keys.transpose(1, 2)) / math.sqrt(d)
        self.attention_weights = masked_softmax(scores, valid_lens)
        return torch.bmm(self.dropout(self.attention_weights), values)

In [7]:
queries = torch.normal(0, 1, (2, 1, 2))
keys = torch.normal(0, 1, (2, 10, 2))
values = torch.normal(0, 1, (2, 10, 4))
valid_lens = torch.tensor([2, 6])

attention = DotProductAttention(dropout=0.5)
attention.eval()
check_shape(attention(queries, keys, values, valid_lens), (2, 1, 4))

Shape check passed: torch.Size([2, 1, 4]) == (2, 1, 4)


Additive Attention

Let's explain features = queries.unsqueeze(2) + keys.unsqueeze(1) with an example.

This line is performing an element-wise addition between two tensors (queries and keys) after expanding their dimensions. This dimension expansion is crucial for enabling a powerful feature in PyTorch called broadcasting, which allows tensors of different shapes to be combined under certain conditions.

Here's a breakdown:

queries.unsqueeze(2):

queries typically has a shape like (batch_size, num_queries, d). Let's assume (2, 1, 20) from your earlier example (2 batches, 1 query per batch, 20 features per query).
The .unsqueeze(2) method adds a new dimension of size 1 at the specified index (index 2, which is the third dimension, as Python uses 0-based indexing).
So, if queries was (batch_size, num_queries, d), after unsqueeze(2) it becomes (batch_size, num_queries, 1, d).
Using our example: (2, 1, 20) becomes (2, 1, 1, 20).
keys.unsqueeze(1):

keys typically has a shape like (batch_size, num_key_value_pairs, d). Let's assume (2, 10, 20) (2 batches, 10 key-value pairs per batch, 20 features per key).
The .unsqueeze(1) method adds a new dimension of size 1 at index 1 (the second dimension).
So, if keys was (batch_size, num_key_value_pairs, d), after unsqueeze(1) it becomes (batch_size, 1, num_key_value_pairs, d).
Using our example: (2, 10, 20) becomes (2, 1, 10, 20).
+ (Broadcasting):

Now you have two tensors with shapes: (2, 1, 1, 20) and (2, 1, 10, 20). (Note: In the AdditiveAttention class, the d for queries and keys are mapped to num_hiddens by W_q and W_k, so let's adjust our example's d to num_hiddens=8 as in the code).
Let's refine the shapes based on the AdditiveAttention context:

queries originally (batch_size, no. of queries, input_dim) -> after W_q(queries): (batch_size, no. of queries, num_hiddens). E.g., (2, 1, 8).

queries.unsqueeze(2) becomes (batch_size, no. of queries, 1, num_hiddens). E.g., (2, 1, 1, 8).

keys originally (batch_size, no. of key-value pairs, input_dim) -> after W_k(keys): (batch_size, no. of key-value pairs, num_hiddens). E.g., (2, 10, 8).

keys.unsqueeze(1) becomes (batch_size, 1, no. of key-value pairs, num_hiddens). E.g., (2, 1, 10, 8).

Now, when you add queries.unsqueeze(2) ((2, 1, 1, 8)) and keys.unsqueeze(1) ((2, 1, 10, 8)), PyTorch's broadcasting rules apply:

It aligns the dimensions from right to left.
If dimensions are equal, they match.
If one dimension is 1, it's stretched (broadcast) to match the other.
If dimensions are unequal and neither is 1, it's an error.
In our example:

Dimension 3 (features): 8 and 8 - Match.
Dimension 2 (keys/queries placeholder): 1 and 10 - The 1 is broadcast to 10.
Dimension 1 (queries/keys placeholder): 1 and 1 - Match (or 1 is broadcast to 1).
Dimension 0 (batch): 2 and 2 - Match.
The resulting features tensor will have the shape of the broadcasted dimensions, which is (batch_size, num_queries, num_key_value_pairs, num_hiddens). In our example, (2, 1, 10, 8).

What this achieves:

This operation effectively creates a tensor where each query's representation is combined element-wise with every key's representation. For instance, features[b, i, j, :] would contain the element-wise sum of the i-th query's features and the j-th key's features, for batch b.

This is a common pattern in attention mechanisms, particularly additive attention, to prepare a tensor where every possible query-key interaction is explicitly represented, which then gets passed through a non-linear activation (like torch.tanh) before being projected to a single score by self.w_v.

----------

if you have Q as (Q1) and K as (K1, K2, K3), the operation effectively computes (Q1+K1, Q1+K2, Q1+K3). If Q was (Q1, Q2) it would be ((Q1+K1, Q1+K2, Q1+K3), (Q2+K1, Q2+K2, Q2+K3)) and so on. This creates all possible interaction pairs, which is a key step in additive attention.

------------

Let's explain squeeze(-1) with a clear example.

In PyTorch (and similar libraries), the .squeeze() method is used to remove dimensions of size 1 from a tensor. When you specify an argument like -1, you're telling it to only remove dimensions of size 1 at that specific position.

squeeze(): Without any arguments, squeeze() will remove all dimensions that have a size of 1. For example, a tensor with shape (1, 3, 1, 4, 1) would become (3, 4).

squeeze(-1): When you pass -1 as an argument, it means "remove the last dimension if its size is 1". If the last dimension is not 1, it does nothing.

Example:

Imagine you have a tensor named my_tensor with the following shape:

my_tensor.shape = (2, 1, 3, 1, 5, 1)

Let's see what happens with squeeze(-1):

my_tensor.squeeze(-1): This checks the last dimension. Its size is 1. So, this dimension is removed.
The new shape becomes: (2, 1, 3, 1, 5)
Now, let's take the result from that operation and apply squeeze(-1) again:

my_tensor_after_first_squeeze.squeeze(-1): This checks the last dimension, which now has a size of 5. Since it's not 1, this squeeze(-1) operation does nothing.
The shape remains: (2, 1, 3, 1, 5)
Why is squeeze(-1) often used in attention mechanisms?

In the AdditiveAttention example we just discussed, the self.w_v linear layer was defined as nn.LazyLinear(1, bias=False). This means it transforms the input features into a single output value. So, if the input to self.w_v was (batch_size, num_queries, num_key_value_pairs, num_hiddens), the output would be (batch_size, num_queries, num_key_value_pairs, 1).

This final 1 in the dimension is often redundant. It's just a single number, so having it wrapped in an extra dimension of size 1 doesn't add much meaning and can sometimes complicate subsequent operations. squeeze(-1) efficiently removes this unnecessary dimension, making the scores tensor have a cleaner shape like (batch_size, num_queries, num_key_value_pairs), which is more intuitive for representing the attention scores between each query and each key-value pair.

In [8]:
class AdditiveAttention(nn.Module):
    """Additive attention."""
    def __init__(self, num_hiddens, dropout, **kwargs):
        super(AdditiveAttention, self).__init__(**kwargs)
        self.W_k = nn.LazyLinear(num_hiddens, bias=False)
        self.W_q = nn.LazyLinear(num_hiddens, bias=False)
        self.w_v = nn.LazyLinear(1, bias=False)
        self.dropout = nn.Dropout(dropout)

    def forward(self, queries, keys, values, valid_lens):
        queries, keys = self.W_q(queries), self.W_k(keys)
        # After dimension expansion, shape of queries: (batch_size, no. of
        # queries, 1, num_hiddens) and shape of keys: (batch_size, 1, no. of
        # key-value pairs, num_hiddens). Sum them up with broadcasting
        features = queries.unsqueeze(2) + keys.unsqueeze(1)
        features = torch.tanh(features)
        # There is only one output of self.w_v, so we remove the last
        # one-dimensional entry from the shape. Shape of scores: (batch_size,
        # no. of queries, no. of key-value pairs)
        scores = self.w_v(features).squeeze(-1)
        self.attention_weights = masked_softmax(scores, valid_lens)
        # Shape of values: (batch_size, no. of key-value pairs, value
        # dimension)
        return torch.bmm(self.dropout(self.attention_weights), values)

In [9]:
queries = torch.normal(0, 1, (2, 1, 20))

attention = AdditiveAttention(num_hiddens=8, dropout=0.1)
attention.eval()
check_shape(attention(queries, keys, values, valid_lens), (2, 1, 4))

Shape check passed: torch.Size([2, 1, 4]) == (2, 1, 4)


In [10]:
import torch

# Example for queries after W_q and unsqueeze(2)
# Shape: (batch_size, num_queries, 1, num_hiddens)
# Let's simplify with batch_size=1, num_queries=1, num_hiddens=2
queries_expanded = torch.tensor([[[[1.0, 2.0]]]]) # Shape (1, 1, 1, 2)
print(f"Shape of queries_expanded: {queries_expanded.shape}\n{queries_expanded}")

# Example for keys after W_k and unsqueeze(1)
# Shape: (batch_size, 1, num_key_value_pairs, num_hiddens)
# Let's simplify with batch_size=1, num_key_value_pairs=3, num_hiddens=2
keys_expanded = torch.tensor([[[ # (batch_size, 1, num_key_value_pairs, num_hiddens)
    [10.0, 20.0],
    [30.0, 40.0],
    [50.0, 60.0]
]]]) # Corrected Shape (1, 1, 3, 2)
print(f"Shape of keys_expanded: {keys_expanded.shape}\n{keys_expanded}")

# Perform element-wise addition with broadcasting
features = queries_expanded + keys_expanded

print(f"\nShape of result (features): {features.shape}\n{features}")

# Let's verify the broadcasting for one example (Q1 + K1, Q1 + K2, Q1 + K3)
# Conceptually, Q1 = [1.0, 2.0]
# K1 = [10.0, 20.0]
# K2 = [30.0, 40.0]
# K3 = [50.0, 60.0]

# Expected result for the first query:
# Q1 + K1 = [1.0+10.0, 2.0+20.0] = [11.0, 22.0]
# Q1 + K2 = [1.0+30.0, 2.0+40.0] = [31.0, 42.0]
# Q1 + K3 = [1.0+50.0, 2.0+60.0] = [51.0, 62.0]

Shape of queries_expanded: torch.Size([1, 1, 1, 2])
tensor([[[[1., 2.]]]])
Shape of keys_expanded: torch.Size([1, 1, 3, 2])
tensor([[[[10., 20.],
          [30., 40.],
          [50., 60.]]]])

Shape of result (features): torch.Size([1, 1, 3, 2])
tensor([[[[11., 22.],
          [31., 42.],
          [51., 62.]]]])
