In [None]:
try:
  import google.colab
  IN_COLAB = True
except:
  IN_COLAB = False

'Process in Colab' if IN_COLAB else 'Process in Local'

In [None]:
if IN_COLAB:
    !pip install transformers
    !pip install datasets
    !pip install --upgrade accelerate
    !pip install pfrl@git+https://github.com/voidful/pfrl.git
    !pip install textrl


In [None]:
if IN_COLAB:
    from google.colab import drive
    drive.mount('/content/drive/')

In [None]:
# 깃허브에서는 빼야됨
%cd drive/MyDrive/projects/ClauseSummary

In [None]:
from typing import Dict
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
import datetime
import re
import sys
import os
import gc
import logging
from pprint import pprint
from tqdm.notebook import tqdm
logging.basicConfig(level=logging.INFO, stream=sys.stdout, format='%(asctime)s %(message)s', datefmt='%m-%d %H:%M')

import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence
import torch.optim as optim
from torch.optim import AdamW, SGD
from torch.nn import MSELoss
from torch.utils.data import DataLoader

from datasets import load_from_disk, load_dataset, Dataset, DatasetDict
from transformers import AutoTokenizer, LongformerTokenizer, AutoModel, AutoModelForSeq2SeqLM
from transformers import get_linear_schedule_with_warmup
from transformers import Trainer, TrainingArguments

import pfrl
from textrl import TextRLEnv, TextRLActor, train_agent_with_evaluation
import logging


In [None]:
def add_newline_before_number(text:str) -> str: #숫자. 형태로 되어있는것에 개행문자를 추가.
    text = re.sub(r'(\d+)\.',r'\n\1.', str(text))
    return text

def change_it_to_a_comma(text:str): # (1) (2) 형태를 ,으로
    items = re.split(r'\(\d+\)', text)
    if len(items) > 1:
        items[1] = items[0]+items[1]
        del items[0]
    return ','.join(items)

def remove_whitespace_after_str(text:str):
    text = re.sub(r"\b갑\s", r'갑', text)
    text = re.sub(r"\b을\s", r'을', text)
    text = re.sub(r"\b병\s", r'병', text)
    text = re.sub(r"\b정\s", r'정', text)
    return text

def change_number_point(text:str): # 1. 2. 등을 제 1조 2 항 등으로 바꿔줌
    items = re.split(r'\d+\.', text)
    if len(items) > 1:
        items[1] = items[0]+items[1]
        del items[0]
    return ''.join(items)

def summary_preprocessing_func(text: str):
    text = add_newline_before_number(text)
    text = change_it_to_a_comma(text)
    text = remove_whitespace_after_str(text)
    text = change_number_point(text)
    return text

def text_preprocessing_func(text):
    return re.sub(r'\n[\n ]+', '\n', text)

In [None]:
def preprocessing(row: Dict[str, str]):
    text = row['text']
    summary = row['summary']
    text = text_preprocessing_func(text)
    summary = summary_preprocessing_func(summary)

    return {'text': text, 
            'summary': summary
            }

def df_preprocessing(df: pd.DataFrame):
    text_df = df[['text', 'summary']]
    text_df = text_df.apply(preprocessing, axis=1, result_type='expand')

    df[['text', 'summary']] = text_df[['text', 'summary']]

    return df

In [None]:
class TokenizeMapWrapper:
    def __init__(self, tokenizer, feature, option=None):
        if option is None:
            option = {
                'max_length': 4096,
                'truncation': True,
                'padding': 'max_length',
            }
        
        self.feature = feature
        self.tokenizer = tokenizer

    def __call__(self, row):
        return self.tokenizer(row[self.feature], **self.option)

    def __repr__(self):
        return f'{self.__class__.__name__}(tokenizer={self.tokenizer})'

class RewardTokenizeMapWrapper(TokenizeMapWrapper):
    def __init__(self, tokenizer, feature, max_token=4096, option=None):
        if option is None:
            option = {
                'max_length': max_token,
                'truncation': True,
            }

        self.max_token = option['max_new_tokens']
        self.option = option
        self.feature = feature
        self.tokenizer = tokenizer

    def __call__(self, row):
        total_text = row[self.feature]
        if len(re.findall('\nSummary: \n', total_text)) == 1:
            text, summary = total_text.split('Summary: \n')
            summary = '\nSummary: \n' + summary
        else:
            print('warning: more than two summary exists')
            text_split = total_text.split('Summary: \n')
            text = text_split[0]
            summary = '\nSummary: \n'.join(text_split[1:])
        
        tokenized_text = self.tokenizer(text, **self.option)
        tokenized_summary = self.tokenizer(summary, **self.option)
        tokenized_total_text = dict()
        if len(tokenized_text['input_ids']) + len(tokenized_summary['input_ids']) <= self.max_token:
            for key in tokenized_text:
                tokenized_total_text[key] = tokenized_text[key] + tokenized_summary[key]
                if len(tokenized_total_text[key]) < self.max_token:
                    tokenized_total_text[key] = (tokenized_total_text[key] 
                                                 + [1] * (self.max_token - len(tokenized_total_text[key]))
                    )
        else:
            for key in tokenized_text:
                tokenized_total_text[key] = (tokenized_text[key][:- len(tokenized_summary['input_ids'])] 
                                             + tokenized_summary[key]
                )

        return tokenized_total_text

