implementing a very naive attention mechanism 

In [66]:
import torch 

inputs = torch.tensor(
[[0.43, 0.15, 0.89], 
[0.55, 0.87, 0.66], 
[0.57, 0.85, 0.64],  
[0.22, 0.58, 0.33],  
[0.77, 0.25, 0.10],  
[0.05, 0.80, 0.55]])

In [3]:
attn_scores = torch.zeros(inputs.shape[0], inputs.shape[0])

'''for i, x_i in enumerate(inputs):
    for j, x_j in enumerate(inputs):
        attn_scores[i][j] = torch.dot(x_i, x_j)'''


attn_scores = inputs @ inputs.T
print(attn_scores)

tensor([[0.9995, 0.9544, 0.9422, 0.4753, 0.4576, 0.6310],
        [0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865],
        [0.9422, 1.4754, 1.4570, 0.8296, 0.7154, 1.0605],
        [0.4753, 0.8434, 0.8296, 0.4937, 0.3474, 0.6565],
        [0.4576, 0.7070, 0.7154, 0.3474, 0.6654, 0.2935],
        [0.6310, 1.0865, 1.0605, 0.6565, 0.2935, 0.9450]])


In [4]:
attn_weights = torch.softmax(attn_scores, dim=-1)
attn_weights


tensor([[0.2098, 0.2006, 0.1981, 0.1242, 0.1220, 0.1452],
        [0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581],
        [0.1390, 0.2369, 0.2326, 0.1242, 0.1108, 0.1565],
        [0.1435, 0.2074, 0.2046, 0.1462, 0.1263, 0.1720],
        [0.1526, 0.1958, 0.1975, 0.1367, 0.1879, 0.1295],
        [0.1385, 0.2184, 0.2128, 0.1420, 0.0988, 0.1896]])

In [5]:
#context_vec = attn_weights @


context_vec = torch.zeros(inputs.shape)


'''for i, a_i in enumerate(attn_weights):
    for j, x_i in enumerate(inputs):
        context_vec[i] += a_i[j] * x_i'''

context_vec = attn_weights @ inputs
context_vec


tensor([[0.4421, 0.5931, 0.5790],
        [0.4419, 0.6515, 0.5683],
        [0.4431, 0.6496, 0.5671],
        [0.4304, 0.6298, 0.5510],
        [0.4671, 0.5910, 0.5266],
        [0.4177, 0.6503, 0.5645]])

coding a better, trainable attention mechanism

In [6]:
x_2 = inputs[1]
d_in = x_2.shape[0]
d_out = x_2.shape[0] - 1

In [7]:
torch.manual_seed(123)

w_q = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False) # query matrix
w_k = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False) # key matrix
w_v = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False) # value matrix 

In [8]:
x_q = x_2 @ w_q 
x_v = x_2 @ w_v
x_k = x_2 @ w_k
keys = inputs @ w_k # keys for all the input elements
values = inputs @ w_v # values for all the input elements

In [9]:
print(x_k)
print(x_q)
print(x_v)

tensor([0.4433, 1.1419])
tensor([0.4306, 1.4551])
tensor([0.3951, 1.0037])


In [10]:
print(keys)
print(values)

tensor([[0.3669, 0.7646],
        [0.4433, 1.1419],
        [0.4361, 1.1156],
        [0.2408, 0.6706],
        [0.1827, 0.3292],
        [0.3275, 0.9642]])
tensor([[0.1855, 0.8812],
        [0.3951, 1.0037],
        [0.3879, 0.9831],
        [0.2393, 0.5493],
        [0.1492, 0.3346],
        [0.3221, 0.7863]])


In [11]:
attn_scores_2 = x_q @ keys.T # dot product of (x^2)_q with (x^i)_k for all ^ as index in the inputs 
d_k = keys.shape[-1]
attn_weights_2 = torch.softmax(attn_scores_2 / d_k ** 0.5, dim=-1)

In [12]:
attn_weights_2

tensor([0.1500, 0.2264, 0.2199, 0.1311, 0.0906, 0.1820])

In [13]:
context_vec_2 = attn_weights_2 @ values
context_vec_2

tensor([0.3061, 0.8210])

In [14]:
import torch.nn as nn
'''implementing a python class for the above, improved attention'''

