In [87]:
# Importing packages/libraries
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

In [None]:
def positional_encoding(d_model, max_seq_len=5000):
  all_idx = torch.arange(0, d_model, step=2).float()
  denominator = torch.pow(10000, all_idx/d_model)
  positions = torch.arange(0, max_seq_len).reshape(max_seq_len, 1).float()
  sin_idx = torch.sin(positions/denominator)
  cos_idx = torch.cos(positions/denominator)
  pe = torch.stack((sin_idx, cos_idx)).permute(1, 2, 0).flatten(start_dim=1, end_dim=2)
  return pe

In [118]:
class MultiHeadAttention(nn.Module):
  # here input dimension of model = d_model. In PyTorch's implementation embed_dim has been used instead of d_model.
  def __init__(self, d_model, num_heads, cross=False): # x [batch_size, sequence_length, embedding_dim]
    super(MultiHeadAttention, self).__init__()
    self.num_heads = num_heads
    self.head_dim = d_model // num_heads # splitting the query, key, value into multiple attention heads
    self.d_model = d_model # size of query, key, value vectors

    # we only use a single layer to compute all query, key, value  then split them 
    # vectors and to make our model faster as a single layer requires only one 
    # matrix multiplication while 3 layers would require 3 such multiplications
    self.cross = cross # boolean to store whether we are applying self attention or cross attention(as required by the decoder)
    if cross:
      self.linear_qk = nn.Linear(d_model, 2*d_model)
      self.linear_v  = nn.Linear(d_model, d_model)
    else:
      self.linear_qkv = nn.Linear(d_model, 3*d_model)
    self.linear_output = nn.Linear(d_model, d_model)
  
  def calculate_weights(self, q, k):
    att_weights = torch.matmul(q, k.transpose(-1, -2))
    scaled_weights = att_weights / math.sqrt(self.d_model)
    return scaled_weights
  
  def forward(self, x, y= None, mask= None):
    if not self.cross:
      qkv = self.linear_qkv(x) # x = batch_size, sequence_length, d_model
    else: 
      qk = self.linear_qk(y) # x = batch_size, sequence_length, d_model
      v  = self.linear_v(x)  # y = batch_size, sequence_length, d_model
      qkv = torch.cat((qk, v), dim = -1)
    batch_size, seq_len, d_model = qkv.size()
    qkv = qkv.view(batch_size, seq_len, self.num_heads, 3, self.head_dim).permute(0, 2, 1, 4, 3) 
    # after permuting = batch_size, num_heads, seq_len, head_dim, 3
    q, k, v = qkv.unbind(dim=-1)
    weights = self.calculate_weights(q, k)

    if mask is not None:
      weights = weights.masked_fill(mask == 0, -1e9)
    weights = torch.softmax(weights, dim=-1)

    # weights =  batch_size, num_heads, seq_len, seq_len
    # values  =  batch_size, num_heads, seq_len, head_dim
    updated_values = torch.einsum('bnij,bnjk->bnik', weights, v)
    updated_values = updated_values.reshape(batch_size, seq_len, self.num_heads * self.head_dim)

    output = self.linear_output(updated_values)
    return output

In [119]:
# This class has works only along the last dimension that is along the embedding dimension.
# We can make it more general by adding a parameter that computes the mean across batches as well.

class LayerNormalization(nn.Module):
  def __init__(self, d_model, epsilon = 1e-05):
    super().__init__()
    self.d_model = d_model
    self.epsilon = epsilon
    self.gammas = nn.Parameter(torch.ones(d_model))
    self.betas =  nn.Parameter(torch.ones(d_model))
  
  def forward(self, input_tensor):
    mean = input_tensor.mean(dim = -1, keepdim = True)
    std_dev = torch.sqrt(((input_tensor - mean) ** 2).mean(dim = -1, keepdim = True) + self.epsilon)
    normalized = (input_tensor - mean) / std_dev
    output = self.gammas * normalized + self.betas
    return output


In [120]:
class FeedForward(nn.Module):
  def __init__(self, d_model, hidden, dropout_prob =0.1):
    super().__init__()
    self.d_model = d_model
    self.hidden = hidden
    self.linear1 = nn.Linear(d_model, hidden)
    self.linear2 = nn.Linear(hidden, d_model)
    self.relu = nn.ReLU()
    self.dropout = nn.Dropout(p= dropout_prob)
  
  def forward(self, x):
    x = self.linear1(x)
    x = self.dropout(self.relu(x))
    x = self.linear2(x)
    return x
  

In [121]:
class EncoderBlock(nn.Module):
  def __init__(self, d_model, hidden, num_heads, dropout_prob):
    super().__init__()
    self.mul_head_att = MultiHeadAttention(d_model, num_heads)

    self.norm1 = LayerNormalization(d_model)
    self.norm2 = LayerNormalization(d_model)

    self.ff_layers = FeedForward(d_model, hidden, dropout_prob)

    self.dropout1 = nn.Dropout(p=dropout_prob)
    self.dropout2 = nn.Dropout(p=dropout_prob)

  def forward(self, x, mask):
    res_x = x.clone()
    x = self.norm1(self.dropout1(self.mul_head_att(x, mask)) + res_x)   
    res_x = x.clone()
    x = self.norm2(self.dropout2(self.ff_layers (x)) + res_x)
    return x



