In [None]:
try:
  import google.colab
  IN_COLAB = True
except:
  IN_COLAB = False

'Process in Colab' if IN_COLAB else 'Process in Local'

In [None]:
if IN_COLAB:
    from google.colab import drive
    drive.mount('/content/drive/')

In [None]:
%cd drive/MyDrive/projects/ClauseSummary

In [None]:
import os

In [None]:
if IN_COLAB:
    !pip install transformers
    !pip install datasets
    !pip install --upgrade accelerate
    if not os.path.exists('RL4LMs'):
        !git clone https://github.com/allenai/RL4LMs.git
    %cd RL4LMs
    !pip install -e .

In [None]:
MAX_TOKEN = 4096

In [None]:
from typing import Dict, Any
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
import datetime
import re
import sys
import os
import gc
import logging
from pprint import pprint
from tqdm.notebook import tqdm
logging.basicConfig(level=logging.INFO, stream=sys.stdout, format='%(asctime)s %(message)s', datefmt='%m-%d %H:%M')

import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence
import torch.optim as optim
from torch.optim import AdamW, SGD
from torch.nn import MSELoss
from torch.utils.data import DataLoader

from datasets import load_from_disk, load_dataset, Dataset, DatasetDict
from transformers import AutoTokenizer, LongformerTokenizer, AutoModel, AutoModelForSeq2SeqLM

from rl4lms.envs.text_generation.observation import Observation
from rl4lms.envs.text_generation.reward import RewardFunction

In [None]:
def add_newline_before_number(text:str) -> str: #숫자. 형태로 되어있는것에 개행문자를 추가.
    text = re.sub(r'(\d+)\.',r'\n\1.', str(text))
    return text

def change_it_to_a_comma(text:str): # (1) (2) 형태를 ,으로
    items = re.split(r'\(\d+\)', text)
    if len(items) > 1:
        items[1] = items[0]+items[1]
        del items[0]
    return ','.join(items)

def remove_whitespace_after_str(text:str):
    text = re.sub(r"\b갑\s", r'갑', text)
    text = re.sub(r"\b을\s", r'을', text)
    text = re.sub(r"\b병\s", r'병', text)
    text = re.sub(r"\b정\s", r'정', text)
    return text

def change_number_point(text:str): # 1. 2. 등을 제 1조 2 항 등으로 바꿔줌
    items = re.split(r'\d+\.', text)
    if len(items) > 1:
        items[1] = items[0]+items[1]
        del items[0]
    return ''.join(items)

def summary_preprocessing_func(text: str):
    text = add_newline_before_number(text)
    text = change_it_to_a_comma(text)
    text = remove_whitespace_after_str(text)
    text = change_number_point(text)
    return text

def text_preprocessing_func(text):
    return re.sub(r'\n[\n ]+', '\n', text)

In [None]:
def preprocessing(row: Dict[str, str]):
    text = row['text']
    summary = row['summary']
    text = text_preprocessing_func(text)
    summary = summary_preprocessing_func(summary)

    return {'text': text, 
            'summary': summary
            }

def df_preprocessing(df: pd.DataFrame):
    text_df = df[['text', 'summary']]
    text_df = text_df.apply(preprocessing, axis=1, result_type='expand')

    df[['text', 'summary']] = text_df[['text', 'summary']]

    return df

In [None]:
class ModelForRewardGeneration(nn.Module):
    def __init__(self, encoder_path, hidden_size=256):
        super(ModelForRewardGeneration, self).__init__()
        self.encoder = AutoModel.from_pretrained(encoder_path)
        self.hidden_size = hidden_size
        # TODO: head 설계
        self.head = nn.Sequential(
            nn.Linear(768, hidden_size, bias=False),
            nn.BatchNorm1d(hidden_size),
            nn.GELU(),
            nn.Dropout1d(0.1),
            nn.Linear(hidden_size, 1),
        )

    def forward(self, input_ids=None, attention_mask=None):
        x = self.encoder(input_ids, attention_mask).pooler_output
        x = self.head(x)
        return x

def reference_reward_loss(reward, pred):
    return - torch.log10(1 + torch.exp(-reward * pred))

