##Importing Libraries

In [None]:
import torch
import torch.nn as nn
import math

##Input Embeddings

In [None]:
class InputEmbeddings(nn.Module):
  def __init__(self,d_model,vocab_size):
    super().__init__()
    self.d_model = d_model
    self.vocab_size = vocab_size
    self.embedding = nn.Embedding(vocab_size,d_model)

  def forward(self,x):
    return self.embedding(x) * math.sqrt(self.d_model)

##Positional Encoding


In [None]:
class PositionalEncoding(nn.Module):

  def __init__(self,d_model,seq_length,dropout):
    super().__init__()
    self.d_model = d_model
    self.seq_length = seq_length
    self.dropout = dropout

    pe = torch.zeros(seq_length ,d_model)

    position = torch.arange(0,seq_length).unsqueeze(1)
    div_term = torch.exp(torch.arange(0,d_model,2).float() * (-math.log(10000.0) / d_model))

    pe[ : , 0::2] = torch.sin(position * div_term)
    pe[ : , 1::2] = torch.cos(position * div_term)

    pe = pe.unsqueeze(0)

    self.register_buffer('pe',pe)

  def forward(self,x):
    x = x+ (self.pe[ : , :x.shape[1], : ]).requires_grad_(False)
    return self.dropout(x)


##LayerNormalization

In [None]:
class LayerNormalization(nn.module):

  def __init__(self, eps = 10**-6):
    super().__init__()
    self.eps = eps
    self.alpha = nn.Parameter(torch.ones(1))
    self.beta = nn.Parameter(torch.zeros(1))

  def forward(self,x):
    mean = x.mean(dim = -1,keepdim= True )
    std = x.std(dim = -1,keepdim = True)
    return (self.alpha* (x-mean))/((self.std + self.eps)+self.bias)


##Feed Forward


In [None]:
class FeedForward(nn.module):

  def __init__(self,d_model,d_ff,dropout):
    super().__init__()
    self.linear1 = nn.Linear(d_model,d_ff)
    self.linear2 = nn.Linear(d_ff,d_model)
    self.dropout = dropout

  def forward(self,x):
    return (self.linear2(self.dropout(torch.ReLU(self.linear1(x)))))

##Multi-Headed Attention


In [None]:
class MultiHeadAttention(nn.Module):
  def __init__(self,d_model,h,dropout):
    super().__init__()
    self.d_model = d_model
    self.h = h
    assert d_model%h ==0 ,"d_model is not divisible by h"
    self.d_k = d_model//h
    self.dropout = dropout
    self.w_q = nn.Linear(d_model,d_model)
    self.w_k = nn.Linear(d_model,d_model)
    self.w_v = nn.Linear(d_model,d_model)
    self.w_o = nn.Linear(d_model,d_model)

  @staticmethod
  def attention(query, key, value, mask,dropout):
    d_k = query.shape[-1]

    attention_scores = (query @ key.transpose(-2,-1))/math.sqrt(d_k)
    attention_scores = attention_scores.softmax(dim = 1)
    if mask is not None:
      attention_scores.masked_fill_(mask == 0,-1e9)
    if dropout is not None:
      attention_scores = dropout(attention_scores)

    return (attention_scores @ value), attention_scores

  def forward(self,q,k,v,mask):
    query = self.w_q(q)
    key   = self.w_k(k)
    value = self.w_v(v)

    query = query.view(query.shape[0],query.shape[1],self.h,self.d_k).transpose(1,2)
    key   = key.view(key.shape[0],key.shape[1],self.h,self.d_k).transpose(1,2)
    value = value.view(value.shape[0],value.shape[1],self.h,self,self.d_k).transpose(1,2)

    x , self.attention_scores = MultiHeadAttention.attention(query,key,value,mask,self.dropout)
    x = x.transpose(1,2).contiguous().view(x.shape[0],-1,self.d_k *self.h)

    return self.w_o(x)

##Residual Connection

In [None]:
class ResidualConnection(nn.module):

  def __init__(self,dropout):
    super.__init__()
    self.dropout = dropout
    self.norm = LayerNormalization()

  def forward(self,x,sublayer):
    return x+ self.dropout(sublayer(self.norm(x)))


##Encoder Block

In [None]:
class EncoderBlock(nn.module):
  def __init__(self,self_attention_block:MultiHeadAttention,feed_forward_network:FeedForward,dropout):
    super().__init__()
    self.feed_forward_network = feed_forward_network
    self.self_attention_block = self_attention_block
    self.ResidualConnection = nn.ModuleList([ResidualConnection(dropout) for _ in range(2) ])

  def forward(self, x , src_mask):
    x = self.ResidualConnection[0](x,lambda x: self.self_attention_block(x,x,x,src_mask))
    x = self.ResidualConnection[1](x,lambda x: self.feed_forward_network)
    return x

