# **Attention Mechanism**

In [146]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import random
import warnings

random.seed(42)
torch.manual_seed(123)
warnings.filterwarnings("ignore")

## Simplified Self-Attention Example

Example input tensor containing 6 words, each represented by a 3-dimensional vector. For demo, we take the second word as the input query.

In [147]:
inputs = torch.tensor([
    [0.43, 0.15, 0.89], # word 1
    [0.55, 0.87, 0.66], # word 2
    [0.57, 0.85, 0.64], # word 3
    [0.22, 0.58, 0.33], # word 4
    [0.77, 0.25, 0.10], # word 5
    [0.05, 0.80, 0.55]  # word 6
])

input_query = inputs[1]

We need to then compute the dot product of the input query with each word in the input tensor to get the attention scores.

In [148]:
input_1 = inputs[0]
print(f"Input one: {input_1}\nInput query: {input_query}")

print(f"Dot product: {torch.dot(input_query, input_1)}")

Input one: tensor([0.4300, 0.1500, 0.8900])
Input query: tensor([0.5500, 0.8700, 0.6600])
Dot product: 0.9544000625610352


In [149]:
query = inputs[1]
attn_scores_2 = torch.empty(inputs.shape[0])
for idx, input_i in enumerate(inputs):
    attn_scores_2[idx] = torch.dot(query, input_i)

print(f"Attention scores: {attn_scores_2}")

Attention scores: tensor([0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865])


To compute the attention weights, we just need to normalize the attention scores. We can do this simply, or by using a `softmax` function.

In [150]:
def manual_softmax(n):
    # could be unstable for some n
    return torch.exp(n) / torch.exp(n).sum(dim=0)

attn_weights_2 = torch.softmax(attn_scores_2, dim=0)
print(f"Attention weights: {attn_weights_2}")

Attention weights: tensor([0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581])


Now to compute the context vector, we multiply the attention weights with the input words and sum them up.

In [151]:
query = inputs[1]
context_vec_2 = torch.zeros(query.shape)

for idx, input_i in enumerate(inputs):
    context_vec_2 += attn_weights_2[idx] * input_i

print(f"Context vector: {context_vec_2}")

Context vector: tensor([0.4419, 0.6515, 0.5683])


## Simplified Self-Attention Generalized

This time, we'll use the same idea, but we will compute the attention scores in a more generalized way rather than just focusing on one input query.

In [152]:
inputs = torch.tensor([
    [0.43, 0.15, 0.89], # word 1
    [0.55, 0.87, 0.66], # word 2
    [0.57, 0.85, 0.64], # word 3
    [0.22, 0.58, 0.33], # word 4
    [0.77, 0.25, 0.10], # word 5
    [0.05, 0.80, 0.55]  # word 6
])

attn_scores = torch.empty(inputs.shape[0], inputs.shape[0])
print(f"Attention scores shape: {attn_scores.shape}")

Attention scores shape: torch.Size([6, 6])


By the same logic of dot product we compute the attention score, but this also be done more efficiently by matrix multiplication.

In [153]:
for i, x_i in enumerate(inputs):
    for j, x_j in enumerate(inputs):
        attn_scores[i, j] = torch.dot(x_i, x_j)

print(f"Attention scores: \n{attn_scores}\n")
print(f"Matrix multiplication result: \n{inputs @ inputs.T}")

Attention scores: 
tensor([[0.9995, 0.9544, 0.9422, 0.4753, 0.4576, 0.6310],
        [0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865],
        [0.9422, 1.4754, 1.4570, 0.8296, 0.7154, 1.0605],
        [0.4753, 0.8434, 0.8296, 0.4937, 0.3474, 0.6565],
        [0.4576, 0.7070, 0.7154, 0.3474, 0.6654, 0.2935],
        [0.6310, 1.0865, 1.0605, 0.6565, 0.2935, 0.9450]])

