In [24]:
import torch 
import torch.nn as nn
import math

### Input Embedding 


In [25]:
class InputEmbeddings(nn.Module): 
    
    def __init__(self, d_model: int, vocab_size: int) -> None: 
        super().__init__()  
        self.d_model = d_model  # size of model embedding
        self.vocab_size = vocab_size 
        self.embedding = nn.Embedding(vocab_size, d_model) 
    
    def forward(self, x): 
        """
            (batch, seq_len) -> (batch, seq_len, d_model) 
            embedding(vocab, d_model) -> maps indices to a d model dimensional vector. 
            * math.sqrt(self.d_model) -> scale the embedding by sqrt(d_model) 
        """
        return self.embedding(x) * math.sqrt(self.d_model)

### Position Encoding 


In [26]:
class PositionalEncoding(nn.Module):

    def __init__(self, d_model: int, seq_len: int, dropout: float) -> None:
        super().__init__()
        self.d_model = d_model
        self.seq_len = seq_len
        self.dropout = nn.Dropout(dropout)
        # Create a matrix of shape (seq_len, d_model)
        # positional encoding for each token in the sequence has d_model dimensions.
        pe = torch.zeros(seq_len, d_model)
        # Create a vector of shape (seq_len)
        position = torch.arange(0, seq_len, dtype=torch.float).unsqueeze(1) # (seq_len, 1)
        # Create a vector of shape (d_model)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) # (d_model / 2)
        # Apply sine to even indices
        pe[:, 0::2] = torch.sin(position * div_term) # sin(position * (10000 ** (2i / d_model))
        # Apply cosine to odd indicess
        pe[:, 1::2] = torch.cos(position * div_term) # cos(position * (10000 ** (2i / d_model))
        # Add a batch dimension to the positional encoding
        pe = pe.unsqueeze(0) # (1, seq_len, d_model)
        # Register the positional encoding as a buffer
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + (self.pe[:, :x.shape[1], :]).requires_grad_(False) # (batch, seq_len, d_model)
        return self.dropout(x)

### LayerNormalization Class

In [27]:
class LayerNormalization(nn.Module): 
        
        def __init__(self, d_model: int, eps: float = 1e-6) -> None: 
            super().__init__() 
            self.d_model = d_model 
            self.eps = eps 
            self.gamma = nn.Parameter(torch.ones(d_model)) 
            self.beta = nn.Parameter(torch.zeros(d_model)) 
        
        def forward(self, x): 
            mean = x.mean(dim=-1, keepdim=True) # get mean 
            std = x.std(dim=-1, keepdim=True)   # get varianceb 
            # normalize 
            x = (x - mean) / (std + self.eps)
            # scale and shift: y = gamma * x + beta
            # gamma for scaling, beta for shifting
            y = self.gamma * x + self.beta
            return y

### FeedForwardBlock Class

In [28]:
class FeedForwardBlock(nn.Module): 
    def __init__(self, d_model: int, d_ff: int, dropout: float) -> None: 
        super().__init__() 
        self.d_model = d_model 
        self.d_ff = d_ff 
        self.dropout = nn.Dropout(dropout) 
        self.linear1 = nn.Linear(d_model, d_ff) 
        self.linear2 = nn.Linear(d_ff, d_model) 
        self.relu = nn.ReLU()
    
    def forward(self, x: torch.Tensor) -> torch.Tensor: 
        # just a simple fully connected feed forward network
        x = self.linear1(x) 
        x = self.relu(x) 
        x = self.dropout(x) 
        x = self.linear2(x) 
        return x