##Encoder

In [None]:
class Encoder(nn.module):
  def __init__(self,layers:nn.ModuleList):
    super().__init__()
    self.layers = layers
    self.norm = LayerNormalization()

  def forward(self,x,mask):
    for layer in self.layers:
      x = layer(x,mask)
    return self.norm(x)


##Decoder Block

In [None]:
class DecoderBlock(nn.module):
  def __init__(self,self_attention_block: MultiHeadAttention,cross_attention_block:MultiHeadAttention,feed_forward_network:FeedForward,dropout):
    super().__init__()
    self.self_attention_block = self_attention_block
    self.cross_attention_block = cross_attention_block
    self.feed_forward_network = feed_forward_network
    self.ResidualConnection = nn.ModuleList([ResidualConnection(dropout) for _ in range(3)])

  def forward(self,x,encoder_output,src_mask,tgt_mask):
    x = self.ResidualConnection[0](x, lambda x: self.self_attention_block(x,x,x,tgt_mask))
    x = self.ResidualConnection[1](x, lambda x: self.cross_attention_block(x,encoder_output,encoder_output,src_mask))
    x = self.ResidualConnection[2](x,self.feed_forward_network)
    return x

##Decoder

In [None]:
class Decoder(nn.module):
  def __init__(self,layers:nn.ModuleList):
    super().__init__()
    self.layers = layers
    self.norm = LayerNormalization()
  def forward(self,x,src_mask,tgt_mask):
    for layer in self.layers:
      x = layer(x,src_mask,tgt_mask)
    return self.norm(x)

##Projection Layer

In [None]:
class ProjectionLayer(nn.module):
  def __init__(self,d_model,vocab_size):
    super().__init__()
    self.proj = nn.Linear(d_model,vocab_size)

  def forward(self,x):
    return torch.log_softmax(self.proj(x),dim = -1)

##Transformer

In [None]:
class Transformer(nn.module):
  def __init__(self,encoder,decoder,src_embed: InputEmbeddings,tgt_embed:InputEmbeddings,src_pos: PositionalEncoding, tgt_pos: PositionalEncoding,projection :ProjectionLayer):
    super().__init__()
    self.encoder = encoder
    self.decoder = decoder
    self.src_embed = src_embed
    self.tgt_embed = tgt_embed
    self.src_pos = src_pos
    self.tgt_pos = tgt_pos
    self.projection = projection

  def encode(self,src,src_mask):
    src = self.src_embed(src)
    src = self.src_pos(src)
    return self.encoder(src,src_mask)

  def decode(self,encoder_output,tgt,src_mask,tgt_mask):
    tgt = self.tgt_embed(tgt)
    tgt = self.tgt_pos(tgt)
    return self.decoder(tgt,encoder_output,src_mask,tgt_mask)

  def project(self,x):
    return self.projection(x)

##Transformer Build

In [None]:
def build_transformer(src_vocab_size,tgt_vocab_size,src_seq_len,tgt_seq_len,d_model,N,h,dropout,d_ff):
  src_embed = InputEmbeddings(d_model,src_vocab_size)
  tgt_embed = InputEmbeddings(d_model,tgt_vocab_size)

  src_pos = PositionalEncoding(d_model,src_seq_len,dropout)
  tgt_pos = PositionalEncoding(d_model,tgt_seq_len,dropout)

  encoder_list = []
  decoder_list = []

  for _ in range(N):
    encoder_self_attention = MultiHeadAttention(d_model,h,dropout)
    encoder_feed_forward = FeedForward(d_model,d_ff,dropout)
    encoder_block_out = EncoderBlock(encoder_self_attention,encoder_feed_forward,dropout)

    encoder_list.append(encoder_block_out)

  for _ in range(N):
    decoder_self_attention = MultiHeadAttention(d_model,h,dropout)
    decoder_cross_attention = MultiHeadAttention(d_model,h,dropout)
    decoder_feed_forward = FeedForward(d_model,d_ff,dropout)
    decoder_block_out = DecoderBlock(encoder_self_attention,decoder_cross_attention,encoder_feed_forward,dropout)
    decoder_list.append(decoder_block_out)

  encoded = Encoder(d_model,nn.ModuleList(encoder_list))
  decoded = Decoder(d_model,nn.ModuleList(decoder_list))

  projection_layer = ProjectionLayer(d_model,tgt_vocab_size)
  transformer = Transformer(encoded,decoded,src_embed,tgt_embed,src_pos,tgt_pos,projection_layer)

  for p in transformer.Parameter():
    if p.dim>1:
      nn.init.xavier_uniform_(p)

  return transformer