Matrix multiplication result: 
tensor([[0.9995, 0.9544, 0.9422, 0.4753, 0.4576, 0.6310],
        [0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865],
        [0.9422, 1.4754, 1.4570, 0.8296, 0.7154, 1.0605],
        [0.4753, 0.8434, 0.8296, 0.4937, 0.3474, 0.6565],
        [0.4576, 0.7070, 0.7154, 0.3474, 0.6654, 0.2935],
        [0.6310, 1.0865, 1.0605, 0.6565, 0.2935, 0.9450]])


With the same `softmax` function, we normalize the attention scores to get the attention weights.

In [154]:
attn_weights = torch.softmax(attn_scores, dim=1)
print(f"Attention weights: \n{attn_weights}")

Attention weights: 
tensor([[0.2098, 0.2006, 0.1981, 0.1242, 0.1220, 0.1452],
        [0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581],
        [0.1390, 0.2369, 0.2326, 0.1242, 0.1108, 0.1565],
        [0.1435, 0.2074, 0.2046, 0.1462, 0.1263, 0.1720],
        [0.1526, 0.1958, 0.1975, 0.1367, 0.1879, 0.1295],
        [0.1385, 0.2184, 0.2128, 0.1420, 0.0988, 0.1896]])


Using matrix multiplication for effieciency, we compute the context vector using the attention weights of the input words.

In [155]:
context_vec = attn_weights @ inputs
print(f"Context vector: \n{context_vec}")

Context vector: 
tensor([[0.4421, 0.5931, 0.5790],
        [0.4419, 0.6515, 0.5683],
        [0.4431, 0.6496, 0.5671],
        [0.4304, 0.6298, 0.5510],
        [0.4671, 0.5910, 0.5266],
        [0.4177, 0.6503, 0.5645]])


## Trainable Self-Attention Example

In [156]:
x_2 = inputs[1]          # word 2
d_in = inputs.shape[1]   # input tensor dimensionality
d_out = 2                # output tensor dimensionality

W_query = torch.nn.Parameter(torch.rand(d_in, d_out))
W_key = torch.nn.Parameter(torch.rand(d_in, d_out))
W_value = torch.nn.Parameter(torch.rand(d_in, d_out))

print(f"W_query: {W_query}\nW_key: {W_key}\nW_value: {W_value}")

W_query: Parameter containing:
tensor([[0.2961, 0.5166],
        [0.2517, 0.6886],
        [0.0740, 0.8665]], requires_grad=True)
W_key: Parameter containing:
tensor([[0.1366, 0.1025],
        [0.1841, 0.7264],
        [0.3153, 0.6871]], requires_grad=True)
W_value: Parameter containing:
tensor([[0.0756, 0.1966],
        [0.3164, 0.4017],
        [0.1186, 0.8274]], requires_grad=True)


In [157]:
query_2 = x_2 @ W_query
keys = inputs @ W_key
value = inputs @ W_value
print(f"Keys: \n{keys}\n\nValues: \n{value}\n\nQuery vector: {query_2}")

Keys: 
tensor([[0.3669, 0.7646],
        [0.4433, 1.1419],
        [0.4361, 1.1156],
        [0.2408, 0.6706],
        [0.1827, 0.3292],
        [0.3275, 0.9642]], grad_fn=<MmBackward0>)

Values: 
tensor([[0.1855, 0.8812],
        [0.3951, 1.0037],
        [0.3879, 0.9831],
        [0.2393, 0.5493],
        [0.1492, 0.3346],
        [0.3221, 0.7863]], grad_fn=<MmBackward0>)

Query vector: tensor([0.4306, 1.4551], grad_fn=<SqueezeBackward4>)


In [158]:
attn_scores_2 = query_2 @ keys.T
print(f"Attention scores: \n{attn_scores_2}")

Attention scores: 
tensor([1.2705, 1.8524, 1.8111, 1.0795, 0.5577, 1.5440],
       grad_fn=<SqueezeBackward4>)


In [159]:
d_k = keys.shape[1]  # dimensionality of keys

attn_weights_2 = torch.softmax(attn_scores_2 / d_k ** 0.5, dim=-1)
print(f"Attention weights: \n{attn_weights_2}")