### Multi-head attention:
- gồm nhiều self-attention với hi vọng mỗi self attention nhìn ở nhiều đặc điểm khác nhau sẽ hiểu ngữ cảnh tốt hơn. 
- d_k = d_model / head 
    - d_k là số chiều mà 1 head (1 attention nhìn được trong d_model - dimension of feature vector)
    - d_model : là số chiều của feature vector 
    - head: là số self attention được khởi tạo. 
    - VD : Tôi đang đi đâu đó 
         -  10   10  10  10  10   (mỗi từ được đại diện bởi vector 10 dimension - sử dụng 2 head)
    - head1 5    5   5   5   5
    - head2 5    5   5   5   5 
    - -> gom lại rồi trả về shape như đầu vào.

- self-attention : chắc năng chính là tổng hợp các feature tại từ đang xét với ngữ cảnh của các từ xung quanh 
- feedforwardblock : Suy diện (tăng tính biểu diễn của feature).

#### Mask in multi head attention
2 dạng mask được sử dụng : 
- padding mask : được dùng để đảm bảo rằng các padding tokens (khi mà chuẩn hóa input đầu với max input len, đảm bảo các câu trong batch sẽ có cùng len) không ảnh hưởng gì đến cơ chế attention 
- look-ahead mask (mask multi head attention - Causal mask): được dùng để đảm bảo trong quá trình đào tạo và suy luận mỗi VỊ TRÍ trong chuỗi chỉ có thể tham gia vào các vị trí trước đó và vị trí hiện tại chứ không liên quan đến bất kỳ vị trí nào trong tương lai.   

In [29]:
class MultiHeadAttentionBlock(nn.Module): 
    
    def __init__(self, d_model: int, h: int, dropout: float) -> None: 
        super().__init__()
        self.d_model = d_model 
        self.h = h 
        assert d_model % h == 0, "d_model must be divisible by h" 
        
        self.d_k = d_model // h 
        self.w_q = nn.Linear(d_model, d_model, bias=False)
        self.w_k = nn.Linear(d_model, d_model, bias=False) 
        self.w_v = nn.Linear(d_model, d_model, bias=False) 
        self.w_o = nn.Linear(self.h * self.d_k, d_model, bias=False)
        self.dropout = nn.Dropout(dropout)
        
    @classmethod
    def attention(self, query, key, value, mask, dropout: nn.Dropout):
        d_k = query.shape[-1] 
        
        # (batch, h, seq_len, d_k) -> (batch, h, seq_len, seq_len)
        attention_scores = (query @ key.transpose(-2, -1)) / math.sqrt(d_k) 
        # print("attention scores : ", attention_scores.shape)
        if mask is not None: 
            # print(mask.shape)
            attention_scores = attention_scores.masked_fill(mask == 0, -1e9) 
            # print("mask")
        attention_scores = attention_scores.softmax(dim=-1) 
        if dropout is not None: 
            attention_scores = dropout(attention_scores) 
        # (batch, h, seq_len, seq_len) @ (batch, h, seq_len, d_k) -> (batch, h, seq_len, d_k)
        return (attention_scores @ value), attention_scores 
    
    def forward(self, query, key, value, mask): 
        query = self.w_q(query)  # (batch, seq_len, d_model) * (batch, d_model, d_model)-> (batch, seq_len, d_model)
        key = self.w_k(key)      # same 
        value = self.w_v(value)  # same
        
        # split into h heads d_model = h * d_k
        # (batch, seq_len, d_model) -> (batch, seq_len, h, d_k) -> (batch, h, seq_len, d_k)
        query = query.view(query.shape[0], -1, self.h, self.d_k).transpose(1, 2) 
        key = key.view(key.shape[0], -1, self.h, self.d_k).transpose(1, 2) 
        value = value.view(value.shape[0], -1, self.h, self.d_k).transpose(1, 2)
        
        # apply attention 
        x, self.attention_scores = MultiHeadAttentionBlock.attention(query, key, value, mask, self.dropout) 
        
        # combine all heads together 
        # (batch, h, seq_len, d_k) -> (batch, seq_len, h, d_k) -> (batch, seq_len, h * d_k) 
        # contiguous() -> make sure the tensor is stored in a contiguous chunk of memory
        x = x.transpose(1, 2).contiguous().view(x.shape[0], -1, self.h * self.d_k)
        
        # Multiply by W_o to stabilize the output shape 
        return self.w_o(x)
        

