# Transformer
[참고한 커널](https://www.kaggle.com/code/arunmohan003/transformer-from-scratch-using-pytorch)


# Import Libraries

In [1]:
import torch.nn as nn
import torch
import torch.nn.functional as F
import math, copy, re
import warnings
import pandas as pd
import numpy as np
import seaborn as sns
import torchtext
import matplotlib.pyplot as plt
warnings.simplefilter('ignore')
print(torch.__version__)

2.1.0+cpu


# Basic component

## create word embeddings
- input sequence를 vector으로 임베드 해주는 부분이다.
- 본문에서는 embedding dimension 을 512, vocab_size를 100으로 가정한다.
- 만약 batch size를 32로 하고, sequence length를 10으로 설정하면 output은 32 by 10 by 512가 될 것이다.

In [2]:
class Embedding(nn.Module):
    def __init__(self, vocab_size, embed_dim):
        super(Embedding, self).__init__()
        self.embed = nn.Embedding(vocab_size, embed_dim)
    def forward(self, x):
        out = self.embed(x)
        return out
    

## Positional Encoding
![](https://miro.medium.com/max/524/1*yWGV9ck-0ltfV2wscUeo7Q.png) \
![](https://miro.medium.com/max/564/1*SgNlyFaHH8ljBbpCupDhSQ.png) \
두 식을 통해 구성하며, 완성된 결과는 embed matrix와 동일한 shape를 갖는다. 


In [4]:
class PositionalEmbedding(nn.Module):
    def __init__(self, max_seq_len, embed_model_dim):
        super(PositionalEmbedding, self).__init__()
        self.embed_dim = embed_model_dim

        pe = torch.zeros(max_seq_len, self.embed_dim)
        for pos in range(max_seq_len):
            for i in range(0, self.embed_dim, 2):
                pe[pos, i] = math.sin(pos / (10000**((2*i)/self.embed_dim)))
                pe[pos, i+1] = math.cos(pos / (10000**((2*i)/self.embed_dim)))
            pe = pe.unsqueeze(0)
            self.register_buffer('pe', pe) 
            # 학습시키는 parameter은 state_dict, 학습시키지 않는 텐서는 register_buffer으로 등록한다.

        def forward(self, x):
            # embedding을 상대적으로 크게 만들자(왜지..?)
            x = x*math.sqrt(self.embed_dim)
            seq_len = x.size(1)
            x = x+torch.autograd.Variable(self.pe[:,:seq_len], requires_grad = False)
            return x

    

## MultiHead Attention
- self-attention: input sequence의 각 word가, 다른 position의 word와의 연관성을 파악한다.
- 총 다섯 단계로 구성되는데, 
    1. Query vector, Key vector, Value vector을 계산한다. 각 vector은 1x64의 dim을 가진다.(model의 dim = 512이고 8개의 multihead 사용)
    2. Attention score을 계산한다. 즉, querry matrix와 key matrix를 연산한다. [QxK.t]
    3. Output matrix을 key matrix의 차원의 square root로 나누고, softmax를 적용한다.
    4. 이 결과를 value matrix에 곱한다.
    5. Linear layer을 거친다.


In [5]:
class MultiHeadAttention(nn.Module):
    def __init__(self, embed_dim = 512, n_heads = 8):
        super(MultiHeadAttention, self).__init__()
        self.embed_dim = embed_dim # 512
        self.n_heads = n_heads # 8
        self.single_head_dim = self.embed_dim / self.n_heads #512/8 = 64
        self.query_matrix = nn.Linear(self.single_head_dim, self.single_head_dim, bias = False) #64x64
        self.key_matrix = nn.Linear(self.single_head_dim, self.single_head_dim, bias = False) #64x64
        self.value_matrix = nn.Linear(self.single_head_dim, self.single_head_dim, bias = False) #64x64
        self.out = nn.Linear(self.n_heads*self.single_head_dim, self.embed_dim) # 512x512

        def forward(self, key, query, value, mask = None):
            seq_length_query = query.size(1)
            batch_size = key.size(0)
            seq_length = key.size(1)

            key = key.view(batch_size, seq_length, self.n_heads, self.single_head_dim) #32, 10, 8, 64
            query = query.view(batch_size, seq_length_query, self.n_heads, self.single_head_dim)
            value = value.view(batch_size, seq_length_query, self.n_heads, self.single_head_dim)

            k = self.key_matrix(key)
            q = self.query_matrix(query)
            v = self.value_matrix(value)

            k, q, v = k.transpose(1, 2), q.transpose(1, 2), v.transpose(1, 2) #32, 8, 10, 64
            
            k_adjusted = k.transpose(-1, -2) # (batch_size, n_heads, single_head_dim, seq_len) (32, 8, 64, 10)
            product = torch.matmul(q, k_adjusted) #(32, 8, 10, 64)x(32, 8, 64, 10) = (32, 8, 10, 10)

            # 만약 mask 옵션이 켜져있을때는, mask position의 값들을 0으로 만들어준다.
            if mask is not None:
                product = product.masked_fill(mask == 0, float('-1e20'))
            
            product = product / math.sqrt(self.single_head_dim) # /sqrt(64)
            
            scores = F.softmax(product, dim = -1) #applying softmax

            scores = torch.matmul(scores, v) #(32, 8, 10, 10)x(32, 8, 10, 64) = (32, 8, 10, 64)
            # (32, 8, 10, 64) -> (32, 10, 8, 64) -> (32, 10, 512)
            concat = scores.transpose(1, 2).contiguous().view(batch_size, seq_length_query, self.single_head_dim * self.n_heads)
            output = self.out(concat)

            return output
        
            

# Encoder
![](https://www.researchgate.net/profile/Ehsan-Amjadian/publication/352239001/figure/fig1/AS:1033334390013952@1623377525434/Detailed-view-of-a-transformer-encoder-block-It-first-passes-the-input-through-an.jpg)

In [6]:
class TransformerBlock(nn.Module):
    def __init__(self, embed_dim, expansion_factor=4, n_heads=8):
        super(TransformerBlock, self).__init__()
        
        """
        Args:
           embed_dim: dimension of the embedding
           expansion_factor: fator ehich determines output dimension of linear layer
           n_heads: number of attention heads
        
        """
        self.attention = MultiHeadAttention(embed_dim, n_heads)
        
        self.norm1 = nn.LayerNorm(embed_dim) 
        self.norm2 = nn.LayerNorm(embed_dim)
        
        self.feed_forward = nn.Sequential(
                          nn.Linear(embed_dim, expansion_factor*embed_dim),
                          nn.ReLU(),
                          nn.Linear(expansion_factor*embed_dim, embed_dim)
        )

        self.dropout1 = nn.Dropout(0.2)
        self.dropout2 = nn.Dropout(0.2)

    def forward(self,key,query,value):
        
        """
        Args:
           key: key vector
           query: query vector
           value: value vector
           norm2_out: output of transformer block
        
        """
        
        attention_out = self.attention(key,query,value)  #32x10x512
        attention_residual_out = attention_out + value  #32x10x512
        norm1_out = self.dropout1(self.norm1(attention_residual_out)) #32x10x512

        feed_fwd_out = self.feed_forward(norm1_out) #32x10x512 -> #32x10x2048 -> 32x10x512
        feed_fwd_residual_out = feed_fwd_out + norm1_out #32x10x512
        norm2_out = self.dropout2(self.norm2(feed_fwd_residual_out)) #32x10x512

        return norm2_out



class TransformerEncoder(nn.Module):
    """
    Args:
        seq_len : length of input sequence
        embed_dim: dimension of embedding
        num_layers: number of encoder layers
        expansion_factor: factor which determines number of linear layers in feed forward layer
        n_heads: number of heads in multihead attention
        
    Returns:
        out: output of the encoder
    """
    def __init__(self, seq_len, vocab_size, embed_dim, num_layers=2, expansion_factor=4, n_heads=8):
        super(TransformerEncoder, self).__init__()
        
        self.embedding_layer = Embedding(vocab_size, embed_dim)
        self.positional_encoder = PositionalEmbedding(seq_len, embed_dim)

        self.layers = nn.ModuleList([TransformerBlock(embed_dim, expansion_factor, n_heads) for i in range(num_layers)])
    
    def forward(self, x):
        embed_out = self.embedding_layer(x)
        out = self.positional_encoder(embed_out)
        for layer in self.layers:
            out = layer(out,out,out)

        return out  #32x10x512

In [7]:
class DecoderBlock(nn.Module):
    def __init__(self, embed_dim, expansion_factor=4, n_heads=8):
        super(DecoderBlock, self).__init__()

        """
        Args:
           embed_dim: dimension of the embedding
           expansion_factor: fator ehich determines output dimension of linear layer
           n_heads: number of attention heads
        
        """
        self.attention = MultiHeadAttention(embed_dim, n_heads=8)
        self.norm = nn.LayerNorm(embed_dim)
        self.dropout = nn.Dropout(0.2)
        self.transformer_block = TransformerBlock(embed_dim, expansion_factor, n_heads)
        
    
    def forward(self, key, query, x,mask):
        
        """
        Args:
           key: key vector
           query: query vector
           value: value vector
           mask: mask to be given for multi head attention 
        Returns:
           out: output of transformer block
    
        """
        
        #we need to pass mask mask only to fst attention
        attention = self.attention(x,x,x,mask=mask) #32x10x512
        value = self.dropout(self.norm(attention + x))
        
        out = self.transformer_block(key, query, value)

        
        return out


class TransformerDecoder(nn.Module):
    def __init__(self, target_vocab_size, embed_dim, seq_len, num_layers=2, expansion_factor=4, n_heads=8):
        super(TransformerDecoder, self).__init__()
        """  
        Args:
           target_vocab_size: vocabulary size of taget
           embed_dim: dimension of embedding
           seq_len : length of input sequence
           num_layers: number of encoder layers
           expansion_factor: factor which determines number of linear layers in feed forward layer
           n_heads: number of heads in multihead attention
        
        """
        self.word_embedding = nn.Embedding(target_vocab_size, embed_dim)
        self.position_embedding = PositionalEmbedding(seq_len, embed_dim)

        self.layers = nn.ModuleList(
            [
                DecoderBlock(embed_dim, expansion_factor=4, n_heads=8) 
                for _ in range(num_layers)
            ]

        )
        self.fc_out = nn.Linear(embed_dim, target_vocab_size)
        self.dropout = nn.Dropout(0.2)

    def forward(self, x, enc_out, mask):
        
        """
        Args:
            x: input vector from target
            enc_out : output from encoder layer
            trg_mask: mask for decoder self attention
        Returns:
            out: output vector
        """
            
        
        x = self.word_embedding(x)  #32x10x512
        x = self.position_embedding(x) #32x10x512
        x = self.dropout(x)
     
        for layer in self.layers:
            x = layer(enc_out, x, enc_out, mask) 

        out = F.softmax(self.fc_out(x))

        return out

In [8]:
class Transformer(nn.Module):
    def __init__(self, embed_dim, src_vocab_size, target_vocab_size, seq_length,num_layers=2, expansion_factor=4, n_heads=8):
        super(Transformer, self).__init__()
        
        """  
        Args:
           embed_dim:  dimension of embedding 
           src_vocab_size: vocabulary size of source
           target_vocab_size: vocabulary size of target
           seq_length : length of input sequence
           num_layers: number of encoder layers
           expansion_factor: factor which determines number of linear layers in feed forward layer
           n_heads: number of heads in multihead attention
        
        """
        
        self.target_vocab_size = target_vocab_size

        self.encoder = TransformerEncoder(seq_length, src_vocab_size, embed_dim, num_layers=num_layers, expansion_factor=expansion_factor, n_heads=n_heads)
        self.decoder = TransformerDecoder(target_vocab_size, embed_dim, seq_length, num_layers=num_layers, expansion_factor=expansion_factor, n_heads=n_heads)
        
    
    def make_trg_mask(self, trg):
        """
        Args:
            trg: target sequence
        Returns:
            trg_mask: target mask
        """
        batch_size, trg_len = trg.shape
        # returns the lower triangular part of matrix filled with ones
        trg_mask = torch.tril(torch.ones((trg_len, trg_len))).expand(
            batch_size, 1, trg_len, trg_len
        )
        return trg_mask    

    def decode(self,src,trg):
        """
        for inference
        Args:
            src: input to encoder 
            trg: input to decoder
        out:
            out_labels : returns final prediction of sequence
        """
        trg_mask = self.make_trg_mask(trg)
        enc_out = self.encoder(src)
        out_labels = []
        batch_size,seq_len = src.shape[0],src.shape[1]
        #outputs = torch.zeros(seq_len, batch_size, self.target_vocab_size)
        out = trg
        for i in range(seq_len): #10
            out = self.decoder(out,enc_out,trg_mask) #bs x seq_len x vocab_dim
            # taking the last token
            out = out[:,-1,:]
     
            out = out.argmax(-1)
            out_labels.append(out.item())
            out = torch.unsqueeze(out,axis=0)
          
        
        return out_labels
    
    def forward(self, src, trg):
        """
        Args:
            src: input to encoder 
            trg: input to decoder
        out:
            out: final vector which returns probabilities of each target word
        """
        trg_mask = self.make_trg_mask(trg)
        enc_out = self.encoder(src)
   
        outputs = self.decoder(trg, enc_out, trg_mask)
        return outputs

In [9]:
src_vocab_size = 11
target_vocab_size = 11
num_layers = 6
seq_length= 12


# let 0 be sos token and 1 be eos token
src = torch.tensor([[0, 2, 5, 6, 4, 3, 9, 5, 2, 9, 10, 1], 
                    [0, 2, 8, 7, 3, 4, 5, 6, 7, 2, 10, 1]])
target = torch.tensor([[0, 1, 7, 4, 3, 5, 9, 2, 8, 10, 9, 1], 
                       [0, 1, 5, 6, 2, 4, 7, 6, 2, 8, 10, 1]])

print(src.shape,target.shape)
model = Transformer(embed_dim=512, src_vocab_size=src_vocab_size, 
                    target_vocab_size=target_vocab_size, seq_length=seq_length,
                    num_layers=num_layers, expansion_factor=4, n_heads=8)
model

torch.Size([2, 12]) torch.Size([2, 12])


Transformer(
  (encoder): TransformerEncoder(
    (embedding_layer): Embedding(
      (embed): Embedding(11, 512)
    )
    (positional_encoder): PositionalEmbedding()
    (layers): ModuleList(
      (0-5): 6 x TransformerBlock(
        (attention): MultiHeadAttention(
          (query_matrix): Linear(in_features=64, out_features=64, bias=False)
          (key_matrix): Linear(in_features=64, out_features=64, bias=False)
          (value_matrix): Linear(in_features=64, out_features=64, bias=False)
          (out): Linear(in_features=512, out_features=512, bias=True)
        )
        (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (feed_forward): Sequential(
          (0): Linear(in_features=512, out_features=2048, bias=True)
          (1): ReLU()
          (2): Linear(in_features=2048, out_features=512, bias=True)
        )
        (dropout1): Dropout(p=0.2, inplace=False)
        (dropout2): 

In [10]:
out = model(src, target)
out.shape

torch.Size([2, 12, 11])

In [11]:
# inference
model = Transformer(embed_dim=512, src_vocab_size=src_vocab_size, 
                    target_vocab_size=target_vocab_size, seq_length=seq_length, 
                    num_layers=num_layers, expansion_factor=4, n_heads=8)
                  


src = torch.tensor([[0, 2, 5, 6, 4, 3, 9, 5, 2, 9, 10, 1]])
trg = torch.tensor([[0]])
print(src.shape,trg.shape)
out = model.decode(src, trg)
out

torch.Size([1, 12]) torch.Size([1, 1])


[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]