In [1]:
import torch
import torch.nn as nn
import math

In [2]:
#first layer in the transformer: the input embeddings
class InputEmbedding(nn.Module):
    def __init__(self,d_model:int,vocab_size:int):
        super().__init__()
        self.d_model=d_model
        self.vocab_size=vocab_size
        self.embedding=nn.Embedding(vocab_size,d_model)

    def forward(self,x):
        return self.embedding(x)*math.sqrt(self.dmodel)

In [3]:
#second layer: positional encoding
class PositionalEncoding(nn.Module):
    def __init__(self,d_model:int, seq_len:int,dropout:float) ->None:
        super().__init__()
        self.d_model=d_model
        self.seq_len=seq_len
        self.dropout=nn.Dropout(dropout)

        #creating matrix shape(seq_len,d_model)
        pe=torch.zeros(seq_len,d_model)
        #vector od size(seq_len,1)
        #applying the positional encoding fromulars(sin,cos)
        position=torch.arange(0,seq_len,dtype=torch.float32).unsqueeze(1)
        div_term=torch.exp(torch.arange(0,d_model,2).float()*(math.log(10000)/d_model))
         #sin for even indeces
        pe[:,0::2]=torch.sin(position*div_term)
        #ccos for odd indeces
        pe[:,1::2]=torch.cos(position*div_term)
         #adding an extra dimension ie(1,seq_len,d_model)
        pe=pe.unsqueeze(0)
        #registering the tensor to the module buffer
        self.register_buffer('pe',pe)
    def forward(self,x):
        x=x+(self.pe[:, :x.shape[1],:]).requires_grad(False)
        return self.dropout(x)




In [4]:
#layer 3: layer normalization

class LayerNormalization(nn.Module):
    def __init__(self,eps:float =10**-6)->None:
        super().__init__()
        self.eps=eps
        #multiplicative
        self.alpha=nn.Parameter(torch.ones(1))
        #additive
        self.bias=nn.Parameter(torch.zeros(1))

    def forward(self,x):
        mean=x.mean(dim=-1,keepdim=True)
        std=x.std(dim=-1,keepdim=True)
        return self.alpha*(x-mean)/ (std+self.eps) + self.bias


In [5]:
#layer 5 : the feed forward network
class FeedForward(nn.Module):
    def __init__(self,d_model:int,d_ff:int,dropout:float)->None:
        super().__init__()
        self.linear_1=nn.Linear(d_model,d_ff)
        self.dropout=nn.Dropout(dropout)
        self.linear_2=nn.Linear(d_ff,d_model)

    def forward(self,x):
        return self.linear_2(self.dropout(torch.relu(self.linear_1)))


-multihead attention takes the output of positional encoding and uses it three times 
as key,query and value as its 

key,query and value are multiplied by some matrix(weights) to obtain some matrix of same size as the input
the new matrices are then each split severally to create several heads

attention is then applied to all heads
concatination of heads is done then multiply by some weights to obtain the output


In [6]:
#layer 6: multihead attention

class MultiheadAttention(nn.Module):
    def __init__(self,d_model:int,h:int,dropout:float)->None:
        super().__init__()
        self.d_model=d_model
        self.h=h
        #ensure dmodel is divisle by h
        assert d_model % h==0, "d_model is not divisible by h"
        self.d_k=d_model/h
        
        #weight matrices
        self.w_q=nn.Linear(d_model,d_model)#for query
        self.w_k=nn.Linear(d_model,d_model)#for key
        self.w_v=nn.Linear(d_model,d_model)#for value
         #weight to multiply the heads later after concatination
        self.wo=nn.Linear(d_model,d_model)

        self.dropout=nn.Dropout(dropout)


     #calculating attention using the formular( Vaswan et al.,2017)
    @staticmethod #creating static method to avoid creating class instance b4 calling it
    def attention(query,key,value,mask,dropout:nn.Dropout):
        d_k=query.shape[-1]
        attention_scores=(query @ key.transpose(-2,-1))/math.sqrt(d_k)
        if mask is not None:
            attention_scores.masked_fill_(mask==0, -1e9)
        attention_scores=attention_scores.softmax(dim=-1) # (batch,h, seq_len,d_model)
        if dropout is not None:
            attention_scores=dropout(attention_scores)
        return (attention_scores @ value), attention_scores


    def forward(self,q,k,v,mask):
        #(batch,seq_len,d_model) => (batch,seq_len,d_model)
        query=self.w_q(q) 
        key=self.w_k(k)
        value=self.w_v(v)

         #(barch,seqlen,d_model)=>(batch,seq_len, h, d_k)=>(batch,h,seqlen,d_k)
        query=query.view(query.shape[0], query.shape[1],self.h, self.d_k).transpose(1,2)
        key=key.view(key.shape[0],key.shape[1],self.h,self.d_k).transpose(1,2)
        value=value.view(value.shape[0],value.shape[1],self.h,self.d_k).transpose(1,2)

        x, self.attention_score = MultiheadAttention.attention(query,key,value,mask,self.dropout)
        #(batch,h,seqlen,dk)=>(batch,seqlen,h,d_k)=====>(batch,seqlen,dmodel)
        x=x.transpose(1,2).contiguous().view(x.shape[0],-1,self.h*self.d_k)

        return self.w_o(x)



        





In [7]:
#   Connection layer now to do the connection

class ResidualConnection(nn.Module):
    def __init__(self, dropout:float) ->None:
        super().__init__()
        self.dropout=nn.Dropout(dropout)
        self.norm=LayerNormalization()

    def forward(self,x,sublayer):
        return x+self.dropout(sublayer(self.norm(x)))


