### Setup imports and environment variables.

In [None]:
# ! pip install -r ../requirements.txt

In [None]:
import os

from dataclasses import dataclass, field
from typing import Dict, Optional

import torch
from datasets import Dataset, load_dataset
from torch.optim import Adam
from tqdm import tqdm
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    HfArgumentParser,
    LlamaTokenizer,
    pipeline
)

from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser, TrainingArguments
from trl import AutoModelForCausalLMWithValueHead, DPOTrainer
from trl import DPOTrainer

In [None]:
tqdm.pandas()

In [None]:
huggingface_hub_token = ###YOUR_TOKEN_HERE###
wandb_api_key = ###YOUR_KEY_HERRE###

In [None]:
os.environ['WANDB_API_KEY'] = wandb_api_key

### Setup Llama reward model

In [None]:
def setup_llama_reward_model():
    rm_tokenizer = AutoTokenizer.from_pretrained("weqweasdas/hh_rlhf_rm_open_llama_3b")
    
    rm_pipe = pipeline(
        "sentiment-analysis",
        model="weqweasdas/hh_rlhf_rm_open_llama_3b",
        device="cuda",
        tokenizer=rm_tokenizer,
        model_kwargs={"torch_dtype": torch.bfloat16}
    )
    
    pipe_kwargs = {
        "return_all_scores": True,
        "function_to_apply": "none",
        "batch_size": 1
    }
    
    test_texts = [
        "###Human: My daughter wants to know how to convert fractions to decimals, but I'm not sure how to explain it. Can you help? ###Assistant: Sure. So one way of converting fractions to decimals is to ask “how many halves are there?” and then write this as a decimal number. But that's a little tricky. Here's a simpler way:  if a fraction is expressed as a/b, then it's decimal equivalent is just a/b * 1.0  So, for example, the decimal equivalent of 1/2 is 1/2 * 1.0 = 0.5.",
        "###Human: I have fresh whole chicken in my fridge. What dish can I prepare using it that will take me less than an hour to cook? ###Assistant: Are you interested in a quick and easy recipe you can prepare with chicken you have on hand, or something more involved?  In terms of both effort and time, what are you looking for?",
        "###Human: My daughter wants to know how to convert fractions to decimals, but I'm not sure how to explain it. Can you help? ###Assistant: Yes, of course. Here you go."
    ]
    
    pipe_outputs = rm_pipe(test_texts, **pipe_kwargs)
    rewards = [output[0]["score"] for output in pipe_outputs]

### Set up DPO arguments.

In [None]:
from dataclasses import dataclass, field
from typing import Dict, Optional

import torch
from datasets import Dataset, load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser, TrainingArguments

from trl import DPOTrainer


# Define and parse arguments.
@dataclass
class ScriptArguments:
    """
    The arguments for the DPO training script.
    """

    # data parameters
    beta: Optional[float] = field(default=0.1, metadata={"help": "the beta parameter for DPO loss"})

    # training parameters
    model_name_or_path: Optional[str] = field(default="EleutherAI/gpt-neo-125m", metadata={"help": "the model name"})
    learning_rate: Optional[float] = field(default=1.5e-4, metadata={"help": "optimizer learning rate"})
    per_device_train_batch_size: Optional[int] = field(default=8, metadata={"help": "batch size per device"})
    gradient_accumulation_steps: Optional[int] = field(
        default=1, metadata={"help": "the number of gradient accumulation steps"}
    )
    max_length: Optional[int] = field(default=256, metadata={"help": "max length of each sample"})
    max_prompt_length: Optional[int] = field(default=128, metadata={"help": "max length of each sample's prompt"})
    max_target_length: Optional[int] = field(
        default=128, metadata={"help": "Only used for encoder decoder model. Max target of each sample's prompt"}
    )
    label_pad_token_id: Optional[int] = field(default=-100, metadata={"help": "label for non response tokens"})
    max_steps: Optional[int] = field(default=20050, metadata={"help": "max number of training steps"})

    # instrumentation
    sanity_check: Optional[bool] = field(default=False, metadata={"help": "only train on 1000 samples"})
    report_to: Optional[str] = field(
        default="wandb",
        metadata={
            "help": 'The list of integrations to report the results and logs to. Supported platforms are `"azure_ml"`,'
            '`"comet_ml"`, `"mlflow"`, `"neptune"`, `"tensorboard"`,`"clearml"` and `"wandb"`. '
            'Use `"all"` to report to all integrations installed, `"none"` for no integrations.'
        },
    )
    # debug argument for distributed training
    ignore_bias_buffers: Optional[bool] = field(
        default=False,
        metadata={
            "help": "fix for DDP issues with LM bias/mask buffers - invalid scalar type,`inplace operation. See"
            "https://github.com/huggingface/transformers/issues/22482#issuecomment-1595790992"
        },
    )
    gradient_checkpointing: Optional[bool] = field(
        default=False, metadata={"help": "Whether to use gradient checkpointing or no"}
    )
    gradient_checkpointing_kwargs: Optional[dict] = field(
        default=None,
        metadata={
            "help": "key word arguments to be passed along `torch.utils.checkpoint.checkpoint` method - e.g. `use_reentrant=False`"
        },
    )

    huggingface_hub_name: Optional[str] = field(
        default='amirabdullah19852020/gpt-neo-125m_hh_reward', metadata={"help": "Huggingface repo name"}
    )

