# Transformers from scratch 

In [119]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import torch.nn.functional as f
from torch import Tensor
from torch import nn
from IPython.display import Image


***Self Attention Function***

This is the function of a single attention head. This takes in three tensors named Q(query), K(key) & V(value). These are learnable parameters. The shapes of these tensors are (batch size, sequence_length, num_features/dimension of feature vector). These are generated using a feed forward layer which uses the embeddings embedded with positional encoding as input.

Attention(Q,K,V) = softmax $(QK^T/\sqrt d_k)V$

Masking is added in the original paper, 

Attention(Q,K,V) = softmax $(QK^T/ \sqrt d_k + M)V$. M is coded using a triangular function.

In [31]:
Image(url='https://miro.medium.com/v2/resize:fit:1400/format:webp/1*BzhKcJJxv974OxWOVqUuQQ.png')

In [120]:
def self_attention(Q: Tensor, K:Tensor, V: Tensor) -> Tensor:
    qk = Q.bmm(K.transpose(1,2)) #matrix batchwise multiplication
    scalling = Q.size(-1)**.5 #scalled by the square root of the number of features
    softmax_weight = f.softmax(qk/scalling,dim= -1) #conversion into softmax probability. values will be 0-1 via this.
    self_attention_block = softmax_weight.bmm(V) #MatMul again
    return self_attention_block, softmax_weight

In [122]:
# Test the function
q = Tensor(30,50,64)
k = Tensor(30,50,64)
v = Tensor(30,50,64)
result, _ = self_attention(q,k,v)
print(result.shape)


torch.Size([30, 50, 64])


***Multi-head Attention Class***

This takes in three tensors named Q(query), K(key) & V(value). These are then passed through a linear layer and then self_attention is computed paralelly. That is why they are called multi-head attention. 

In [133]:
class SingleAttentionHead(nn.Module):
    def __init__(self, model_dimension):
        super().__init__()
        self.q_l = nn.Linear(model_dimension,model_dimension)
        self.k_l = nn.Linear(model_dimension,model_dimension)
        self.v_l = nn.Linear(model_dimension,model_dimension)
        
    def forward(self, q,k,v):
        q = self.q_l(q)
        k = self.k_l(k)
        v = self.v_l(v)
        single_attention_head, attention_weight = self_attention(q,k,v)
        return single_attention_head,attention_weight
    

In [135]:
#Batch size= 30 (passing 30 sentences at once, max number of words in each sentence= 50, 64 is the vector
#represetation for each word, model dimension D_l
q = torch.rand(30,50,512)
k = torch.rand(30,50,512)
v = torch.rand(30,50,512)
model = SingleAttentionHead(512)

In [136]:
model

SingleAttentionHead(
  (q_l): Linear(in_features=512, out_features=512, bias=True)
  (k_l): Linear(in_features=512, out_features=512, bias=True)
  (v_l): Linear(in_features=512, out_features=512, bias=True)
)

In [137]:
attention, w = model.forward(q,k,v)
print(f"attention: {attention.shape}")
print(f"attention_weight: {w.shape}")

attention: torch.Size([30, 50, 512])
attention_weight: torch.Size([30, 50, 50])


In [184]:
class MultiAttentionHead(nn.Module):
    def __init__(self,num_of_heads, model_dims):
        super(MultiAttentionHead,self).__init__()
        self.num_of_heads = num_of_heads
        self.model_dims = model_dims
        self.head = model_dims//num_of_heads
        
        self.attention_heads = nn.ModuleList(
        [SingleAttentionHead(self.model_dims) for _ in range(self.num_of_heads)]
        )
        
        self.linear_output = nn.Linear(self.num_of_heads *self.model_dims,self.model_dims)
        
    def forward(self,q,k,v):
        attention_outputs = []
        attention_weights = []
        
        for each_attention_head in self.attention_heads:
            output, weights = each_attention_head(q,k,v)
            attention_outputs.append(output)
            attention_weights.append(weights)
        concated_output = torch.cat(attention_outputs, dim=-1)
        print(f"cncat_ot: {concated_output.shape}")
        concated_linear_output = self.linear_output(concated_output)
        print(f"cncat_ln_ot: {concated_linear_output.shape}")
        print(f"atten_w: {attention_weights[0].shape,len(attention_weights)}")
        return concated_linear_output, attention_weights
        

In [185]:
q = torch.rand(30,50,512)
k = torch.rand(30,50,512)
v = torch.rand(30,50,512)

In [186]:
model = MultiAttentionHead(8,512)

In [187]:
multihead_output, multihead_weight = model.forward(q,k,v)

cncat_ot: torch.Size([30, 50, 4096])
cncat_ln_ot: torch.Size([30, 50, 512])
atten_w: (torch.Size([30, 50, 50]), 8)
