## Import

In [None]:
!pip install transformers
from typing import Dict, List
import csv
import torch
from transformers import (
    EncoderDecoderModel,
    GPT2Tokenizer as BaseGPT2Tokenizer,
    PreTrainedTokenizer, BertTokenizerFast,
    PreTrainedTokenizerFast,
    DataCollatorForSeq2Seq,
    Seq2SeqTrainingArguments,
    AutoTokenizer,
    XLMRobertaTokenizerFast,
    Trainer
)
from torch.utils.data import DataLoader
from transformers.models.encoder_decoder.modeling_encoder_decoder import EncoderDecoderModel

encoder_model_name = "xlm-roberta-base"
decoder_model_name = "skt/kogpt2-base-v2"

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device, torch.cuda.device_count()

In [None]:
class GPT2Tokenizer(PreTrainedTokenizerFast):
    def build_inputs_with_special_tokens(self, token_ids: List[int]) -> List[int]:
        return token_ids + [self.eos_token_id]        

src_tokenizer = XLMRobertaTokenizerFast.from_pretrained(encoder_model_name)
trg_tokenizer = GPT2Tokenizer.from_pretrained(decoder_model_name, bos_token='</s>', eos_token='</s>', unk_token='<unk>',
  pad_token='<pad>', mask_token='<mask>')

## Data

In [None]:
class PairedDataset:
    def __init__(self, 
        src_tokenizer: PreTrainedTokenizerFast, tgt_tokenizer: PreTrainedTokenizerFast,
        file_path: str
    ):
        self.src_tokenizer = src_tokenizer
        self.trg_tokenizer = tgt_tokenizer
        with open(file_path, 'r') as fd:
            reader = csv.reader(fd)
            next(reader)
            self.data = [row for row in reader]

    def __getitem__(self, index: int) -> Dict[str, torch.Tensor]:
        src, trg = self.data[index]
        embeddings = self.src_tokenizer(src, return_attention_mask=False, return_token_type_ids=False)
        embeddings['labels'] = self.trg_tokenizer.build_inputs_with_special_tokens(self.trg_tokenizer(trg, return_attention_mask=False)['input_ids'])

        return embeddings

    def __len__(self):
        return len(self.data)
train_dataset = PairedDataset(src_tokenizer, trg_tokenizer, '일-한 언어 corpus_train.csv')
eval_dataset = PairedDataset(src_tokenizer, trg_tokenizer, '일-한 언어 corpus_eval.csv')        

## Model

In [None]:
model = EncoderDecoderModel.from_encoder_decoder_pretrained(
    encoder_model_name,
    decoder_model_name,
    pad_token_id=trg_tokenizer.bos_token_id,
)
model.config.decoder_start_token_id = trg_tokenizer.bos_token_id

In [None]:
# for Trainer

collate_fn = DataCollatorForSeq2Seq(src_tokenizer, model)
wandb.init(project="temp", name='roberta+kogpt2')

arguments = Seq2SeqTrainingArguments(
    output_dir='dump',
    do_train=True,
    do_eval=True,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    num_train_epochs=4,
    per_device_train_batch_size=32,
    per_device_eval_batch_size=32,
    warmup_ratio=0.1,
    gradient_accumulation_steps=4,
    save_total_limit=5,
    dataloader_num_workers=1,
    fp16=True,
    load_best_model_at_end=True,
    report_to='wandb'
)

trainer = Trainer(
    model,
    arguments,
    data_collator=collate_fn,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset
)

## Training

In [None]:
model = EncoderDecoderModel.from_pretrained("best_model")

In [None]:
trainer.train()

model.save_pretrained("dump/best_model")

## 번역

In [None]:
text = "もんじゃ焼き"
embeddings = src_tokenizer(text, return_attention_mask=False, return_token_type_ids=False, return_tensors='pt')
embeddings = {k: v for k, v in embeddings.items()}
output = model.generate(**embeddings)[0, 1:-1]

In [None]:
trg_tokenizer.decode(output.cpu())

## Evaluating model

In [None]:
from nltk.translate.bleu_score import sentence_bleu

In [None]:
from tqdm import tqdm
from statistics import mean

bleu = []
f1 = []

with torch.no_grad(), open('test.csv', 'r') as fd:
    reader = csv.reader(fd)
    next(reader)
    datas = [row for row in reader]    

    for data in tqdm(datas, "Testing"):
        input, label = data
        embeddings = src_tokenizer(input, return_attention_mask=False, return_token_type_ids=False, return_tensors='pt')
        embeddings = {k: v for k, v in embeddings.items()}
        with torch.no_grad():
            output = model.generate(**embeddings)[0, 1:-1]
        preds = trg_tokenizer.decode(output.cpu())

        bleu.append(sentence_bleu([label.split()], preds.split(), weights=[1,0,0,0]))

In [None]:
print(f"Bleu score: {mean(bleu)}")