In [1]:
import torch
import torch.nn as nn
import numpy as np

# Self Multi headed attention

<img src="https://d2beiqkhq929f0.cloudfront.net/public_assets/assets/000/020/062/original/Screenshot_2022-11-18_at_9.18.24_AM.png?1668744025" widht=800>

In [2]:
class MHSelfAttention(nn.Module):
  def __init__(self, embed_size, num_heads):
    super(MHSelfAttention, self).__init__()
    self.embed_size = embed_size
    self.num_heads = num_heads
    self.head_dim = self.embed_size // self.num_heads

    assert (self.head_dim * num_heads == embed_size)
    # defining K, Q and V matrices (embed_size, embed_size)
    self.V = nn.Linear(embed_size, embed_size)
    self.K = nn.Linear(embed_size, embed_size)
    self.Q = nn.Linear(embed_size, embed_size)

    self.softmax = nn.Softmax(dim=3)
    
    # takes the concated heads as input (all heads) (Multi head layer) and results some linear mapped output of embed_size
    # self.head_dim * num_heads == embed_size
    self.fc_out = nn.Linear(self.head_dim * num_heads, embed_size)

  def forward(self, v, k, q, mask=None):

    # keys, vlaues, queries: [B, seq, embed_size]
    keys = self.K(k)
    values = self.V(v)
    queries = self.Q(q)

    # keys, vlaues, queries: [B, seq, num_heads, head_dim]
    batch_size  = keys.shape[0]
    value_len, key_len, query_len = values.shape[1], keys.shape[1], queries.shape[1]


    # lengths of keys, values and queries can be change because
    # in Decoder block attention was calculated using Query of decoder but keys and values of encoder ouput
    keys = keys.reshape(batch_size, key_len, self.num_heads, self.head_dim)
    values = values.reshape(batch_size, value_len, self.num_heads, self.head_dim)
    queries = queries.reshape(batch_size, query_len, self.num_heads, self.head_dim)

    energies = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])
    # print(energies.shape, mask.shape)
    if mask is not None:
      # print(energies.shape)
      energies = energies.masked_fill(mask == 0, float("-inf"))

    attention = self.softmax(energies/(self.embed_size)**0.5)
    out = torch.einsum('nhql,nlhd->nqhd', [attention, values])
    out = out.reshape(batch_size, query_len, embed_size) # concating all the multiple attentions
    out = self.fc_out(out)
    # out: [B, query_len, embed_size]
    return out

# Block

<img src='https://d2beiqkhq929f0.cloudfront.net/public_assets/assets/000/020/094/original/Screenshot_2022-11-18_at_2.50.50_PM.png?1668763324' width=200 height=200>

In [3]:
class Block(nn.Module):
  def __init__(self, embed_size, num_heads, dropout, forward_expansion=4):
    super(Block, self).__init__()
    self.attention = MHSelfAttention(embed_size, num_heads)
    self.norm_layer1 = nn.LayerNorm(embed_size)
    self.norm_layer2 = nn.LayerNorm(embed_size)

    self.feed_forward = nn.Sequential(
        nn.Linear(embed_size, forward_expansion*embed_size),
        nn.ReLU(),
        nn.Linear(forward_expansion * embed_size, embed_size),
    )

    self.dropout = nn.Dropout(dropout)

  def forward(self, value, key, query, mask):
    attention = self.attention(value, key, query, mask)
    norm1 = self.norm_layer1(attention + query)
    x = self.dropout(norm1)
    forward = self.feed_forward(x)
    out = self.norm_layer2(forward + x)
    return out

# Positional Encoding
<img src="https://d2beiqkhq929f0.cloudfront.net/public_assets/assets/000/020/097/original/Screenshot_2022-11-18_at_3.37.35_PM.png?1668766059">

In [4]:
class PositionalEncoding(nn.Module):
  def __init__(self, embeds, max_len, device):
    super(PositionalEncoding, self).__init__()
    self.embed_size = embeds
    self.max_len = max_len
    self.device = device

  def forward(self, x):
    encoding = torch.zeros(self.max_len, self.embed_size, device=self.device)
    # encoding.requires_grad = False
    pos = torch.arange(0, self.max_len, device=self.device)
    pos = pos.float().unsqueeze(dim=1)
    i = torch.arange(0, self.embed_size, step=2, device=self.device).float()

    encoding[:, 0::2] = torch.sin(pos / (10000 ** (i / self.embed_size)))
    encoding[:, 1::2] = torch.cos(pos / (10000 ** (i / self.embed_size)))
    
    batch_size, seq_len = x.size()
    return encoding[:seq_len, :].expand(batch_size, seq_len, self.embed_size)


# Encoder

<img src="https://d2beiqkhq929f0.cloudfront.net/public_assets/assets/000/020/095/original/Screenshot_2022-11-18_at_2.58.42_PM.png?1668763726" width=300 height=400>

In [5]:
class Encoder(nn.Module):
  def __init__(self, src_vocab_size, embed_size, num_layers, heads, device, forward_expansion, dropout,max_length):
    super(Encoder, self).__init__()
    self.embed_size = embed_size
    self.device = device

    self.word_embedding = nn.Embedding(src_vocab_size, embed_size)
    self.positional_encoding = PositionalEncoding(embed_size, max_length, device)

    self.layers = nn.ModuleList([
        Block(embed_size,heads,dropout=dropout,forward_expansion=forward_expansion) for _ in range(num_layers)
        ])

    self.dropout = nn.Dropout(dropout)

  def forward(self, x, mask):
    batch_size, seq_length = x.shape
    out = self.dropout(
        (self.word_embedding(x) + self.positional_encoding(x))
    )

    # for encoder key, value and query will be same as input (or output of prev encoder layer)
    for layer in self.layers:
        out = layer(out, out, out, mask)

    return out

