In [None]:
 ! pip install -U bitsandbytes accelerate transformers datasets trl peft evaluate rouge_score

In [None]:
from transformers import (

    AutoModelForCausalLM,

    AutoModelForSequenceClassification,

    AutoTokenizer,

    BitsAndBytesConfig,

    DistilBertTokenizer,

    TrainingArguments,

    pipeline,

)

import evaluate

from datasets import load_dataset, Dataset

from trl import (

    SFTTrainer,

    PPOTrainer,

    RewardTrainer,

    PPOConfig,

    RewardConfig,

    AutoModelForCausalLMWithValueHead,

)

from peft import LoraConfig, get_peft_model

from bitsandbytes.optim import AdamW8bit

from tqdm import tqdm

import torch

from torch.utils.data import DataLoader, Dataset as torchDataset

import numpy as np

# Hugging face and wandb login

In [None]:
from huggingface_hub import login

login(token='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')
wandb.login(key="ba3349aecf7f23a3abb849de3155be527d3585f1")

# Hyperparameter

In [None]:
dataset = load_dataset("openai/summarize_from_feedback", "comparisons")

base_reward_model_checkpoint = "google/gemma-2-2b"

reward_model_repo_name="reward_model"

reward_model_checkpoint=f"JaishreeramCoder/{reward_model_repo_name}"

output_dir="/content/sample_data"

base_sft_model_checkpoint = "meta-llama/Llama-3.1-8B"

sft_model_repo_name = "sft_model"

sft_model_checkpoint=f"JaishreeramCoder/{sft_model_repo_name}"

rlhf_model_repo_name="ppo_model"

rlhf_model_checkpoint=f"JaishreeramCoder/{rlhf_model_repo_name}"

num_train_epochs_reward_model = 5

num_train_epochs_sft = 5

num_train_epochs_ppo_outer=5

ppo_training_batch_size=8

eval_batch_size = 8

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
def count_parameters(model):
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    return total_params, trainable_params

# RLHF based finetuning

In [None]:
compute_dtype = getattr(torch, "float16")



quantization_config = BitsAndBytesConfig(

    load_in_4bit=True,

    bnb_4bit_quant_type="nf4",

    bnb_4bit_compute_dtype=compute_dtype,

    bnb_4bit_use_double_quant=False,

)

lora_config =  LoraConfig(

    lora_alpha=16,

    lora_dropout=0.1,

    r=64,

    bias="none",

    task_type="CAUSAL_LM",

)

rlhf_tokenizer = AutoTokenizer.from_pretrained(sft_model_checkpoint)
rlhf_tokenizer.pad_token = rlhf_tokenizer.eos_token

rlhf_model = AutoModelForCausalLMWithValueHead.from_pretrained(
    sft_model_checkpoint, quantization_config=quantization_config,peft_config=lora_config
)
rlhf_model.train()

In [None]:
def reward_fn(response):
    """
    Takes single text as input and returns the reward score
    """
    with torch.no_grad():
        reward_model.eval()
        input_text = reward_tokenizer.decode(
            response, skip_special_tokens=True
        )  # skips eos, bos, pad token
        input = reward_tokenizer(
            input_text,
            padding="max_length",
            truncation=True,
            max_length=128,
            return_tensors="pt",
        )
        logits = reward_model(**input).logits
        predicted_score = torch.tensor(logits.argmax(dim=-1),dtype=torch.float32)
        return predicted_score

In [None]:
def get_rlhf_dataset(data):
    input_ids, attention_mask = ([], [])
    for i in range(len(data["choice"])):
        input = f"Summarize the following text:\n\n{data['info'][i]['post']}"
        cur = rlhf_tokenizer(
            input,
            padding="max_length",
            truncation=True,
            max_length=512,
            padding_side="left",
        )
        cur_input_ids = cur.input_ids
        cur_attention_mask = cur.attention_mask
        input_ids.append(cur_input_ids)
        attention_mask.append(cur_attention_mask)
    output = {"input_ids": input_ids, "attention_mask": attention_mask}
    output = Dataset.from_dict(output)
    return output

In [None]:
rlhf_train_dataset = get_rlhf_dataset(dataset["train"][2000:3000])
rlhf_eval_dataset = get_rlhf_dataset(dataset["validation"][2000:3000])

In [None]:
generation_kwargs = {

    "min_length": -1,  # don't ignore the EOS token

    "top_k": 0.0,  # no top-k sampling

    "top_p": 1.0,  # no nucleus sampling

    "do_sample": True,  # yes, we want to sample

    "eos_token_id": rlhf_tokenizer.eos_token_id,

    "bos_token_id": rlhf_tokenizer.bos_token_id,

    "pad_token_id": rlhf_tokenizer.eos_token_id,  # most decoder models don't have a padding token - use EOS token instead

    "max_new_tokens": 32,  # specify how many tokens you want to generate at most

}

