## Attention mechanisms
Now we move onto self attention.

In self attention, our goal is to calculate context vectors for each element in the input sequence. A context vector can be interpreted as an enriched embedding vector.

There are 4 types of attention mechanisms we'll be looking at:
- simplified self-attention (simple self attention technique)
- self-attention (self attention w/ trainable weights)
- casual attention (self-attention used in LLMs that allow model to consider previous and current inputs)
- multi-head attention (extension of self and casual attention that allows model to attend info from different representation subspaces)

#### Simple self attention technique

- There is an input sequence in the form of a sentence 
- Sentence is transformed into token embeddings
- We will calculate the context vector for each of element in the input sequence using their token embeddings

In [1]:
import torch

In [2]:
# Let's create our embedding vectors for the sentence: Your journey starts with one step

inputs = torch.tensor([
    [0.43, 0.15, 0.89], # Your (x^1)
    [0.55, 0.87, 0.66], # journey (x^2)
    [0.57, 0.85, 0.64], # starts (x^3)
    [0.22, 0.58, 0.33], # with (x^4)
    [0.77, 0.25, 0.10], # one (x^5)
    [0.05, 0.80, 0.55]  # step (x^6)
    ])

In [3]:
query = inputs[1]
attention_scores = torch.empty(inputs.shape[0])
for idx, elem in enumerate(inputs):
    attention_scores[idx] = torch.dot(elem, query)
print(attention_scores)

tensor([0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865])


**Normalizing attention scores**

Now that we have the attention scores, we need to calculate the attention weights by normalizing the attention scores

Why do we normalize? So we obtain attention weights that add up to 1, which maintains stability in LLM training

We can do this with:
```
temp_weights = attention_scores / attention_scores.sum()
print("Attention weights: ", temp_weights)
print("Sum: ", temp_weights.sum())  # tensor(1.0000)
```
But we will use the softmax approach, which is better at handling extreme values
```
def softmax(x):
    return torch.exp(x) / torch.exp(x).sum()
```

In [4]:
# Use PyTorch's softmax function

attention_weights = torch.softmax(attention_scores, dim=0)
print("Attention weights:", attention_weights)
print("Sum:", attention_weights.sum())


Attention weights: tensor([0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581])
Sum: tensor(1.)


In [5]:
# Now let's calculate the actual context vector
query = inputs[1]
context_vector = torch.zeros(query.shape)
for idx, elem in enumerate(inputs):
    context_vector += attention_weights[idx] * elem
print("Context vector:", context_vector)

Context vector: tensor([0.4419, 0.6515, 0.5683])


Let's calculate attention weights and context vectors for all inputs
```
attention_scores = torch.empty(inputs.shape[0], inputs.shape[0])
for i, elem in enumerate(inputs):
    for j, elem2 in enumerate(inputs):
        attention_scores[i,j] = torch.dot(elem, elem2)
```
But for-loops are slow, so let's use matrix multiplication

In [6]:
attention_scores = inputs @ inputs.T
# Now let's normalize the attention scores
attention_weights = torch.softmax(attention_scores, dim=1)
print("Attention weights:\n", attention_weights)

Attention weights:
 tensor([[0.2098, 0.2006, 0.1981, 0.1242, 0.1220, 0.1452],
        [0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581],
        [0.1390, 0.2369, 0.2326, 0.1242, 0.1108, 0.1565],
        [0.1435, 0.2074, 0.2046, 0.1462, 0.1263, 0.1720],
        [0.1526, 0.1958, 0.1975, 0.1367, 0.1879, 0.1295],
        [0.1385, 0.2184, 0.2128, 0.1420, 0.0988, 0.1896]])


In [7]:
# Now we compute context vector via matrix multiplication
all_context_vectors = attention_weights @ inputs
print("Context vector:\n", all_context_vectors)

Context vector:
 tensor([[0.4421, 0.5931, 0.5790],
        [0.4419, 0.6515, 0.5683],
        [0.4431, 0.6496, 0.5671],
        [0.4304, 0.6298, 0.5510],
        [0.4671, 0.5910, 0.5266],
        [0.4177, 0.6503, 0.5645]])


#### Self attention w/ trainable weights

- Also called scaled dot-production attention
- Main difference is that the weights are updated during model training
- First we will do it step by step and then organize code into class

In [8]:
# We compute query, key and value vectors for each input element
x = inputs[1]
d_in = inputs.shape[1]
d_out = 2   # (usually d_in = d_out but we use different values here)

In [9]:
# Initialize the weight matrices
torch.manual_seed(123)
W_query = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
W_key = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
W_value = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)

# Compute query, key and value vectors
query2 = x @ W_query
key2 = x @ W_key
value2 = x @ W_value
print("Query vector:", query2)

