# Improving Your LLMs with RLHF on SageMaker

Reinforcement learning from human feedback (RLHF) has proven to be essential to recent large language models (LLMs), e.g. ChatGPT, Claude’s impressive capability and fast adoption. Gone are the days that you need unnatural prompt engineering to get base models, e.g. GPT3 to solve your tasks. Thanks to RLHF, Large Language Models (LLMs) are now much more aligned with human value. 

However, well-known to the reinforcement learning community, RLHF is notoriously hard to get right. Until very recently only a small number of ML scientists have mastered the skill. In this notebook, we demystify and bring the technique at the disposal of any ML scientist. We describe how to train a base model with RLHF on Amazon SageMaker step by step. 


## Install Prerequisites

In [None]:
!git clone https://github.com/CarperAI/trlx.git
!pip install torch==2.0.0 --extra-index-url https://download.pytorch.org/whl/cu116 # for cuda
%cd trlx
!git checkout 355c9741f2e606de796f5c6f9b682f7dd00f97c5
!pip install -e .
!pip install transformers==4.27.1 accelerate==0.19.0
%cd ..

## Supervised Fine-tuning a Base LLM

We first make some changes to the code in `trlx/examples/hh/sft_hh.py`, so that the trained model weights are stored in the `checkpoints/sft_hh_gptj_6b` folder

```
        from itertools import islice

        ...
        
        train=TrainConfig(
            seq_length=1024,
            epochs=100,
            total_steps=10000,
            batch_size=1,
            checkpoint_interval=10000,
            eval_interval=1000,
            pipeline="PromptPipeline",
            trainer="AccelerateSFTTrainer",
            checkpoint_dir="checkpoints/sft_hh_gptj_6b",           # <-- changes
            tracker="tensorboard",                                 # <-- changes
            logging_dir="checkpoints/sft_hh_gptj_6b"               # <-- changes
        ),
        
        ...

        trlx.train(
            config=config,
            samples=dataset["train"]["chosen_sample"],
            eval_prompts = [{"prompt": x["prompt"], "original_output": x["chosen"]} for x in islice(dataset["test"], 280)]  # <-- changes
            metric_fn=lambda **kwargs: {"reward": reward_fn(**kwargs)},
            stop_sequences=["Human:", "human:", "Assistant:", "assistant:"],
        )
```

Prepend the following line to line 600 of `trlx/trainer/accelerate_base_trainer.py`:

```
        self.save_pretrained(directory+‘/pretrained’)             # <-- changes
```

We then perform training using the following command:

In [None]:
!cd trlx/examples/hh ; accelerate launch --num_processes 7 --config_file ../../configs/accelerate/zero2-bf16.yaml sft_hh.py

## RLHF Training

Now we have acquired all the required components for RLHF training, i.e. a Supervised Fine-Tuned base model (SFT), a reward model (RM), we are ready to begin optimizing the policy using RLHF. To do this, we will modify the path to SFT model in `examples/hh/ppo_hh.py` to the model weights trained in the previous section (i.e. `checkpoints/sft_hh_gptj_6b/best_checkpoint`), and the final weights would be stored in `checkpoints/ppo_hh_6B`:

```
    elif config_name == "6B":
        ...
        default_config.model.model_path = "checkpoints/sft_hh_gptj_6b/best_checkpoint"     # <-- changes
        default_config.train.checkpoint_dir = "checkpoints/ppo_hh_6B"      # <-- changes
        ...
```

We then run the training command to start RLHF training:

In [None]:
!cd trlx/examples/hh ; CONFIG_NAME=6B accelerate launch --num_processes 7 --config_file ../../configs/accelerate/zero2-bf16.yaml ppo_hh.py

## Running Inference

In [None]:
import os
import json
import math
import torch
import random
import numpy as np
from itertools import islice
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import AutoConfig


def set_seed(seed_val=42):
    random.seed(seed_val)
    np.random.seed(seed_val)
    torch.manual_seed(seed_val)
    torch.cuda.manual_seed_all(seed_val)


