In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt

In [None]:
#Input embeddings
class InputEmbeddings(nn.Module):
    def __init__(self, vocab_size, d_model):
        super(InputEmbeddings, self).__init__()
        self.d_model = d_model
        self.vocab_size = vocab_size
        self.embed = nn.Embedding(vocab_size, d_model)

    def forward(self, x):
        # (batch, seq_len) --> (batch, seq_len, d_model)
        return self.embed(x)*np.sqrt(self.d_model)

#positional Encoding
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, seq_len):
        super(PositionalEncoding, self).__init__()
        self.d_model = d_model
        self.seq_len = seq_len
        pe = torch.zeros(seq_len, d_model)
        position=torch.arange(0, seq_len, dtype=torch.float).unsqueeze(1) #(seq,1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe) # useful when we saving the model.

    def forward(self, x):
        x = x + (self.pe[:, :x.shape[1], :]).requires_grad_(False) # (batch, seq_len, d_model)
        return x

class ProjectionLayer(nn.Module):

    def __init__(self, d_model, vocab_size) -> None:
        super().__init__()
        self.proj = nn.Linear(d_model, vocab_size)

    def forward(self, x) -> None:
        # (batch, seq_len, d_model) --> (batch, seq_len, vocab_size)
        return self.proj(x)




In [None]:
#feed forward network
class FFN(nn.Module):
  def __init__(self, d_model, d_ff, dropout):
    super(FFN, self).__init__()
    self.d_model = d_model
    self.d_ff = d_ff
    self.dropout = dropout
    self.model=nn.Sequential(
        nn.Linear(d_model, d_ff),
        nn.ReLU(True),
        nn.Dropout(dropout),
        nn.Linear(d_ff, d_model)

    )
  def forward(self, x):
    return self.model(x)


#skip connection for vanishing gradinet problem, to transfer strong signal.
class ResidualConnection(nn.Module):

        def __init__(self, d_model):
            super().__init__()
            self.norm = nn.LayerNorm(d_model)

        def forward(self, x, sublayer):
            return self.norm(x + sublayer(x))


In [None]:
#Attention block
class SelfAttention(nn.Module):
  def __init__(self, d_model):
    super(SelfAttention, self).__init__()
    self.w_q = nn.Linear(d_model, d_model, bias=False) # Wq
    self.w_k = nn.Linear(d_model, d_model, bias=False) # Wk
    self.w_v = nn.Linear(d_model, d_model, bias=False) # Wv
  def forward(self,q,k,v,flag):
    Q = self.w_q(q)
    K = self.w_k(k)
    V = self.w_v(v)
    if flag=='Encoder': #checking encoder or decoder
      print('In encoder')
      attention_value_matrix=F.scaled_dot_product_attention(Q,K,V,is_causal=False)

    elif flag=='Decoder':
      attention_value_matrix=F.scaled_dot_product_attention(Q,K,V,is_causal=True)
      print('In decoder')
    else:
      print('Error')
    return attention_value_matrix


In [None]:
class EncoderLayer(nn.Module):
  def __init__(self, d_model, d_ff, dropout):
    super(EncoderLayer, self).__init__()
    self.self_attention = SelfAttention(d_model)
    self.ffn = FFN(d_model, d_ff, dropout)
    self.residual_connections = nn.ModuleList([ResidualConnection(d_model) for _ in range(2)])

  def forward(self, x):
    x = self.residual_connections[0](x, lambda x: self.self_attention(x,x,x,'Encoder'))
    x = self.residual_connections[1](x, self.ffn)
    return x


In [None]:
class Encoder(nn.Module):
  def __init__(self, d_model, d_ff, dropout):
    super(Encoder, self).__init__()
    self.encoderlayer = EncoderLayer(d_model, d_ff, dropout)

  def forward(self, x):
    x = self.encoderlayer(x)
    return x

In [None]:
#Decoder block
class DecoderBlock(nn.Module):
  def __init__(self, d_model, d_ff, dropout):
    super(DecoderBlock, self).__init__()
    self.self_attention = SelfAttention(d_model)
    self.cross_attention = SelfAttention(d_model)
    self.ffn = FFN(d_model, d_ff, dropout)
    self.residual_connections = nn.ModuleList([ResidualConnection(d_model) for _ in range(3)])

  def forward(self, x,encoder_output):
    x = self.residual_connections[0](x, lambda x: self.self_attention(x,x,x,'Decoder'))
    x = self.residual_connections[1](x, lambda x: self.cross_attention(x,encoder_output,encoder_output,'Encoder'))
    x = self.residual_connections[2](x, self.ffn)
    return x

