<a href="https://colab.research.google.com/github/Santosw-Git/Transformer_code_from_scratch/blob/main/Transformer_from_scratch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# 0. **Import Modules**

In [None]:
import torch
import torch.nn as nn
import math

# **1.Input Embedding**

In [None]:
class InputEmbeddings(nn.Module):
  def __init__(self,d_model:int,vocab_size:int):
    super().__init__()
    self.d_model=d_model
    self.vocab_size=vocab_size
    self.embedding=nn.Embedding(vocab_size,d_model)

  def forword(self,x):
    return self.embedding(x) * math.sqrt(self.d_model)


# 2.**Positional_Encoding**

In [None]:
class PositionalEncoding(nn.Module):
  def __init__(self,d_model:int,seq_len:int,dropout:float):
    super().__init__()
    self.d_model=d_model
    self.seq_len=seq_len
    self.dropout=nn.Dropout(dropout)

    PE=torch.zeros(self.seq_len,self.d_model)
    position=torch.arange(0,seq_len,dtype=torch.float).unsqueeze(1)
    div_term=torch.exp(torch.arange(0,self.d_model,2).float() * (-math.log(10000.0)/self.d_model))
    PE[:,0::2]=torch.sin(position * div_term)
    PE[:,1::2]=torch.cos(position * div_term)
    PE=PE.unsqueeze(0)
    self.register_buffer("PE",PE)

  def forword(self,x):
    x=x+ (self.PE[:,:x.shape[1],:]).requires_grad(False)
    return self.dropout(x)




# 3.MultiHead **Attention**

In [None]:
class MultiHeadAttention(nn.Module):
  def __init__(self,d_model:int,head:int,dropout:float):
    super().__init__()
    self.d_model=d_model
    self.head=head
    assert self.d_model % self.head==0, "d_model is not divisible by head"
    self.d_k=d_model//head
    self.w_q=nn.Linear(d_model,d_model)
    self.w_k=nn.Linear(d_model,d_model)
    self.w_v=nn.Linear(d_model,d_model)
    self.w_o=nn.Linear(d_model,d_model)
    self.dropout=nn.Dropout(dropout)

  @staticmethod
  def attention(query,key,value,mask,dropout:nn.Dropout):
    d_k=query.shape[-1]
    attention_score=(query @ key.transpose(-2,-1))/math.sqrt(d_k)
    if mask is not None:
      attention_score.masked_fill_(mask==0 , -1e9)
    attention_score=attention_score.softmax(dim=-1)

    if dropout is not None:
      attention_score=dropout(attention_score)
    return (attention_score @ value) , attention_score


  def forward(self,q,k,v,mask):
    query=self.w_q(q) #(1,11,512)
    key=self.w_k(k)
    value=self.w_v(v)
    #from here it is for mutihead and above it was for self attention
    query=query.view(query.shape[0],query.shape[1],self.head,self.d_k).transpose(1,2) #[1,8,11,64]
    key=key.view(key.shape[0],key.shape[1],self.head,self.d_k).transpose(1,2)
    value=value.view(value.shape[0],value.shape[1],self.head,self.d_k).transpose(1,2)

    x,self.attention_score=MultiHeadAttention.attention(query,key,value,mask,self.dropout)
    x=x.transpose(1,2).contiguous().view(x.shape[0],-1,self.head * self.d_k)
    return self.w_o(x)



# 4.Layer **Normalization**

In [None]:
class LayerNormalization(nn.Module):
  def __init__(self,eps:float=10**-6) -> None:
    super().__init__()
    self.eps=eps
    self.alpha=nn.Parameter(torch.ones(1))
    self.bias=nn.Parameter(torch.zeros(1))

  def forward(self,x):
    mean=x.mean(dim=-1,keepdim=True)
    std=x.std(dim=-1,keepdim=True)
    return self.alpa * (x-mean)/(std+self.eps) + self.bias



# 5.Feed Forward **Network**

In [None]:
class FeedForwardNetwork(nn.Module):
  def __init__(self,d_model:int,d_ff:int,dropout:float)-> None:
    super.__init__()
    self.d_model=d_model
    self.linear_1=nn.Linear(self.d_model,d_ff)
    self.dropout=nn.Dropout(dropout)
    self.linear_2=nn.Linear(d_ff,self.d_model)

  def forward(self,x):
    return self.linear_2(self.dropout(torch.relu(self.linear_1(x))))




## 6.Residual **Connection**

In [None]:
class ResidualConnection(nn.Module):
  def __init__(self,dropout:float) -> None:
    super.__init__()
    self.dropout=nn.Dropout(dropout)
    self.norm=LayerNormalization()

  def forward(self,x,layer):
    # return x + self.dropout(self.norm(layer(x)))
    return x + self.dropout(layer(self.norm(x)))


# **7.Encoder**

In [None]:
class EncoderBlock(nn.Module):
  def __init__(self,self_attention_block:MultiHeadAttention , feed_forward_block:FeedForwardNetwork,dropout:float)->None:
    super.__init__()
    self.self_attention_block=self_attention_block
    self.feed_forward_block=feed_forward_block
    self.residual_connections=nn.ModuleList([ResidualConnection(dropout) for _ in range(2)])

  def forward(self,x,src_mask):
    x=self.residual_connections[0](x,lambda x : self.self_attention_block(x,x,x,src_mask))
    x=self.residual_connections[1](x,self.feed_forward_block)
    return x