In [None]:
class ModelForRewardGeneration(nn.Module):
    def __init__(self, encoder_path, hidden_size=256):
        super(ModelForRewardGeneration, self).__init__()
        self.encoder = AutoModel.from_pretrained(encoder_path)
        self.hidden_size = hidden_size
        # TODO: head 설계
        self.head = nn.Sequential(
            nn.Linear(768, hidden_size, bias=False),
            nn.BatchNorm1d(hidden_size),
            nn.GELU(),
            nn.Dropout1d(0.1),
            nn.Linear(hidden_size, 1),
        )

    def forward(self, input_ids=None, attention_mask=None):
        x = self.encoder(input_ids, attention_mask).pooler_output
        x = self.head(x)
        return x

def reference_reward_loss(reward, pred):
    return - torch.log10(1 + torch.exp(-reward * pred))

## Setting

In [None]:
MANUAL_TRAINING = True
MANUAL_VALIDATION = True
MID_CHECKPOINT_NUM = 2
MID_PROCESS_PRINT_NUM = 35

NUM_EPOCHS = 1
MAX_TOKEN = 4096
learning_rate = 2e-5
decay = 0.01

In [None]:
reward_model_checkpoint = 'psyche/kolongformer-4096'
reward_model_path = './model/230705-10 59'
summary_model_checkpoint = 'KETI-AIR-Downstream/long-ke-t5-base-summarization'

print(f'reward_model_checkpoint: {reward_model_checkpoint}\nsummary_model_checkpoint: {summary_model_checkpoint}')

In [None]:
original_dataset_path = './data/dataset-term.json'
tokenized_dataset_path = f'./data/{summary_model_checkpoint.replace("/", "-")}-tokenized-dataset'

In [None]:
SAVE_STR = datetime.datetime.now().strftime('%y%m%d-%H:%M')
model_save_path = f"./model/{SAVE_STR}"

## Load Tokenizer & Model checkpoint

In [None]:
summary_tokenizer = AutoTokenizer.from_pretrained(summary_model_checkpoint)
reward_tokenizer = LongformerTokenizer.from_pretrained(reward_model_checkpoint)

summary_model = AutoModelForSeq2SeqLM.from_pretrained(summary_model_checkpoint)
reward_model = ModelForRewardGeneration(reward_model_checkpoint)
reward_model.encoder = AutoModel.from_pretrained(reward_model_path + '-encoder-final')
reward_model.head.load_state_dict(torch.load(reward_model_path + '-head-final.pt'))

## Loading Dataset

In [None]:
df = pd.read_json(original_dataset_path, encoding='utf-8')

if not os.path.exists(tokenized_dataset_path):
    dataset = Dataset.from_pandas(df[['text']])
    
    tokenizer_wrapper = TokenizeMapWrapper(summary_tokenizer, 'text')

    tokenized_dataset = (dataset
                         .map(tokenizer_wrapper,
                              batched=True,
                              batch_size=128,
                              num_proc=10
                              )
                         .remove_columns(['text'])
                         )

    tokenized_dataset_dict = tokenized_dataset.train_test_split(test_size=0.1, shuffle=True)
    tokenized_dataset_dict.save_to_disk(tokenized_dataset_path)
else:
    tokenized_dataset_dict = load_from_disk(tokenized_dataset_path)

df['input'] = df['text']
df = df.drop(columns=['text'], axis=1)

## Training

In [None]:
## finding the best parameters
gc.collect()
torch.cuda.empty_cache()

total_loss = []
epoch_loss = []
batch_loss = []

summary_model.train()
reward_model.eval()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
trainset = tokenized_dataset_dict['train'].with_format('torch', device=device)
testset = tokenized_dataset_dict['test'].with_format('torch', device=device)
dataloader = DataLoader(trainset, batch_size=12, shuffle=False) # TODO: Batch size 조절

# TODO: Minor Hyperparameter Tuning
criterion = MSELoss()
optimizer = AdamW(summary_model.parameters(), lr=learning_rate, weight_decay=decay)
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=100, num_training_steps=NUM_EPOCHS * len(dataloader))
training_stats = []

In [None]:
option = {
    'max_new_tokens': MAX_TOKEN,
    'truncation': True,
}

In [None]:
class SummaryRLEnv(TextRLEnv):
    def get_reward(self, input_item, predicted_list, finish):
        reward = [0]
        if finish:
            tokenized_text = reward_tokenizer(input_item['text'], **option)
            tokenized_summary = reward_tokenizer(predicted_list[0], **option)

            tokenized_total_text = dict()
            for key in tokenized_text:
                if len(tokenized_text['input_ids']) + len(tokenized_summary['input_ids']) < self.max_token:
                    tokenized_total_text[key] = tokenized_text[key] + tokenized_summary[key]
                else:
                    tokenized_total_text[key] = (tokenized_text[key][:- len(tokenized_summary['input_ids'])]
                                                 + tokenized_summary[key]
                    )
                tokenized_total_text[key] = (tokenized_total_text[key] 
                                             + ([1] * (self.max_token - len(tokenized_total_text[key])))
                )
            reward = [float(reward_model(**tokenized_total_text).squeeze()) * 10]
        return reward

In [None]:
env = SummaryRLEnv(summary_model, summary_tokenizer, observation_input=df.to_dict(), compare_sample=len(dataset))
actor = TextRLActor(env, summary_model, summary_tokenizer)
agent = actor.agent_ppo(update_interval=100, minibatch_size=3, epochs=3)

In [None]:
train_agent_with_evaluation(
    agent=agent,
    env=env,
    steps=300,
    eval_n_steps=None,
    eval_n_episodes=1,
    train_max_episode_len=300,
    eval_interval=10,
    outdir=model_save_path,
)