#  Part 1: Creating a Simplified Attention Mechanism
---

In [None]:
import torch
inputs = torch.tensor(
    [[0.43, 0.15, 0.89], #Your
     [0.55, 0.87, 0.66], #Journey
     [0.57, 0.85, 0.64], #Starts
     [0.22, 0.58, 0.33], #With
     [0.77, 0.25, 0.10], #One
     [0.05, 0.80, 0.55]] #Step
)

<div style="padding: 15px; border: 1px solid #e0e0e0; border-radius: 5px; background-color: #f9f9f9;">
    <p style="margin: 0; font-weight: bold; color: #222;">PLOT THE SIMILARITY BETWEEN WORDS USING WORD EMBEDDINGS</p>
    <hr style="margin: 10px 0; border: 0; border-top: 1px solid #ddd;">
    <p style="margin: 0; font-size: 0.95em; color: #555; font-style: italic;">
        (Visualizing high-dimensional vector relationships in 3D space to observe how the model clusters semantically related tokens, providing a clear window into the model's internal understanding of language.)
    </p>
</div>

In [None]:
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

words = ['Your', 'journey', 'starts', 'with', 'one', 'step']
x_coords = inputs[:, 0].numpy()
y_coords = inputs[:, 1].numpy()
z_coords = inputs[:, 2].numpy()

fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')

for x, y,z, word in zip(x_coords, y_coords, z_coords, words):
    ax.scatter(x,y,z)
    ax.text(x,y,z, word, fontsize=10)

ax.set_xlabel('X')
ax.set_ylabel('Y')
ax.set_xlabel('Z')

plt.title("A tale of 3D word embeddings")
plt.show()

<div style="padding: 15px; border: 1px solid #e0e0e0; border-radius: 5px; background-color: #f9f9f9;">
    <p style="margin: 0; font-weight: bold; color: #222;">CODE A SIMPLE DOT PRODUCT BASED ATTENTION SCORE</p>
    <hr style="margin: 10px 0; border: 0; border-top: 1px solid #ddd;">
    <p style="margin: 0; font-size: 0.95em; color: #555; font-style: italic;">
        (Implementing the fundamental mathematical operation where the dot product between query and key vectors measures their alignment, determining how much focus one token should place on another.)
    </p>
</div>

In [None]:

query = inputs[1]
print("Inputs Shape :", inputs.shape)
attn_scores_2 = torch.empty(inputs.shape[0])
print("Init Attention Scores: ", attn_scores_2)
for i, x_i in enumerate(inputs):
    print("The I : ", i , "The XI : ", x_i)
    attn_scores_2[i] = torch.dot(x_i , query)

print("Attention Scores : ", attn_scores_2)



<div style="padding: 15px; border: 1px solid #e0e0e0; border-radius: 5px; background-color: #f9f9f9;">
    <p style="margin: 0; font-weight: bold; color: #222;">NORMALIZE ATTENTION SCORE TO REPRESENT IT AS A PERCENTAGE OF WEIGHTS THAT SUM UP TO ONE</p>
    <hr style="margin: 10px 0; border: 0; border-top: 1px solid #ddd;">
    <p style="margin: 0; font-size: 0.95em; color: #555; font-style: italic;">
        (Applying the Softmax transformation to convert raw similarity scores into a probability distribution, ensuring that the influence of all input tokens is balanced and relative to one another.)
    </p>
</div>

In [None]:
total_sum_scores = attn_scores_2.sum()
print("total sum : " , total_sum_scores)
normalized_attention_scores = attn_scores_2 / total_sum_scores;
print("Normalized attention scores : ", normalized_attention_scores)
print("Sum of normalized attention scores : ", normalized_attention_scores.sum())

<div style="padding: 15px; border: 1px solid #e0e0e0; border-radius: 5px; background-color: #f9f9f9;">
    <p style="margin: 0; font-weight: bold; color: #222;">USING SOFTMAX NORMALIZATION FOR A DIFFERENTIABLE GRADIENT</p>
    <hr style="margin: 10px 0; border: 0; border-top: 1px solid #ddd;">
    <p style="margin: 0; font-size: 0.95em; color: #555; font-style: italic;">
        (Leveraging the Softmax function not just for probability distribution, but to maintain a smooth, differentiable surface that allows backpropagation to effectively update weights across the entire network.)
    </p>
</div>

In [None]:
def softmax_naive(x):
    return torch.exp(x) / torch.exp(x).sum(dim=0)

attn_scores_2_naive = softmax_naive(attn_scores_2)
print("Softmax normalized attention weights : ", attn_scores_2_naive)
print("Sum: ", attn_scores_2_naive.sum())

<div style="padding: 15px; border: 1px solid #e0e0e0; border-radius: 8px; background-color: #fafafa;">
    <p style="margin: 0; font-weight: bold; color: #333;">USING PYTORCH'S IMPLEMENTATION OF SOFTMAX:</p>
    <p style="margin: 10px 0 0 0; font-family: monospace; font-size: 1.1em; color:black">
        exp(xi) / Σ exp(xj) 
    </p>
    <p style="margin: 5px 0 0 0; font-size: 0.9em; color: #666;">(SAFE FOR VERY LARGE OR SMALL VALUES)</p>
</div>

In [None]:
attn_weights_2 = torch.softmax(attn_scores_2, dim=0)
print("Attention Weights : ", attn_weights_2)
print("Sum: ", attn_weights_2.sum())

<div style="padding: 15px; border: 1px solid #d1d1d1; border-radius: 4px; background-color: #fcfcfc;">
    <span style="text-transform: uppercase; font-weight: bold; color: #444; letter-spacing: 0.5px;">
        The context vector is calculated as a weighted sum of all input vectors
    </span>
</div>

In [None]:
query = inputs[1]
print(query.shape)
context_vec_2 = torch.zeros(query.shape)
for i, x_i in enumerate(inputs):
    context_vec_2 += attn_weights_2[i] * x_i
    print("ATW* X_I: \n", attn_weights_2[i] * x_i)
    print("CV in loop : \n", context_vec_2)
print("Context Vector : \n", context_vec_2)

<div style="padding: 15px; border: 1px solid #e0e0e0; border-radius: 5px; background-color: #f9f9f9;">
    <p style="margin: 0; font-weight: bold; color: #222;">CAN CREATE A TENSOR REPRESENTING AN ATTENTION SCORE BETWEEN EACH PAIR OF INPUTS</p>
    <hr style="margin: 10px 0; border: 0; border-top: 1px solid #ddd;">
    <p style="margin: 0; font-size: 0.95em; color: #555; font-style: italic;">
        (Since for loops are generally slow, same results can be achieved using matrix multiplication)
    </p>
</div>

In [None]:
print("Inputs : \n ", inputs)
print("Inputs Tranposed : \n ", inputs.T)
attn_scores = inputs @ inputs.T
print("Attention Scores :\n ", attn_scores)
print("Attention Scores Shapes : ", attn_scores.shape)

In [None]:
normalized_attention_scores = torch.softmax(attn_scores,dim=-1)
print("Normalized attention scores : \n ", normalized_attention_scores)

<div style="padding: 15px; border: 1px solid #e0e0e0; border-radius: 5px; background-color: #f9f9f9;">
    <p style="margin: 0; font-weight: bold; color: #222;">COMPUTE ALL CONTEXT VECTORS</p>
    <hr style="margin: 10px 0; border: 0; border-top: 1px solid #ddd;">
    <p style="margin: 0; font-size: 0.95em; color: #555; font-style: italic;">
        (Synthesizing the final context vectors by calculating the weighted sum of the value vectors. This step aggregates information from the entire sequence based on the normalized attention scores, producing the enriched representation for each token.)
    </p>
</div>

In [None]:
all_context_vectors = normalized_attention_scores @ inputs
print("All context vectors : \n", all_context_vectors)

#  Part 2: Creating a Self Attention Mechanism With Trianable Weights (SCALED)
---

In [None]:
import torch
inputs = torch.tensor(
    [[0.43, 0.15, 0.89], #Your
     [0.55, 0.87, 0.66], #Journey
     [0.57, 0.85, 0.64], #Starts
     [0.22, 0.58, 0.33], #With
     [0.77, 0.25, 0.10], #One
     [0.05, 0.80, 0.55]] #Step
)

In [None]:
x_2 = inputs[1]
d_in = inputs.shape[1]
d_out = 2

<div style="padding: 15px; border: 1px solid #e0e0e0; border-radius: 5px; background-color: #f9f9f9;">
    <p style="margin: 0; font-weight: bold; color: #222;">INITIALIZE THE WEIGHT MATRICES OF QUERY, KEY AND VALUE</p>
</div>

In [None]:
torch.manual_seed(123)
W_query = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
W_key = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
W_value = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
print("The Weights Query : \n", W_query)
print("The Weights Key : \n", W_key)
print("The Weights Value : \n", W_value)

<div style="padding: 15px; border: 1px solid #e0e0e0; border-radius: 5px; background-color: #f9f9f9;">
    <p style="margin: 0; font-weight: bold; color: #222;">COMPUTE QUERY, KEY AND VALUE VECTORS</p>
    <hr style="margin: 10px 0; border: 0; border-top: 1px solid #ddd;">
    <p style="margin: 0; font-size: 0.95em; color: #555; font-style: italic;">
        (Applying the weight matrices to the input embeddings to generate the specific Q, K, and V representations)
    </p>
</div>

In [None]:
query_2 = inputs @ W_query
key_2 = inputs @ W_key
value_2 = inputs @ W_value
print("Query : \n", query_2)

In [None]:
keys = inputs @ W_key
query = inputs @ W_query
value =  inputs @ W_value

<div style="padding: 15px; border: 1px solid #e0e0e0; border-radius: 5px; background-color: #f9f9f9;">
    <p style="margin: 0; font-weight: bold; color: #222;">WE CAN ATTEND TO HOW THE QUERY MATRIX ATTENDS TO THE KEY MATRIX BY DOING A DOT PRODUCT OR TRANSPOSED MATRIX MULTIPLICATION</p>
    <hr style="margin: 10px 0; border: 0; border-top: 1px solid #ddd;">
    <p style="margin: 0; font-size: 0.95em; color: #555; font-style: italic;">
        (Calculating the raw attention scores by measuring the alignment between queries and the transposed key matrix)
    </p>
</div>

In [None]:
key_2 = keys[1]
print("Shaopes :", key_2.shape , query_2.shape)
attn_scores_22 = query_2[1] @ keys.T
print("Attention Scores : ",attn_scores_22)

<div style="padding: 15px; border: 1px solid #e0e0e0; border-radius: 5px; background-color: #f9f9f9;">
    <p style="margin: 0; font-weight: bold; color: #222;">GET ALL ATTENTION SCORES VIA MATRIX MULTIPLICATIONS</p>
    <hr style="margin: 10px 0; border: 0; border-top: 1px solid #ddd;">
    <p style="margin: 0; font-size: 0.95em; color: #555; font-style: italic;">
        (Performing a single optimized matrix operation to compute the pairwise compatibility scores for the entire sequence)
    </p>
</div>

In [None]:
attn_scores = query @ keys.T
print("All attention scores : \n", attn_scores)

<div style="padding: 15px; border: 1px solid #e0e0e0; border-radius: 5px; background-color: #f9f9f9;">
    <p style="margin: 0; font-weight: bold; color: #222;">COMPUTE ATTENTION WEIGHTS BY SCALING THE ATTENTION SCORES AND USING SOFTMAX FUNCTION , WE SCALE THE ATTENTION SCORES BY DIVIDING THEM BY THE SQUARE ROOT OF THE EMBEDDING DIMENSIONS (I.E √2), MATHEMATICALLY THE SAME AS EXPONENTIATING BY 0.5 [TO AVOID VANISHING GRADIENTS, OVER CONFIDENCE, LEARNING UNSTABILITY AND REDUCE VARIANCE WHICH GROWS LINERALLY WITH THE DIMENSION OF THE MATIRCES]</p>
    <hr style="margin: 10px 0; border: 0; border-top: 1px solid #ddd;">
    <p style="margin: 0; font-size: 0.95em; color: #555; font-style: italic;">
         Without this scaling factor ($\sqrt{d_out}$), the dot products grow large in magnitude, pushing the softmax function into regions where gradients are extremely small, effectively "killing" the ability of the model to learn during backpropagation.
    </p>
</div>

In [None]:
d_k = keys.shape[-1]
print("Embedding dimensions: ", d_k)
attn_weights = torch.softmax(attn_scores / d_k**0.5 , dim=-1)
print("Scaled Attentions with sqrt : \n ", attn_weights)

<div style="padding: 15px; border: 1px solid #e0e0e0; border-radius: 5px; background-color: #f9f9f9;">
    <p style="margin: 0; font-weight: bold; color: #222;">CALCULATING THE CONTEXT VECTOR BY BY MATRIX MULTIPLACTIN OF THE WEIGHTING FACTOR (ATTN_WEIGHTS) AND THE VALUE VECTOR</p>
    <hr style="margin: 10px 0; border: 0; border-top: 1px solid #ddd;">
    <p style="margin: 0; font-size: 0.95em; color: #555; font-style: italic;">
        (This operation produces the final output of the attention mechanism by linearly combining the values based on their calculated importance scores.)
    </p>
</div>

In [None]:
context_vec = attn_weights @ value
print("Context vectors : \n", context_vec)

<div style="padding: 15px; border: 1px solid #e0e0e0; border-radius: 5px; background-color: #f9f9f9;">
    <p style="margin: 0; font-weight: bold; color: #222;">ORGANIZE THE SELF ATTENTION MECHANISM AS A PYTHON CLASS</p>
    <hr style="margin: 10px 0; border: 0; border-top: 1px solid #ddd;">
    <p style="margin: 0; font-size: 0.95em; color: #555; font-style: italic;">
        (Encapsulating the logic into a reusable PyTorch module that maintains the learnable weight matrices and handles the forward pass efficiently.)
    </p>
</div>

In [None]:
import torch.nn as nn
class SelfAttention_v1(nn.Module):
    def __init__(self, d_in, d_out):
        super().__init__()
        self.W_query = nn.Parameter(torch.rand(d_in, d_out))
        self.W_key = nn.Parameter(torch.rand(d_in, d_out))
        self.W_value = nn.Parameter(torch.rand(d_in, d_out))
    
    def forward(self, x):
        query = x @ self.W_query
        keys = x @ self.W_key
        values = x @ self.W_value

        attn_scores = query @ keys.T
        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
        context_vec = attn_weights @ values
        return context_vec

In [None]:
torch.manual_seed(123)
sa_v1 = SelfAttention_v1(d_in, d_out)
print("Context vectors : \n ", sa_v1(inputs))

<div style="padding: 15px; border: 1px solid #e0e0e0; border-radius: 5px; background-color: #f9f9f9;">
    <p style="margin: 0; font-weight: bold; color: #222;">IMPROVE THE SELF ATTENTION CLASS BY USING TORCH'S LINEAR LAYERS, IT DOES MATRIX MULTIPLICATION WHEN BIAS UNITS ARE DISABLED, NO NEED TO MANUALLY IMPLEMENT NN.PARAMETER, AND ALSO HAS AN OPTIMIZED WEIGHT</p>
</div>

In [None]:
class SelfAttention_v2(nn.Module):
    def __init__(self, d_in, d_out, qkv_bias=False):
        super().__init__()
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
    
    def forward(self, x):
        query = self.W_query(x)
        keys = self.W_key(x)
        values = self.W_value(x)

        attn_scores = query @ keys.T
        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
        context_vec = attn_weights @ values
        return context_vec

In [None]:
torch.manual_seed(789)
sa_v2 = SelfAttention_v2(d_in, d_out)
print("Context vectors : \n ", sa_v2(inputs))

#  Part 3: Creating a Casual/Masked Self Attention Mechanism
---

<div style="padding: 15px; border: 1px solid #e0e0e0; border-radius: 5px; background-color: #f9f9f9;">
    <p style="margin: 0; font-weight: bold; color: #222;">FIRSTLY WE GET THE ATTENTION WEIGHTS FROM THE ATTENTION SCORES THEN , WE HIDE FUTURE OR SUBSEQUENT WORDS WORDS</p>
    <hr style="margin: 10px 0; border: 0; border-top: 1px solid #ddd;">
    <p style="margin: 0; font-size: 0.95em; color: #555; font-style: italic;">
        (Implementing causal masking to ensure that the model can only attend to previous and current positions, preventing information leakage from future tokens during training.)
    </p>
</div>

In [None]:
queries = sa_v2.W_query(inputs)
keys = sa_v2.W_key(inputs)
attn_scores = queries @ keys.T
attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
print("Attention Weights : \n", attn_weights)

<div style="padding: 15px; border: 1px solid #e0e0e0; border-radius: 5px; background-color: #f9f9f9;">
    <p style="margin: 0; font-weight: bold; color: #222;">CREATE A MASK WHICH AS A TRIANGLE LOWCUT MATRIX WITH SCALAR VALUES OF 1 , WHICH WE CAN MULTIPLY THE ATTENTION WEIGHTS TO</p>
    <hr style="margin: 10px 0; border: 0; border-top: 1px solid #ddd;">
    <p style="margin: 0; font-size: 0.95em; color: #555; font-style: italic;">
        (Generating a lower-triangular matrix of ones that, when element-wise multiplied with the attention weights, zero-out the connections to future tokens.)
    </p>
</div>

In [None]:
context_length = attn_scores.shape[0]
simple_mask = torch.tril(torch.ones(context_length, context_length))
print("The mask : \n", simple_mask)

In [None]:
masked_weights = attn_weights * simple_mask
print("Masked weights : \n", masked_weights)

<div style="padding: 15px; border: 1px solid #e0e0e0; border-radius: 5px; background-color: #f9f9f9;">
    <p style="margin: 0; font-weight: bold; color: #222;">NOW WE GOTTA NORMALIZE EACH ROW TO SUM UP TO ONE BECAUSE ITS AN ATTENTION WEIGHT SO A SIMPLE NORMALIZATION WILL DO</p>
    <hr style="margin: 10px 0; border: 0; border-top: 1px solid #ddd;">
    <p style="margin: 0; font-size: 0.95em; color: #555; font-style: italic;">
        (Dividing each masked attention weight by the sum of its row to ensure the total probability distribution across all valid previous tokens equals 1.)
    </p>
</div>

In [None]:
sum_of_rows = masked_weights.sum(dim=-1, keepdim=True)
masked_weights = masked_weights / sum_of_rows
hmmm = torch.tensor([0.3800, 0.3097, 0.3103, 0.0000, 0.0000, 0.0000])
print("Hmm sum : ", hmmm.sum())
print("Masked Weights \n :", masked_weights )

<div style="padding: 15px; border: 1px solid #e0e0e0; border-radius: 5px; background-color: #f9f9f9;">
    <p style="margin: 0; font-weight: bold; color: #222;">MASKING AFTER SOFTMAX NORMALIZATION MAKES IT INFLUENCED BY SUBSEQUENT TOKENS, THUS A KIND OF DATA LEAK. SO IF WE APPLY AN UPPER INFINITY TRIANGULAR MASK, WHICH WHEN WHEN EXPONENTIATED GIVES ZERO , WHICH AFTER WILL BE NORMALIZED WITH SOFTMAX WHICH THE MAKSED VALUE DOESNT AFFECT THE OTHERS</p>
    <hr style="margin: 10px 0; border: 0; border-top: 1px solid #ddd;">
    <p style="margin: 0; font-size: 0.95em; color: #555; font-style: italic;">
        (By filling the upper triangle of the raw scores with $-\infty$ before the softmax, the exponentiation $e^{-\infty}$ results in 0, effectively removing future tokens from the probability distribution calculation entirely.)
    </p>
</div>

In [None]:
mask = torch.triu(torch.ones(context_length, context_length), diagonal=1)
masked_scores = attn_scores.masked_fill(mask.bool(), -torch.inf)
print("Masked Score s: \n", masked_scores)

In [None]:
masked_weights =  torch.softmax(masked_scores / keys.shape[-1]**0.5, dim=1)
print("Attention Weights Masked : \n", masked_weights)

In [None]:
context_vec = masked_weights @ sa_v2.W_value(inputs)

<div style="padding: 15px; border: 1px solid #e0e0e0; border-radius: 5px; background-color: #f9f9f9;">
    <p style="margin: 0; font-weight: bold; color: #222;">DROPOUT IN CAUSAL ATTENTION HELPS PREVENT OVERFITTING. SINCE ATTENTION HAS LOTS OF PARAMETERS, IT CAN LOCK ONTO TRAINING DATA TOO CLOSELY. BY DROPPING SOME ATTENTION WEIGHTS RANDOMLY DURING TRAINING, THE MODEL LEARNS MORE DIVERSE CONNECTIONS AND DEVELOPS ROBUST, GENERALIZABLE PATTERNS, WHICH IS IMPORTANT FOR COHERENT FUTURE TEXT. TO COMPHENSATE THE OTHERS ELEMENTS ARE SCALED BY A FACTOR OF 2 IN THIS CASE. ([1-1/P] KIND OF BASED ON BERNOULLI DISTRIBUTIONS)</p>
    <hr style="margin: 10px 0; border: 0; border-top: 1px solid #ddd;">
    <p style="margin: 0; font-size: 0.95em; color: #555; font-style: italic;">
        (Applying a Bernoulli mask to the attention weights to ensure that no single token relationship dominates the learning process, maintaining expected values through inverse-probability scaling.)
    </p>
</div>

In [None]:
torch.manual_seed(123)
dropout = torch.nn.Dropout(0.5)
dropout_mask = dropout(masked_weights)
print("Dropout Mask : \n", dropout_mask)

<div style="padding: 15px; border: 1px solid #e0e0e0; border-radius: 5px; background-color: #f9f9f9;">
    <p style="margin: 0; font-weight: bold; color: #222;">NOW CODING THE CAUSAL ATTENTION MECHANISM AS A CLASS , ALOS MAKING IT SO THAT IT CAN HANDLE BATCHES CONSISTING MORE THAT INPUTS, THIS ENSURES THAT THE CLASS SUPPORTS BATCH OUTPUTS FROM THE GPTDATASETV1 DATA LOADER AND DATASET CLASSES IMPLEMENTED IN THE OTHER NOTEBOOK</p>
</div>

In [None]:
batch = torch.stack((inputs, inputs), dim=0)
print("Simulated Batc Shape: \n", batch.shape)
print("Batch : \n", batch)

In [None]:
class CasualAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout_rate, qkv_bias=False):
        super().__init__()
        self.d_out = d_out
        self.W_query = torch.nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = torch.nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = torch.nn.Linear(d_in, d_out, bias=qkv_bias)
        self.dropout = torch.nn.Dropout(dropout_rate)
        self.register_buffer('mask', torch.triu(torch.ones(context_length,context_length), diagonal=1))

    def forward(self, x):
        buff_index, num_tokens, d_in = x.shape
        queries = self.W_query(x)
        keys = self.W_key(x)
        values = self.W_value(x)

        attn_scores = queries @ keys.transpose(1,2)
        attn_scores =  attn_scores.masked_fill(self.mask.bool()[:num_tokens, :num_tokens], -torch.inf)
        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
        attn_weights = self.dropout(attn_weights)
        context_vec = attn_weights @ values

        return context_vec

