In [None]:
try:
  import google.colab
  IN_COLAB = True
except:
  IN_COLAB = False

'Process in Colab' if IN_COLAB else 'Process in Local'

In [None]:
if IN_COLAB:
    from google.colab import drive
    drive.mount('/content/drive/')

In [None]:
if IN_COLAB:
    # 프로젝트 디렉토리로 이동: 경우에 맞게 설정
    %cd drive/MyDrive/projects/ClauseSummary

In [None]:
import os
if IN_COLAB:
    !pip install transformers
    !pip install datasets
    !pip install torchtyping
    !pip install wandb
    !pip install git+https://github.com/CarperAI/trlx
    !pip install peft

In [None]:
import datetime
from typing import List
import pickle
import re
import string

import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as Fs

from tqdm.notebook import tqdm
from datasets import load_from_disk, load_dataset, Dataset, DatasetDict
from transformers import AutoTokenizer
from transformers import AutoModel, AutoModelForSeq2SeqLM
from peft import LoraConfig
from peft.utils.config import TaskType

import trlx
from trlx.trlx import train
from trlx.data.default_configs import (
    ModelConfig,
    OptimizerConfig,
    PPOConfig,
    SchedulerConfig,
    TokenizerConfig,
    TrainConfig,
    TRLConfig,
)

In [None]:
def tokenize_text_summary(text: str, summary: str, tokenizer, option=None):
    if option is None:
        option = {
            'max_length': 4096,
            'truncation': True,
        }
    max_token = option['max_length']

    if text.startswith('summarization-num_lines-4: '):
        text = text[len('summarization-num_lines-4: '):]

    tokenized_text = tokenizer(text, **option)
    tokenized_summary = tokenizer(summary, **option)

    tokenized_total_text = dict()
    for key in tokenized_text:
        if len(tokenized_text['input_ids']) + len(tokenized_summary['input_ids']) < max_token:
            tokenized_total_text[key] = tokenized_text[key] + tokenized_summary[key]
        else:
            tokenized_total_text[key] = (tokenized_text[key][:- len(tokenized_summary['input_ids'])]
                                         + tokenized_summary[key]
            )
        tokenized_total_text[key] = (tokenized_total_text[key]
                                     + ([1] * (max_token - len(tokenized_total_text[key])))
        )
    return tokenized_total_text

In [None]:
class ModelForRewardGeneration(nn.Module):
    def __init__(self, encoder_path, hidden_size=256):
        super(ModelForRewardGeneration, self).__init__()
        self.encoder = AutoModel.from_pretrained(encoder_path)
        self.hidden_size = hidden_size
        self.head1 = nn.Sequential(
            nn.Linear(768, 1024, bias=False),
            nn.BatchNorm1d(1024),
            nn.GELU(),
            nn.Dropout1d(0.2),
            nn.Linear(1024, 1024, bias=False),
            nn.BatchNorm1d(1024),
            nn.GELU(),
            nn.Dropout1d(0.2),
            nn.Linear(1024, 512, bias=False),
            nn.BatchNorm1d(512),
            nn.GELU(),
            nn.Dropout1d(0.1),
            nn.Linear(512, hidden_size, bias=False),
            nn.BatchNorm1d(hidden_size),
            nn.GELU(),
        )
        self.head2 = nn.Sequential(
            nn.Linear(hidden_size, 1),
        )

    def forward(self, input_ids=None, attention_mask=None):
        x = self.encoder(input_ids, attention_mask).pooler_output
        x = self.head1(x)
        x = self.head2(x)
        return x

    def representation_forward(self, input_ids=None, attention_mask=None):
        x = self.encoder(input_ids, attention_mask).pooler_output
        x = self.head1(x)
        return x

    def load(self, model_path):
        self.encoder = AutoModel.from_pretrained(model_path + '-encoder')
        self.head1.load_state_dict(torch.load(model_path + '-head1.pt'))
        self.head2.load_state_dict(torch.load(model_path + '-head2.pt'))

In [None]:
SAVE_STR = datetime.datetime.now().strftime('%y-%m-%d-%H:%M')

### Config

In [None]:
preprocessed_dataset_existance = os.path.exists('./data/dataset-term-preprocessed.json')

original_dataset_path = './data/dataset-term-preprocessed.json' if preprocessed_dataset_existance else './data/dataset-term.json'
checkpoint = 'KETI-AIR-Downstream/long-ke-t5-base-summarization'

reward_model_checkpoint = 'psyche/kolongformer-4096'
reward_model_path = './model/230707-03:06'

dataset_path = f'./data/dataset-term'
model_save_path = f'./model/{SAVE_STR}-summary-model'

In [None]:
MAX_EPOCH = 5
TOTAL_STEPS = 100000
MAX_SEQ_LEN = 4096
LR = 2e-4

### Loading Dataset, Tokenizers & Models

In [None]:
tokenizer = AutoTokenizer.from_pretrained(checkpoint)

reward_tokenizer = AutoTokenizer.from_pretrained(reward_model_checkpoint)
reward_model = ModelForRewardGeneration(reward_model_checkpoint, 128)
reward_model.load(reward_model_path)

In [None]:
# Load dataset
df = pd.read_json(original_dataset_path)
df['text'] = 'summarization-num_lines-4: ' + df['text'] + '\nSummary: '
df = df[['text']]
if not os.path.exists(dataset_path):
    dataset = Dataset.from_pandas(df)
    dataset.save_to_disk(dataset_path)
else:
    dataset = load_from_disk(dataset_path)

dataset_dict = dataset.train_test_split(test_size=0.1, seed=42)

### PPO