In [122]:
class Encoder(nn.Module):
  def __init__(self, d_model, hidden, num_heads, dropout_prob, num_layers, vocab_size, max_seq_len):
    super().__init__()
    self.max_seq_len = max_seq_len
    self.d_model = d_model

    self.embedding = nn.Embedding(vocab_size, d_model)
    self.dropout = nn.Dropout(p=dropout_prob)
    
    self.layers = nn.ModuleList([
    EncoderBlock(d_model, hidden, num_heads, dropout_prob)
      for _ in range(num_layers)
    ])

  def forward(self, x, mask_encoder):
    x = self.dropout(self.embedding(x) + positional_encoding(self.d_model, self.max_seq_len))
    for layer in self.layers:
      x = layer(x, mask_encoder)
    return x


In [123]:
class DecoderBlock(nn.Module):
  def __init__(self, d_model, hidden, num_heads, dropout_prob):
    super().__init__()

    self.mul_head_att1 = MultiHeadAttention(d_model, num_heads)
    self.mul_head_att2 = MultiHeadAttention(d_model, num_heads, True)
    
    self.norm1 = LayerNormalization(d_model)
    self.norm2 = LayerNormalization(d_model)
    self.norm3 = LayerNormalization(d_model)

    self.ff_layers = FeedForward(d_model, hidden, dropout_prob)

    self.dropout1 = nn.Dropout(p=dropout_prob)
    self.dropout2 = nn.Dropout(p=dropout_prob)
    self.dropout3 = nn.Dropout(p=dropout_prob)

  def forward(self, x, encoder_output, mask_encoder, mask_decoder):
    res_x = x.clone()
    x = self.norm1(self.dropout1(self.mul_head_att1(x=x, mask= mask_decoder)) + res_x)   
    
    res_x = x.clone()
    x = self.norm2(self.dropout2(self.mul_head_att2(x=x, y=encoder_output, mask = mask_encoder)) + res_x)

    res_x = x.clone()
    x = self.norm3(self.dropout3(self.ff_layers(x)) + res_x)  
    return x

class Decoder(nn.Module):
  def __init__(self, d_model, hidden, num_heads, dropout_prob, num_layers, vocab_size, max_seq_len):
    super().__init__()
    self.max_seq_len = max_seq_len
    self.d_model = d_model

    self.embedding = nn.Embedding(vocab_size, d_model)
    self.dropout = nn.Dropout(p=dropout_prob)
    self.layers = nn.ModuleList([DecoderBlock(d_model, hidden, num_heads, dropout_prob) for _ in range(num_layers)])

  def forward(self, x, encoder_output, mask_encoder, mask_decoder):
    x = self.dropout(self.embedding(x) + positional_encoding(self.d_model, self.max_seq_len))
    for layer in self.layers:
      x = layer(x, encoder_output, mask_encoder, mask_decoder)
    return x


In [125]:
class Transformer(nn.Module):
  def __init__(self,d_model, hidden, num_heads, dropout_prob, num_layers, input_vocab_size, output_vocab_size, max_seq_len):
    super().__init__()
    self.d_model = d_model
    self.max_seq_len = max_seq_len
    self.dropout = nn.Dropout(p = dropout_prob)
    self.encoder = Encoder(d_model, hidden, num_heads, dropout_prob, num_layers, input_vocab_size, max_seq_len)
    self.decoder = Decoder(d_model, hidden, num_heads, dropout_prob, num_layers, output_vocab_size, max_seq_len)
    self.linear = nn.Linear(d_model, output_vocab_size)
    self.softmax = nn.Softmax(dim=-1)

  # x will be tokenized sentences. The tokenization can be done by using some library such as nltk, spacy etc.
  def forward(self, x, y):
    x_mask = self.mask_gen(x, self.max_seq_len, False)
    y_mask = self.mask_gen(y, self.max_seq_len, True)
    encoder_output = self.encoder(x, x_mask)
    output = self.softmax(self.linear(self.decoder(y, encoder_output, x_mask, y_mask)))
    return output
    
  def mask_gen(self, x, max_seq_len, lookahead):   
    if lookahead:
      mask = (x != 0).unsqueeze(1).unsqueeze(3)
      lookahead_mask = (1 - torch.triu(torch.ones(1, max_seq_len, max_seq_len), diagonal=1)).bool()
      mask = mask & lookahead_mask
    else:
      mask = (x != 0).unsqueeze(1).unsqueeze(2)
    return mask



In [127]:
d_model = 256
num_heads = 8
dropout_prob = 0.1
batch_size = 16
max_sequence_length = 200
hidden = 1024
num_layers = 1
input_vocab_size = 200
output_vocab_size = 156

# x is input while y is output
x = torch.randint(1, input_vocab_size, (batch_size, max_sequence_length))  # (batch_size, seq_length), The generated random integers will be within the range [1, input_vocab_size).
y = torch.randint(1, output_vocab_size, (batch_size, max_sequence_length))  # (batch_size, seq_length), The generated random integers will be within the range [1, output_vocab_size).

transformer = Transformer(d_model, hidden, num_heads, dropout_prob, num_layers, input_vocab_size, output_vocab_size, max_sequence_length)
output = transformer(x,y)
print(output.shape) # (batch_size, max_seq_length, output_vocab_size)


torch.Size([16, 200, 156])
