# Import Libraries

In [1]:
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from transformers import  PreTrainedTokenizerFast
from datasets import load_dataset
from tokenizers import trainers, Tokenizer, pre_tokenizers, models
from torch.utils.data import Dataset, DataLoader

  from .autonotebook import tqdm as notebook_tqdm


# Import Dataset

In [2]:
ds = load_dataset('thainq107/iwslt2015-en-vi')


In [3]:
ds

DatasetDict({
    train: Dataset({
        features: ['en', 'vi'],
        num_rows: 133317
    })
    validation: Dataset({
        features: ['en', 'vi'],
        num_rows: 1268
    })
    test: Dataset({
        features: ['en', 'vi'],
        num_rows: 1268
    })
})

# Tokenizer

In [None]:
en_tokenizer= Tokenizer(models.WordLevel(unk_token='<unk>'))
vi_tokenizer = Tokenizer(models.WordLevel(unk_token='<unk>'))

en_tokenizer.pre_tokenizer= pre_tokenizers.Whitespace()
vi_tokenizer.pre_tokenizer= pre_tokenizers.Whitespace()

trainer = trainers.WordLevelTrainer(vocab_size=20000, min_frequency = 2, special_tokens= ['<unk>','<pad>','<bos>','<eos>'])

en_tokenizer.train_from_iterator(ds['train']['en'], trainer)
vi_tokenizer.train_from_iterator(ds['train']['vi'],trainer)

en_tokenizer.save('tokenizer_en.json')
vi_tokenizer.save('tokenizer_vi.json')

In [5]:
en_tokenizer = PreTrainedTokenizerFast(tokenizer_file='tokenizer_en.json', bos_token='<bos>',eos_token='<eos>', pad_token='<pad>', unk_token='<unk>')
vi_tokenizer = PreTrainedTokenizerFast(tokenizer_file='tokenizer_vi.json', bos_token='<bos>',eos_token='<eos>', pad_token='<pad>', unk_token='<unk>')

In [6]:
seq_len=100

# Preprocess

In [69]:
def preprocess(examples):
    src_text = examples['en']
    tgt_text = ['<bos> '+ text+' <eos>' for text in examples['vi']]
    
    src_input_ids = en_tokenizer(src_text,truncation=True,max_length=seq_len,padding='max_length',return_tensors='pt')['input_ids']
    tgt_input_ids = vi_tokenizer(tgt_text,truncation=True,max_length=seq_len,padding='max_length',return_tensors='pt')['input_ids']
    
    return {
        'input_ids': src_input_ids,
        'labels': tgt_input_ids
    }

In [70]:
preprocessed_ds = ds.map(preprocess, batched=True)

Map: 100%|██████████| 133317/133317 [00:25<00:00, 5269.89 examples/s]
Map: 100%|██████████| 1268/1268 [00:00<00:00, 4977.20 examples/s]
Map: 100%|██████████| 1268/1268 [00:00<00:00, 6318.13 examples/s]


In [81]:
train_ds = preprocessed_ds['train'][:4000]
val_ds =preprocessed_ds['validation']
test_ds = preprocessed_ds['test']


In [84]:
def check(preprocessed_data):
    for i in range(len(preprocessed_data[:10])):
        print('*'*100)
        print('English: ',preprocessed_data['en'][i])
        print('Input_ids: ',preprocessed_data['input_ids'][i])
        print('Vietnam: ', preprocessed_data['vi'][i])
        print('Input_ids: ',preprocessed_data['labels'][i])
        

In [85]:
check(test_ds)