In [None]:
config = TRLConfig(
    train=TrainConfig(
        seq_length=MAX_SEQ_LEN,
        epochs=MAX_EPOCH,
        total_steps=TOTAL_STEPS,
        batch_size=4,
        checkpoint_interval=1000,
        eval_interval=100,
        pipeline="PromptPipeline",
        trainer="AcceleratePPOTrainer",
        save_best=True,
    ),
    model=ModelConfig(
        model_path=checkpoint,
        num_layers_unfrozen=-1,
        model_arch_type="seq2seq",
    ),
    tokenizer=TokenizerConfig(
        tokenizer_path=checkpoint,
        padding_side="right",
        truncation_side="right",
    ),
    optimizer=OptimizerConfig(
        name="adamw",
        kwargs={
            "lr": LR,
            "betas": [0.9, 0.999],
            "eps": 1.0e-8,
            "weight_decay": 1.0e-4,
        },
    ),
    scheduler=SchedulerConfig(
        name="linear",
        kwargs={

        },
    ),
    method=PPOConfig(
        name="PPOConfig",
        num_rollouts=128,
        chunk_size=8,
        ppo_epochs=MAX_EPOCH,
        init_kl_coef=0.05,
        target=6,
        horizon=10000,
        gamma=0.99,
        lam=0.95,
        cliprange=0.2,
        cliprange_value=0.2,
        vf_coef=1,
        scale_reward=None,
        ref_mean=None,
        ref_std=None,
        cliprange_reward=10,
        gen_kwargs={
            "max_new_tokens": 400,
            "do_sample": True,
            "top_k": 0,
            "top_p": 0.9,
            "eos_token_id": tokenizer.eos_token_id,
        },
    ),
)

config.model.peft_config = LoraConfig(
    r=8,
    task_type=TaskType.SEQ_2_SEQ_LM,
    lora_alpha=32,
    lora_dropout=0.1,
    target_modules=['q', 'v']
)

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
reward_model = reward_model.to(device)
reward_model.eval()

def get_reward(samples: List[str], **kwargs):
    reward_lt = []
    for sample in samples:
        sample = {
            'prompt': sample[:sample.find('Summary: ')],
            'output': sample[sample.find('Summary: ') + len('Summary: '):]
        }
        tokenized_total_text = tokenize_text_summary(sample['prompt'], sample['output'], reward_tokenizer)
        with torch.no_grad():
            score = reward_model(
                input_ids=torch.tensor(tokenized_total_text['input_ids']).repeat(2, 1).to(device),
                attention_mask=torch.tensor(tokenized_total_text['attention_mask']).repeat(2, 1).to(device)
            )
        reward_lt.append(score[0] * 10)

    rewards = torch.cat(reward_lt, dim=0)
    return rewards

def get_prompt_dataset(prompts, max_length=MAX_SEQ_LEN):
    """
    Get the prompt after T5 decoding to make sure dictionary
    of prompts and summaries is consistent decode prompt from trlX pipeline
    """
    formatted_prompts = []
    for i in tqdm(range(len(prompts))):
        tmp = tokenizer.decode(
            tokenizer(
                prompts[i].split("Summary: ")[0],
                truncation=True,
                max_length=max_length - 5,  # to make sure "TL;DR" dont get truncated
                add_special_tokens=False,
            )["input_ids"],
            skip_special_tokens=True,
        ).strip()
        tmp = tmp + "\nSummary: "
        tmp = tokenizer.decode(
            tokenizer(tmp, truncation=True, max_length=max_length, add_special_tokens=False)["input_ids"],
            skip_special_tokens=True,
        ).strip()
        formatted_prompts.append(tmp)
    return formatted_prompts

In [None]:
def remove_unwanted_char(row):
    text = row['text']
    pattern = re.compile(r'[^가-힣a-zA-Z0-9,.!?\r\t\n\f' + string.punctuation + ']')
    return {'text': pattern.sub('', text)}

In [None]:
if not preprocessed_dataset_existance:
    dataset_dict = dataset_dict.map(remove_unwanted_char)

In [None]:
if not preprocessed_dataset_existance:
    if not os.path.exists('data/pseudo_train_set.pkl'):
        train_set = [sample["text"] for sample in dataset_dict["train"]]
        train_set = get_prompt_dataset(train_set)
        with open('data/pseudo_train_set.pkl', 'wb') as f:
            pickle.dump(train_set, f)
    else:
        with open('data/pseudo_train_set.pkl', 'rb') as f:
            train_set = pickle.load(f)

    if not os.path.exists('data/pseudo_val_set.pkl'):
        val_set = [sample["text"] for sample in dataset_dict["test"]]
        val_set = get_prompt_dataset(val_set)
        with open('data/pseudo_val_set.pkl', 'wb') as f:
            pickle.dump(val_set, f)
    else:
        with open('data/pseudo_val_set.pkl', 'rb') as f:
            val_set = pickle.load(f)
else:
    if not os.path.exists('data/train_set.pkl'):
        train_set = [sample["text"] for sample in dataset_dict["train"]]
        train_set = get_prompt_dataset(train_set)
        with open('data/train_set.pkl', 'wb') as f:
            pickle.dump(train_set, f)
    else:
        with open('data/train_set.pkl', 'rb') as f:
            train_set = pickle.load(f)

    if not os.path.exists('data/val_set.pkl'):
        val_set = [sample["text"] for sample in dataset_dict["test"]]
        val_set = get_prompt_dataset(val_set)
        with open('data/val_set.pkl', 'wb') as f:
            pickle.dump(val_set, f)
    else:
        with open('data/val_set.pkl', 'rb') as f:
            val_set = pickle.load(f)

In [None]:
trainer = train(
    prompts=train_set,
    eval_prompts=val_set,
    reward_fn=get_reward,
    config=config,
)

In [None]:
trainer.save_pretrained(model_save_path)