<a href="https://colab.research.google.com/github/ZhengyuanCui/Deep-Learning/blob/master/Transformer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import torch.nn as nn

In [28]:
class Attention(nn.Module):
  def __init__(self, embed_size: int, heads: int) -> None:
    super().__init__()
    self.embed_size = embed_size
    self.num_heads = heads
    self.head_dim = embed_size // heads

    assert embed_size == heads * self.head_dim

    self.key = nn.Linear(self.head_dim, self.head_dim, bias=False)
    self.query = nn.Linear(self.head_dim, self.head_dim, bias=False)
    self.value = nn.Linear(self.head_dim, self.head_dim, bias=False)
    self.fc_out = nn.Linear(self.embed_size, self.embed_size)
  def forward(self, key, query, value, mask):
    N, S_k, T_k = key.shape
    N, S_q, T_q = query.shape # minibatch size, sequence length, embedding size
    N, S_v, T_v = value.shape # For decoder, S_k == S_v but not S_q
    
    k = self.key(key.reshape(N, S_k, self.num_heads, self.head_dim))
    q = self.query(query.reshape(N, S_q, self.num_heads, self.head_dim))
    v = self.value(value.reshape(N, S_v, self.num_heads, self.head_dim))

    energy = torch.einsum("ijkl,imkl -> ikmj", [k, q]) # (N, S_k, heads, head_dim) * (N, S_q, heads, head_dim) -> (N, heads, S_q, S_k)
    # Above computation means that the head_dimension and heads should be the same for src and targect value

    if mask is not None:
      energy = energy.masked_fill(mask == 0, float("-1e20")) # to make this work
    attention = torch.softmax(energy / self.head_dim**(1/2), dim = 3)
    # (N, heads, S_q, S_k) * (N, S_v, heads, head_dim) -> (N, S_q, heads, head_dim)
    out = torch.einsum("ikjm,imkl->ijkl", [attention, v]).reshape(N, S_q, self.embed_size)

    return self.fc_out(out)

In [3]:
class TransformerBlock(nn.Module):
  def __init__(self, embed_size, heads, dropout, forward_expansion):
    super().__init__()
    self.attention = Attention(embed_size, heads)
    self.norm1 = nn.LayerNorm(embed_size)
    self.norm2 = nn.LayerNorm(embed_size)
    self.linear = nn.Sequential(
        nn.Linear(embed_size, forward_expansion*embed_size),
        nn.ReLU(),
        nn.Linear(forward_expansion*embed_size, embed_size)
    )
    self.dropout = nn.Dropout(dropout)

  def forward(self, key, query, value, mask):
    out = self.attention(key, query, value, mask)
    x = self.dropout(self.norm1(query + out))
    out = self.linear(x)
    out = self.dropout(self.norm2(out + x))
    return out

In [10]:
class Encoder(nn.Module):
  def __init__(
      self,
      src_vocab_size,
      embed_size,
      num_layers,
      heads,
      dropout,
      forward_expansion,
      device,
      max_length
  ):
    super().__init__()
    self.embed_size = embed_size
    self.device = device
    self.content_embedding = nn.Embedding(src_vocab_size, embed_size)
    self.positional_embedding = nn.Embedding(max_length, embed_size)
    self.layers = nn.ModuleList([TransformerBlock(embed_size, heads, dropout, forward_expansion) for _ in range(num_layers)])
    self.dropout = nn.Dropout(dropout)

  def forward(self, input, mask):
    N, seq_length = input.shape
    positions = torch.arange(0, seq_length).expand(N, seq_length).to(self.device)
    out = self.dropout(self.content_embedding(input) + self.positional_embedding(positions))
    for layer in self.layers:
      out = layer(out, out, out, mask)
    return out

In [20]:
class DecoderBlock(nn.Module):
  def __init__(self, embed_size, heads, dropout, forward_expansion):
    super().__init__()
    self.attention = Attention(embed_size, heads)
    self.norm = nn.LayerNorm(embed_size)
    self.transformer_block = TransformerBlock(embed_size, heads, dropout, forward_expansion)
    self.dropout = nn.Dropout(dropout)

  def forward(self, input, keys_en, values_en, src_mask, trg_mask): # src_mask: pad to achieve equal length
    attention = self.attention(input, input, input, trg_mask)
    query = self.dropout(self.norm(attention + input))
    out = self.transformer_block(keys_en, query, values_en, src_mask)
    return out