Query vector: tensor([0.4306, 1.4551])


In [10]:
# Now let's do it for all inputs
keys = inputs @ W_key
values = inputs @ W_value
print("keys.shape:", keys.shape)
print("values.shape:", values.shape)

keys.shape: torch.Size([6, 2])
values.shape: torch.Size([6, 2])


**Weight parameters vs attention weights**

*Attention weights*: determine the extent to which context vector depends on different parts of the input (dynamic, context specific)

*Weight parameters*: learned coeffecients that define the network's connections

In [11]:
# Let's compute the attention scores
keys2 = keys[1]
attention_score_22 = query2.dot(keys2)
print(attention_score_22)

tensor(1.8524)


In [12]:
# All attention scores

attention_scores2 = query2 @ keys.T
print("Attention scores:\n", attention_scores2)

Attention scores:
 tensor([1.2705, 1.8524, 1.8111, 1.0795, 0.5577, 1.5440])


In [13]:
# Attention scores to attention weights by normalizing w/ softmax
# The differenc is we are going to scale the attention scores by the square root of the dimension of the key vectors
d_k = keys.shape[-1]
attention_weights2 = torch.softmax(attention_scores2 / (d_k ** 0.5), dim=-1)
print(attention_weights2)

tensor([0.1500, 0.2264, 0.2199, 0.1311, 0.0906, 0.1820])


We normalize by dividing by the embedding dimension to improve training performance by avoiding small gradients. As dot products increase, softmax function results in gradients nearing zero, which slows down learning

In [14]:
# Computing context vectors
context_vector2 = attention_weights2 @ values
print("Context vector:\n", context_vector2)

Context vector:
 tensor([0.3061, 0.8210])


Let's organize all of these steps into one class

In [15]:
class selfAttentionV1(torch.nn.Module):
    def __init__(self, d_in , d_out):
        super().__init__()
        self.d_out = d_out
        self.W_query = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
        self.W_key = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
        self.W_value = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)

    def forward(self, x):
        keys = x @ self.W_key
        values = x @ self.W_value
        queries = x @ self.W_query
        attention_scores = queries @ keys.T
        attention_weights = torch.softmax(attention_scores / (self.d_out ** 0.5), dim=-1)
        context_vector = attention_weights @ values
        return context_vector

In [16]:
torch.manual_seed(123)
sa_v1 = selfAttentionV1(d_in, d_out)
print(sa_v1(inputs))

tensor([[0.2996, 0.8053],
        [0.3061, 0.8210],
        [0.3058, 0.8203],
        [0.2948, 0.7939],
        [0.2927, 0.7891],
        [0.2990, 0.8040]])


We can use PyTorch's Linear layers which perform matrix multiplication more effectively and has an optimized weight initialization scheme

In [17]:
class selfAttentionV2(torch.nn.Module):
    def __init__(self, d_in, d_out):
        super().__init__()
        self.d_out = d_out
        self.W_query = torch.nn.Linear(d_in, d_out, bias=False)
        self.W_key = torch.nn.Linear(d_in, d_out, bias=False)
        self.W_value = torch.nn.Linear(d_in, d_out, bias=False)
    
    def forward(self, x):
        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)
        attn_scores = queries @ keys.T
        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
        context_vec = attn_weights @ values
        return context_vec

In [18]:
torch.manual_seed(789)
sa_v2 = selfAttentionV2(d_in, d_out)
print(sa_v2(inputs))

tensor([[-0.0739,  0.0713],
        [-0.0748,  0.0703],
        [-0.0749,  0.0702],
        [-0.0760,  0.0685],
        [-0.0763,  0.0679],
        [-0.0754,  0.0693]], grad_fn=<MmBackward0>)


#### Casual Attention

Casual attention (or *masked attention*) restricts the model to only consider current and previous inputs when processing a given token. To do this, we apply a zero mask for all entries above the diagonal in the input token attention weight matrix  

In [19]:
queries = sa_v2.W_query(inputs)
keys = sa_v2.W_key(inputs)

attn_scores = queries @ keys.T
attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
print(attn_weights)