Attention weights: 
tensor([0.1500, 0.2264, 0.2199, 0.1311, 0.0906, 0.1820],
       grad_fn=<SoftmaxBackward0>)


In [160]:
context_vec_2 = attn_weights_2 @ value
print(f"Context vector: \n{context_vec_2}")

Context vector: 
tensor([0.3061, 0.8210], grad_fn=<SqueezeBackward4>)


## Trainable Self-Attention Generalized

In [161]:
class SelfAttention_v1(nn.Module):
    def __init__(self, d_in, d_out):
        super().__init__()
        self.W_query = nn.Parameter(torch.rand(d_in, d_out))
        self.W_key = nn.Parameter(torch.rand(d_in, d_out))
        self.W_value = nn.Parameter(torch.rand(d_in, d_out))

    def forward(self, x):
        queries = inputs @ W_query
        keys = inputs @ W_key
        values = inputs @ W_value

        attn_scores = queries @ keys.T
        attn_weights = torch.softmax(attn_scores / d_k ** 0.5, dim=-1)
        context_vec = attn_weights @ values

        return context_vec
    
sa_v1 = SelfAttention_v1(d_in, d_out)
print(f"Context vector from self-attention v1: \n{sa_v1(inputs)}")

Context vector from self-attention v1: 
tensor([[0.2996, 0.8053],
        [0.3061, 0.8210],
        [0.3058, 0.8203],
        [0.2948, 0.7939],
        [0.2927, 0.7891],
        [0.2990, 0.8040]], grad_fn=<MmBackward0>)


In [162]:
example_tensor = torch.nn.Linear(2, 3)
print(f"Example weights: {example_tensor.weight}\n\nExample bias: {example_tensor.bias}")

Example weights: Parameter containing:
tensor([[-0.5980, -0.2029],
        [-0.4980,  0.0467],
        [-0.1320, -0.3793]], requires_grad=True)

Example bias: Parameter containing:
tensor([-0.0643,  0.6699, -0.0558], requires_grad=True)


In [163]:
class SelfAttention_v2(nn.Module):
    def __init__(self, d_in, d_out, qkv_bias=False):
        super().__init__()
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)

    def forward(self, x):
        queries = self.W_query(x)
        keys = self.W_key(x)
        values = self.W_value(x)

        attn_scores = queries @ keys.T
        attn_weights = torch.softmax(attn_scores / d_k ** 0.5, dim=-1)
        context_vec = attn_weights @ values

        return context_vec

sa = SelfAttention_v2(d_in, d_out)
print(f"Context vector from final self-attention: \n{sa(inputs)}")

Context vector from final self-attention: 
tensor([[-0.0089,  0.0272],
        [-0.0049,  0.0264],
        [-0.0049,  0.0264],
        [-0.0094,  0.0277],
        [-0.0079,  0.0272],
        [-0.0089,  0.0276]], grad_fn=<MmBackward0>)


## Causal Attention

In [164]:
queries = sa.W_query(inputs)
keys = sa.W_key(inputs)
values = sa.W_value(inputs)

attn_scores = queries @ keys.T
attn_weights = torch.softmax(attn_scores / d_k ** 0.5, dim=-1)

print(f"Attention weights: \n{attn_weights}")

Attention weights: 
tensor([[0.1723, 0.1562, 0.1570, 0.1706, 0.1828, 0.1612],
        [0.1749, 0.1489, 0.1502, 0.1737, 0.1943, 0.1580],
        [0.1749, 0.1489, 0.1502, 0.1737, 0.1942, 0.1580],
        [0.1712, 0.1570, 0.1577, 0.1707, 0.1812, 0.1622],
        [0.1726, 0.1543, 0.1552, 0.1716, 0.1855, 0.1607],
        [0.1716, 0.1561, 0.1569, 0.1710, 0.1826, 0.1618]],
       grad_fn=<SoftmaxBackward0>)


In [165]:
context_length = attn_scores.shape[0]
mask_simple = torch.tril(torch.ones(context_length, context_length))
print(f"Simple mask: \n{mask_simple}")