****************************************************************************************************
English:  When I was little , I thought my country was the best on the planet , and I grew up singing a song called &quot; Nothing To Envy . &quot;
Input_ids:  [219, 15, 25, 131, 4, 15, 199, 47, 280, 25, 6, 301, 30, 6, 510, 4, 12, 15, 1040, 71, 2168, 13, 1003, 172, 8, 24, 7, 3192, 664, 0, 5, 8, 24, 7, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
Vietnam:  Khi tôi còn nhỏ , Tôi nghĩ rằng BắcTriều Tiên là đất nước tốt nhất trên thế giới và tôi thường hát bài &quot; Chúng ta chẳng có gì phải ghen tị . &quot;
Input_ids:  [2, 316, 7, 122, 235, 4, 44, 80, 50, 0, 3945, 6, 280, 152, 173, 92, 70, 42, 97, 10, 7, 208, 793, 301, 26, 28, 23, 76, 15, 497, 9, 53, 49, 3264, 1789, 5, 26, 28, 23, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1

# Dataset

In [86]:
class MyDataset(Dataset):
    def __init__(self, data):
        self.input_ids = data['input_ids']
        self.labels = data['labels']

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        return torch.tensor( self.input_ids[idx]), torch.tensor(self.labels[idx])

In [None]:
train_ds = MyDataset(train_ds)
val_ds = MyDataset(val_ds)
test_ds = MyDataset(test_ds)

In [88]:
train_loader = DataLoader(train_ds, batch_size=64, shuffle=True)
val_loader = DataLoader(val_ds, batch_size=64, shuffle=True)
test_loader = DataLoader(test_ds, batch_size=64, shuffle=True)


# Model

In [134]:


class Encoder(nn.Module):
    def __init__(self, vocab_size ,embedding_dim, num_layers, n_heads ):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.positional_encoding = nn.Embedding(seq_len,embedding_dim)
        encoder_layer = nn.TransformerEncoderLayer(d_model=embedding_dim,nhead=n_heads, dim_feedforward=4092, dropout=0.2, batch_first=True)
        self.encoder = nn.TransformerEncoder(encoder_layer,num_layers )
        self.fc = nn.Sequential(
            nn.Dropout(0.2),
            nn.Linear(embedding_dim,512),
            nn.LeakyReLU(),
            nn.Linear(512,1024),
            nn.GELU()
        )
    
    def forward(self, inputs):
        sz= inputs.shape[1]
        mask = torch.zeros((sz, sz),device=inputs.device).type(torch.bool)
        padding_mask = (inputs==1)
        
        embeddings = self.embedding(inputs) # NxLxC
        
        positions = torch.arange(seq_len,device=inputs.device).unsqueeze(0)
        positional_encoding = self.positional_encoding(positions)
        embeddings+= positional_encoding
        inputs = self.encoder(embeddings,mask, padding_mask)  # NxLxC
        outputs = self.fc(inputs)
        return outputs
    
class Decoder(nn.Module):
    def __init__(self, vocab_size, num_layers, n_heads ,embedding_dim=1024):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.positional_encoding = nn.Embedding(seq_len, embedding_dim)
        decoder_layer = nn.TransformerDecoderLayer(d_model= embedding_dim, dim_feedforward=2048,dropout=0.2, batch_first=True,nhead=n_heads )
        self.decoder = nn.TransformerDecoder(decoder_layer, num_layers)
        self.fc = nn.Sequential(
            nn.Dropout(0.2),
            nn.Linear(embedding_dim,512),
            nn.LeakyReLU(),
            nn.Linear(512,vocab_size)
        )
    
    def forward(self, tgt_inputs, abstract_features ):
        sz= tgt_inputs.shape[1]
        mask = (torch.triu(torch.ones((sz, sz), device=tgt_inputs.device)) == 1).transpose(0, 1)
        mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
        tgt_padding_mask = (tgt_inputs == 1)
        
        embeddings= self.embedding(tgt_inputs)
        positions = torch.arange(seq_len,device=tgt_inputs.device).unsqueeze(0)
        positional_encoding = self.positional_encoding(positions)
        embeddings+= positional_encoding
        
        outputs= self.decoder(embeddings, abstract_features, tgt_mask= mask , tgt_key_padding_mask  = tgt_padding_mask)
        outputs = self.fc(outputs)
        return outputs
        

class Model(nn.Module):
    def __init__(self, en_tokenizer, vi_tokenizer, num_layers, n_heads, embedding_dim):
        super().__init__()
        self.bos_idx = en_tokenizer.convert_tokens_to_ids('<bos>')
        self.encoder = Encoder(en_tokenizer.vocab_size,embedding_dim,num_layers,n_heads)
        self.decoder = Decoder(vi_tokenizer.vocab_size,num_layers, n_heads)
    
    def forward(self, input_ids, labels):
        abstract_features = self.encoder(input_ids)
        logits= self.decoder(labels, abstract_features)

        return logits.permute(0,2,1)
        
        

In [135]:
import tqdm
device= 'cuda' if torch.cuda.is_available() else 'cpu'
model = Model(en_tokenizer,vi_tokenizer,2,2,128).to(device)
epochs = 1
optimizer= torch.optim.Adam(model.parameters(), lr=1e-4)
criterion= nn.CrossEntropyLoss(ignore_index=1)
model=model.to(device)

In [91]:
def eval(model, val_loader, criterion, device):
    losses = []
   
    with torch.no_grad():
        model.eval()
        for (input_ids, labels) in tqdm.tqdm(val_loader, desc='validation'):
            input_ids = input_ids.to(device)
            labels= labels.to(device)
            
            preds = model(input_ids,labels)
            loss = criterion(preds[:,:,:-1], labels[:,1:])
            
            losses.append(loss)
        
        loss= sum(losses)/len(losses)
        print(f'Valid loss: {loss}')
    return loss

In [None]:
def fit(model, train_loader, val_loader,  optimizer, criterion, device):
    for epoch in tqdm.tqdm(range(epochs), desc='Epoch'):
        training_losses= []
        model.train()
        for idx, (input_ids, labels) in enumerate(train_loader):
            input_ids= input_ids.to(device)
            labels= labels.to(device)
            
            optimizer.zero_grad()
            logits = model(input_ids,labels)
            loss = criterion(logits[:,:,:-1], labels[:,1:])
            training_losses.append(loss)
            loss.backward()
            optimizer.step()
        train_loss = sum(training_losses)/len(training_losses)
        print(f'EPOCH {epoch+1}\t Training Loss: {train_loss}')
        eval(model, val_loader, criterion, device)

In [92]:
fit(model, train_loader, val_loader, optimizer, criterion, device)

Epoch:   0%|          | 0/1 [00:00<?, ?it/s]

validation: 100%|██████████| 20/20 [00:14<00:00,  1.42it/s]
Epoch: 100%|██████████| 1/1 [08:17<00:00, 497.88s/it]

Valid loss: 6.069946765899658
EPOCH 1	 Training Loss: 6.998533725738525





# Inference

In [146]:
def infer(model, sample):
    model.eval()
    input_ids = en_tokenizer(sample,padding= 'max_length', truncation= True,max_length= seq_len)
    input_ids = torch.tensor(input_ids,device=model.device).unsqueeze(0)
    translated_sentence='<bos>'

    for i in range(seq_len):
        tokenized_sentence = torch.tensor(vi_tokenizer(translated_sentence,truncation=True, padding='max_length', max_length= 100)['input_ids'],device= device).unsqueeze(0)
        preds = model(input_ids,tokenized_sentence)
        token_idx = preds.argmax(1)[:,i]
        word = vi_tokenizer.convert_ids_to_tokens(token_idx)[-1]
        translated_sentence += ' '+ word
        
        if word =='<eos>':
            break
    print(f'Origin: {sample}')
    print(f'Translation: {translated_sentence}')