tensor([[0.1921, 0.1646, 0.1652, 0.1550, 0.1721, 0.1510],
        [0.2041, 0.1659, 0.1662, 0.1496, 0.1665, 0.1477],
        [0.2036, 0.1659, 0.1662, 0.1498, 0.1664, 0.1480],
        [0.1869, 0.1667, 0.1668, 0.1571, 0.1661, 0.1564],
        [0.1830, 0.1669, 0.1670, 0.1588, 0.1658, 0.1585],
        [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],
       grad_fn=<SoftmaxBackward0>)


In [20]:
context_length = attn_scores.shape[0]
simple_mask = torch.tril(torch.ones(context_length, context_length))
print(simple_mask)

tensor([[1., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1.]])


In [21]:
masked_weights = attn_weights * simple_mask
print(masked_weights)

tensor([[0.1921, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2041, 0.1659, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2036, 0.1659, 0.1662, 0.0000, 0.0000, 0.0000],
        [0.1869, 0.1667, 0.1668, 0.1571, 0.0000, 0.0000],
        [0.1830, 0.1669, 0.1670, 0.1588, 0.1658, 0.0000],
        [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],
       grad_fn=<MulBackward0>)


In [22]:
row_sums = masked_weights.sum(dim=1, keepdim=True)
masked_weights_norm = masked_weights / row_sums
print(masked_weights_norm)

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5517, 0.4483, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3800, 0.3097, 0.3103, 0.0000, 0.0000, 0.0000],
        [0.2758, 0.2460, 0.2462, 0.2319, 0.0000, 0.0000],
        [0.2175, 0.1983, 0.1984, 0.1888, 0.1971, 0.0000],
        [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],
       grad_fn=<DivBackward0>)


In [23]:
# A more accurate way to mask the attention weights is to apply a mask with negative infinity
upper_ones = torch.triu(torch.ones(context_length, context_length), diagonal=1)
attn_scores_masked = attn_scores.masked_fill(upper_ones.bool(), -torch.inf)
attn_weights = torch.softmax(attn_scores_masked / keys.shape[-1]**0.5, dim=-1)
print(attn_weights)

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5517, 0.4483, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3800, 0.3097, 0.3103, 0.0000, 0.0000, 0.0000],
        [0.2758, 0.2460, 0.2462, 0.2319, 0.0000, 0.0000],
        [0.2175, 0.1983, 0.1984, 0.1888, 0.1971, 0.0000],
        [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],
       grad_fn=<SoftmaxBackward0>)


**Dropout**: a technique where random hidden layer units are ignored during training, thereby preventing overfitting. This is only done during training.

During training it is done either after calculating the attention scores or after applying the attention weights to the value vectors

In [24]:
torch.manual_seed(123)
dropout = torch.nn.Dropout()
print(dropout(attn_weights))

tensor([[2.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.8966, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.6206, 0.0000, 0.0000, 0.0000],
        [0.5517, 0.4921, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.4350, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.3327, 0.0000, 0.0000, 0.0000, 0.0000]],
       grad_fn=<MulBackward0>)


Since the probability of dropping an element is 50%, the values of the remaining elements are scaled up by 1/0.5 = 2 to ensure average influence of attention mechanism remains consistent.

Now let's implement it in a class

In [25]:
# First ensure the code can handle batched inputs
batch = torch.stack((inputs, inputs), dim=0)    # Literally just duplicating the inputs tensor
print(batch)

tensor([[[0.4300, 0.1500, 0.8900],
         [0.5500, 0.8700, 0.6600],
         [0.5700, 0.8500, 0.6400],
         [0.2200, 0.5800, 0.3300],
         [0.7700, 0.2500, 0.1000],
         [0.0500, 0.8000, 0.5500]],

        [[0.4300, 0.1500, 0.8900],
         [0.5500, 0.8700, 0.6600],
         [0.5700, 0.8500, 0.6400],
         [0.2200, 0.5800, 0.3300],
         [0.7700, 0.2500, 0.1000],
         [0.0500, 0.8000, 0.5500]]])


In [26]:
class casualAttention(torch.nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, qkv_bias=False):
        super().__init__()
        self.d_out = d_out
        self.W_query = torch.nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = torch.nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = torch.nn.Linear(d_in, d_out, bias=qkv_bias)
        self.dropout = torch.nn.Dropout(dropout)
        self.register_buffer(
            "mask",
            torch.triu(torch.ones(context_length, context_length),
            diagonal=1)
        )
        # Why use register buffer? Because buffers are automatically moved to the 
        # appropriate device (CPU or GPU) along with our model. So we don't need to
        # ensure that the mask tensor is on the same device as the model's parameters.

    def forward(self, x):
        batch_dim, num_tokens, d_in = x.shape
        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)

        attn_scores = queries @ keys.transpose(1, 2)
        attn_scores.masked_fill_(
            self.mask.bool()[:num_tokens, :num_tokens],
            -torch.inf
        )
        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
        attn_weights = self.dropout(attn_weights)

        context_vec = attn_weights @ values
        return context_vec

In [27]:
torch.manual_seed(123)
context_length = batch.shape[1]
ca = casualAttention(d_in, d_out, context_length, 0.0)
context_vecs = ca(batch)


#### Multi-head attention

Muliple attention "heads" are being ran independently, allowing more than one set of attention weights to process the input.

