## MultiHead attention

In [64]:
import torch
import torch.nn as nn
import math

![(https://production-media.paperswithcode.com/methods/multi-head-attention_l1A3G7a.png)](https://data-science-blog.com/wp-content/uploads/2022/01/mha_img_original.png)

In [93]:
## embeddings 

class InputEmbedding(nn.Module):
    def __init__(self, vocab_size, d_model):
        super().__init__()
        self.d_model = d_model
        self.embedding = nn.Embedding(vocab_size, d_model)
        
        
    def forward(self, x):
        # x is a batch of sequence of words, batch_size, sequence_length -> batch_size, sequence_length, d_model
        return self.embedding(x) * math.sqrt(self.d_model)

In [96]:
class PositionalEncoding(nn.Module):

    def __init__(self, d_model: int, seq_len: int, dropout: float) -> None:
        super().__init__()
        self.d_model = d_model
        self.seq_len = seq_len
        self.dropout = nn.Dropout(dropout)
        # Create a matrix of shape (seq_len, d_model)
        pe = torch.zeros(seq_len, d_model)
        # Create a vector of shape (seq_len)
        position = torch.arange(0, seq_len, dtype=torch.float).unsqueeze(1) # (seq_len, 1)
        # Create a vector of shape (d_model)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) # (d_model / 2)
        # Apply sine to even indices
        pe[:, 0::2] = torch.sin(position * div_term) # sin(position * (10000 ** (2i / d_model))
        # Apply cosine to odd indices
        pe[:, 1::2] = torch.cos(position * div_term) # cos(position * (10000 ** (2i / d_model))
        # Add a batch dimension to the positional encoding
        pe = pe.unsqueeze(0) # (1, seq_len, d_model)
        # Register the positional encoding as a buffer
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + (self.pe[:, :x.shape[1], :]).requires_grad_(False) # (batch, seq_len, d_model)
        return self.dropout(x)

In Layer Normalization, normalization is done across all the features $x_{i,k}$  than across all the batches, this prcoess removes the dependency input sequences with each other.

First, We calculate mean and standard deviation. 

\begin{gather} \mu_i = \frac{1}{K} \sum_{k=1}^{K} x_{i,k} \\ \sigma_i^2 = \frac{1}{K} \sum_{k=1}^{K} (x_{i,k} - \mu_i)^2 \\ \end{gather}


Then we normalize each sample such that the elements in the sample have zero mean and unit variance. 
ϵ
 is for numerical stability in case the denominator becomes zero by chance.
 
 $$\hat{x}_{i,k} = \frac{x_{i,k}-\mu_i}{\sqrt{\sigma_i^2 + \epsilon}}$$
 
 Finally, there is a scaling and shifting step. 
γ
 and 
β
 are learnable parameters.
 
 $$y_i = \gamma \hat{x}_{i} + \beta \equiv {\text{LN}}_{\gamma, \beta} (x_i)$$
 
These parameters $\gamma$ and $\beta$ introduce fluctuations in the normalization

In [94]:
embed = InputEmbedding(2000,256)

In [95]:
embed('yoo')

TypeError: embedding(): argument 'indices' (position 2) must be Tensor, not str

In [80]:
class MultiHeadAttention(nn.Module):
    """This class resembles to the sequence of the above multiheadattenion picture"""
    ##self, input_sequence, head_size, embedding_dimention
    def __init__(self, h: int, d_model: int) -> None:
        super().__init__()
        
        self.h = h
        
        assert d_model % h == 0, "d_model is not divisible by head"
        
        self.d_k = d_model // h
        self.W_Q = nn.Linear(d_model, d_model, bias = False)
        self.W_K = nn.Linear(d_model, d_model, bias = False)
        self.W_V = nn.Linear(d_model, d_model, bias = False)
        self.W_O = nn.Linear(d_model, d_model, bias = False)
        
        
    @staticmethod
    def scaled_dot_product_attention(query, key, value, mask = None, dropout = None):
        d_k = query.shape[-1]
        #dot product between Q and K
        attention_weights = query @ key.transpose(-2,-1)
        #scaling
        
        attention_weights = attention_weights / math.sqrt(d_k)
        
        #masking
        if mask is not None:
            attention_weights = attention_weights.masked_fill_(mask == 0, -1e9)
        
        attention_weights = attention_weights.softmax(dim = -1)
        
        if dropout is not None:
            attention_weights = nn.Dropout(dropout)
            
        return attention_weights @ key, attention_weights
        #dropout
        
    
    def forward(self, q, k, v, mask = None, dropout = None):
        
        #q,k,v are embeddings of whole batch of sequence, so their size would be batch_size, sequence_length, d_model (embedding dimension)
        query = self.W_Q(q)
        key = self.W_K(k)
        value = self.W_V(v)
        
        #divide the q,k,v into different h heads
        
        #query initially had size of (batch_size, sequence_length, d_model)
        #and we split the d_model which is the embedding into different heads with each size of d_k = d_model/h
        #We finally call transpose to swap the h and sequence length, since, we want all the sequence words to have access to embeddings
        
        query = query.view(query.shape[0], query.shape[1], self.h, self.d_k).transpose(1,2)
        key = key.view(key.shape[0], key.shape[1], self.h, self.d_k).transpose(1,2)
        value = value.view(value.shape[0], value.shape[1], self.h, self.d_k).transpose(1,2)
        
        #print(query.shape)
        
        ## Now we perform scaled dot product here
        output, self.attention_weights = MultiHeadAttention.scaled_dot_product_attention(query, key, value)
        
        #concatination part happens here
        #output's dimension is batch_size, h, sequence_length, d_k we combine
        
        #hen you call contiguous(), it actually makes a copy of the tensor such that the order of its elements in memory is the same as if it had been created from scratch with the same data.
        output = output.transpose(1,2).contiguous().view(output.shape[0], -1 , self.d_k * self.h)
        
        #apply the linear part by multiplying with the Linear layer i.e self.W_O
        
        return self.W_O(output), self.attention_weights
        


In [81]:
attention = MultiHeadAttention(8, 256)

In [82]:
q = torch.rand(8, 10, 256)

In [83]:
q.shape

torch.Size([8, 10, 256])

In [85]:
x,y = attention.forward(q,q,q, )

In [86]:
x.shape

torch.Size([8, 10, 256])

In [87]:
y.shape

torch.Size([8, 8, 10, 10])