### ResidualConnection
- xử lý vấn đề vanishing gradients 
- được dùng mỗi khi qua lớp multi head attention hoặc feed-forward layer -> xong rồi sẽ xử dụng layer normalization. 

In [30]:
class ResidualConnection(nn.Module): 
    
    def __init__(self, d_model: int, dropout: float) -> None: 
        super().__init__() 
        self.d_model = d_model 
        self.dropout = nn.Dropout(dropout) 
        self.norm = LayerNormalization(d_model)
    
    def forward(self, x, sublayer): 
        return x + self.dropout(sublayer(self.norm(x)))

#### Encoder

In [31]:
class EncoderBlock(nn.Module): 
    
    def __init__(self, multi_head_attention: MultiHeadAttentionBlock, feed_forward: FeedForwardBlock, dropout: float) -> None: 
        super().__init__() 
        self.multi_head_attention = multi_head_attention 
        self.feed_forward = feed_forward 
        self.residual_connection = ResidualConnection(multi_head_attention.d_model, dropout)
        self.feed_forward_residual_connection = ResidualConnection(multi_head_attention.d_model, dropout)
    
    def forward(self, x, src_mask): 
        x = self.residual_connection(x, lambda x: self.multi_head_attention(x, x, x, src_mask))
        x = self.residual_connection(x, self.feed_forward)
        return x

class Encoder(nn.Module): 
    
    def __init__(self, layers: nn.ModuleList) -> None: 
        super().__init__()
        self.layers = layers
        self.norm = LayerNormalization(layers[0].multi_head_attention.d_model)
    
    def forward(self, x, mask): 
        for layer in self.layers: 
            x = layer(x, mask)
        return self.norm(x)

### Decoder

In [32]:
class DecoderBlock(nn.Module): 
    
    def __init__(self, multi_head_attention: MultiHeadAttentionBlock, feed_forward: FeedForwardBlock, dropout: float) -> None: 
        super().__init__() 
        self.multi_head_attention = multi_head_attention 
        self.feed_forward = feed_forward 
        self.residual_connection = ResidualConnection(multi_head_attention.d_model, dropout)
        self.feed_forward_residual_connection = ResidualConnection(multi_head_attention.d_model, dropout)
        
    def forward(self, x, encoder_output, src_mask, tgt_mask): 
        x = self.residual_connection(x, lambda x: self.multi_head_attention(x, x, x, tgt_mask)) # mask multi head attention
        x = self.residual_connection(x, lambda x: self.multi_head_attention(x, encoder_output, encoder_output, src_mask))   
        x = self.residual_connection(x, self.feed_forward)
        return x

class Decoder(nn.Module): 
    
    def __init__(self, layers: nn.ModuleList) -> None: 
        super().__init__()
        self.layers = layers
        self.norm = LayerNormalization(layers[0].multi_head_attention.d_model)
    
    def forward(self, x, encoder_output, src_mask, tgt_mask): 
        for layer in self.layers: 
            x = layer(x, encoder_output, src_mask, tgt_mask)
        return self.norm(x)

### ProjectionLayers

In [33]:
class ProjectionLayer(nn.Module):

    def __init__(self, d_model, vocab_size) -> None:
        super().__init__()
        self.proj = nn.Linear(d_model, vocab_size)

    def forward(self, x) -> None:
        # (batch, seq_len, d_model) --> (batch, seq_len, vocab_size)
        return self.proj(x)

### Transformer