First, we'll build the multi-head module by stacking multiple CasualAttention modules and then implement it in a harder but more efficient manner

In [45]:
class MultiHeadAttentionWrapper(torch.nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
        super().__init__()
        self.heads = torch.nn.ModuleList([casualAttention(d_in, d_out, context_length, dropout, qkv_bias) for _ in range(num_heads)])
        self.out_proj = torch.nn.Linear(d_out*num_heads, d_out*num_heads)


    def forward(self, x):
        return torch.cat([head(x) for head in self.heads], dim=-1)

In [46]:
torch.manual_seed(123)
context_length = batch.shape[1]
d_in = 3
d_out = 2
num_heads = 2

multi_head = MultiHeadAttentionWrapper(d_in, d_out, context_length, 0.0, num_heads)
context_vecs = multi_head(batch)
print("context_vecs.shape:", context_vecs.shape)
print(context_vecs)

context_vecs.shape: torch.Size([2, 6, 4])
tensor([[[-0.4519,  0.2216,  0.4772,  0.1063],
         [-0.5874,  0.0058,  0.5891,  0.3257],
         [-0.6300, -0.0632,  0.6202,  0.3860],
         [-0.5675, -0.0843,  0.5478,  0.3589],
         [-0.5526, -0.0981,  0.5321,  0.3428],
         [-0.5299, -0.1081,  0.5077,  0.3493]],

        [[-0.4519,  0.2216,  0.4772,  0.1063],
         [-0.5874,  0.0058,  0.5891,  0.3257],
         [-0.6300, -0.0632,  0.6202,  0.3860],
         [-0.5675, -0.0843,  0.5478,  0.3589],
         [-0.5526, -0.0981,  0.5321,  0.3428],
         [-0.5299, -0.1081,  0.5077,  0.3493]]], grad_fn=<CatBackward0>)


**Dims explained**:

1st dim: 2 since we have 2 duplicated input text corpuses

2nd dim: 6 tokens in each input

3rd dim: 4-dim embedding of each token

Note: because of `head(x) for head in self.heads` the single-head attention modules are processed sequentially. We'll process these heads in parallel by computing outputs for all heads simultaneously via matrix multiplication

In [61]:
# Combined

class MultiHeadAttention(torch.nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias = False):
        super().__init__()
        self.num_heads = num_heads
        self.d_out = d_out
        self.head_dim = d_out // num_heads

        self.W_query = torch.nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = torch.nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = torch.nn.Linear(d_in, d_out, bias=qkv_bias)
        self.out_proj = torch.nn.Linear(d_out, d_out)
        self.dropout = torch.nn.Dropout(dropout)
        self.register_buffer(
            "mask",
            torch.triu(torch.ones(context_length, context_length),
            diagonal=1)
        )
        
    def forward(self, x):
        n_corpus, num_tokens, d_in = x.shape
        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)

        keys = keys.view(n_corpus, num_tokens, self.num_heads, self.head_dim)
        queries = queries.view(n_corpus, num_tokens, self.num_heads, self.head_dim)
        values = values.view(n_corpus, num_tokens, self.num_heads, self.head_dim)

        keys = keys.transpose(1,2)
        queries = queries.transpose(1,2)
        values = values.transpose(1,2)

        attention_scores = queries @ keys.transpose(2,3)

        mask_bool = self.mask.bool()[:num_tokens, :num_tokens]
        attention_scores.masked_fill_(mask_bool, -torch.inf)

        attention_weights = torch.softmax(attention_scores / keys.shape[-1]**0.5, dim=-1)
        attention_weights = self.dropout(attention_weights)

        context_vecs = (attn_weights @ values).transpose(1,2)

        context_vecs = context_vecs.contiguous().view(n_corpus, num_tokens, self.d_out)
        context_vecs = self.out_proj(context_vecs)    # not necessary but commonly used in llms

        return context_vecs

In [64]:
torch.manual_seed(123)
batch_size, context_length, d_in = batch.shape
d_out = 2

mha = MultiHeadAttention(d_in, d_out, context_length, 0.0, num_heads=2)

context_vecs = mha(batch)

print("context_vecs.shape:", context_vecs.shape)
print(context_vecs)

context_vecs.shape: torch.Size([2, 6, 2])
tensor([[[0.3190, 0.4858],
         [0.2961, 0.4023],
         [0.2872, 0.3705],
         [0.2712, 0.3910],
         [0.2643, 0.3961],
         [0.2598, 0.4053]],

        [[0.3190, 0.4858],
         [0.2961, 0.4023],
         [0.2872, 0.3705],
         [0.2712, 0.3910],
         [0.2643, 0.3961],
         [0.2598, 0.4053]]], grad_fn=<ViewBackward0>)
