In [None]:
try:
  import google.colab
  IN_COLAB = True
except:
  IN_COLAB = False

'Process in Colab' if IN_COLAB else 'Process in Local'

In [None]:
if IN_COLAB:
    !pip install transformers
    !pip install datasets
    !pip install --upgrade accelerate

In [None]:
if IN_COLAB:
    from google.colab import drive
    drive.mount('/content/drive/')

In [None]:
# 깃허브에서는 빼야됨
%cd drive/MyDrive/projects/ClauseSummary

In [None]:
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
import datetime
import re
import os
import gc
from pprint import pprint
from tqdm.notebook import tqdm

import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence
import torch.optim as optim
from torch.optim import AdamW, SGD
from torch.nn import MSELoss
from torch.utils.data import DataLoader

from datasets import load_from_disk, load_dataset, Dataset, DatasetDict
from transformers import AutoTokenizer, LongformerTokenizer, AutoModel, AutoModelForMaskedLM
from transformers import get_linear_schedule_with_warmup
from transformers import Trainer, TrainingArguments

In [None]:
class TokenizeMapWrapper:
    def __init__(self, tokenizer, feature, option=None):
        if option is None:
            option = {
                'max_length': 512,
                'truncation': True,
                'padding': 'max_length',
            }
        
        self.feature = feature
        self.tokenizer = tokenizer

    def __call__(self, row):
        return self.tokenizer(row[self.feature], **self.option)

    def __repr__(self):
        return f'{self.__class__.__name__}(tokenizer={self.tokenizer})'

class TokenizeMapWrapper(TokenizeMapWrapper):
    def __init__(self, tokenizer, feature, max_token=4096, option=None):
        if option is None:
            option = {
                'max_length': max_token,
                'truncation': True,
            }

        self.max_token = option['max_new_tokens']
        self.option = option
        self.feature = feature
        self.tokenizer = tokenizer

    def __call__(self, row):
        total_text = row[self.feature]
        if len(re.findall('\nSummary: \n', total_text)) == 1:
            text, summary = total_text.split('Summary: \n')
            summary = '\nSummary: \n' + summary
        else:
            print('warning: more than two summary exists')
            text_split = total_text.split('Summary: \n')
            text = text_split[0]
            summary = '\nSummary: \n'.join(text_split[1:])
        
        tokenized_text = self.tokenizer(text, **self.option)
        tokenized_summary = self.tokenizer(summary, **self.option)
        tokenized_total_text = dict()
        if len(tokenized_text['input_ids']) + len(tokenized_summary['input_ids']) <= self.max_token:
            for key in tokenized_text:
                tokenized_total_text[key] = tokenized_text[key] + tokenized_summary[key]
                if len(tokenized_total_text[key]) < self.max_token:
                    tokenized_total_text[key] = (tokenized_total_text[key] 
                                                 + [1] * (self.max_token - len(tokenized_total_text[key]))
                    )
        else:
            for key in tokenized_text:
                tokenized_total_text[key] = (tokenized_text[key][:- len(tokenized_summary['input_ids'])] 
                                             + tokenized_summary[key]
                )

        return tokenized_total_text

In [None]:
class ModelForRewardGeneration(nn.Module):
    def __init__(self, encoder, hidden_size):
        super(ModelForRewardGeneration, self).__init__()
        self.encoder = encoder
        self.hidden_size = hidden_size
        self.head = nn.Sequential(
            nn.Linear(768, hidden_size, bias=False),
            nn.BatchNorm1d(hidden_size),
            nn.GELU(),
            nn.Dropout1d(0.1),
            nn.Linear(hidden_size, 1),
        )

    def forward(self, input_ids=None, attention_mask=None):
        x = self.encoder(input_ids, attention_mask)
        x = x['last_hidden_state'][:, 0, :]
        x = self.head(x)
        return x

def reference_reward_loss(reward, pred):
    return - torch.log10(1 + torch.exp(-reward * pred))

## Setting

In [None]:
MANUAL_TRAINING = True
MANUAL_VALIDATION = True
MID_CHECKPOINT_NUM = 2
MID_PROCESS_PRINT_NUM = 35

NUM_EPOCHS = 1
MAX_TOKEN = 4096
learning_rate = 2e-5
decay = 0.01

In [None]:
kolongformer_checkpoint = "psyche/kolongformer-4096"
checkpoint = kolongformer_checkpoint
print(f'Using Checkpoint: {checkpoint}')

In [None]:
original_dataset_path = './data/dataset-term-reward.json'
tokenized_dataset_path = f'./data/{checkpoint.replace("/", "-")}-tokenized-dataset'

In [None]:
SAVE_STR = datetime.datetime.now().strftime('%y%m%d-%H:%M')
model_save_path = f"./model/{SAVE_STR}"

## Load Tokenizer & Model Checkpoint

In [None]:
tokenizer = LongformerTokenizer.from_pretrained(checkpoint)
#tokenizer = LongformerTokenizer(vocab_file, merges_file, errors='replace', bos_token='<s>', eos_token='</s>', sep_token='</s>', cls_token='<s>', unk_token='<unk>', pad_token='<pad>', mask_token='<mask>', add_prefix_space=False, **kwargs)