In [34]:
class Transformer(nn.Module):

    def __init__(self, encoder: Encoder, decoder: Decoder, src_embed: InputEmbeddings, tgt_embed: InputEmbeddings, src_pos: PositionalEncoding, tgt_pos: PositionalEncoding, projection_layer: ProjectionLayer) -> None:
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.src_embed = src_embed
        self.tgt_embed = tgt_embed
        self.src_pos = src_pos
        self.tgt_pos = tgt_pos
        self.projection_layer = projection_layer

    def encode(self, src, src_mask):
        # (batch, seq_len, d_model)
        src = self.src_embed(src)
        src = self.src_pos(src)
        return self.encoder(src, src_mask)
    
    def decode(self, encoder_output: torch.Tensor, src_mask: torch.Tensor, tgt: torch.Tensor, tgt_mask: torch.Tensor):
        # (batch, seq_len, d_model)
        tgt = self.tgt_embed(tgt)
        tgt = self.tgt_pos(tgt)
        return self.decoder(tgt, encoder_output, src_mask, tgt_mask)
    
    def project(self, x):
        # (batch, seq_len, vocab_size)
        return self.projection_layer(x)

In [35]:
def build_transformer(src_vocab_size: int, tgt_vocab_size: int, src_seq_len: int, tgt_seq_len: int, d_model: int = 512, N: int = 6, h: int = 8, dropout: float = 0.1 , d_ff: int = 2048) -> Transformer: 
    # create embedding layer: 
    # src_embed = InputEmbeddings(d_model=d_model, vocab_size=src_vocab_size)
    src_embed = InputEmbeddings(d_model=d_model, vocab_size=src_vocab_size)
    tgt_embed = InputEmbeddings(d_model=d_model, vocab_size=tgt_vocab_size) 
    
    # create positional encoding:
    src_pos = PositionalEncoding(d_model=d_model, seq_len=src_seq_len, dropout=dropout)
    tgt_pos = PositionalEncoding(d_model=d_model, seq_len=tgt_seq_len, dropout=dropout) 
    
    # create encoder and decoder layers: 
    encoder_blocks = [] 
    for _ in range(N): 
        encoder_multi_head_attention = MultiHeadAttentionBlock(d_model=d_model, h=h, dropout=dropout) 
        feed_forward = FeedForwardBlock(d_model=d_model, d_ff=d_ff, dropout=dropout)
        encoder_block = EncoderBlock(multi_head_attention=encoder_multi_head_attention, feed_forward=feed_forward, dropout=dropout)
        encoder_blocks.append(encoder_block)
    
    decoder_blocks = [] 
    for _ in range(N): 
        decoder_multi_head_attention = MultiHeadAttentionBlock(d_model=d_model, h=h, dropout=dropout)
        feed_forward = FeedForwardBlock(d_model=d_model, d_ff=d_ff, dropout=dropout)
        decoder_block = DecoderBlock(multi_head_attention=decoder_multi_head_attention, feed_forward=feed_forward, dropout=dropout)
        decoder_blocks.append(decoder_block)
    
    encoder = Encoder(nn.ModuleList(encoder_blocks)) 
    decoder = Decoder(nn.ModuleList(decoder_blocks)) 
    
    # create projection layer:
    projection_layer = ProjectionLayer(d_model=d_model, vocab_size=tgt_vocab_size)
    transformer = Transformer(encoder=encoder, decoder=decoder, src_embed=src_embed, tgt_embed=tgt_embed, src_pos=src_pos, tgt_pos=tgt_pos, projection_layer=projection_layer) 
    
    # initialize weights: 
    for p in transformer.parameters():
        if p.dim() > 1:
            nn.init.xavier_uniform_(p)
    
    return transformer
    
    

### Transformer cho dịch máy từ en -> vi

In [None]:

import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader, RandomSampler

import numpy as np
import matplotlib.pyplot as plt
import glob
from sklearn.model_selection import train_test_split
import unicodedata
import time
from tqdm import tqdm 

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
en, vi = [], []
en_path = '/home/hoang.minh.an/anhalu-data/learning/deep-learning-from-scratch/nlp/seq2seq/data/en_sents.txt' 
vi_path = '/home/hoang.minh.an/anhalu-data/learning/deep-learning-from-scratch/nlp/seq2seq/data/vi_sents.txt'