In [None]:
torch.manual_seed(123)
context_length = batch.shape[1]
ca = CasualAttention(d_in, d_out, context_length, 0.0)
context_vecs = ca(batch)
print("Context Vectors : \n", context_vecs)

#  Part 4: Creating a Multi-Head Self Attention Mechanism
---

<div style="padding: 15px; border: 1px solid #e0e0e0; border-radius: 5px; background-color: #f9f9f9;">
    <p style="margin: 0; font-weight: bold; color: #222;">THIS MECHANISMS CREATES MULTIPLE INSTANCES OF THE SELF-ATTENTION MECHANISM , WITH EACH WITH ITS OWN WEIGHT MATRICES AND OUTPUTS, THEN CONCATINATE THE CONTEXT VECTORS FROM THE ATTENTION HEADS INTO ONE CONTEXT VECTOR ACROSS THE COLUMNS</p>
    <hr style="margin: 10px 0; border: 0; border-top: 1px solid #ddd;">
    <p style="margin: 0; font-size: 0.95em; color: #555; font-style: italic;">
        (Implementing Multi-Head Attention to allow the model to jointly attend to information from different representation subspaces at different positions.)
    </p>
</div>

In [None]:
class MultiHeadAttenitionWrapper(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout_rate, num_heads, qkv_bias=False ):
        super().__init__()
        self.heads= nn.ModuleList([
            CasualAttention(d_in,d_out, context_length, dropout_rate, qkv_bias)
            for _ in range(num_heads)
        ])
    def forward(self, x):
        return torch.cat([head(x) for head in self.heads], dim=-1)

