In [7]:
import datetime

import numpy as np
import pandas as pd

import torch.nn as nn
import torch.nn.functional as F

from datasets import load_dataset, DatasetDict
from transformers import DataCollatorForSeq2Seq
from transformers import BartConfig, T5Config
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from transformers import Seq2SeqTrainingArguments, Seq2SeqTrainer

import matplotlib as mpl
import matplotlib.pyplot as plt
%matplotlib inline

import seaborn as sns

In [None]:
dataset = load_dataset(
    'json', 
    data_files='./data/grammar-correction.json', 
    field=['original_form', '"corrected_form"'],
    encoding='utf-8'
)

len(dataset)

In [None]:
train_testvalid = dataset.train_test_split(test=0.1)
test_valid = train_testvalid['test'].train_test_split(test=0.5)
dataset_dict = DatasetDict({
    'train': train_testvalid['train'],
    'valid': test_valid['train'],
    'test': test_valid['test'],
    })

In [8]:
kobart_checkpoint = 'gogamza/kobart-base-v2'
kot5_checkpoint = 'psyche/KoT5'
checkpoint = kobart_checkpoint

In [None]:
if checkpoint == kobart_checkpoint:
    config = BartConfig(
        
    )
else:
    config = T5Config(

    )

In [9]:
tokenizer = AutoTokenizer.from_pretrained(checkpoint, max_length=512, trunicated=True)
model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint)

You passed along `num_labels=3` with an incompatible id to label map: {'0': 'NEGATIVE', '1': 'POSITIVE'}. The number of labels wil be overwritten to 2.
You passed along `num_labels=3` with an incompatible id to label map: {'0': 'NEGATIVE', '1': 'POSITIVE'}. The number of labels wil be overwritten to 2.
You passed along `num_labels=3` with an incompatible id to label map: {'0': 'NEGATIVE', '1': 'POSITIVE'}. The number of labels wil be overwritten to 2.


In [None]:
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)

In [None]:
training_args = Seq2SeqTrainingArguments(
    output_dir="./results",
    evaluation_strategy = "epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    num_train_epochs=1,
    weight_decay=0.01,
    push_to_hub=False,
)

In [None]:
trainer = Seq2SeqTrainer(
    model=model,
    tokenizer=tokenizer,
    args=training_args,
    train_dataset=dataset_dict['train'],
    eval_dataset=dataset_dict['valid']
    data_collator=data_collator,
)

In [None]:
trainer.train()

In [None]:
trainer.evaluate(dataset_dict['valid'])

In [None]:
# To prevent unwanted saves
raise RuntimeError

In [None]:
NOW_STR = datetime.datetime.now().strftime('%y%m%d-%H:%M')
trainer.create_model_card(
    language='Korean',
    tags='Grammar',
    model='KoGrammar',
    finetuned_from=checkpoint
)
trainer.save_model(f"./models/{NOW_STR}")