with open(en_path, 'r') as f: 
    en = f.readlines() 

with open(vi_path, 'r') as f:
    vi = f.readlines()

print(f"Len en : {len(en)}") 
print(f"Len vi : {len(vi)}") 
print(en[0], vi[0])
en = en[:50000] 
vi = vi[:50000]

In [None]:
import re
PAD_token = 0
SOS_token = 1
EOS_token = 2
UNK_token = 3 

class Lang: 
    def __init__(self, name):
        self.name = name
        self.word2index = {"SOS": 1, "EOS": 2, "PAD": 0, "UNK": 3}
        self.word2count = {}
        self.index2word = {1: "SOS", 2: "EOS", 0: "PAD", 3: "UNK"}
        self.n_words = 4  # Count SOS, EOS, PAD, UNK

    def addSentence(self, sentence):
        for word in sentence.split(' '):
            self.addWord(word)

    def addWord(self, word):
        if word not in self.word2index:
            self.word2index[word] = self.n_words
            self.word2count[word] = 1
            self.index2word[self.n_words] = word
            self.n_words += 1
        else:
            self.word2count[word] += 1
    
    def get_idx(self, word):
        if word not in self.word2index:
            return self.word2index['UNK']
        return self.word2index[word]
            
# Turn a Unicode string to plain ASCII, thanks to
def unicodeToAscii(s):
    return ''.join(
        c for c in unicodedata.normalize('NFD', s)
        if unicodedata.category(c) != 'Mn'
    )

# Lowercase, trim, and remove non-letter characters
def normalizeString(s):
    s = s.lower().strip()
    # s = unicodeToAscii(s)
    # s = re.sub(r"([.!?])", r" \1", s)
    # s = re.sub(r"[^a-zA-Z.!?]+", r" ", s)
    return s


def readLangs(lang1, lang2, en, vi, reverse=False):
    pairs = []
    for i in range(len(en)):
        pairs.append([normalizeString(en[i]), normalizeString(vi[i])])
    if reverse:
        pairs = [list(reversed(p)) for p in pairs]
        input_lang = Lang(lang2)
        output_lang = Lang(lang1)
    else:
        input_lang = Lang(lang1)
        output_lang = Lang(lang2)
    
    for pair in pairs:
        input_lang.addSentence(pair[0])
        output_lang.addSentence(pair[1])
    
    print("Count words in language:")
    print(input_lang.name, input_lang.n_words)
    print(output_lang.name, output_lang.n_words)   
    
    return input_lang, output_lang, pairs

en_lang, vi_lang, pairs = readLangs('en', 'vi', en, vi, False)
MAX_LENGTH = 60
# print(MAX_LENGTH)

# filter pair for max lenght 
pairs = [pair for pair in pairs if len(pair[0].split(" ")) < MAX_LENGTH and len(pair[1].split(" ")) < MAX_LENGTH] 
print(f"Len pairs: {len(pairs)}")
print(pairs[0])

In [38]:
def word2index(lang: Lang, sentence: str): 
    sentence = normalizeString(sentence)
    res = [SOS_token] 
    for word in sentence.split(" "):
        if len(word) == 0: 
            continue
        res.append(lang.get_idx(word))
    res.append(EOS_token)
    return np.array(res)