In [None]:
torch.manual_seed(123)
context_length = batch.shape[1]
d_in, d_out = 3, 2
mha = MultiHeadAttenitionWrapper(d_in,d_out, context_length, 0.0, num_heads=2)
context_vecs = mha(batch)
print("Context Vectors shape : ", context_vecs.shape)
print("Context Vectors : \n", context_vecs)

#  Part 5: Implementing a Multi-Head Self Attention Mechanism With Split Weights
---

<div style="padding: 15px; border: 1px solid #e0e0e0; border-radius: 5px; background-color: #f9f9f9;">
    <p style="margin: 0; font-weight: bold; color: #222;">INSTEAD OF HAVING A WRAPPER AND A CASUAL ATTENTION CLASS, WE CAN COMBINE INTO A SINGLE CLASS WITH MORE ATTENTION TO EFFICIECCY</p>
    <hr style="margin: 10px 0; border: 0; border-top: 1px solid #ddd;">
    <p style="margin: 0; font-size: 0.95em; color: #555; font-style: italic;">
        (Rather than instantiating multiple independent heads, we use a single large weight matrix and split it into multiple subspaces using tensor reshaping. This "split weights" approach allows for highly optimized, parallelized matrix multiplications across all heads simultaneously.)
    </p>
</div>

In [None]:
import torch
import torch.nn as nn

class MultiHeadAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
        super().__init__()
        assert (d_out % num_heads == 0), "d_out must be divisible by num_heads"

        self.d_out = d_out
        self.num_heads = num_heads
        self.head_dim = d_out // num_heads 

        # Projections to transform input into Query, Key, and Value spaces
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key   = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        
        # Final linear layer to merge concatenated heads
        self.out_proj = nn.Linear(d_out, d_out)  
        self.dropout = nn.Dropout(dropout)

        # Causal mask: upper triangular matrix to prevent attending to future tokens
        self.register_buffer(
            "mask",
            torch.triu(torch.ones(context_length, context_length), diagonal=1)
        )

    def forward(self, x):
        b, num_tokens, d_in = x.shape

        # 1. Linear projections -> (b, num_tokens, d_out)
        keys = self.W_key(x) 
        queries = self.W_query(x)
        values = self.W_value(x)

        # 2. Split d_out into heads -> (b, num_tokens, num_heads, head_dim)
        keys = keys.view(b, num_tokens, self.num_heads, self.head_dim) 
        values = values.view(b, num_tokens, self.num_heads, self.head_dim)
        queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)

        # 3. Transpose to move heads to the front -> (b, num_heads, num_tokens, head_dim)
        # This allows us to perform matrix multiplication on each head independently.
        keys = keys.transpose(1, 2)
        queries = queries.transpose(1, 2)
        values = values.transpose(1, 2)

        # 4. Compute Attention Scores (Dot Product)
        # (b, nh, T, hd) @ (b, nh, hd, T) -> (b, nh, T, T)
        # We transpose the last two dims of 'keys' so the head_dims align for multiplication.
        attn_scores = queries @ keys.transpose(2, 3) 

        # 5. Apply Causal Mask
        # We slice the pre-computed mask to the current sequence length (num_tokens)
        mask_bool = self.mask.bool()[:num_tokens, :num_tokens]
        attn_scores.masked_fill_(mask_bool, -torch.inf)
        
        # 6. Normalize Scores (Softmax + Scaling)
        # Scaling by sqrt(head_dim) prevents gradients from vanishing
        attn_weights = torch.softmax(attn_scores / self.head_dim**0.5, dim=-1)
        attn_weights = self.dropout(attn_weights)

        # 7. Compute Context Vector
        # (b, nh, T, T) @ (b, nh, T, hd) -> (b, nh, T, hd)
        # Then move num_heads back to merge: -> (b, T, nh, hd)
        context_vec = (attn_weights @ values).transpose(1, 2) 
        
        # 8. Concatenate heads back into one vector -> (b, T, d_out)
        context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out)
        
        # 9. Final output projection
        return self.out_proj(context_vec)

In [None]:
torch.manual_seed(123)
context_length = batch.shape[1]
d_in, d_out = 3, 2
mha = MultiHeadAttention(d_in,d_out, context_length, 0.0, num_heads=2)
context_vecs = mha(batch)
print("Context Vectors shape : ", context_vecs.shape)
print("Context Vectors : \n", context_vecs)