# Transformers implementation in PyTorch

## Imports

In [48]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from collections import Counter
import numpy as np
import re
import math

In [16]:
torch.manual_seed(23)

<torch._C.Generator at 0x7bbcc38c58d0>

In [49]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cpu


In [18]:
MAX_SEQ_LEN = 30

## Transformer modules definitions

In [62]:
class PositionalEmbedding(nn.Module):
  def __init__(self, d_model, max_seq_len=MAX_SEQ_LEN):
    super().__init__()
    self.pos_embed_matrix = torch.zeros(max_seq_len, d_model, device=device)
    token_pos = torch.arange(0, max_seq_len, dtype=torch.float).unsqueeze(1)
    div_term = torch.exp(torch.arange(0, d_model, 2).float()
                          *(-math.log(10000.0)/d_model))
    self.pos_embed_matrix[:, 0::2] = torch.sin(token_pos * div_term)
    self.pos_embed_matrix[:, 1::2] = torch.cos(token_pos * div_term)
    self.pos_embed_matrix = self.pos_embed_matrix.unsqueeze(0).transpose(0,1)

  def forward(self, x):
    return x + self.pos_embed_matrix[:x.size(0), :]

In [51]:
class PositionFeedForward(nn.Module):
  def __init__(self, d_model, d_ff):
    super().__init__()
    self.linear1 = nn.Linear(d_model, d_ff)
    self.linear2 = nn.Linear(d_ff, d_model)

  def forward(self, x):
    return self.linear2(F.relu(self.linear1(x)))

In [65]:
class MultiHeadAttention(nn.Module):
  def __init__(self, d_model=512, num_heads=8):
    super().__init__()
    assert d_model % num_heads == 0, 'Embedding size not compatible with num_heads'
    self.d_v = d_model // num_heads
    self.d_k = self.d_v
    self.num_heads = num_heads

    self.W_q = nn.Linear(d_model, d_model)
    self.W_k = nn.Linear(d_model, d_model)
    self.W_v = nn.Linear(d_model, d_model)
    self.W_o = nn.Linear(d_model, d_model)

  def forward(self, Q, K, V, mask=None):
    batch_size = Q.size(0)
    Q = self.W_q(Q).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
    K = self.W_k(K).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
    V = self.W_v(V).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)

    weighted_values, attention = self.scale_dot_product(Q, K, V, mask)
    weighted_values = weighted_values.transpose(1,2).contiguous().view(batch_size, -1, self.num_heads * self.d_k)
    weighted_values = self.W_o(weighted_values)

    return weighted_values, attention

  def scale_dot_product(self, Q, K, V, mask=None):
    scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
    if mask is not None:
      scores = scores.masked_fill(mask == 0, -1e9)

    attention = F.softmax(scores, dim=-1)
    weighted_values = torch.matmul(attention, V)

    return weighted_values, attention

In [53]:
class EncoderSubLayer(nn.Module):
  def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
    super().__init__()
    self.self_attention = MultiHeadAttention(d_model, num_heads)
    self.ffn = PositionFeedForward(d_model, d_ff)
    self.norm1 = nn.LayerNorm(d_model)
    self.norm2 = nn.LayerNorm(d_model)
    self.dropout1 = nn.Dropout(dropout)
    self.dropout2 = nn.Dropout(dropout)

  def forward(self, x, mask=None):
    attention_score, _ = self.self_attention(x, x, x, mask)
    x = x + self.dropout1(attention_score)
    x = self.norm1(x)
    x = x + self.dropout2(self.ffn(x))
    return self.norm2(x)


In [54]:
class Encoder(nn.Module):
  def __init__(self, d_model, num_heads, d_ff, num_layers, dropout=0.1):
    super().__init__()
    self.layers = nn.ModuleList([EncoderSubLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)])
    self.norm = nn.LayerNorm(d_model)

  def forward(self, x, mask=None):
    for layer in self.layers:
      x = layer(x, mask)

    return self.norm(x)