def get_dataloader(pairs, input_lang: Lang, output_lang: Lang, batch_size, max_lenght) -> DataLoader:
    n = len(pairs)
    input_pad = np.full((n, max_lenght), PAD_token, dtype=np.int32)
    target_pad = np.full((n, max_lenght), PAD_token, dtype=np.int32)
    for idx, pair in enumerate(pairs):
        input = word2index(input_lang, pair[0])
        target = word2index(output_lang, pair[1])
        input_pad[idx, :len(input)] = input
        target_pad[idx, :len(target)] = target
    input_tensor = torch.tensor(input_pad, dtype=torch.long, device=device)
    target_tensor = torch.tensor(target_pad, dtype=torch.long, device=device)
    
    dataset = TensorDataset(input_tensor, target_tensor) 
    sampler = RandomSampler(dataset) 
    dataloader = DataLoader(dataset, sampler=sampler, batch_size=batch_size) 
    return dataloader




In [None]:
# init hyperparameters
MAX_LENGTH = MAX_LENGTH + 2 # add 2 for SOS and EOS token
src_vocab_size = en_lang.n_words
tgt_vocab_size = vi_lang.n_words
print(f"Src vocab size: {src_vocab_size}")
print(f"Tgt vocab size: {tgt_vocab_size}")
src_seq_len = MAX_LENGTH
tgt_seq_len = MAX_LENGTH
d_model = 128
N = 2 # number of encoder and decoder layers    
h = 4 # number of attention heads
dropout = 0.1
d_ff = 128 # feed forward hidden layer size
lr = 5e-4

transformer = build_transformer(src_vocab_size=src_vocab_size, tgt_vocab_size=tgt_vocab_size, src_seq_len=src_seq_len, tgt_seq_len=tgt_seq_len, d_model=d_model, N=N, h=h, dropout=dropout, d_ff=d_ff).to(device)

In [None]:
batch_size = 4
num_epochs = 20
train_pairs, test_pairs = train_test_split(pairs, test_size=0.05, random_state=42) 
print(f"Len train: {len(train_pairs)}") 
print(f"Len test: {len(test_pairs)}")

train_dataloader = get_dataloader(train_pairs, en_lang, vi_lang, batch_size=batch_size, max_lenght=MAX_LENGTH) 
test_dataloader = get_dataloader(test_pairs, en_lang, vi_lang, batch_size=batch_size, max_lenght=MAX_LENGTH)

In [74]:
class EarlyStopping():
    def __init__(self, tolerance=5, min_delta=0):

        self.tolerance = tolerance
        self.min_delta = min_delta
        self.counter = 0
        self.early_stop = False

    def __call__(self, train_loss, validation_loss):
        if (validation_loss - train_loss) > self.min_delta:
            self.counter +=1
            if self.counter >= self.tolerance:  
                self.early_stop = True