class Decoder(nn.Module):
  def __init__(
      self,
      trg_vocab_size,
      embed_size,
      num_layers,
      heads,
      dropout,
      forward_expansion,
      device,
      max_length
  ):
    super().__init__()
    self.device = device
    self.context_embedding = nn.Embedding(trg_vocab_size, embed_size)
    self.positional_embedding = nn.Embedding(max_length, embed_size)
    self.layers = nn.ModuleList(
        [DecoderBlock(embed_size, heads, dropout, forward_expansion) for _ in range(num_layers)]
    )

    self.linear = nn.Linear(embed_size, trg_vocab_size)
    self.dropout = nn.Dropout(dropout)

  def forward(self, input, encoder_out, src_mask, trg_mask):
    N, seq_length = input.shape
    positions = torch.arange(0, seq_length).expand(N, seq_length).to(self.device)

    out = self.dropout(self.context_embedding(input) + self.positional_embedding(positions))

    for layer in self.layers:
      out = layer(out, encoder_out, encoder_out, src_mask, trg_mask)
    out = self.linear(out)
    return out

In [22]:
class Transformer(nn.Module):
  def __init__(
      self,
      src_vocab_size,
      trg_vocab_size,
      src_pad_idx,
      embed_size = 256,
      num_encoder_layers = 6,
      num_decoder_layers = 6,
      forward_expansion = 4,
      heads = 8,
      dropout = 0,
      device = "cuda",
      max_length = 100
  ):
    super().__init__()
    self.encoder = Encoder(
        src_vocab_size,
        embed_size,
        num_encoder_layers,
        heads,
        dropout,
        forward_expansion,
        device,
        max_length
    )
    self.decoder = Decoder(
        trg_vocab_size,
        embed_size,
        num_decoder_layers,
        heads,
        dropout,
        forward_expansion,
        device,
        max_length
    )

    self.src_pad_idx = src_pad_idx
    self.device = device
  
  def make_src_mask(self, src):
    src_mask = (src != self.src_pad_idx).unsqueeze(1).unsqueeze(2) # (N, 1, 1, src_len) need to think more, all
    return src_mask.to(device)

  def make_trg_mask(self, target):
    N, trg_len = target.shape
    trg_mask = torch.tril(torch.ones(trg_len, trg_len)).expand(N, 1, trg_len, trg_len) # target mask the lower half triangle of the matrix size trg_len by trg_len
    return trg_mask.to(device)
  
  def forward(self, src, trg):
    src_mask = self.make_src_mask(src)
    trg_mask = self.make_trg_mask(trg)
    out = self.encoder(src, src_mask)
    out = self.decoder(trg, out, src_mask, trg_mask)
    return out

In [29]:
# Example: to check it runs

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# A random input example
src = torch.tensor([[1, 5, 3, 2, 4, 8, 6, 8, 7, 9], [2, 3, 4, 6, 2, 9, 8, 7, 0, 0]]).to(device)
trg = torch.tensor([[4, 11, 8], [3, 5, 7]]).to(device)

src_pad_idx = 0
src_vocab_size = 10
trg_vocab_size = 15
model = Transformer(src_vocab_size, trg_vocab_size, src_pad_idx).to(device)
out = model(src, trg)
print(out)
print(out.shape)

tensor([[[ 5.0398e-01, -1.0023e+00,  5.6857e-02, -2.6799e-01, -6.8624e-01,
          -8.7852e-01,  1.0655e-01, -2.1558e-01,  8.1713e-02,  4.1557e-01,
           6.1862e-02,  1.1059e+00,  7.6783e-01, -4.6962e-02, -4.9102e-01],
         [ 7.2262e-01, -8.1058e-01, -3.1346e-01,  1.9133e-01, -2.6969e-01,
           7.3957e-01,  4.8730e-01, -5.3206e-01, -1.3396e-01,  3.1679e-01,
          -7.5704e-01,  6.7768e-01, -3.3729e-01,  7.5567e-01,  6.1345e-01],
         [-1.0473e-01,  1.8362e-03, -6.5824e-01, -5.5786e-02, -1.0601e+00,
          -1.2123e+00,  7.6936e-01,  1.5073e-01, -6.2403e-03,  1.1964e-02,
          -3.0024e-02,  1.9799e+00,  1.8055e-01,  2.4357e-01,  9.7782e-01]],

        [[ 7.8722e-01, -1.2320e+00,  4.7462e-01,  4.7971e-01, -1.0315e+00,
           2.4525e-02,  1.7244e-01,  7.5950e-02, -7.6394e-02,  7.2079e-01,
           4.5045e-01,  4.7763e-01,  1.1796e-01, -8.4385e-01, -4.9978e-01],
         [ 3.9467e-01, -4.5988e-01, -2.0479e-02,  2.6329e-01, -7.0352e-01,
          -2.5410e-