In [55]:
class DecoderSubLayer(nn.Module):
  def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
    super().__init__()
    self.self_attention = MultiHeadAttention(d_model, num_heads)
    self.cross_attention = MultiHeadAttention(d_model, num_heads)
    self.feed_forward = PositionFeedForward(d_model, d_ff)
    self.norm1 = nn.LayerNorm(d_model)
    self.norm2 = nn.LayerNorm(d_model)
    self.norm3 = nn.LayerNorm(d_model)
    self.dropout1 = nn.Dropout(dropout)
    self.dropout2 = nn.Dropout(dropout)
    self.dropout3 = nn.Dropout(dropout)

  def forward(self, x, encoder_output, target_mask=None, encoder_mask=None):
    attention_score, _ = self.self_attention(x, x, x, target_mask)
    x = x + self.dropout1(attention_score)
    x = self.norm1(x)

    encoder_attention, _ = self.cross_attention(x, encoder_output, encoder_output, encoder_mask)
    x = x + self.dropout2(encoder_attention)
    x = self.norm2(x)

    ff_output = self.feed_forward(x)
    x = x + self.dropout3(ff_output)

    return self.norm3(x)

In [56]:
class Decoder(nn.Module):
  def __init__(self, d_model, num_heads, d_ff, num_layers, dropout=0.1):
    super().__init__()
    self.layers = nn.ModuleList([DecoderSubLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)])
    self.norm = nn.LayerNorm(d_model)

  def forward(self, x, encoder_output, target_mask, encoder_mask):
    for layer in self.layers:
      x = layer(x, encoder_output, target_mask, encoder_mask)

    return self.norm(x)

In [60]:
class Transformer(nn.Module):
  def __init__(self, d_model, num_heads, d_ff, num_layers, input_vocab_size,
               target_vocab_size, max_len=MAX_SEQ_LEN, dropout=0.1):
    """
      d_model: dimension of the input embeddings
      num_heads: number of attention heads
      d_ff: dimension of the feedforward layer
      num_layers: number of transformer layers
      input_vocab_size: size of the input vocabulary
      target_vocab_size: size of the target vocabulary
      max_len: maximum length of the input sequence
      dropout: dropout rate
    """
    super().__init__()
    self.encoder_embedding = nn.Embedding(input_vocab_size, d_model)
    self.decoder_embedding = nn.Embedding(target_vocab_size, d_model)
    self.pos_embedding = PositionalEmbedding(d_model, max_len)
    self.encoder = Encoder(d_model, num_heads, d_ff, num_layers, dropout)
    self.decoder = Decoder(d_model, num_heads, d_ff, num_layers, dropout)
    self.output_layer = nn.Linear(d_model, target_vocab_size)

  def forward(self, source, target):
    # Encoder Mask
    source_mask, target_mask = self.mask(source, target)
    # Encoder Embedding and positional Encoding
    source = self.encoder_embedding(source) * math.sqrt(self.encoder_embedding.embedding_dim)
    source = self.pos_embedding(source)
    # Encoder
    encoder_output = self.encoder(source, source_mask)

    # Decoder embedding and positiona encoding
    target = self.decoder_embedding(target) * math.sqrt(self.decoder_embedding.embedding_dim)
    target = self.pos_embedding(target)
    # Decoder
    decoder_output = self.decoder(target, encoder_output, target_mask, source_mask)

    return self.output_layer(decoder_output)

  def mask(self, source, target):
    source_mask = (source != 0).unsqueeze(1).unsqueeze(2)
    target_mask = (target != 0).unsqueeze(1).unsqueeze(2)
    size = target.size(1)
    no_mask = torch.tril(torch.ones((1, size, size), device=device)).bool()
    target_mask = target_mask & no_mask
    return source_mask, target_mask

## Testing the transformer code

In [66]:
seq_len_source = 10
seq_len_target = 10
batch_size = 2
input_vocab_size = 50
target_vocab_size = 50

source = torch.randint(1, input_vocab_size, (batch_size, seq_len_source))
target = torch.randint(1, target_vocab_size, (batch_size, seq_len_target))

d_model = 512
num_heads = 8
d_ff = 2048
num_layers = 6

model = Transformer(d_model, num_heads, d_ff, num_layers, input_vocab_size,
                    target_vocab_size, max_len=MAX_SEQ_LEN, dropout=0.1)
model.to(device)
source = source.to(device)
target = target.to(device)
output = model(source, target)

# Expected output shape [batch, seq len target, target vocab size] i.e. [2, 10, 50]
print(f"output.shape: {output.shape}")

output.shape: torch.Size([2, 10, 50])