In [None]:
optimizer = optim.Adam(transformer.parameters(), lr=lr, betas=(0.9, 0.98), eps=1e-9) 
criterion = nn.CrossEntropyLoss(ignore_index=PAD_token) 
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def train(transformer, dataloader, optimizer, criterion, num_epochs):
    early_stopping = EarlyStopping(tolerance=5, min_delta=0.1)
    for epoch in range(num_epochs): 
        transformer.train() 
        correct_predictions = 0
        total_predictions = 0
        epoch_loss = 0
        for batch in tqdm(dataloader, "transformer training"):
            src, tgt = batch
            src, tgt = src.to(device), tgt.to(device)   # tgt : (batch, seq_len)
            # remove the last token from target
            tgt_input = tgt[:, :-1] # (batch, seq_len - 1)
            # remove the first token from target -> for prediction the next token
            tgt_output = tgt[:, 1:] # (batch, seq_len - 1)
            # mask padding token for src (batch, 1, 1, seq_len)
            src_mask = (src != 0).unsqueeze(-2).unsqueeze(-2).to(device)    # (batch, 1, 1, seq_len)
            
            """
            create target mask: 
            - mask padding token tgt_input != 0 
            - mask future token 
                - create a upper triangular matrix of shape (1, seq_len, seq_len)
                [[1, 1, 1, 1], 
                 [0, 1, 1, 1],
                 [0, 0, 1, 1],
                 [0, 0, 0, 1]]
                 - transpose to boolean mask
                 [[true, true, true, true],
                  [false, true, true, true],
                  [false, false, true, true],
                  [false, false, false, true]]
                 - for prediction the next token, we don't need to see the future token 
                token 1 : true, false, false, false
                token 2 : true, true, false, false
                token 3 : true, true, true, false
                token 4 : true, true, true, true
                
                True : we can see the token
                False : we can't see the token
            """
            feature_mask = (torch.triu(torch.ones((1, tgt_input.size(1), tgt_input.size(1)), device=device)) == 1).to(device)
            padding_mask = (tgt_input != 0).unsqueeze(-2).to(device)
            tgt_mask = padding_mask & feature_mask
            tgt_mask = tgt_mask.unsqueeze(1) # (batch, 1, seq_len - 1, seq_len - 1)
            # zero the gradient
            optimizer.zero_grad()
            # src : (batch, seq_len) -> (batch, seq_len, d_model)
            # src_mask : (batch, 1, seq_len)
            encoder_output = transformer.encode(src, src_mask)
            # tgt_input : (batch, seq_len - 1) -> (batch, seq_len - 1, d_model)
            # tgt_mask : (batch, seq_len - 1, seq_len - 1)
            # encoder_output : (batch, seq_len, d_model)
            # k, v: encoder_output (batch, seq_len, d_model) 
            # q: tgt_input (batch, seq_len - 1, d_model) 
            # attention_scores : (batch, h, seq_len - 1, seq_len)
            # x : (batch, seq_len - 1, d_model)
            decoder_output = transformer.decode(encoder_output, src_mask, tgt_input, tgt_mask)
            
            output = transformer.project(decoder_output)
            # output : (batch, seq_len - 1, vocab_size)
            # tgt_output : (batch, seq_len - 1)
            # output.view(-1, output.size(-1)) : (batch * seq_len - 1, vocab_size)
            # tgt_output.view(-1) : (batch * seq_len - 1)
            
            loss = criterion(output.reshape(-1, output.size(-1)), tgt_output.reshape(-1))
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
            
            # get accuracy 
            probabilities = F.softmax(output, dim=-1)
            # Calculate accuracy excluding padding tokens
            predictions = probabilities.argmax(dim=-1)

            non_pad_mask = tgt_output != 0

            correct_predictions += (predictions[non_pad_mask] == tgt_output[non_pad_mask]).sum().item()
            total_predictions += non_pad_mask.sum().item()

            
        train_accuracy = correct_predictions / total_predictions
            
        # evaluate(transformer, test_dataloader, criterion)
        print(f"Epoch {epoch} Loss: {epoch_loss / len(train_dataloader)}, accuracy: {train_accuracy}")
        eval_loss, accuracy = evaluate(transformer, test_dataloader, criterion)
        early_stopping(train_loss=epoch_loss / len(train_dataloader), validation_loss=eval_loss)
        if early_stopping.early_stop:
            print("Early stopping")
            break


