In [None]:
# importing required libraries
import torch.nn as nn
import torch
import torch.nn.functional as F
import math,copy,re
import warnings
import pandas as pd
import numpy as np
import seaborn as sns
import torchtext
import matplotlib.pyplot as plt
warnings.simplefilter("ignore")
print(torch.__version__)

# Embeddings

## Word Embedding

In [None]:
class Embedding(nn.Module):
    def __init__(self, vocab_size, embed_dim):
        super(Embedding, self).__init__()
        self.embed = nn.Embedding(vocab_size, embed_dim)

    def forward(self, x):
        out = self.embed(x)
        return out

## Positional Embedding
- 在nn.Module中，register_buffer()是一个方法，用于在PyTorch模型中注册一个缓冲区（buffer）。缓冲区是一种状态，不同于模型的参数，它们不会被优化，但可以在模型中使用。通常，缓冲区用于存储与模型相关的不可训练数据，例如在BatchNormalization中使用的运行统计信息

In [None]:
class PositionalEmbedding(nn.Module):
    def __init__(self, max_seq_len, embed_model_dim):
        super(PositionalEmbedding, self).__init__()
        self.embed_dim = embed_model_dim

        pe = torch.zeros(max_seq_len, embed_model_dim)
        for pos in range(max_seq_len):
            for i in range(0, self.embed_dim, 2):
                pe[pos, i] = math.sin(pos / (10000 ** ((2 * i)/self.embed_dim)))
                pe[pos, i + 1] = math.cos(pos / (10000 ** ((2 * (i + 1))/self.embed_dim)))
        pe = pe.unsqueeze(0)  # add a new dimension of size 1 at the pos 0
        self.register_buffer('pe', pe)

    def forward(self, x):

        # Make embeddings relatively larger
        x = x * math.sqrt(self.embed_dim)

        # Add constant to embedding
        seq_len = x.size(1)  # get the size of dim=1
        x = x + torch.autograd.Variable(self.pe[:,:seq_len], requires_grad=False)
        
        return x

# Attention

## Multi-Head Self-Attention
- with mask mechanism

In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self, 
                 embed_dim: int = 512, 
                 n_heads: int = 8):
        super(MultiHeadAttention, self).__init__()

        # Basic Attributes
        self.embed_dim = embed_dim
        self.n_heads = n_heads
        self.dk = embed_dim // n_heads

        # Query, Key, Value : input_dim = d_model//n_heads = dk
        self.Q = nn.Linear(self.dk, self.dk, bias=False)
        self.K = nn.Linear(self.dk, self.dk, bias=False)
        self.V = nn.Linear(self.dk, self.dk, bias=False)
        self.out = nn.Linear(self.n_heads * self.dk, self.embed_dim)

    def forward(self, key, query, value, mask = None):
        
        # Get dim info
        batch_size = key.size(0)
        seq_length = key.size(1)

        # query dimension could change in decoder during inference
        seq_length_query = query.size(1)

        # (batch_size x seq_length x 8 x 64)
        key = key.view(batch_size, seq_length, self.n_heads, self.dk)
        query = query.view(batch_size, seq_length_query, self.n_heads, self.dk)
        value = value.view(batch_size, seq_length, self.n_heads, self.dk)

        k = self.K(key)
        q = self.Q(query)
        v = self.V(value)

        # (batch_size, n_heads, seq_len, dk)
        q = q.transpose(1, 2)
        k = k.transpose(1, 2)
        v = v.transpose(1, 2)

        # computes attention
        k_T = k.transpose(-1, -2)  # (batch_size, n_heads, dk, seq_len)
        product = torch.matmul(q, k_T)/math.sqrt(self.dk)

        if mask is not None:
            product = product.masked_fill(mask == 0, float(-1e20))

        scores = torch.matmul(F.softmax(product, dim=-1), v)

        # concatenate heads and put through final linear layer
        # (32x8x10x64) -> (32x10x8x64)  -> (batch_size, seq_len, d_model)
        concat = scores.transpose(1, 2).contiguous().view(batch_size, seq_length_query, self.dk*self.n_heads)

        output = self.out(concat)

        return output



# Transformer

## Encoder

In [None]:
class EncodeBlock(nn.Module):
    def __init__(self, embed_dim, expansion_factor=4, n_heads=8):
        super(EncodeBlock, self).__init__()

        self.attention = MultiHeadAttention(embed_dim, n_heads)

        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)

        self.feed_forward = nn.Sequential(
            nn.Linear(embed_dim, embed_dim * expansion_factor),  # 512 * 2048
            nn.ReLU(),
            nn.Linear(embed_dim * expansion_factor, embed_dim),
        )

        self.dropout1 = nn.Dropout(0.1)
        self.dropout2 = nn.Dropout(0.1)

    def forward(self, key, query, value):
        attention_out = self.attention(key, query, value)
        attention_res_out = attention_out + value
        norm1_out = self.dropout1(self.norm1(attention_res_out))

        ff_out = self.feed_forward(norm1_out)
        ff_res_out = ff_out + norm1_out
        norm2_out = self.dropout2(self.norm2(ff_res_out))

        return norm2_out

In [None]:
class Encoder(nn.Module):
    def __init__(self, seq_len, vocab_size, embed_dim, num_layers=6, expansion_factor=4, n_heads=8):
        super(Encoder, self).__init__()

        self.embedding_layer = Embedding(vocab_size, embed_dim)
        self.positional_encoder = PositionalEmbedding(seq_len, embed_dim)

        self.layers = nn.ModuleList([EncodeBlock(embed_dim, expansion_factor, n_heads) for _ in range(num_layers)])
        self.dropout = nn.Dropout(0.1)
    def forward(self, x):
        embed_out = self.embedding_layer(x)
        out = self.positional_encoder(embed_out)
        out = self.dropout(out)
        for layer in self.layers:
            out = layer(out, out, out)
        
        return out