def create_hf_model(model_class, model_name_or_path, tokenizer, disable_dropout=False):
    model_config = AutoConfig.from_pretrained(model_name_or_path)
    if disable_dropout:
        model_config.dropout = 0.0

    model = model_class.from_pretrained(
        model_name_or_path,
        from_tf=bool(".ckpt" in model_name_or_path),
        config=model_config,
        torch_dtype=torch.float16,
    )

    model.config.end_token_id = tokenizer.eos_token_id
    model.config.pad_token_id = model.config.eos_token_id
    model.resize_token_embeddings(
        int(8 * math.ceil(len(tokenizer) / 8.0))
    )  # make the vocab size multiple of 8

    return model


def load_hf_tokenizer(model_name_or_path, truncation_side="left", padding_side="left"):
    if os.path.exists(model_name_or_path):
        # Locally tokenizer loading has some issue, so we need to force download
        model_json = os.path.join(model_name_or_path, "config.json")
        if os.path.exists(model_json):
            model_json_file = json.load(open(model_json))
            model_name = (
                "/home/ec2-user/SageMaker/trlx/examples/hh/" + model_json_file["_name_or_path"]
            )
            print(model_name)
            tokenizer = AutoTokenizer.from_pretrained(
                model_name,
                fast_tokenizer=True,
                truncation_side=truncation_side,
                padding_side=padding_side,
            )
    else:
        tokenizer = AutoTokenizer.from_pretrained(
            model_name_or_path,
            fast_tokenizer=True,
            truncation_side=truncation_side,
            padding_side=padding_side,
        )

    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.pad_token_id = tokenizer.eos_token_id
    return tokenizer


def generate(
    model,
    tokenizer,
    inputs,
    num_beams=1,
    num_beam_groups=1,
    num_return_sequences=1,
    max_new_tokens=128,
):
    prompt_length = inputs.input_ids.shape[1]
    generate_ids = model.generate(
        inputs.input_ids,
        attention_mask=inputs.attention_mask,
        num_beams=num_beams,
        num_beam_groups=num_beam_groups,
        num_return_sequences=num_return_sequences,
        max_new_tokens=max_new_tokens,
        top_k=0,
        top_p=1.0,
        do_sample=True,
    )

    result = tokenizer.batch_decode(
        generate_ids[:, prompt_length:],
        skip_special_tokens=True,
        clean_up_tokenization_spaces=False,
    )
    return result


def run_inference(model, tokenizer, device, prompts, max_prompt_length):
    for prompt in prompts:
        inputs = tokenizer(
            prompt["prompt"],
            return_tensors="pt",
            truncation=True,
            padding=False,
            max_length=max_prompt_length,
            add_special_tokens=False,
        ).to(device)

        print(f"\n\n\n\tPrompt------------------------\n  {prompt['prompt']}")
        r_base = generate(
            model, tokenizer, inputs, num_beams=1, num_return_sequences=1, max_new_tokens=128
        )[0]
        for stop in ["Human:", "human:", "Assistant:", "assistant:"]:
            stop_ix = r_base.find(stop)
            if stop_ix >= 0:
                r_base = r_base[:stop_ix].rstrip()
        print(f"\tResponse------------------------\n {r_base}")

        prompt["machine_output"] = r_base


def main(model_name_or_path_baseline):
    set_seed(seed_val=42)

    device = torch.device("cuda:0")

    max_new_tokens = 128
    seq_length = 512
    max_prompt_length = seq_length - max_new_tokens
    tokenizer_pth = model_name_or_path_baseline

    tokenizer = load_hf_tokenizer(tokenizer_pth, truncation_side="left", padding_side="left")

    model = create_hf_model(AutoModelForCausalLM, model_name_or_path_baseline, tokenizer=tokenizer)
    model.to(device)
    dataset = load_dataset("Dahoas/rm-static")
    eval_prompts = [
        {"prompt": x["prompt"], "original_output": x["chosen"]}
        for x in islice(dataset["test"], 100)
    ]

    run_inference(model, tokenizer, device, eval_prompts, max_prompt_length=max_prompt_length)

The following code generates outputs for the HH test set:

In [None]:
model_name_or_path_baseline = (
    "/home/ec2-user/SageMaker/trlx/examples/hh/checkpoints/ppo_hh_6B/best_checkpoint/pretrained/"
)
main(model_name_or_path_baseline)