### Setup dataset.

In [None]:
def extract_anthropic_prompt(prompt_and_response):
    """Extract the anthropic prompt from a prompt and response pair."""
    search_term = "\n\nAssistant:"
    search_term_idx = prompt_and_response.rfind(search_term)
    assert search_term_idx != -1, f"Prompt and response does not contain '{search_term}'"
    return prompt_and_response[: search_term_idx + len(search_term)]


def get_hh(split: str = 'train', sanity_check: bool = False, silent: bool = False, cache_dir: str = None) -> Dataset:
    """Load the Anthropic Helpful-Harmless dataset from Hugging Face and convert it to the necessary format.

    The dataset is converted to a dictionary with the following structure:
    {
        'prompt': List[str],
        'chosen': List[str],
        'rejected': List[str],
    }

    Prompts should be structured as follows:
      \n\nHuman: <prompt>\n\nAssistant:
    Multiple turns are allowed, but the prompt should always start with \n\nHuman: and end with \n\nAssistant:.
    """
    dataset = load_dataset("Anthropic/hh-rlhf", split=split, cache_dir=cache_dir)
    if sanity_check:
        dataset = dataset.select(range(min(len(dataset), 300)))

    def split_prompt_and_responses(sample) -> Dict[str, str]:
        prompt = extract_anthropic_prompt(sample["chosen"])
        return {
            "prompt": prompt,
            "chosen": sample["chosen"][len(prompt) :],
            "rejected": sample["rejected"][len(prompt) :],
        }

    return dataset.map(split_prompt_and_responses)

dataset = get_hh()

In [None]:
script_args = ScriptArguments()

In [None]:
model = AutoModelForCausalLM.from_pretrained(script_args.model_name_or_path).cuda()

if script_args.ignore_bias_buffers:
    # torch distributed hack
    model._ddp_params_and_buffers_to_ignore = [
        name for name, buffer in model.named_buffers() if buffer.dtype == torch.bool
    ]

model_ref = AutoModelForCausalLM.from_pretrained(script_args.model_name_or_path).cpu()
tokenizer = AutoTokenizer.from_pretrained(script_args.model_name_or_path)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

In [None]:
train_dataset = get_hh("train", sanity_check=script_args.sanity_check)

In [None]:
eval_dataset = get_hh("test", sanity_check=True)

In [None]:
training_args = TrainingArguments(
        per_device_train_batch_size=script_args.per_device_train_batch_size,
        max_steps=script_args.max_steps,
        remove_unused_columns=False,
        gradient_accumulation_steps=script_args.gradient_accumulation_steps,
        learning_rate=script_args.learning_rate,
        push_to_hub=True,
        hub_model_id=script_args.huggingface_hub_name,
        evaluation_strategy="steps",
        logging_first_step=True,
        logging_steps=10,
        eval_steps=2000,
        output_dir="./test",
        optim="adamw_hf",
        warmup_steps=150,
        report_to=script_args.report_to,
        bf16=True,
        gradient_checkpointing=script_args.gradient_checkpointing
)

In [None]:
dpo_trainer = DPOTrainer(
    model,
    model_ref,
    args=training_args,
    beta=script_args.beta,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    tokenizer=tokenizer,
    max_length=script_args.max_length,
    max_target_length=script_args.max_target_length,
    max_prompt_length=script_args.max_prompt_length,
    generate_during_eval=True
)

In [None]:
! nvidia-smi

In [None]:
dpo_trainer.train()

In [None]:
import huggingface_hub
huggingface_hub.login(huggingface_hub_token)

In [None]:
dpo_trainer.push_to_hub()