since in the paper (Vaswan et al.,2017) the multihead attention layer, add and norm layer, and the feed forward layer are all placed in one block(encoder)
then they are Nx number of times, am going to place them now in a block

In [8]:
class EncoderBlock(nn.Module):
    def __init__(self,self_attention_block:MultiheadAttention, feed_forward_block:FeedForward,dropout:float)->None:
        super().__init__()
        self.self_attention_block=self.self_attention_block
        self.feed_forward_block=feed_forward_block
        self.residual_connection=nn.ModuleList([ResidualConnection(dropout) for _ in range(2)])
    def forward(self,x,src_mask):
        x=self.residual_connection[0](x,lambda x:self.self_attention_block(x,x,x,src_mask))
        x=self.residual_connection[1](x, self.feed_forward_block)
        return x
    

In [9]:
class Encoder(nn.Module):
    def __init__(self, layers:nn.ModuleList)->None:
        super().__init__()
        self.layer=layers
        self.norm=LayerNormalization()
    def forward(self,x,mask):
        for layer in self.layer:
            x=layer(x,mask)
        return x

Part 2:   The Decoder Layer

In [10]:
class DecoderBlock(nn.Module):
    def __init__(self,self_attention_block:MultiheadAttention, cross_attention_block:MultiheadAttention,feed_forward_block:FeedForward,dropout:float )->None:
        super().__init__()
        self.self_attention_block=self_attention_block
        self.cross_attention_block=cross_attention_block
        self.feed_forward_block=feed_forward_block

        self.residual_connections=nn.ModuleList([ResidualConnection(dropout) for _ in range(3)])

    def forward(self, x, encoder_output,src_mask,tgt_mask):
        x=self.residual_connections[0](x, lambda x: self.self_attention_block(x,x,x,tgt_mask))
        x=self.residual_connections[1](x, lambda x: self.cross_attention_block(x,encoder_output,encoder_output,src_mask))
        x=self.residual_connections[2](x,self.feed_forward_block)
        return x
    

In [None]:
class Decoder(nn.ModuleList):
    def __init__(self,layers:nn.ModuleList):
        super().__init__()
        self.layers=layers
        self.norm=LayerNormalization
    def forward(self,x,encoder_output,src_mask,tgt_mask):
        for layer in self.layers:
            x=layer(x,encoder_output,src_mask,tgt_mask)
        return self.norm(x)


In [12]:
class ProjectionLayer(nn.Module):
    def __init__(self,d_model:int, vocab_size:int )->None:
        super().__init__()
        self.proj=nn.Linear(d_model,vocab_size)
    def forward(self,x):
        #(batch,seqlen,d_model)---->(batch,seqlen,vocabsize)
        return torch.log_softmax(self.proj(x),dim=-1)



In [13]:
#transformer block now!

class Transformer(nn.Module):
    def __init__(self, encoder:Encoder, decoder:Decoder,src_embed:InputEmbedding, tgt_embed:InputEmbedding, src_pos:PositionalEncoding, tgt_pos:PositionalEncoding, projection_layer:ProjectionLayer)->None:
        super().__init__()
        self.encoder=encoder
        self.decoder=decoder
        self.src_embed=src_embed
        self.tgt_embed=tgt_embed
        self.src_pos=src_pos
        self.tgt_pos=tgt_pos
        self.projection_layer=projection_layer

    def encoder(self,src,src_mask):
        src=self.embed(src)
        src=self.src_pos(src)
        return self.encoder(src,src_mask)
    def decoder(self,encoder_output,src_mask,tgt,tgt_mask):
        tgt=self.tgt_embed(tgt)
        tgt=self.tgt_pos(tgt)
        return self.decoder(tgt,encoder_output,src_mask,tgt_mask)
    def project(self,x):
        return self.projection_layer(x)

In [14]:
def build_transformer(src_vocab_size:int,tgt_vocab_size:int,src_seq_len:int,tgt_seq_len:int, d_model:int = 512, N:int=6,h:int=8,dropout:float=0.1, d_ff:int=2048):
    src_embed=InputEmbedding(d_model,src_vocab_size)
    tgt_embed=InputEmbedding(d_model,tgt_vocab_size)

    src_pos=PositionalEncoding(d_model,src_seq_len,dropout)
    tgt_pos=PositionalEncoding(d_model,tgt_seq_len,dropout)

    encoder_blocks = []
    for _ in range(N):
        encoder_self_attention_block=MultiheadAttention(d_model,dropout)
        feed_forward_block=FeedForward(d_model,d_ff,dropout)
        encoder_block=EncoderBlock(encoder_self_attention_block,feed_forward_block,dropout)
        encoder_blocks.append(encoder_block)

    decoder_blocks=[]
    for _ in range(N):
        decoder_self_attention_block=MultiheadAttention(d_model,h,dropout)
        decoder_cross_attention_block=MultiheadAttention(d_model,h,dropout)
        feed_forward_block=FeedForward(d_model,d_ff,dropout)
        decoder_block=DecoderBlock(decoder_self_attention_block,decoder_cross_attention_block,feed_forward_block,dropout)
        decoder_blocks.append(decoder_block)

    
    encoder=Encoder(nn.ModuleList(encoder_blocks))
    decoder=Decoder(nn.ModuleList(decoder_blocks))

    projection_layer=ProjectionLayer(d_model,tgt_vocab_size)

    transformer=Transformer(encoder,decoder,src_embed,tgt_embed,src_pos,tgt_pos,projection_layer)

    for p in transformer.parameters():
        if p.dim()>1:
            nn.init.xavier_uniform_(p)
    return transformer