encoder_model = AutoModel.from_pretrained(checkpoint)
model = ModelForRewardGeneration(encoder_model, 256)

In [None]:
if len(tokenizer) != model.config.vocab_size:
    raise RuntimeError(f'Tokenizer vocab size and model vocab size do not match(Tokenizer:{len(tokenizer)} Model: {model.config.vocab_size}). Which would lead to further error in training.')

## Loading Dataset

In [None]:
if not os.path.exists(tokenized_dataset_path):
    df = pd.read_json(original_dataset_path, encoding='utf-8')
    df['total_text'] = (pd.Series(['Text: \n'] * len(df)) 
                        + df['text'] 
                        + pd.Series(['\nSummary: \n'] * len(df)) 
                        + df['summary']
    )
    dataset = Dataset.from_pandas(df[['total_text', 'reward']])
    tokenizer_wrapper = TokenizeMapWrapper(tokenizer, 'total_text')

    tokenized_dataset = (dataset
                         .map(tokenizer_wrapper,
                              batched=True,
                              batch_size=128,
                              num_proc=10
                              )
                         .remove_columns(['total_text'])
                         )

    tokenized_dataset_dict = tokenized_dataset.train_test_split(test_size=0.1, shuffle=True)
    tokenized_dataset_dict.save_to_disk(tokenized_dataset_path)
else:
    tokenized_dataset_dict = load_from_disk(tokenized_dataset_path)

## Training

In [None]:
print(checkpoint)

In [None]:
training_args = TrainingArguments(
    output_dir="./results",
    evaluation_strategy = "epoch",
    learning_rate=learning_rate,
    per_device_train_batch_size=64,
    per_device_eval_batch_size=64,
    num_train_epochs=NUM_EPOCHS,
    weight_decay=decay,
    report_to="tensorboard",
    push_to_hub=False,
)

trainer = Trainer(
    model=model,
    tokenizer=tokenizer,
    args=training_args
)

In [None]:
## finding the best parameters
gc.collect()
torch.cuda.empty_cache()

total_loss = []
epoch_loss = []
batch_loss = []

model.train()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
trainset = tokenized_dataset_dict['train'].with_format('torch', device=device)
testset = tokenized_dataset_dict['test'].with_format('torch', device=device)
dataloader = DataLoader(trainset, batch_size=12, shuffle=False) # TODO: Batch size 조절

# TODO: Minor Hyperparameter Tuning
criterion = MSELoss()
optimizer = AdamW(model.parameters(), lr=learning_rate, weight_decay=decay)
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=300, num_training_steps=NUM_EPOCHS * len(dataloader))
training_stats = []

In [None]:
for epoch in range(NUM_EPOCHS):
    total_steps = len(dataloader)
    save_divisor = total_steps // MID_CHECKPOINT_NUM
    print_divisor = total_steps // MID_PROCESS_PRINT_NUM
    with tqdm(dataloader, leave=False, desc='Batch', position=0, postfix={'Epoch': 1, 'Batch': 1, 'loss': 0, 'loss_mean': 0}) as tqdm_bar:
        for i, batch in enumerate(tqdm_bar):
            tqdm_bar.set_description(f'Batch: {i + 1}')
            X = {
                    'input_ids': batch['input_ids'],
                    'attention_mask': batch['attention_mask'],
                }
            y = batch['reward']

            outputs = model(**X)
            loss = criterion(outputs, y)
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            scheduler.step()
            batch_loss.append(loss.item())

            if i % print_divisor == print_divisor - 1:
                epoch_loss += batch_loss
                batch_loss_series = pd.Series(batch_loss, dtype=np.float64)
                training_stats.append(
                    {
                        'Epoch': epoch + 1,
                        'Batch': i + 1,
                        'loss': loss.item(),
                        'loss_mean': batch_loss_series.mean()
                    }
                )
                tqdm_bar.set_postfix(training_stats[-1])
                batch_loss = []

            if i % save_divisor == save_divisor - 1:
                trainer.create_model_card(
                    language='Korean',
                    finetuned_from=checkpoint
                )
                trainer.save_model(model_save_path + checkpoint + f'-epoch-{epoch + 1}' + f'-batch-{i + 1}')

            total_loss += epoch_loss
            batch_loss_series = pd.Series(epoch_loss, dtype=np.float64)
            epoch_loss = []

In [None]:
training_stats_df = pd.DataFrame(training_stats)
training_stats_df.to_csv('./training_stats.csv')

trainer.create_model_card(
    language='Korean',
    finetuned_from=checkpoint
)
trainer.save_model(model_save_path + '-final')

## Analysis

In [None]:
training_stats_df.describe()

In [None]:
total_loss = pd.Series(total_loss)
total_loss.plot.line()

In [None]:
training_stats_df['loss_mean'].plot.line()

In [None]:
training_stats_df[['BLEU', 'ROUGE']].plot.line()