## Decoder

In [None]:
class DecoderBlock(nn.Module):
    def __init__(self, embed_dim, expansion_factor=4, n_heads=8):
        super(DecoderBlock, self).__init__()

        self.attention = MultiHeadAttention(embed_dim, n_heads)
        self.norm = nn.LayerNorm(embed_dim)
        self.dropout = nn.Dropout(0.1)
        self.transformer_block = EncodeBlock(embed_dim, expansion_factor, n_heads)

    def forward(self, key, value, x, mask):

        # Only pass mask to the first attention block
        attention = self.attention(x, x, x, mask)
        query = self.dropout(self.norm(attention + x))

        out = self.transformer_block(key, value, query)

        return out

In [None]:
class Decoder(nn.Module):
    def __init__(self, t_vocab_size, embed_dim, seq_len, num_layers=6, expansion_factor=4, n_heads=8):
        super(Decoder, self).__init__()

        self.word_embedding = Embedding(t_vocab_size, embed_dim)
        self.pos_embedding = PositionalEmbedding(seq_len, embed_dim)
        self.layers = nn.ModuleList(
            [
                DecoderBlock(embed_dim, expansion_factor, n_heads)
                for _ in range(num_layers)
            ]
        )
        self.fc_out = nn.Linear(embed_dim, t_vocab_size)
        self.dropout = nn.Dropout(0.1)

    def forward(self, x, encoder_out, mask):
        x = self.word_embedding(x)
        x = self.pos_embedding(x)

        # dropout in each pos+embeddings & before each sub_layer
        x = self.dropout(x)

        for layer in self.layers:
            x = layer(encoder_out, encoder_out, x, mask)  # dropout had made in sub_layer

        out = F.softmax(self.fc_out(x))
        return out
        

## Transformer

In [None]:
class Transformer(nn.Module):
    def __init__(self, embed_dim, s_vocab_size, t_vocab_size, seq_len, 
                 num_layers = 6, 
                 expansion_factor = 4,
                 n_heads = 8):
        super(Transformer, self).__init__()

        self.t_vocab_size = t_vocab_size
        self.encoder = Encoder(seq_len = seq_len, 
                               vocab_size = s_vocab_size, 
                               embed_dim = embed_dim, 
                               num_layers = num_layers, 
                               expansion_factor = expansion_factor, 
                               n_heads = n_heads)
        self.decoder = Decoder(t_vocab_size = t_vocab_size, 
                               embed_dim = embed_dim, 
                               seq_len = seq_len, 
                               num_layers = num_layers, 
                               expansion_factor = expansion_factor, 
                               n_heads = n_heads)
    
    # Get the triangle mask for the target sequence
    def make_trg_mask(self, trg):
        batch_size, trg_len = trg.shape
        trg_mask = torch.tril(torch.ones((trg_len, trg_len))).expand(batch_size, 1, trg_len, trg_len)
        return trg_mask
    
    # For inference
    def decode(self, src, trg):
        trg_mask = self.make_trg_mask(trg)
        encoder_out = self.encoder(src)
        out_labels = []
        # batch_size, seq_len = src.shape[0], src.shape[1]
        seq_len = trg.shape[1]

        out = trg
        for i in range(seq_len):
            out = self.decoder(out, encoder_out, trg_mask)
            
            # take the last token
            out = out[:,-1,:]
            out = out.argmax(-1)
            out_labels.append(out.item())
            out = torch.unsqueeze(out, axis=0)
        
        return out_labels
    
    # For training
    def forward(self, src, trg):
        trg_mask = self.make_trg_mask(trg)
        encoder_out = self.encoder(src)
        outputs = self.decoder(trg, encoder_out, trg_mask)
        return outputs


# Test

In [1]:
from transformer import *

src_vocab_size = 11
target_vocab_size = 11
num_layers = 6
seq_length= 12


# let 0 be sos token and 1 be eos token
src = torch.tensor([[0, 2, 5, 6, 4, 3, 9, 5, 2, 9, 10, 1], 
                    [0, 2, 8, 7, 3, 4, 5, 6, 7, 2, 10, 1]])
target = torch.tensor([[0, 1, 7, 4, 3, 5, 9, 2, 8, 10, 9, 1], 
                       [0, 1, 5, 6, 2, 4, 7, 6, 2, 8, 10, 1]])

print(src.shape,target.shape)
model = Transformer(embed_dim = 512, 
                    s_vocab_size = src_vocab_size, 
                    t_vocab_size = target_vocab_size, seq_len = seq_length, 
                    num_layers = num_layers, 
                    expansion_factor = 4,
                    n_heads = 8)
model

torch.Size([2, 12]) torch.Size([2, 12])


Transformer(
  (encoder): Encoder(
    (embedding_layer): Embedding(
      (embed): Embedding(11, 512)
    )
    (positional_encoder): PositionalEmbedding()
    (layers): ModuleList(
      (0-5): 6 x EncodeBlock(
        (attention): MultiHeadAttention(
          (Q): Linear(in_features=64, out_features=64, bias=False)
          (K): Linear(in_features=64, out_features=64, bias=False)
          (V): Linear(in_features=64, out_features=64, bias=False)
          (out): Linear(in_features=512, out_features=512, bias=True)
        )
        (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (feed_forward): Sequential(
          (0): Linear(in_features=512, out_features=2048, bias=True)
          (1): ReLU()
          (2): Linear(in_features=2048, out_features=512, bias=True)
        )
        (dropout1): Dropout(p=0.1, inplace=False)
        (dropout2): Dropout(p=0.1, inplace=False)
      )
    )
   