In [None]:
import warnings
warnings.filterwarnings('ignore')
import datetime
import os
import gc
from collections import namedtuple
from pprint import pprint

import numpy as np
import pandas as pd

import tensorboard
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.nn import CrossEntropyLoss
from torch.optim import AdamW

from datasets import load_dataset, DatasetDict, Dataset
from transformers import DataCollatorForSeq2Seq
from transformers import BartConfig, T5Config
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from transformers import Seq2SeqTrainingArguments, Seq2SeqTrainer

import matplotlib as mpl
import matplotlib.pyplot as plt
%matplotlib inline

import seaborn as sns

## Settings

In [None]:
MANUAL_TRAINING = True
MANUAL_VALIDATION = True

## Loading Tokenizer & Model Checkpoint

In [None]:
kobart_checkpoint = 'gogamza/kobart-base-v2'
kot5_checkpoint = 'psyche/KoT5'
checkpoint = kobart_checkpoint

In [None]:
if checkpoint == kobart_checkpoint:
    config = BartConfig.from_pretrained(kobart_checkpoint)
    #config['vocab'] = 30000
else:
    config = T5Config.from_pretrained(kot5_checkpoint)

In [None]:
tokenizer = AutoTokenizer.from_pretrained(checkpoint, 
                                          max_length=512, 
                                          truncation=False, 
                                          padding='max_length',
                                          #vocab=config.vocab_size
                                          )
model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint, config=config)

In [None]:
if len(tokenizer) != model.config.vocab_size:
    raise RuntimeError(f'Tokenizer vocab size and model vocab size do not match(Tokenizer:{len(tokenizer)} Model: {model.config.vocab_size}). Which would lead to further error in training.')

## Loading Datasets

In [None]:
dataset = Dataset.from_pandas(pd.read_json('data/simplified_data.json'))

len(dataset)

In [None]:
train_testvalid = dataset.train_test_split(test_size=0.1)
test_valid = train_testvalid['test'].train_test_split(test_size=0.5)
dataset_dict = DatasetDict({
    'train': train_testvalid['train'],
    'valid': test_valid['train'],
    'test': test_valid['test'],
    })

In [None]:
def tokenize(row):
    form_embeddings = tokenizer(row['form'], max_length=512, truncation=True, padding='max_length')
    with tokenizer.as_target_tokenizer():
        correct_form_embeddings = tokenizer(row['corrected_form'], max_length=512, truncation=True, padding='max_length')

    return {
        'input_ids': form_embeddings['input_ids'],
        'attention_mask': form_embeddings['attention_mask'],
        'labels': correct_form_embeddings['input_ids'],
    }

In [None]:
dataset_dict.keys()

In [None]:
replaced_checkpoint = checkpoint.replace('/', '-')
tokenized_dataset_path = f'data/{replaced_checkpoint}_tokenized_dataset'

if not os.path.exists(tokenized_dataset_path):
    tokenized_dataset = (dataset_dict
                         .map(tokenize, 
                              batched=True, 
                              remove_columns=['form', 'corrected_form'], 
                              batch_size=512, 
                              #num_proc=8
                              )
                         )
    
    tokenized_dataset.save_to_disk(tokenized_dataset_path)
else:
    tokenized_dataset = load_dataset(tokenized_dataset_path)

In [None]:
tokenized_dataset['train'][10]

In [None]:
raise RuntimeError

## Training

In [None]:
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model, return_tensors='pt')

In [None]:
training_args = Seq2SeqTrainingArguments(
    output_dir="./results",
    evaluation_strategy = "epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=64,
    per_device_eval_batch_size=64,
    num_train_epochs=2,
    weight_decay=0.01,
    report_to="tensorboard",
    push_to_hub=False,
)

In [None]:
trainer = Seq2SeqTrainer(
    model=model,
    tokenizer=tokenizer,
    args=training_args,
    train_dataset=dataset_dict['train'],
    eval_dataset=dataset_dict['valid'],
    data_collator=data_collator,
)

In [None]:
if not MANUAL_TRAINING:
    trainer.train()
else:
    total_loss_lt = []
    batch_loss_lt = []

    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
    optimizer = AdamW(model.parameters(), lr=2e-5)
    trainset = tokenized_dataset['train'].with_format("torch", device=device)
    dataloader = DataLoader(trainset, batch_size=64, shuffle=True)
    if not next(model.parameters()).is_cuda and device == torch.device('cuda'):
        model.to(device)
    
    model.train()
    for epoch in range(2):
        for batch in dataloader:
            X = {
                    'input_ids': batch['input_ids'],
                    'attention_mask': batch['attention_mask'],
                }
            y = batch['labels']
            outputs = model(**X, labels=y)
            loss = outputs.loss
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            gc.collect()
            torch.cuda.empty_cache()
            batch_loss_lt.append(loss.item())

        total_loss_lt += batch_loss_lt
        batch_loss_series = pd.Series(batch_loss_lt)
        print(f'epoch {epoch + 1} loss: {loss.item()} mean: {batch_loss_series.mean()}')
    '''
    except:
        print(
            'input_ids: ' + str(X['input_ids'].shape), 
            'attention_mask: ' + str(X['attention_mask'].shape), 
            'labels: ' + str(y.shape), 
            sep='\t'
        )
        '''

    total_loss_series = pd.Series(total_loss_lt)
    total_loss_series.plot.line()

## Validation

In [None]:
if not MANUAL_VALIDATION:
    trainer.evaluate(dataset_dict['valid'])
else:
    loss_lt = []

    model.eval()
    validset = tokenized_dataset['valid'].with_format("torch", device=device)
    dataloader = DataLoader(validset, batch_size=1, shuffle=True)
    if not next(model.parameters()).is_cuda and device == torch.device('cuda'):
        model.to(device)

    try:
        with torch.no_grad():
            for batch in dataloader:
                X = {
                        'input_ids': batch['input_ids'],
                        'attention_mask': batch['attention_mask'],
                    }
                y = batch['labels']
                outputs = model(**X, labels=y)
                loss = outputs.loss
                loss_lt.append(loss.item())
                gc.collect()
                torch.cuda.empty_cache()
    except:
        pass
    
    loss_series = pd.Series(loss_lt)
    print(f'loss: {loss_series.mean()}')

In [None]:
validset = tokenized_dataset['valid'].with_format("torch", device=device)
test_sample = validset.shuffle().select(range(1))
test_sample_gt = test_sample['labels']
test_sample = test_sample.remove_columns('labels')[0]
test_sample_input = dict()
test_sample_input['input_ids'] = test_sample['input_ids'].unsqueeze(0)
test_sample_input['attention_mask'] = test_sample['attention_mask'].unsqueeze(0)
output = model.generate(**test_sample_input)
input_text = tokenizer.decode(test_sample_input['input_ids'].squeeze(0))
output_text = tokenizer.decode(output.squeeze(0))
gt_text = tokenizer.decode(test_sample_gt.squeeze(0))

print(input_text, output_text, gt_text, sep='\n\n')

## Saving

In [None]:
# To prevent unwanted saves
raise RuntimeError

In [None]:
NOW_STR = datetime.datetime.now().strftime('%y%m%d-%H:%M')
trainer.create_model_card(
    language='Korean',
    tags='Grammar',
    model='KoGrammar',
    finetuned_from=checkpoint
)
trainer.save_model(f"./models/{NOW_STR}")