In [None]:
class Decoder(nn.Module):
  def __init__(self, d_model, d_ff, dropout):
    super(Decoder, self).__init__()
    self.decoderlayer = DecoderBlock(d_model, d_ff, dropout)

  def forward(self, x,encoder_output):
    x = self.decoderlayer(x,encoder_output)
    return x

In [None]:
#Testing Block
torch.manual_seed(0)

#Sample input
src_seq_len = 5
target_seq_len = 6
src_seq = torch.randint(0, 1000, (1, src_seq_len))  # Batch size 1, sequence length 10, random vocabulary indices
target_seq = torch.randint(0, 1500, (1, target_seq_len)) # Batch size 1, sequence length 10, random vocabulary indices
print(src_seq)
print(target_seq)
print('#######################################################################')


# vocab and model dim
src_vocab_size = 1000
target_vocab_size = 1500
d_model = 4

# Instantiate InputEmbeddings
src_input_embeddings = InputEmbeddings(src_vocab_size, d_model)
target_input_embeddings = InputEmbeddings(target_vocab_size, d_model)

#positional encodings for source and target
src_positional_encoding = PositionalEncoding(d_model, src_seq_len)
target_positional_encoding = PositionalEncoding(d_model, target_seq_len)

# Generate input embeddings
src_embeddings = src_input_embeddings(src_seq)
target_embeddings = target_input_embeddings(target_seq)

# Add positional encoding
src_encoded_input = src_positional_encoding(src_embeddings)
target_encoded_input = target_positional_encoding(target_embeddings)

# feed forward neural network configs
d_ff = 2048
dropout = 0.1

#Instantiate Encoder
encoder_layer = Encoder(d_model, d_ff, dropout)

# Pass encoded input through the encoder layer
encoder_output = encoder_layer(src_encoded_input)

# Print the output of encoder shape
print("Output shape encoder:", encoder_output.shape)
print(encoder_output)
print('#######################################################################')
#Instantiate decoder
decoder_layer = Decoder(d_model, d_ff, dropout)

# Pass encoded input through the decoder layer
decoder_output = decoder_layer(target_encoded_input,encoder_output)

# Print the output of decoder shape
print("Output shape decoder:", decoder_output.shape)
print(decoder_output)
print('#######################################################################')
#projection layer
projection_layer = ProjectionLayer(d_model, target_vocab_size)

# Pass decoder output through the projection layer
proj_output = projection_layer(decoder_output)

# Print the output shape
print("Output shape projection:", proj_output.shape)
print(proj_output)
print('#######################################################################')
#apply softmax on projected output
softmax_output = F.softmax(proj_output, dim=-1)
print(f'logits_output_shape : {softmax_output.shape}')
print(f'logits : {softmax_output}')
print('#######################################################################')

tensor([[ 44, 239, 933, 760, 963]])
tensor([[ 879,  427, 1003,  997,  183,  101]])
#######################################################################
In encoder
Output shape encoder: torch.Size([1, 5, 4])
tensor([[[-4.2172e-01,  7.7449e-01, -1.4334e+00,  1.0806e+00],
         [ 1.5940e+00,  1.0219e-01, -9.1965e-01, -7.7652e-01],
         [ 9.0996e-01, -1.6295e+00,  7.9224e-04,  7.1876e-01],
         [-8.0718e-01,  1.4484e+00,  4.0221e-01, -1.0434e+00],
         [ 1.3678e+00, -1.4276e+00, -1.8148e-01,  2.4123e-01]]],
       grad_fn=<NativeLayerNormBackward0>)
#######################################################################
In decoder
In encoder
Output shape decoder: torch.Size([1, 6, 4])
tensor([[[ 1.0217,  0.3031, -1.6590,  0.3342],
         [-0.4455, -0.7246, -0.5534,  1.7234],
         [ 0.3227, -0.7446, -1.0643,  1.4861],
         [-0.6521, -1.2934,  0.9045,  1.0410],
         [-0.8802,  1.0959, -1.1081,  0.8924],
         [-1.1076,  0.1991,  1.5315, -0.6230]]],
       g