In [None]:
def evaluate(model, data, reward_model):
    with torch.no_grad():
        model.eval()
        reward_model.eval()
        reward_value = []
        data_size = len(data["input_ids"])
        
        for i in tqdm(range(0, data_size, eval_batch_size)):
            cur_data = torch.tensor(data["input_ids"][i : i + eval_batch_size])
            cur_data=cur_data.detach().clone()
            cur_data=cur_data.to(device)
            response = model.generate(cur_data, **generation_kwargs)
            response = response[: cur_data.shape[1] :]
            for j in range(eval_batch_size):
                reward_value.append(reward_fn(response[j]))
        avg_reward = np.mean(np.array(reward_value))
        return avg_reward

sft_avg_reward = evaluate(ppo_model, rlhf_eval_dataset, reward_model)
print(f"Average Reward for supervised finetuned model: {sft_avg_reward}")

with torch.no_grad():
    cur_data = torch.tensor(rlhf_eval_dataset["input_ids"][0])
    cur_data=cur_data.detach().clone()
    cur_data=cur_data.to(device)
    response = rlhf_model.generate(cur_data, **generation_kwargs)
    input_text=rlhf_tokenizer.decode(cur_data)
    sft_response_text=rlhf_tokenizer.decode(response)

In [None]:
print(count_parameters(rlhf_model))

In [None]:
ppo_config = PPOConfig(

    gradient_accumulation_steps=ppo_training_batch_size,

    batch_size=ppo_training_batch_size,

    mini_batch_size=1,

    learning_rate=1e-5,

    model_name=rlhf_model_checkpoint,

    is_peft_model=True,

)

In [None]:
for epoch in tqdm(range(num_train_epochs_ppo_outer)):
    for i in range(0,len(rlhf_eval_dataset),ppo_training_batch_size):
        reward_value=[]
        data=torch.tensor(rlhf_eval_dataset["input_ids"][i:i+ppo_training_batch_size])
        data=data.detach().clone()
        data=data.to(device)
        data=[data[j] for j in range(ppo_training_batch_size)]
        responses=ppo_trainer.generate(data,**generation_kwargs,return_prompt=False)
        for j in range(ppo_training_batch_size):
            cur_reward=reward_fn(responses[j])
            reward_value.append(cur_reward)
        ppo_trainer.step(queries=data,responses=responses,scores=reward_value)

In [None]:
with torch.no_grad():
    cur_data = torch.tensor(rlhf_eval_dataset["input_ids"][0])
    cur_data=cur_data.detach().clone()
    cur_data=cur_data.to(device)
    response = rlhf_model.generate(cur_data, **generation_kwargs)
    input_text=rlhf_tokenizer.decode(cur_data)
    rlhf_response_text=rlhf_tokenizer.decode(response)
    print(f"Input Text:\n{input_text}")
    print(f"\nSFT model Response:\n{sft_response_text}")
    print(f"\nRLHF model Response:\n{rlhf_response_text}")
    

In [None]:
rlhf_avg_reward = evaluate(ppo_trainer.model, rlhf_eval_dataset, reward_model)
print(f"Average Reward for Supervised finetuned model: {sft_avg_reward}")
print(f"Average Reward for RLHF finetuned model: {rlhf_avg_reward}")

In [None]:
rouge_metric = evaluate.load("rouge")

def compute_metrics(decoded_preds, decoded_actual_labels):

    result = rouge_metric.compute(

        predictions=decoded_preds, references=decoded_actual_labels

    )

    print(f"RLHF Model ROUGE values: {result}")

In [None]:
def evaluate_sft_model(sft_model, sft_eval_dataset):

    with torch.no_grad():

        sft_model.eval()

        decoded_preds = []

        decoded_actual_labels = []

        for i in tqdm(range(0, len(sft_eval_dataset["input_ids"]), eval_batch_size)):

            cur_data = torch.tensor(

                sft_eval_dataset["input_ids"][i : i + eval_batch_size]

            )

            cur_data=cur_data.to(device)

            cur_preds = sft_model.generate(cur_data, **generation_kwargs)

            cur_preds = cur_preds[:, cur_data.shape[1] :]

            for j in range(eval_batch_size):

                generated_text = sft_tokenizer.decode(

                    cur_preds[j], skip_special_tokens=True

                )

                decoded_preds.append(generated_text)

            cur_actual_label_ids = torch.tensor(

                sft_eval_dataset["labels"][i : i + eval_batch_size]

            )

            for j in range(eval_batch_size):

                decoded_actual_labels.append(

                    sft_tokenizer.decode(

                        cur_actual_label_ids[j], skip_special_tokens=True

                    )

                )

        sft_model_eval_result = compute_metrics(

            decoded_preds=decoded_preds, decoded_actual_labels=decoded_actual_labels

        )





evaluate_sft_model(rlhf_model, rlhf_eval_dataset)

In [None]:
#rlhf_model.merge_and_unload()
rlhf_model.push_to_hub(rlhf_model_repo_name)
rlhf_tokenizer.push_to_hub(rlhf_model_repo_name)