<a href="https://colab.research.google.com/github/ariefpurnamamuharram/Android-Kotlin-Examples/blob/master/Transformers_from_Scratch_in_PyTorch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Transformers from Stratch in PyTorch
---
References:
- https://medium.com/the-dl/transformers-from-scratch-in-pytorch-8777e346ca51, with some modifications by APM

In [1]:
# Install the required packages
!pip install numpy
!pip install torch torchvision torchaudio



In [2]:
# Scaled Dot Product Attention

import torch
import torch.nn as nn
import torch.nn.functional as f
from torch import Tensor

class ScaledDotProductAttention(nn.Module):
  def __init__(self):
    super(ScaledDotProductAttention, self).__init__()
  
  def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
    # Torch dimensions:
    # 0: Rows
    # 1: Columns
    # torch.bmm: Performs a batch matrix-matrix product of matrices
    # torch.bmm: https://pytorch.org/docs/stable/generated/torch.bmm.html
    # torch.transpose: Retruns a tensor that is a transposed version of 'input'
    # torch.transpose: https://pytorch.org/docs/stable/generated/torch.transpose.html
    temp = q.bmm(k.transpose(1,2)) # why self.k is transpose(1,2) NOT (0,1)?
    scale = q.size(-1) ** 0.5
    softmax = f.softmax(temp / scale, dim=-1)
    return softmax.bmm(v)

In [3]:
# Multi-Head Attention

import torch
import torch.nn as nn
from torch import Tensor

class AttentionHead(nn.Module):
  def __init__(self, dim_in: int, dim_k: int, dim_v: int):
    super(AttentionHead, self).__init__()
    self.q = nn.Linear(dim_in, dim_k)
    self.k = nn.Linear(dim_in, dim_v)
    self.v = nn.Linear(dim_in, dim_v)
    self.scaled_dot_product_att = ScaledDotProductAttention()
  
  def forward(self, query: Tensor, key: Tensor, value: Tensor) -> Tensor:
    return self.scaled_dot_product_att(self.q(query), self.k(key), self.v(value))

class MultiHeadAttention(nn.Module):
  def __init__(self, num_heads: int, dim_in: int, dim_k: int, dim_v: int):
    super(MultiHeadAttention, self).__init__()
    self.heads = nn.ModuleList(
        [AttentionHead(dim_in, dim_k, dim_v) for _ in range(num_heads)]
    )
    self.linear = nn.Linear(num_heads * dim_v, dim_in)
  
  def forward(self, query: Tensor, key: Tensor, value: Tensor) -> Tensor:
    # torch.cat: Concatenates the give sequence of 'seq' tensors in the given dimension.
    # torch.cat: https://pytorch.org/docs/stable/generated/torch.cat.html
    return self.linear(
        torch.cat([h(query, key, value) for h in self.heads], dim=-1)
    )

In [4]:
# Positional Encoding

import torch
import torch.nn as nn
from torch import Tensor