# Decoder Block

<img src="https://d2beiqkhq929f0.cloudfront.net/public_assets/assets/000/020/099/original/Screenshot_2022-11-18_at_4.02.02_PM.png?1668767521" height=400>

In [6]:
class DecoderBlock(nn.Module):
  def __init__(self, embed_size, num_heads, forward_expansion, dropout, device):
      super(DecoderBlock, self).__init__()
      self.norm = nn.LayerNorm(embed_size)
      self.attention = MHSelfAttention(embed_size, num_heads)
      self.block = Block(
          embed_size, num_heads, dropout, forward_expansion
      )
      self.dropout = nn.Dropout(dropout)

  def forward(self, x, value, key, src_mask, trg_mask):
    # using src_mask so that there will be no extra calculations for 0 padded indices in sentence
      attention = self.attention(x, x, x, trg_mask)
      query = self.dropout(self.norm(attention + x))
      out = self.block(value, key, query, src_mask)
      return out

# Decoder
<img src="https://d2beiqkhq929f0.cloudfront.net/public_assets/assets/000/020/098/original/Screenshot_2022-11-18_at_4.00.38_PM.png?1668767466" height=500>

In [7]:
class Decoder(nn.Module):
  def __init__(self, trg_vocab_size, embed_size, num_layers, num_heads, forward_expansion, dropout, device, max_length):
    super(Decoder, self).__init__()
    self.device = device
    self.word_embeddings = nn.Embedding(trg_vocab_size, embed_size)
    self.positional_encoding = PositionalEncoding(embed_size, max_length, device)

    self.layers = nn.ModuleList(
        [
            DecoderBlock(embed_size, num_heads, forward_expansion, dropout, device) for _ in range(num_layers)
        ]
    )
    self.fc_out = nn.Linear(embed_size, trg_vocab_size)
    self.dropout = nn.Dropout(dropout)

  def forward(self, x, encoder_out, src_mask, trg_mask):
    batch_size, seq_length = x.shape

    x = self.dropout((self.word_embeddings(x) + self.positional_encoding(x)))

    for layer in self.layers:
        x = layer(x, encoder_out, encoder_out, src_mask, trg_mask)

    out = self.fc_out(x)

    return out


# Combining all together into Transformer

<img src="https://d2beiqkhq929f0.cloudfront.net/public_assets/assets/000/020/124/original/Screenshot_2022-11-18_at_9.23.52_PM.png?1668786833" width=500 height=500>

In [8]:
class Transformer(nn.Module):
  def __init__(
          self, src_vocab_size, trg_vocab_size, src_pad_idx, trg_pad_idx, embed_size, 
          num_layers, forward_expansion, num_heads, dropout,device,max_length):
    super(Transformer, self).__init__()

    self.src_pad_idx = src_pad_idx
    self.trg_pad_idx = trg_pad_idx
    self.device = device
    self.num_heads = num_heads
    self.head_dim = embed_size//num_heads

    self.encoder = Encoder(src_vocab_size,embed_size,num_layers,num_heads,
                           device,forward_expansion,dropout,max_length)

    self.decoder = Decoder(trg_vocab_size,embed_size,num_layers,num_heads,
                           forward_expansion,dropout,device,max_length,)
    
  def make_source_mask(self, src):
    batch_size, seq_len = src.shape
    src_mask = (src != self.src_pad_idx).unsqueeze(1).unsqueeze(2)
    # (N, 1, 1, src_len)
    return src_mask.to(self.device)

  def make_target_mask(self, trg):
    N, trg_len = trg.shape
    trg_mask = torch.tril(torch.ones((trg_len, trg_len))).expand(N, 1, trg_len, trg_len)

    return trg_mask.to(self.device)

  def forward(self, source, target):
    src_mask = self.make_source_mask(source)
    trg_mask = self.make_target_mask(target)
    enc_src = self.encoder(source, src_mask)
    out = self.decoder(target, enc_src, src_mask, trg_mask)
    return out

In [9]:
seq_len = 12
batch_size = 32
embed_size = 512 # defined in paper
num_heads = 8 # defined in paper
num_layers = 6 # define in paper
src_vocab_size = 1024
trg_vocab_size = 2048
forward_expansion = 4
src_pad_idx = 0
trg_pad_idx = 0
dropout = 0.10
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

inputs = torch.randint(low=0, high=10, size=(batch_size, seq_len)).to(device)
target = torch.randint(low=0, high=10, size=(batch_size, seq_len+3)).to(device)

# max length is needed to perform the positional encoding
max_length = 1000

model = Transformer(src_vocab_size, trg_vocab_size, src_pad_idx, trg_pad_idx, 
                    embed_size, num_layers, forward_expansion, num_heads, dropout, device, max_length).to(device)
preds = model(inputs, target)

In [10]:
preds.shape, preds.isnan().all()

(torch.Size([32, 15, 2048]), tensor(False))