class SelfAttentionV1(nn.Module):
    def __init__(self, d_in, d_out):
        super().__init__()
        self.W_query = nn.Parameter(torch.rand(d_in, d_out))
        self.W_key = nn.Parameter(torch.rand(d_in, d_out))
        self.W_value = nn.Parameter(torch.rand(d_in, d_out))
    
    def forward(self, x):
        keys = x @ self.W_key
        values = x @ self.W_value
        queries = x @ self.W_query

        attention_weights = torch.softmax((queries @ keys.T) / keys.shape[-1] ** 0.5, dim=-1)
        context_vec = attention_weights @ values
        return context_vec


In [15]:
torch.manual_seed(123)
sa_v1 = SelfAttentionV1(d_in, d_out)
print(sa_v1(inputs))

tensor([[0.2996, 0.8053],
        [0.3061, 0.8210],
        [0.3058, 0.8203],
        [0.2948, 0.7939],
        [0.2927, 0.7891],
        [0.2990, 0.8040]], grad_fn=<MmBackward0>)


In [16]:
import torch.nn as nn
'''re-implementing the class with nn.Linear for better weight initialization '''

class SelfAttentionV2(nn.Module):
    def __init__(self, d_in, d_out, qkv_bias = False):
        super().__init__()
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
    
    def forward(self, x):
        keys = self.W_key(x)
        values = self.W_value(x)
        queries = self.W_query(x)

        attention_weights = torch.softmax((queries @ keys.T) / keys.shape[-1] ** 0.5, dim=-1)
        context_vec = attention_weights @ values
        return context_vec


In [17]:
torch.manual_seed(123)
sa_v2 = SelfAttentionV2(d_in, d_out)
print(sa_v2(inputs))

tensor([[-0.5337, -0.1051],
        [-0.5323, -0.1080],
        [-0.5323, -0.1079],
        [-0.5297, -0.1076],
        [-0.5311, -0.1066],
        [-0.5299, -0.1081]], grad_fn=<MmBackward0>)


In [18]:
sa_v2 = SelfAttentionV2(d_in, d_out)
wk = sa_v2.W_key.weight.T
wq = sa_v2.W_query.weight.T

In [19]:
keys = inputs @ wk
queries = inputs @ wq

In [20]:
attn_scores = queries @ keys.T
type(attn_scores.shape)

torch.Size

In [21]:
mask = torch.triu(torch.ones(attn_scores.shape[0], attn_scores.shape[1]), diagonal=1)
masked = attn_scores.masked_fill(mask.bool(), -torch.inf)
attn_weight = torch.softmax(masked / keys.shape[-1] ** 0.5, dim=1)
dropout = torch.nn.Dropout(0.3)
dropout(attn_weight)


tensor([[1.4286, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.6286, 0.8000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.4042, 0.5114, 0.5129, 0.0000, 0.0000, 0.0000],
        [0.3234, 0.3685, 0.3690, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.2891, 0.2894, 0.0000, 0.2929, 0.0000],
        [0.0000, 0.2450, 0.2455, 0.2453, 0.2512, 0.2406]],
       grad_fn=<MulBackward0>)

In [27]:
class CausalAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, qkv_bias = False):
        super().__init__()
        self.W_K = torch.nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_Q = torch.nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_V = torch.nn.Linear(d_in, d_out, bias=qkv_bias)
        self.dropout = torch.nn.Dropout(dropout)
        self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1))

    def forward(self, x):
        keys = self.W_K(x)
        values = self.W_V(x)
        queries = self.W_Q(x)
        attn_scores = queries @ keys.transpose(1,2)
        attn_scores = attn_scores.masked_fill(self.mask.bool(), -torch.inf)
        attn_weights = torch.softmax(attn_scores / keys.shape[-1] ** 0.5, dim = -1)
        attn_weights = self.dropout(attn_weights)

        context_vec = attn_weights @ values

        return context_vec


In [28]:
batch = torch.stack((inputs, inputs))
context_length = batch.shape[1]

ca = CausalAttention(d_in, d_out, context_length, 0.0)
cv = ca(batch)
print(cv)

tensor([[[-0.2636,  0.4184],
         [-0.3537,  0.2172],
         [-0.3752,  0.1474],
         [-0.3611,  0.0971],
         [-0.2495,  0.0649],
         [-0.3074,  0.0523]],

        [[-0.2636,  0.4184],
         [-0.3537,  0.2172],
         [-0.3752,  0.1474],
         [-0.3611,  0.0971],
         [-0.2495,  0.0649],
         [-0.3074,  0.0523]]], grad_fn=<UnsafeViewBackward0>)


