# En-Vi Translator with low resource

## Installation

In [1]:
# !pip install -q install torch torchvision torchaudio
# !pip install -q transformers sentencepiece datasets accelerate evaluate sacrebleu

## Libraries

In [2]:
import os
import numpy as np

import torch
from torch.utils.data import Dataset

from datasets import load_dataset
import evaluate
from transformers import (
    AutoTokenizer,
    AutoModelForSeq2SeqLM,
    DataCollatorForSeq2Seq,
    Seq2SeqTrainingArguments,
    Seq2SeqTrainer
)

  from .autonotebook import tqdm as notebook_tqdm


## Build Dataset

In [3]:
class NMTDataset(Dataset):
    def __init__(self, cfg, split='train', prefix=''):
        self.cfg = cfg

        src_texts, tgt_texts = self.read_data(split, prefix)

        self.src_input_ids = self.text_to_sequence(src_texts)
        self.labels = self.text_to_sequence(tgt_texts)

    def read_data(self, split, prefix):
        dataset = load_dataset('mt_eng_vietnamese', 
                               'iwslt2015-en-vi', 
                               split=split,
                               cache_dir=self.cfg.cache_dir)

        src_texts = [prefix + sample['translation'][self.cfg.src_lang] for sample in dataset]
        tgt_texts = [sample['translation'][self.cfg.tgt_lang] for sample in dataset]

        return src_texts, tgt_texts
    
    def text_to_sequence(self, text):
        inputs = self.cfg.tokenizer(
            text,
            padding='max_length',
            truncation=True,
            max_length=self.cfg.max_length,
            return_tensors='pt'
        )

        return inputs.input_ids
    
    def __getitem__(self, index):
        return {
            'input_ids': self.src_input_ids[index],
            'labels': self.labels[index]
        }
    
    def __len__(self):
        return np.shape(self.src_input_ids)[0]

## Configuration

In [4]:
class BaseConfig:
    def __init__(self, **kwargs):
        for k, v in kwargs.items():
            setattr(self, k, v)


class NMTConfig(BaseConfig):
    # Data
    src_lang = 'en'
    tgt_lang = 'vi'
    max_length = 75
    add_special_token = True

    # Model
    model_name = "Helsinki-NLP/opus-mt-en-vi"
    cache_dir = './.cache/'

    # Training
    device = 'cuda' if torch.cuda.is_available() else ('mps' if torch.backends.mps.is_available() else 'cpu')
    if device == 'mps':
        use_mps_device=True
        
    learning_rate = 1e-5
    train_batch_size = 16
    eval_batch_size = 16
    num_train_epochs = 2
    save_total_limit = 1
    ckpt_dir = f'./checkpoints'
    eval_steps = 1000

    # interfere
    beam_search = 5

In [5]:
cfg = NMTConfig()
cfg.tokenizer = AutoTokenizer.from_pretrained(cfg.model_name, cache_dir=cfg.cache_dir)
model = AutoModelForSeq2SeqLM.from_pretrained(cfg.model_name, cache_dir=cfg.cache_dir)



## Setup evaluation metrics

In [6]:
metric = evaluate.load('sacrebleu', cache_dir=cfg.cache_dir)

In [7]:
def postprocess_text(preds, labels):
    preds = [pred.strip() for pred in preds]
    labels = [[label.strip()] for label in labels]

    return preds, labels


def compute_metrics(eval_preds):
    preds, labels = eval_preds
    if isinstance(preds, tuple):
        preds = preds[0]
    decoded_preds = cfg.tokenizer.batch_decode(preds, skip_special_tokens=True)

    labels = np.where(labels != -100, labels, cfg.tokenizer.pad_token_id)
    decoded_labels = cfg.tokenizer.batch_decode(labels, skip_special_tokens=True)

    decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)

    result = metric.compute(predictions=decoded_preds, references=decoded_labels)
    result = {"bleu": result["score"]}

    prediction_lens = [np.count_nonzero(pred != cfg.tokenizer.pad_token_id) for pred in preds]
    result["gen_len"] = np.mean(prediction_lens)
    result = {k: round(v, 4) for k, v in result.items()}
    return result

## Trainer

In [8]:
prefix = "translate English to Vietnamese: "

train_dataset = NMTDataset(cfg, 'train', prefix=prefix)
valid_dataset = NMTDataset(cfg, 'validation', prefix=prefix)
test_dataset = NMTDataset(cfg, 'test', prefix=prefix)

Downloading data: 100%|██████████| 17.8M/17.8M [00:01<00:00, 13.6MB/s]
Downloading data: 100%|██████████| 181k/181k [00:00<00:00, 223kB/s]
Downloading data: 100%|██████████| 181k/181k [00:00<00:00, 223kB/s]
Generating train split: 100%|██████████| 133318/133318 [00:00<00:00, 1959745.77 examples/s]
Generating validation split: 100%|██████████| 1269/1269 [00:00<00:00, 529135.28 examples/s]
Generating test split: 100%|██████████| 1269/1269 [00:00<00:00, 841646.39 examples/s]


In [9]:
training_args = Seq2SeqTrainingArguments(
    predict_with_generate=True,
    save_strategy='epoch',
    evaluation_strategy='epoch',
    output_dir=cfg.ckpt_dir,
    per_device_train_batch_size=cfg.train_batch_size,
    per_device_eval_batch_size=cfg.eval_batch_size,
    use_mps_device=cfg.use_mps_device,
    save_total_limit=cfg.save_total_limit,
    learning_rate=cfg.learning_rate,
    num_train_epochs=cfg.num_train_epochs,
    load_best_model_at_end=True
)

data_collator = DataCollatorForSeq2Seq(
    tokenizer=cfg.tokenizer,
    model=model
)

trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    tokenizer=cfg.tokenizer,
    train_dataset=train_dataset,
    eval_dataset=valid_dataset,
    data_collator=data_collator,
    compute_metrics=compute_metrics
)



In [10]:
trainer.train()

  0%|          | 60/83330 [00:45<16:32:11,  1.40it/s]

In [None]:
model_dir = './models/'
trainer.save_model(output_dir=model_dir)