In [None]:
class SummaryRewardFunction(RewardFunction):
    def __init__(self, reward_tokenizer, reward_model, option=None, *args) -> None:
        super().__init__()
        if option is None:
            option = {
                'max_new_tokens': MAX_TOKEN,
                'truncation': True,
            }
        self.option = option
        self.reward_tokenizer = reward_tokenizer
        self.reward_model = reward_model

   def __call__(self, prev_observation: Observation,
                action: int,
                current_observation: Observation,
                done: bool,
                meta_info: Dict[str, Any] = None) -> float:
       if done:
           tokenized_text = self.reward_tokenizer(input_item['text'], **self.option)
            tokenized_summary = self.reward_tokenizer(predicted_list[0], **self.option)

            tokenized_total_text = dict()
            for key in tokenized_text:
                if len(tokenized_text['input_ids']) + len(tokenized_summary['input_ids']) < MAX_TOKEN:
                    tokenized_total_text[key] = tokenized_text[key] + tokenized_summary[key]
                else:
                    tokenized_total_text[key] = (tokenized_text[key][:- len(tokenized_summary['input_ids'])]
                                                 + tokenized_summary[key]
                    )
                tokenized_total_text[key] = (tokenized_total_text[key] 
                                             + ([1] * (MAX_TOKEN - len(tokenized_total_text[key])))
                )
            reward = [float(self.reward_model(**tokenized_total_text).squeeze()) * 10]
           
           return reward
       return 0

## Setting & Config

In [None]:
yaml_str = '''
datapool:
  id: cnn_daily_mail
  args:
    prompt_prefix: "summarization-num_lines-4: "
tokenizer:
  model_name: KETI-AIR-Downstream/long-ke-t5-base-summarization
  padding_side: left
  truncation_side: left
  pad_token_as_eos_token: False
  max_new_tokens: 4096
  truncation: True
env:
  n_envs: 10
  args:
    max_prompt_length: 4096
    max_episode_length: 450
    terminate_on_eos: True
    prompt_truncation_side: "right"
    context_start_token: 0
alg:
  id: ppo
  args: 
    n_steps: 512
    batch_size: 64
    verbose: 1
    learning_rate: 2e-5
    n_epochs: 5
    ent_coef: 0.0
  kl_div:
    coeff: 0.001
    target_kl: 0.2
  policy:
    id: seq2seq_lm_actor_critic_policy
    args:
      model_name: t5-base
      apply_model_parallel: True
      prompt_truncation_side: "right"
      generation_kwargs:
        do_sample: True
        top_k: 50
        min_length: 50
        max_new_tokens: 100
train_evaluation:
  eval_batch_size: 100
  n_iters: 100
  eval_every: 10
  generation_kwargs: 
    do_sample: True
    top_k: 0
    temperature: 0.7
    min_length: 50
    max_new_tokens: 100
'''

In [None]:
reward_model_checkpoint = 'psyche/kolongformer-4096'
reward_model_path = './model/230705-10 59'
summary_model_checkpoint = 'KETI-AIR-Downstream/long-ke-t5-base-summarization'

print(f'reward_model_checkpoint: {reward_model_checkpoint}\nsummary_model_checkpoint: {summary_model_checkpoint}')

In [None]:
original_dataset_path = './data/dataset-term.json'
tokenized_dataset_path = f'./data/{summary_model_checkpoint.replace("/", "-")}-tokenized-dataset'

In [None]:
SAVE_STR = datetime.datetime.now().strftime('%y%m%d-%H:%M')
model_save_path = f"./model/{SAVE_STR}"

## Load tokenizer & Model checkpoint

In [None]:
summary_tokenizer = AutoTokenizer.from_pretrained(summary_model_checkpoint)
reward_tokenizer = LongformerTokenizer.from_pretrained(reward_model_checkpoint)

summary_model = AutoModelForSeq2SeqLM.from_pretrained(summary_model_checkpoint)
reward_model = ModelForRewardGeneration(reward_model_checkpoint)
reward_model.encoder = AutoModel.from_pretrained(reward_model_path + '-encoder-final')
reward_model.head.load_state_dict(torch.load(reward_model_path + '-head-final.pt'))

## Load Dataset

In [None]:
df = pd.read_json(original_dataset_path, encoding='utf-8')

df['input'] = df['text']
df = df.drop(columns=['text'], axis=1)

## Training

In [None]:
gc.collect()
torch.cuda.empty_cache()

summary_model.train()
reward_model.eval()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')