internally, pytorch flattens the tensors (B, T, D) -> (B*T, D), does the operations and then reconstructs the original dims in case of batch operations.

### Why one head is not enough

- Limited subspace: A single attention head attends in a fixed representation space. It can't simultaneously focus on diverse semantic/structural patterns (like subject-object relationships and polysemy resolution).

- Capacity bottleneck: One head has fewer parameters and thus limited capacity to model all the nuances. Multi-head splits the job into specialized roles — one does syntax, another semantics, etc.

- Inductive bias: Multi-head attention acts like an ensemble. Each head learns a different perspective. This enforces modularity and improves generalization.

- Empirical evidence: Multi-head models like Transformers perform significantly better than single-head counterparts on benchmarks (e.g., BLEU for translation, GLUE for NLP tasks).

In [82]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias = False):
        super().__init__()
        assert (d_out % num_heads == 0), "d_out must be divisible by num_heads"
        self.d_out = d_out
        self.num_heads = num_heads
        self.head_dim = d_out // num_heads
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.out_proj = nn.Linear(d_out, d_out)
        self.dropout = nn.Dropout(dropout)
        self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1))


    def forward(self, x):
        b, num_tokens, d_in = x.shape
        keys = self.W_key(x)
        values = self.W_value(x)
        queries = self.W_query(x)

        keys = keys.view(b, num_tokens, self.num_heads, self.head_dim).transpose(1, 2)       # splitting the matrix in order to
        values = values.view(b, num_tokens, self.num_heads, self.head_dim).transpose(1, 2)   # generate (B, NUM_HEADS) independent matrices
        queries = queries.view(b, num_tokens, self.num_heads, self.head_dim).transpose(1, 2) # of size (NUM_TOKENS, HEAD_DIM) for parallel
                                                                                             # computation (hence, multi-head attention)

        attn_scores = queries @ keys.transpose(2, 3)
        attn_scores = attn_scores.masked_fill(self.mask.bool(), -torch.inf)

        attn_weights = torch.softmax(attn_scores / keys.shape[-1] ** 0.5, dim=-1)
        attn_weights = self.dropout(attn_weights)

        context_vector = (attn_weights @ values).transpose(1, 2)
        context_vector = context_vector.contiguous().view(b, num_tokens, self.d_out)
        context_vector = self.out_proj(context_vector)

        return context_vector

        

        



In [85]:
torch.manual_seed(123)
batch = torch.stack((inputs, inputs), dim=0)
batch_size, context_length, d_in = batch.shape
d_out = 2

mha = MultiHeadAttention(d_in, d_out, context_length, 0, num_heads=2)
cv = mha(batch)
print(cv)


tensor([[[0.3510, 0.4315],
         [0.3571, 0.3429],
         [0.3587, 0.3164],
         [0.3373, 0.3482],
         [0.3329, 0.3713],
         [0.3238, 0.3734]],

        [[0.3510, 0.4315],
         [0.3571, 0.3429],
         [0.3587, 0.3164],
         [0.3373, 0.3482],
         [0.3329, 0.3713],
         [0.3238, 0.3734]]], grad_fn=<ViewBackward0>)


#### The first n-2 dims define how many independent matrix multiplications happen in parallel.


Imagine you have a batch of data or multiple sequences you want to process simultaneously:

Those leading dims act as batch dimensions or sequence dims.

PyTorch treats each slice along those dims as an independent matrix.

It runs the matrix multiplication on all these slices in parallel efficiently on GPU/CPU.

#### Multi-Head Attention Example

Input Sentence: "I go to the bank to deposit money"

Each attention head processes the same input, but focuses on different relationships due to different learned projections (W_q, W_k, W_v).

##### Head 1: Semantic Disambiguation
- Focus: "bank" ↔ "deposit", "money"
- Learns that "bank" is a financial institution, not a river bank.

##### Head 2: Subject Identification
- Focus: "I" ↔ "go", "to"
- Captures who is acting — the grammatical subject.

##### Head 3: Grammatical Structure
- Focus: "to" ↔ "go", "deposit"
- Models motion or action phrases — verb-preposition-object relationships.


Multi-head attention enables different heads to learn specialized roles by attending in different representational subspaces, improving model expressiveness.