Simple mask: 
tensor([[1., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1.]])


In [166]:
masked_simple = attn_weights * mask_simple
print(f"Masked attention weights (simple): \n{masked_simple}")

Masked attention weights (simple): 
tensor([[0.1723, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1749, 0.1489, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1749, 0.1489, 0.1502, 0.0000, 0.0000, 0.0000],
        [0.1712, 0.1570, 0.1577, 0.1707, 0.0000, 0.0000],
        [0.1726, 0.1543, 0.1552, 0.1716, 0.1855, 0.0000],
        [0.1716, 0.1561, 0.1569, 0.1710, 0.1826, 0.1618]],
       grad_fn=<MulBackward0>)


In [167]:
row_sums = masked_simple.sum(dim=-1, keepdim=True)
masked_simple_normalized = masked_simple / row_sums
print(f"Masked attention weights (simple): \n{masked_simple_normalized}")

Masked attention weights (simple): 
tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5402, 0.4598, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3690, 0.3142, 0.3168, 0.0000, 0.0000, 0.0000],
        [0.2607, 0.2392, 0.2402, 0.2599, 0.0000, 0.0000],
        [0.2056, 0.1839, 0.1849, 0.2045, 0.2211, 0.0000],
        [0.1716, 0.1561, 0.1569, 0.1710, 0.1826, 0.1618]],
       grad_fn=<DivBackward0>)


In [168]:
mask = torch.triu(torch.ones(context_length, context_length), diagonal=1)
masked = attn_scores.masked_fill(mask.bool(), -torch.inf)
print(f"Masked attention scores: \n{masked}")

Masked attention scores: 
tensor([[-0.1760,    -inf,    -inf,    -inf,    -inf,    -inf],
        [-0.3144, -0.5421,    -inf,    -inf,    -inf,    -inf],
        [-0.3139, -0.5413, -0.5294,    -inf,    -inf,    -inf],
        [-0.1709, -0.2928, -0.2864, -0.1750,    -inf,    -inf],
        [-0.2171, -0.3751, -0.3669, -0.2248, -0.1147,    -inf],
        [-0.1869, -0.3200, -0.3130, -0.1913, -0.0989, -0.2699]],
       grad_fn=<MaskedFillBackward0>)


In [169]:
example_neg_inf_tensor = torch.tensor([float('-inf')])
print(f"Example negative infinity tensor: {example_neg_inf_tensor} \ne ** -inf or exp(-inf): {torch.exp(example_neg_inf_tensor)}")


Example negative infinity tensor: tensor([-inf]) 
e ** -inf or exp(-inf): tensor([0.])


In [170]:
attn_weights = torch.softmax(masked, dim=-1)
print(f"Masked attention weights: \n{attn_weights}")

Masked attention weights: 
tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5567, 0.4433, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3842, 0.3061, 0.3097, 0.0000, 0.0000, 0.0000],
        [0.2651, 0.2347, 0.2362, 0.2640, 0.0000, 0.0000],
        [0.2077, 0.1773, 0.1788, 0.2061, 0.2301, 0.0000],
        [0.1735, 0.1518, 0.1529, 0.1727, 0.1894, 0.1596]],
       grad_fn=<SoftmaxBackward0>)


## Dropout Masking

In [171]:
dropout_rate = 0.5
dropout_layer = torch.nn.Dropout(dropout_rate)

In [172]:
example_tensor = torch.ones(6,6)

print(f"Example tensor before dropout: \n{example_tensor}\n\nExample tensor after dropout: \n{dropout_layer(example_tensor)}\n")

print(f"Dropout scale factor: {1 / (1 - dropout_rate)}")

Example tensor before dropout: 
tensor([[1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1.]])

Example tensor after dropout: 
tensor([[2., 2., 0., 0., 2., 0.],
        [2., 2., 0., 2., 0., 2.],
        [2., 2., 2., 0., 0., 0.],
        [2., 0., 0., 0., 0., 0.],
        [0., 0., 2., 2., 2., 2.],
        [2., 2., 0., 2., 2., 0.]])

