In [None]:
import json
import copy
import logging
from dataclasses import dataclass, field

import torch
from torch.utils.data import Dataset
import transformers
from transformers import Trainer, TrainingArguments, AutoTokenizer, AutoModel, BitsAndBytesConfig
from datasets import load_dataset

In [None]:
from trl import AutoModelForCausalLMWithValueHead, PPOConfig, PPOTrainer

In [None]:
base_model_path = 'skt/kogpt2-base-v2'

In [None]:
model = AutoModelForCausalLMWithValueHead.from_pretrained(base_model_path)

tokenizer = AutoTokenizer.from_pretrained(
    base_model_path,
    padding_side="right",
    model_max_length=512,
)

In [None]:
# data config
IGNORE_INDEX = -100
DEFAULT_PAD_TOKEN = "[PAD]"
DEFAULT_BOS_TOKEN = "<s>"
DEFAULT_EOS_TOKEN = "</s>"
DEFAULT_UNK_TOKEN = "<UNK>"

In [None]:
tokenizer.add_special_tokens(
    {
        "pad_token": DEFAULT_PAD_TOKEN,
        "bos_token": DEFAULT_BOS_TOKEN,
        "eos_token": DEFAULT_EOS_TOKEN,
        "unk_token": DEFAULT_UNK_TOKEN,
    }
)
tokenizer.pad_token = tokenizer.eos_token

In [None]:
from transformers import pipeline

reward_model = pipeline("text-classification", model=base_model_path)

In [None]:
from datasets import load_dataset

data_path = 'AIdenU/orca_dpo_data_ko'

dataset = load_dataset(data_path)

In [None]:
dataset

In [None]:
PROMPT_DICT = {
    "prompt_input": """
<start_of_turn>user
{system}

### Input:
{user_input}

<start_of_turn>model
""".lstrip(),
    "prompt_no_input": """
<start_of_turn>user
### Input:
{user_input}

<start_of_turn>model
""".lstrip(),
}

In [None]:
prompt_input, prompt_no_input = PROMPT_DICT["prompt_input"], PROMPT_DICT["prompt_no_input"]  # 템플릿 가져오기

In [None]:
from typing import Optional, Dict, Sequence

class PPO_dataset(Dataset):
    '''SFT dataset by wygo'''
    def __init__(self, list_data_dict: list, system: str, user_input: str, tokenizer: transformers.PreTrainedTokenizer, verbose=False):
        super(PPO_dataset, self).__init__()
        logging.warning("Loading data...")

        ## format
        system = 'system'  
        user_input = 'question' 

        prompt_input, prompt_no_input = PROMPT_DICT["prompt_input"], PROMPT_DICT["prompt_no_input"]  # 템플릿 가져오기

        # 입력
        sources = []
        for example in list_data_dict:
            if example.get(user_input, "") != "":
                tmp = prompt_input.format_map({
                    'system':example[system],
                    'user_input':example[user_input],
                })
            else:
                tmp = prompt_no_input.format_map({
                    'user_input':example[user_input],
                })
            sources.append(tmp)

        # source data tokenized
        sources_tokenized = self._tokenize_fn(sources, tokenizer)  # source만


        ## 입력은 source, 출력은 source+target 이지만 학습은 target 부분만
        input_ids = sources_tokenized["input_ids"]
        query = sources_tokenized["query"]

        data_dict = dict(input_ids=input_ids,
                         query=query)

        self.input_ids = data_dict["input_ids"]
        self.query = data_dict["query"]

    def _tokenize_fn(self, strings: Sequence[str], tokenizer: transformers.PreTrainedTokenizer) -> Dict:
        """Tokenize a list of strings."""
        tokenized_list = [
            tokenizer(
                text,
                return_tensors="pt",
                padding="longest",
                # max_length=tokenizer.model_max_length,
                max_length=256,
                truncation=True,
            )
            for text in strings
        ]
        input_ids = [tokenized.input_ids[0] for tokenized in tokenized_list]

        return dict(
            query=strings,
            input_ids=input_ids,
        )

    def __len__(self):
        return len(self.input_ids)


    def __getitem__(self, i) -> Dict[str, torch.Tensor]:
        return dict(input_ids=self.input_ids[i],
                    query=self.query[i],
                    )

def collator(data):
    return {key: [d[key] for d in data] for key in data[0]}

In [None]:
list_data_dict = dataset['train']

In [None]:
train_dataset = PPO_dataset(list_data_dict=list_data_dict, 
                            system='system',
                            user_input='question',
                            tokenizer=tokenizer)
eval_dataset  = None  # eval은 안함

In [None]:
from trl import PPOConfig
from bitsandbytes.optim import AdamW

batch_size = 1

config = PPOConfig(
    learning_rate=1.41e-5,
    batch_size=batch_size,
    mini_batch_size=batch_size,
    steps=100,
)

In [None]:
save_dir = './output_3_PPO'

In [None]:
from trl import PPOTrainer

ppo_trainer = PPOTrainer(
    model=model,
    config=config,
    dataset=train_dataset,
    data_collator=collator,
    tokenizer=tokenizer,
)

In [None]:
generation_kwargs = {
    "min_length": -1,
    "top_k": 0.0,
    "top_p": 1.0,
    "do_sample": True,
    "pad_token_id": tokenizer.eos_token_id,
    # "max_length": 512,
}

In [None]:
from trl.core import LengthSampler

In [None]:
output_min_length = 4
# output_max_length = 400
output_max_length = 99
output_length_sampler = LengthSampler(output_min_length, output_max_length)

In [None]:
from tqdm import tqdm

In [None]:
len(ppo_trainer.dataloader)

In [None]:
for epoch, batch in tqdm(enumerate(ppo_trainer.dataloader)):
    query_tensors = batch["input_ids"]

    #### Get response from gpt2
    response_tensors = []
    for query in query_tensors:
        gen_len = output_length_sampler()
        generation_kwargs["max_new_tokens"] = gen_len
        response = ppo_trainer.generate(query, **generation_kwargs)
        response_tensors.append(response.squeeze()[-gen_len:])
    batch["response"] = [tokenizer.decode(r.squeeze()) for r in response_tensors]

    #### Compute reward score
    texts = [q + r for q, r in zip(batch["query"], batch["response"])]
    pipe_outputs = reward_model(texts)
    # rewards = [torch.tensor(output[1]["score"]) for output in pipe_outputs]
    rewards = [torch.tensor(output["score"]) for output in pipe_outputs]

    #### Run PPO step
    stats = ppo_trainer.step(query_tensors, response_tensors, rewards)
    ppo_trainer.log_stats(stats, batch, rewards)