class PositionalEncoding(nn.Module):
  def __init__(self, device: torch.device = torch.device("cpu")):
    super(PositionalEncoding, self).__init__()
    self.device = device
  
  def forward(self, seq_len: int, dim_model: int) -> Tensor:
    pos = torch.arange(seq_len, dtype=torch.float, device=self.device).reshape(1, -1, 1)
    dim = torch.arange(dim_model, dtype=torch.float, device=self.device).reshape(1, 1, -1)
    phase = pos / 1e4 ** (dim // dim_model)

    return torch.where(dim.long() % 2 == 0, torch.sin(phase), torch.cos(phase))

In [5]:
# Transformer Architecture
# --------------------------

import torch
import torch.nn as nn
from torch import Tensor

class FeedForward(nn.Module):
  def __init__(self):
    super(FeedForward, self).__init__()
  
  def forward(self, dim_input: int = 512, dim_feedforward: int = 2048):
    return nn.Sequential(
        nn.Linear(dim_input, dim_feedforward),
        nn.ReLU(),
        nn.Linear(dim_feedforward, dim_input)
    )

class Residual(nn.Module):
  def __init__(self, sublayer: nn.Module, dimension: int, dropout: float = 0.1):
    super(Residual, self).__init__()
    self.sublayer = sublayer
    self.norm = nn.LayerNorm(dimension)
    self.dropout = nn.Dropout(dropout)
  
  def forward(self, *tensors: Tensor) -> Tensor:
    # Assume that the "value" tensor is given last, so we can compute the
    # residual. This matches the signature of 'MultiHeadAttention'.
    return self.norm(tensors[-1] + self.dropout(self.sublayer(*tensors)))


# Transformer Encoder

class TransformerEncoderLayer(nn.Module):
  def __init__(
      self,
      dim_model: int = 512,
      num_heads: int = 6,
      dim_feedforward: int = 2048,
      dropout: float = 0.1
  ):
    super(TransformerEncoderLayer, self).__init__()
    dim_k = dim_v = dim_model // num_heads
    self.attention = Residual(
        MultiHeadAttention(num_heads, dim_model, dim_k, dim_v),
        dimension = dim_model,
        dropout = dropout
    )
    self.feed_forward = Residual(
        FeedForward()(dim_model, dim_feedforward),
        dimension=dim_model,
        dropout=dropout,
    )
  
  def forward(self, src: Tensor) -> Tensor:
    src = self.attention(src, src, src)
    return self.feed_forward(src)

class TransformerEncoder(nn.Module):
  def __init__(
      self,
      num_layers: int = 6,
      dim_model: int = 512,
      num_heads: int = 8,
      dim_feedforward: int = 2048,
      dropout: float = 0.1,
  ):
    super(TransformerEncoder, self).__init__()
    self.layers = nn.ModuleList([
                                 TransformerEncoderLayer(dim_model, num_heads, dim_feedforward, dropout)
                                 for _ in range(num_layers)
    ])
    self.position_encoding = PositionalEncoding()
  
  def forward(self, src: Tensor) -> Tensor:
    seq_len, dimension = src.size(1), src.size(2)
    src += self.position_encoding(seq_len, dimension)
    for layer in self.layers:
      src = layer(src)
    
    return src


# Transformer Decoder

class TransformerDecoderLayer(nn.Module):
  def __init__(
      self,
      dim_model: int = 512,
      num_heads: int = 6, 
      dim_feedforward: int = 2048,
      dropout: float = 0.1,
  ):
    super(TransformerDecoderLayer, self).__init__()
    dim_k = dim_v = dim_model // num_heads
    self.attention_1 = Residual(
        MultiHeadAttention(num_heads, dim_model, dim_k, dim_v),
        dimension=dim_model,
        dropout=dropout
    )
    self.attention_2 = Residual(
        MultiHeadAttention(num_heads, dim_model, dim_k, dim_v),
        dimension=dim_model,
        dropout=dropout
    )
    self.feed_forward = Residual(
        FeedForward()(dim_model, dim_feedforward),
        dimension=dim_model,
        dropout=dropout
    )
  
  def forward(self, tgt: Tensor, memory: Tensor) -> Tensor:
    tgt = self.attention_1(tgt, tgt, tgt)
    tgt = self.attention_2(memory, memory, tgt)
    return self.feed_forward(tgt)

class TransformerDecoder(nn.Module):
  def __init__(
      self,
      num_layers: int = 6,
      dim_model: int = 512,
      num_heads: int = 8,
      dim_feedforward: int = 2048,
      dropout: float = 0.1,
  ):
    super(TransformerDecoder, self).__init__()
    self.layers = nn.ModuleList([
                                 TransformerDecoderLayer(dim_model, num_heads, dim_feedforward, dropout)
                                 for _ in range(num_layers)
    ])
    self.linear = nn.Linear(dim_model, dim_model)
    self.position_encoding = PositionalEncoding()
  
  def forward(self, tgt: Tensor, memory: Tensor) -> Tensor:
    seq_len, dimension = tgt.size(1), tgt.size(2)
    tgt += self.position_encoding(seq_len, dimension)
    for layer in self.layers:
      tgt = layer(tgt, memory)
    
    return torch.softmax(self.linear(tgt), dim=1)


# Transformer

class Transformer(nn.Module):
  def __init__(
      self, 
      num_encoder_layers: int = 6,
      num_decoder_layers: int = 6,
      dim_model: int = 512,
      num_heads: int = 6,
      dim_feedforward: int = 2048,
      dropout: float = 0.1,
      activation: nn.Module = nn.ReLU(),
  ): 
    super(Transformer, self).__init__()
    self.encoder = TransformerEncoder(
        num_layers = num_encoder_layers,
        dim_model = dim_model,
        num_heads = num_heads,
        dim_feedforward = dim_feedforward,
        dropout = dropout
    )
    self.decoder = TransformerDecoder(
        num_layers = num_decoder_layers,
        dim_model = dim_model,
        num_heads = num_heads,
        dim_feedforward = dim_feedforward,
        dropout = dropout
    )
  
  def forward(self, src: Tensor, tgt: Tensor) -> Tensor:
    return self.decoder(tgt, self.encoder(src))

In [6]:
src = torch.rand(64, 16, 512)
tgt = torch.rand(64, 16, 512)
out = Transformer()(src,tgt)
print(out.shape)

  from ipykernel import kernelapp as app


torch.Size([64, 16, 512])