Dropout scale factor: 2.0


In [173]:
attn_weights = dropout_layer(attn_weights)
print(f"Attention weights after dropout: \n{attn_weights}")

Attention weights after dropout: 
tensor([[2.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [1.1134, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.7684, 0.0000, 0.6195, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.4154, 0.0000, 0.3576, 0.4122, 0.4602, 0.0000],
        [0.3469, 0.3037, 0.0000, 0.3454, 0.0000, 0.3193]],
       grad_fn=<MulBackward0>)


## Causal Self-Attention Generalized

In [174]:
print(f"Inputs: \n{inputs}")

Inputs: 
tensor([[0.4300, 0.1500, 0.8900],
        [0.5500, 0.8700, 0.6600],
        [0.5700, 0.8500, 0.6400],
        [0.2200, 0.5800, 0.3300],
        [0.7700, 0.2500, 0.1000],
        [0.0500, 0.8000, 0.5500]])


In [175]:
batch = torch.stack((inputs, inputs), dim=0)
print(f"Batch shape: {batch.shape}\nBatch: \n{batch}")

Batch shape: torch.Size([2, 6, 3])
Batch: 
tensor([[[0.4300, 0.1500, 0.8900],
         [0.5500, 0.8700, 0.6600],
         [0.5700, 0.8500, 0.6400],
         [0.2200, 0.5800, 0.3300],
         [0.7700, 0.2500, 0.1000],
         [0.0500, 0.8000, 0.5500]],

        [[0.4300, 0.1500, 0.8900],
         [0.5500, 0.8700, 0.6600],
         [0.5700, 0.8500, 0.6400],
         [0.2200, 0.5800, 0.3300],
         [0.7700, 0.2500, 0.1000],
         [0.0500, 0.8000, 0.5500]]])


In [176]:
class CasualAttention(nn.Module):
    def __init__(self, d_in, d_out, 
                 context_length, dropout, qkv_bias=False):
        super().__init__()
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)

        self.dropout = torch.nn.Dropout(dropout)
        self.register_buffer("mask", torch.triu(torch.ones(context_length, context_length), diagonal=1))

    def forward(self, x):
        b, num_tokens, d_k = x.shape

        queries = self.W_query(x)
        keys = self.W_key(x)
        values = self.W_value(x)

        attn_scores = queries @ keys.transpose(1, 2)
        attn_scores.masked_fill_(
            self.mask.bool()[:num_tokens, :num_tokens], -torch.inf
        )

        attn_weights = torch.softmax(
            attn_scores / d_k ** 0.5, dim=-1
        )

        attn_weights = self.dropout(attn_weights)

        context_vec = attn_weights @ values

        return context_vec

torch.manual_seed(789)
context_length = batch.shape[1]
dropout = 0.0
ca = CasualAttention(d_in, 
                     d_out, 
                     context_length, 
                     dropout)
print(f"Context vector from casual attention: \n{ca(batch)}")

Context vector from casual attention: 
tensor([[[-0.0872,  0.0286],
         [-0.0993,  0.0505],
         [-0.1001,  0.0638],
         [-0.0983,  0.0489],
         [-0.0513,  0.1100],
         [-0.0757,  0.0688]],

        [[-0.0872,  0.0286],
         [-0.0993,  0.0505],
         [-0.1001,  0.0638],
         [-0.0983,  0.0489],
         [-0.0513,  0.1100],
         [-0.0757,  0.0688]]], grad_fn=<UnsafeViewBackward0>)


## Multi-Head Attention

In [177]:
class MultiHeadAttentionWrapper(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, num_heads=2, qkv_bias=False):
        super().__init__()
        self.heads = nn.ModuleList(
            [CasualAttention(d_in,  d_out,  context_length,  dropout,  qkv_bias) for _ in range(num_heads)]
        )

    def forward(self, x):
        return torch.cat([head(x) for head in self.heads], dim=-1)

