In [None]:
import torch 
import torch.nn as nn

In [2]:
device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cpu



##` STEP 1:`

#### Matrix Calculation of Self-Attention
- The first step is to calculate the Query, Key, and Value matrices. We do that by packing our embeddings into a matrix X, and multiplying it by the weight matrices we’ve trained (WQ, WK, WV).
![alt text](<IMG/Matrix Calculation of Self-Attention.jpg>)

## `STEP 2:`
Finally, since we’re dealing with matrices, we can condense steps two through six in one formula to calculate the outputs of the self-attention layer.
![alt text](IMG/attention_score.jpg)
![alt text](IMG/self_attn.jpg)

##### MULTI HEAD ATTENTION 
The paper further refined the self-attention layer by adding a mechanism called “multi-headed” attention. This improves the performance of the attention layer in two ways:

- It expands the model’s ability to focus on different positions. If we’re translating a sentence like “The animal didn’t cross the street because it was too tired”, it would be useful to know which word “it” refers to.

- It gives the attention layer multiple “representation subspaces”. As we’ll see next, with multi-headed attention we have not only one, but multiple sets of Query/Key/Value weight matrices (the Transformer uses eight attention heads, so we end up with eight sets for each encoder/decoder). Each of these sets is randomly initialized. Then, after training, each set is used to project the input embeddings (or vectors from lower encoders/decoders) into a different representation subspace.

![alt text](IMG/Multiheadattention.jpg)

If we do the same self-attention calculation we outlined above, just eight different times with different weight matrices, we end up with eight different Z matrices

![alt text](<IMG/Screenshot 2025-03-03 103818.jpg>)

This leaves us with a bit of a challenge. The feed-forward layer is not expecting eight matrices – it’s expecting a single matrix (a vector for each word). So we need a way to condense these eight down into a single matrix.
## `STEP 3:`
How do we do that? We concat the matrices then multiply them by an additional weights matrix WO.

![alt text](IMG/Multihead_attention.jpg)

![alt text](<IMG/Screenshot 2025-03-03 103959.jpg>)


In [None]:
class SelfAttention(nn.Module):
    ## embed_size represent the dmodel size
    ##  heads represent no of head used in each layer of transformer
    def __init__(self,embed_size,heads):
        super(SelfAttention,self).__init__()
        self.embed_size=embed_size
        self.heads=heads
        ## In this work we employ heads = 8 parallel attention layers, or heads. For each of these we use
        ## dmodel/heads = 64. Here dmodel is embed_size
        self.head_dim=embed_size//heads

        assert (self.head_dim*heads==embed_size), "Embedding size needs to be divisible by number of heads"

        self.values=nn.Linear(embed_size,embed_size,bias=False) ## input X: xW , W is a trainable paramter
        self.keys=nn.Linear(embed_size,embed_size,bias=False)   
        self.queries=nn.Linear(embed_size,embed_size,bias=False)
        self.fc_out=nn.Linear(embed_size,embed_size,bias=False) ## Concat(head_outputs),W_o
    def forward(self,values,keys,query,mask=None):
        ##query=(N,query_len,embed_size)
        N=query.shape[0] ## Number of training Examples
        ## Number tokens in key,query,values 
        values_len,key_len,query_len=values.shape[1],keys.shape[1],query.shape[1]
        ## For query,values,key shape is (N,len,embed_size)--> But we want this (N,len,head,head_dim) we know that embed_size=head*head_dim
        values=self.values(values) ## (N,value_len,embed_size)
        keys=self.keys(keys) ## (N,key_len,embed_size)
        queries=self.queries(query) ## (N,query_len,embed_size)

        ## Split the embedding into self.heads pieces
        values=values.reshape(N,values_len,self.heads,self.head_dim) ## VALUE SHAPE: (N,values_len,head,head_dim)
        keys=keys.reshape(N,key_len,self.heads,self.head_dim) ## KEYS SHAPE:(N,key_len,heads,head_dim)
        queries=queries.reshape(N,query_len,self.heads,self.head_dim) ## QUERY SHAPE: (N,query_len,head,head_dim)
        ## attn_scores (N,heads,key_len,query_len)-->nhqk
        attn_scores=torch.einsum("nqhd,nkhd->nhqk",[queries,keys])

        ## Mask padded indices so their weight become 0
        if mask is not None:
            attn_scores=attn_scores.masked_fill(mask==0,float("-1e20"))
        ## Scale and Normalize
        attention=torch.softmax(attn_scores/(self.head_dim**0.5),dim=3)

        ## Attention Shape: (N,heads,query_len,key_len)
        ## Values Shape: (N,value_len,heads,head_dim)
        ## out after multiply :(N,query_len,heads,head_dim)
        out =torch.einsum("nhql,nlhd->nqhd",[attention,values])
        ## We reshape and concatenate the last two dimenisons
        ## resultant out shape will (N,query_len,embed_size)
        out=out.reshape(N,query_len,self.heads*self.head_dim)
        out=self.fc_out(out) ## (N,query_len,embed_size)
        return out