In [None]:
from transformers import MBart50TokenizerFast, MBartForConditionalGeneration
from typing import Iterable, List
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader, Dataset
from timeit import default_timer as timer
from torch.nn import Transformer
from torch import Tensor
from sklearn.model_selection import train_test_split
from tqdm.auto import tqdm
import torch.nn as nn
import torch
import torch.nn.functional as F
import numpy as np
import math
import os
import pandas as pd
import matplotlib.pyplot as plt
import json
from torch.optim import AdamW

In [None]:
# Set seed.
seed = 42
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = True

In [None]:
#model name is mBART
model_name = "facebook/mbart-large-50-many-to-many-mmt"
tokenizer = MBart50TokenizerFast.from_pretrained(model_name)
training_file_path = "/kaggle/input/sam-data/sampleData.json"
cols = ["ko_text", "en_text"]
test_size = 0.25

BATCH_SIZE = 1
NUM_EPOCHS = 2
if(torch.cuda.is_available()):
    DEVICE = "cuda"
else:
    DEVICE = "cpu"
print(f"Using {DEVICE}")
model = MBartForConditionalGeneration.from_pretrained(model_name).to(DEVICE)


In [None]:
# Total parameters and trainable parameters.
total_params = sum(p.numel() for p in model.parameters())
print(f"{total_params:,} total parameters.")
total_trainable_params = sum(
    p.numel() for p in model.parameters() if p.requires_grad)
print(f"{total_trainable_params:,} training parameters.")
print(model)

In [None]:
#dataset load
with open(training_file_path, encoding = 'utf-8') as f:
    json_data = json.load(f)
# json_data
texts = [{col: item[col] for col in cols} for item in json_data["Text"]]

df = pd.DataFrame(texts)
df.head()

In [None]:
tokenizer.src_lang = "ko_KR"
tokenizer.tgt_lang = "en_XX"
#mBART use special language token for identification
model.config.decoder_start_token_id = tokenizer.convert_tokens_to_ids("en_XX")

In [None]:
# Custom Dataset class.
class TranslationDataset(Dataset):
    def __init__(self, df, tokenizer, max_length = 128):
        self.df = df
        self.tokenizer = tokenizer
        self.max_length = max_length
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        src = self.df[cols[0]].iloc[idx]
        tgt = self.df[cols[1]].iloc[idx]

        src_enc = self.tokenizer(src, return_tensors = "pt", padding="max_length", truncation=True, max_length = self.max_length)
        tgt_enc = self.tokenizer(tgt, return_tensors = "pt", padding="max_length", truncation=True, max_length = self.max_length)

        input_ids = src_enc["input_ids"].squeeze()
        attention_mask = src_enc["attention_mask"].squeeze()
        labels = tgt_enc["input_ids"].squeeze()
        labels[labels == self.tokenizer.pad_token_id] = -100 #ignore padding in loss calculation
        return{
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "labels":labels
        }


In [None]:
#dataset split into train and test data
train_data, test_data = train_test_split(df, test_size = test_size)

In [None]:
train_dataset = TranslationDataset(train_data, tokenizer)
valid_dataset = TranslationDataset(test_data, tokenizer)
iterator = iter(train_dataset)
print(next(iterator))

In [None]:
optimizer = AdamW(model.parameters(), lr=1e-4, betas=(0.9, 0.98), eps=1e-9)

In [None]:
train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE)
valid_dataloader = DataLoader(valid_dataset, batch_size=BATCH_SIZE)
train_dataset[0]

In [None]:
def train(model, optimizer, num_epochs, dataloader):
    print('Training')
    model.to(DEVICE)
    model.train()

    train_loss = []
    for epoch in range(num_epochs):
        epoch_loss = 0
        num_iter = 0
        for batch in dataloader:
            input_ids = torch.tensor(batch['input_ids']).to(DEVICE)
            attention_mask = torch.tensor(batch['attention_mask']).to(DEVICE)
            labels = torch.tensor(batch['labels']).to(DEVICE)
    
            outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
            loss = outputs.loss

            epoch_loss += loss.item()
            num_iter += 1
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            print(f"Epoch {epoch}, iter = {num_iter}, Loss: {loss.item()}")
        train_loss.append(epoch_loss/num_iter)
    return train_loss

def evaluate(model, dataloader):
    print('Validating')
    model.eval()
    losses = 0
    num_iter = 0
    # valid_loss = []
    for batch in dataloader:
        input_ids = torch.tensor(batch['input_ids']).to(DEVICE)
        attention_mask = torch.tensor(batch['attention_mask']).to(DEVICE)
        labels = torch.tensor(batch['labels']).to(DEVICE)

        outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
        loss = outputs.loss

        losses += loss.item()
        num_iter += 1
        print(f"iter = {num_iter}, Loss: {loss.item()}")
    # valid_loss.append(epoch_loss/num_iter)
    return losses/num_iter

In [None]:

src_text = ["안녕하세요. 만나서 반갑습니다."]
inputs = tokenizer(src_text, return_tensors="pt", padding=True)

inputs

In [None]:
train(model, optimizer, NUM_EPOCHS, train_dataloader)

In [None]:
evaluate(model, valid_dataloader)

In [None]:
#save the model
model.save_pretrained("/kaggle/working/facebook/mbart-large-50-many-to-many-mmt-finetuning")

#save tokenizer
tokenizer.save_pretrained("/kaggle/working/facebook/mbart-large-50-many-to-many-mmt-finetuning-token")

In [None]:
#load the model
finetuned_model = MBartForConditionalGeneration.from_pretrained("/kaggle/working/facebook/mbart-large-50-many-to-many-mmt-finetuning").to(DEVICE)
finetuned_tokenizer = MBart50TokenizerFast.from_pretrained("/kaggle/working/facebook/mbart-large-50-many-to-many-mmt-finetuning-token")

In [None]:
src_text = ["안녕하세요. 만나서 반갑습니다."]
inputs = tokenizer(src_text, return_tensors="pt", padding=True).to(DEVICE)

outputs = finetuned_model.generate(**inputs)

translation = tokenizer.decode(outputs[0], skip_special_tokens = True)

In [None]:
translation