context_length = batch.shape[1]
d_in, d_out = inputs.shape[-1], 2

mha = MultiHeadAttentionWrapper(d_in, d_out, context_length, dropout=0, num_heads=2)
print(f"Context vector from multi-head attention: \n{mha(batch)}\nShape: {mha(batch).shape}")

Context vector from multi-head attention: 
tensor([[[ 0.2482, -0.4838,  0.1874,  0.2627],
         [ 0.3719, -0.4306,  0.2216,  0.1136],
         [ 0.4159, -0.4116,  0.2300,  0.0680],
         [ 0.3774, -0.3525,  0.2117,  0.0237],
         [ 0.3980, -0.3240,  0.1728,  0.0655],
         [ 0.3670, -0.3078,  0.1810,  0.0206]],

        [[ 0.2482, -0.4838,  0.1874,  0.2627],
         [ 0.3719, -0.4306,  0.2216,  0.1136],
         [ 0.4159, -0.4116,  0.2300,  0.0680],
         [ 0.3774, -0.3525,  0.2117,  0.0237],
         [ 0.3980, -0.3240,  0.1728,  0.0655],
         [ 0.3670, -0.3078,  0.1810,  0.0206]]], grad_fn=<CatBackward0>)
Shape: torch.Size([2, 6, 4])


## Efficient Multi-Head Attention

In [178]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
        super().__init__()
        assert (d_out % num_heads == 0), \
            f"{d_out} must be divisible by {num_heads}."
        
        self.d_out = d_out
        self.num_heads = num_heads
        self.head_dim = d_out // num_heads

        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.out_proj = nn.Linear(d_out, d_out)
        self.dropout = nn.Dropout(dropout)
        self.register_buffer(
            "mask",
            torch.triu(torch.ones(context_length, context_length), diagonal=1)
        )

    def forward(self, x):
        b, num_tokens, d_in = x.shape

        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)

        keys = keys.view(b, num_tokens, self.num_heads, self.head_dim)
        values = values.view(b, num_tokens, self.num_heads, self.head_dim)
        queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)

        keys = keys.transpose(1, 2)
        values = values.transpose(1, 2)
        queries = queries.transpose(1, 2)

        attn_scores = queries @ keys.transpose(2, 3)

        mask_bool = self.mask.bool()[:num_tokens, :num_tokens]

        attn_scores.masked_fill_(mask_bool, -torch.inf)

        attn_weights = torch.softmax(attn_scores / keys.shape[-1] ** 0.5, dim=-1)
        attn_weights = self.dropout(attn_weights)

        context_vec = (attn_weights @ values).transpose(1, 2)

        context_vec = context_vec.reshape(b, num_tokens, self.d_out)
        context_vec = self.out_proj(context_vec)

        return context_vec
    
batch_size, context_length, d_in = batch.shape
d_out = 4
mha = MultiHeadAttention(d_in, d_out, context_length, dropout=0.0, num_heads=2)

context_vecs = mha(batch)
print(f"Context vectors from multi-head attention: \n{context_vecs}\nContext vectors shape: {context_vecs.shape}")

Context vectors from multi-head attention: 
tensor([[[0.2119, 0.2659, 0.5129, 0.6238],
         [0.2040, 0.2688, 0.4774, 0.6668],
         [0.2020, 0.2700, 0.4630, 0.6830],
         [0.1947, 0.2909, 0.4538, 0.6645],
         [0.1842, 0.3148, 0.4116, 0.6585],
         [0.1855, 0.3131, 0.4304, 0.6513]],

        [[0.2119, 0.2659, 0.5129, 0.6238],
         [0.2040, 0.2688, 0.4774, 0.6668],
         [0.2020, 0.2700, 0.4630, 0.6830],
         [0.1947, 0.2909, 0.4538, 0.6645],
         [0.1842, 0.3148, 0.4116, 0.6585],
         [0.1855, 0.3131, 0.4304, 0.6513]]], grad_fn=<ViewBackward0>)
Context vectors shape: torch.Size([2, 6, 4])