In [None]:
class Encoder(nn.Module):
  def __init__(self,layers:nn.ModuleList)->None:
    super.__init__()
    self.layers=layers
    self.norm=LayerNormalization()

  def forward(self,x,mask):
    for layer in self.layers:
      x=layer(x,mask)
    return self.norm(x)




# 8.Decoder **Block**

In [None]:
class DecoderBlock(nn.Module):
  def __init__(self,self_attention_block:MultiHeadAttentionBlock,cross_attention_block:MultiHeadAttentionBlock,feed_forward_block:FeedForwardNetwork,dropout:float)->None:
    super().__init__()
    self.self_attention_block=self_attention_block
    self.cross_attention_block=cross_attention_block
    self.feed_forward_block=feed_forward_block
    self.residual_connections=nn.ModuleList([ResidualConection for _ in range(3)])

  def forward(self,x,encoder_output,src_mask,target_mask):
    x=self.residual_connections[0](x,lambda x: self.self_attention_block(x,x,x,target_mask))
    x=self.residual_connections[1](x,lambda x:self.cross_attention_block(d,encoder_output,encoder_output,src_mask))
    x=self.residual_connections[2](s,feed_forward_block)
    return x


In [None]:
class Decoder(nn.Module):
  def __init__(self,layers:nn.ModuleList)->None:
    super().__init__()
    self.layers=layers
    self.norm=LayerNormalization()

  def forward(self,x,encoder_output,src_mask,target_mask):
    for layer in self.layers:
      x=layer(x,encoder_output,src_mask,target_mask)
    return self.norm(x)



# 9.**Linear_layer**

In [None]:
class ProjectionLayer(nn.Module):
  def __init__(self,d_model:int,vocab_size:int) -> None:
    super().__init__()
    self.proj=nn.Linear(d_model,vocab_size)

  def forward(self,x):
    return torch.log_softmax(self.proj(x),dim=-1)

# 10.Building the Transformer

In [None]:
class Transformer(nn.Module):
  def __init__(self,encoder:Encoder,decoder:Decoder,src_embed:InputEmbeddings,target_embed:InputEmbeddings,src_pos:PositionalEmbedding,target_pos:Position_Embedding,projection_layer:ProjectionLayer):
    super().__init__()
    self.encoder=encoder
    self.decoder=decoder
    self.src_embed=src_embed
    self.target_embed=target_embed
    self.src_pos=src_pos
    self.target_pos=target_pos
    self.projection_layer=projection_layer

  def encoder(self,src,src_mask):
    src=self.src_embed(src)
    src=self.src_pos(src)
    return self.encoder(src,src_mask)

  def decoder(self,encoder_output,src_mask,target,target_mask):
    target=self.target_embed(target)
    target=self.target_pos(target)
    return self.decoder(target,encoder_output,src_mask,target_mask)

  def project(self,x):
    return self.projection_layer(x)




In [None]:
def build_transformer(src_vocab_size:int,target_vocab_size:int,src_seq_len:int,target_seq_len:int,d_model:int=512,N:int=6,head:int=8,dropout:float=0.1,d_ff:int=2048) -> Transformer:
  src_embed=InputEmbeddings(d_model,src_vocab_size)
  target_embded=PositionalEmbedding(d_model,target_vocab_size)
  src_pos=PositionalEncoding(d_model,src_seq_len,dropout)
  target_pos=PositionalEncoding(d_model,target_seq_len)
  encoder_blocks=[]

  for _ in range(N):
    encoder_self_attention_block=MultiHeadAttention(d_model,head,dropout)
    feed_forward_block=FeedForwardNetwork(d_model,d_ff,dropout)
    encoder_block=EncoderBlock(encoder_self_attention_block,feed_forward_block)
    encoder_blocks.append(encoder_block)

  decoder_blocks=[]
  for _ in range(N):
    decoder_self_attention_block=MultiHeadAttention(d_model,head,dropout)
    decoder_cross_attention_block=MultiHeadAttention(d_model,head,dropout)
    feed_forward_block=FeedForwardNetwork(d_model,d_ff,dropout)
    decoder_block=DecoderBlock(decoder_self_attention_block,decoder_cross_attention_block,feed_forward_block,dropout)
    decoder_blocks.append(decoder_block)

  encoder=Encoder(nn.ModuleList(encoder_blocks))
  decoder=Decoder(nn.ModuleList(decoder_blocks))
  projection=ProjectionLayer(d_model,target_vocab_size)
  transformer=Transformer(encoder,decoder,src_embed,target_embed,src_pos,target_pos,projection)

  for p in transformer.parameters():
    if p.dim()>1:
      nn.init.xavier_uniform_(p)
    return transformer

# **11.Tokenizer**

In [None]:
!pip install datasets
from datasets import load_dataset
from tokenizers import Tokenizer
from tokenizers.models import WordLevel
from tokenizers.trainers import WordLevelTrainer
from tokenizers.pre_tokenizers import Whitespace
from pathlib import Path




In [None]:
def get_or_build_tokenizer(config,ds,lang):
  tokenizer_path = Path(config["tokenizer_file"].format(lang))