# Building single-head self-attention 

In [7]:
import torch
from torch import nn
import torch.nn.functional as F

## NOTE:
### Any self-attention (single-head or multi-head) has learnable parameters(Wq, Wk, Wv), by which it is multiplied and forwarded. So it is just like a neural network module

In [15]:
#defining single-head self-attention module

class single_head_self_attentional(nn.Module):
    def __init__(self,k):
        super.init()
        self.k=k
        
        #Wq, Wk, Wv are just matrices which linearly scale the input vectors to form queries,keys,values
        #using nn.Linear layer with no bias will do the job for you
        #passing the i/p vectors through this layer will scale them
        
        self.toqueries=nn.Linear(self.k,self.k,bias=False) #Wq
        self.tokeys=nn.Linear(self.k,self.k,bias=False)    #Wk
        self.tovalues=nn.Linear(self.k,self.k,bias=False)  #Wv
        
    
    def forward(self,X):
        #X is an input torch tensor of the size (t * k) batch size is 1 for easy unbderstanding
        #It represents an input sequence with t words and each word having the dimension k
        
        Q=self.toqueries(X)
        K=self.tokeys(X)
        V=self.tovalues(X)
        
        unscaled_weights=torch.mm(Q,K.transpose())
        scaled_weights=unscaled_weights/(self.k ** (1/2)) #scaling for preventing large values going into softmax which 
                                            #will lead to derivate being near 0 and thus prevent/slow down learning
            
        
        weights=F.softmax(scaled_weights,dim=1)
        
        result=torch.mm(weights,X)
        
        return result
        
        
        
        

# Building multi-head self-attention

## NOTE:
### Attention heads are the no. of times you want to perform self-attention on a single o/p vector.
### 2 options
### 1. Results are concatenated and then dimensions are reduced(higher time complexity)
### 2. I/p vectors are scaled down, multi-head attention is performed, and the results are concatinated with additional scaling (lower time complexity)
### Lets go with the 2nd option

In [18]:
#defining multi-head self-attention module

class multi_head_self_attention(nn.Module):
    
    def __init__(self,k,h=4):  
        super.init()
        
        assert k%h==0  #then only the dimensions can be scaled 
                      #heads=4 by default as most dimensions are divisble by 4 
            
        self.k=k
        self.heads=h
        self.s=k//h #ineteger division
        
        
        #small hack -> instead of applying several short linear layers(matrices) which are 3h in number (also, variable), we can just
        #apply a single linear layer and split the results into chunks as required (which will be very easy)
        
        self.toqueries=nn.Linear(self.k,self.k,bias=False)
        self.tokeys=nn.Linear(self.k,self.k,bias=False)
        self.tovalues=nn.Linear(self.k,self.k,bias=False)
        

    def forward(self,X):
        #X is an input torch tensor of the size (t * k) batch size is 1 for easy unbderstanding
        #It represents an input sequence with t words and each word having the dimension k
        
        t,h=X.size()
        
        Q=self.toqueries(X)
        K=self.tokeys(X)
        V=self.tovalues(X)
        
        #breaking them into small chunks/parts 
        
        #original dimension was (t,k)
        queries=Q.reshape(t,self.s,h) 
        keys=K.reshape(t,self.s,h)
        values=V.reshape(t,self.s,h)
        
        #we need to do dot product between Q and K, and this is the same for all the heads
        #so we make the head dimension as the batched dimension, so that the dot product is done for all the batches,
        #thus for all the heads
        #NOTE: dot product is always between 2 matrices, no matter if they are arranged in batches or not
        
        b_Q=queries.transpose(1,2).transpose(0,1)
        b_K=keys.transpose(1,2).transpose(0,1)
        b_V=values.transpose(1,2).transpose(0,1)
        
        #dimension of b_Q,b_K,b_V is (h,t,s)
        
        unscaled_weights=torch.bmm(b_Q,b_K.transpose(1,2))
        scaled_weights=unscaled_weights/(self.s ** (1/2))
        
        weights=F.softmax(scaled_weights,dim=2)
        
        out=torch.bmm(weights,b_V)
        #having dim (h,t,s)
        
        result=out.transpose(0,1).transpose(1,2).reshape(t,self.s*h)
        
        return result
        
        
        
        