def evaluate(transformer, dataloader, criterion): 
    transformer.eval()
    eval_loss = 0
    correct_predictions = 0
    total_predictions = 0
    with torch.no_grad():
        for batch in tqdm(dataloader, "transformer evaluation"):
            src, tgt = batch
            src, tgt = src.to(device), tgt.to(device)   # tgt : (batch, seq_len)
            # remove the last token from target
            tgt_input = tgt[:, :-1] # (batch, seq_len - 1)
            # remove the first token from target -> for prediction the next token
            tgt_output = tgt[:, 1:] # (batch, seq_len - 1)
            # mask padding token for src (batch, 1, 1, seq_len)
            src_mask = (src != 0).unsqueeze(-2).unsqueeze(-2).to(device)    # (batch, 1, 1, seq_len)
            
            feature_mask = (torch.triu(torch.ones((1, tgt_input.size(1), tgt_input.size(1)), device=device)) == 1).to(device)
            padding_mask = (tgt_input != 0).unsqueeze(-2).to(device)
            tgt_mask = padding_mask & feature_mask
            tgt_mask = tgt_mask.unsqueeze(1) # (batch, 1, seq_len - 1, seq_len - 1)
            
            encoder_output = transformer.encode(src, src_mask)
            decoder_output = transformer.decode(encoder_output, src_mask, tgt_input, tgt_mask)
            output = transformer.project(decoder_output)
            loss = criterion(output.reshape(-1, output.size(-1)), tgt_output.reshape(-1))
            eval_loss += loss.item()
            probabilities = F.softmax(output, dim=-1)
            # Calculate accuracy excluding padding tokens
            predictions = probabilities.argmax(dim=-1)
            non_pad_mask = tgt_output != 0
            correct_predictions += (predictions[non_pad_mask] == tgt_output[non_pad_mask]).sum().item()
            total_predictions += non_pad_mask.sum().item()
            
    accuracy = correct_predictions / total_predictions       
    print(f"Eval loss: {eval_loss / len(dataloader)}, accuracy: {accuracy / len(dataloader)}")
    return eval_loss / len(dataloader), accuracy

train(transformer, train_dataloader, optimizer, criterion, num_epochs)

In [43]:
torch.save(transformer.state_dict(), "/home/hoang.minh.an/anhalu-data/learning/deep-learning-from-scratch/nlp/seq2seq/save_model/transformer_en2vi_v2.pth")

In [None]:
# load model from path 
path = "/home/hoang.minh.an/anhalu-data/learning/deep-learning-from-scratch/nlp/seq2seq/save_model/transformer_en2vi_v2.pth"
transformer.load_state_dict(torch.load(path))

In [None]:
def get_accuracy_in_dataloader(model, dataloader): 
    correct_predictions_all = 0
    total_predictions_all = 0
    with torch.no_grad():
        for batch in tqdm(dataloader, "transformer evaluation"):
            src, tgt = batch
            src, tgt = src.to(device), tgt.to(device)   # tgt : (batch, seq_len)
            # remove the last token from target
            tgt_input = tgt[:, :-1] # (batch, seq_len - 1)
            # remove the first token from target -> for prediction the next token
            tgt_output = tgt[:, 1:] # (batch, seq_len - 1)
            # mask padding token for src (batch, 1, 1, seq_len)
            src_mask = (src != 0).unsqueeze(-2).unsqueeze(-2).to(device)    # (batch, 1, 1, seq_len)
            
            feature_mask = (torch.triu(torch.ones((1, tgt_input.size(1), tgt_input.size(1)), device=device)) == 1).to(device)
            padding_mask = (tgt_input != 0).unsqueeze(-2).to(device)
            tgt_mask = padding_mask & feature_mask
            tgt_mask = tgt_mask.unsqueeze(1) # (batch, 1, seq_len - 1, seq_len - 1)
            
            encoder_output = model.encode(src, src_mask)
            decoder_output = model.decode(encoder_output, src_mask, tgt_input, tgt_mask)
            output = model.project(decoder_output)
            probabilities = F.softmax(output, dim=-1)
            # Calculate accuracy excluding padding tokens
            predictions = probabilities.argmax(dim=-1)
            non_pad_mask = tgt_output != 0
            
            correct_predictions = (predictions[non_pad_mask] == tgt_output[non_pad_mask]).sum().item()
            total_predictions = non_pad_mask.sum().item() 

            correct_predictions_all += correct_predictions
            total_predictions_all += total_predictions
            # print(predictions[0])
            # print(tgt_output[0])
            # break
            
    accuracy = correct_predictions_all / total_predictions_all  
    print("Correct predictions: ", correct_predictions_all)
    print("Total predictions: ", total_predictions_all)
    print("Accuracy : ", accuracy)
    return accuracy

a = get_accuracy_in_dataloader(transformer, test_dataloader)
print(a)