In [None]:
# dpo
# from dataclasses import dataclass, field
from typing import Dict, Optional

# import torch
from datasets import Dataset, load_dataset
# from peft import LoraConfig
# from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser, TrainingArguments

# from trl import DPOTrainer

def extract_anthropic_prompt(prompt_and_response, search_term="\n\nAssistant:"):
    """Extract the anthropic prompt from a prompt and response pair."""
    search_term_idx = prompt_and_response.rfind(search_term)
    assert search_term_idx != -1, f"Prompt and response does not contain '{search_term}'"
    return prompt_and_response[: search_term_idx + len(search_term)]


def get_hh(split: str, sanity_check: bool = False, silent: bool = False, cache_dir: str = None) -> Dataset:
    """Load the Anthropic Helpful-Harmless dataset from Hugging Face and convert it to the necessary format.

    The dataset is converted to a dictionary with the following structure:
    {
        'prompt': List[str],
        'chosen': List[str],
        'rejected': List[str],
    }

    Prompts should be structured as follows:
      \n\nHuman: <prompt>\n\nAssistant:
    Multiple turns are allowed, but the prompt should always start with \n\nHuman: and end with \n\nAssistant:.
    """
    dataset = load_dataset("Anthropic/hh-rlhf", split=split, cache_dir=cache_dir)
    if sanity_check:
        dataset = dataset.select(range(min(len(dataset), 1000)))

    def split_prompt_and_responses(sample) -> Dict[str, str]:
        prompt = extract_anthropic_prompt(sample["chosen"])
        return {
            "prompt": prompt,
            "chosen": sample["chosen"][len(prompt) :],
            "rejected": sample["rejected"][len(prompt) :],
        }

    return dataset.map(split_prompt_and_responses)

In [None]:
split = 'train'
cache_dir = None
dataset = load_dataset("Anthropic/hh-rlhf", split=split, cache_dir=cache_dir)

In [2]:
from typing import Dict, Optional
from datasets import Dataset, load_dataset

from fastchat.model.model_adapter import get_conversation_template

def extract_anthropic_prompt(prompt_and_response, search_term="\n\nAssistant:"):
    """Extract the anthropic prompt from a prompt and response pair."""
    search_term_idx = prompt_and_response.rfind(search_term)
    assert search_term_idx != -1, f"Prompt and response does not contain '{search_term}'"
    return prompt_and_response[: search_term_idx + len(search_term)]

class hankang_DPODataset:
    def __init__(
        self, 
        dataset_path="/data/llm_datasets/Ultrafeedback_binarized.ko.hankang/",
        data_format='chat-orca',
        search_term='\n\n### Assistant:',
        num_train=None,
        num_eval=None,
    ):
        self.dataset_path = dataset_path
        self.data_format = data_format
        self.search_term = search_term
        self.num_train = num_train
        self.num_eval = num_eval
    
    def get_prompt_and_response(self, data):
        conv = get_conversation_template(self.data_format)

        for idx, _conv in enumerate(data):
            role = _conv['role']
            content = _conv['content_kr']
            if idx % 2 == 0 and role == 'user':
                conv.append_message(conv.roles[0], content)
            elif idx % 2 == 1 and role == 'assistant':
                conv.append_message(conv.roles[1], content)
            else:
                print("Warning: data type invaild")

        if len(conv.messages) == 0:
            print("Warning: data is empty")
        if len(conv.messages) % 2 != 0:
            print("Warning: data has weird pair")

        return conv.get_prompt()
    
    def make_dpo_data_module(self):
        def validate_prompt_and_responses(data) -> bool:
            try:
                prompt_and_response = self.get_prompt_and_response(data['chosen'])
                prompt_and_response_rejected = self.get_prompt_and_response(data['rejected'])
                prompt = extract_anthropic_prompt(prompt_and_response, self.search_term)
                promopt_rejected = extract_anthropic_prompt(prompt_and_response_rejected, self.search_term)
            except AssertionError:
                return False

            return True

        def split_prompt_and_responses(data) -> Dict[str, str]:
            prompt_and_response = self.get_prompt_and_response(data['chosen'])
            prompt_and_response_rejected = self.get_prompt_and_response(data['rejected'])
            prompt = extract_anthropic_prompt(prompt_and_response, self.search_term)
            promopt_rejected = extract_anthropic_prompt(prompt_and_response_rejected, self.search_term)
            return {
                "prompt": prompt,
                "chosen": prompt_and_response[len(prompt) :],
                "rejected": prompt_and_response_rejected[len(promopt_rejected) :],
            }
                             
                             
        dataset = load_dataset(self.dataset_path)

        train_dataset = dataset['train']
        eval_dataset = dataset['test']

        original_columns = list(train_dataset.features.keys())

        if self.num_train is not None:
            train_dataset = train_dataset.select(range(min(len(train_dataset), self.num_train)))
        if self.num_eval is not None:
            eval_dataset = eval_dataset.select(range(min(len(train_dataset), self.num_eval)))

        train_dataset = train_dataset.filter(validate_prompt_and_responses)
        train_dataset = train_dataset.map(split_prompt_and_responses, remove_columns=original_columns)

        eval_dataset = eval_dataset.filter(validate_prompt_and_responses)
        eval_dataset = eval_dataset.map(split_prompt_and_responses, remove_columns=original_columns)

        return dict(train_dataset=train_dataset, eval_dataset=eval_dataset)

In [1]:
from fastchat.train.data_modules.dpo_dataset import hankang_DPODataset

dpo_dataset = hankang_DPODataset()
dpo_datamodule = dpo_dataset.make_dpo_data_module()

  from .autonotebook import tqdm as notebook_tqdm
Filter: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 61752/61752 [00:06<00:00, 9603.56 examples/s]
Map: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 61734/61734 [00:12<00:00, 5061.69 examples/s]
Filter: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1999/1999 [00:00<00:00, 9628.81 examples/s]
Map: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1999/1999 [00:00<00:00, 4990.62 examples/s]


  from .autonotebook import tqdm as notebook_tqdm


In [None]:
from trl import DPOTrainer
from transformers import Trainer

In [None]:
DPOTrainer(

In [None]:
import transformers
dir(transformers.TrainingArguments)

In [3]:
a = 853/46302 
#16:08:08

In [4]:
x = 16/ (a)

In [6]:
159182/60/60

44.21722222222222