In [None]:
 ! pip install -U bitsandbytes accelerate transformers datasets trl peft evaluate rouge_score

In [None]:
from transformers import (
    AutoModelForCausalLM,
    AutoModelForSequenceClassification,
    AutoTokenizer,
    BitsAndBytesConfig,
    DistilBertTokenizer,
    TrainingArguments,
    pipeline,
)
import evaluate
from datasets import load_dataset, Dataset
from trl import (
    SFTTrainer,
    PPOTrainer,
    RewardTrainer,
    PPOConfig,
    RewardConfig,
    AutoModelForCausalLMWithValueHead,
)
from peft import LoraConfig, get_peft_model
from bitsandbytes.optim import AdamW8bit
from tqdm import tqdm
import torch
from torch.utils.data import DataLoader, Dataset as torchDataset
import numpy as np

# Hugging face login

In [None]:
from huggingface_hub import login
login(token='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')

# Hyperparameter

In [None]:
dataset = load_dataset("openai/summarize_from_feedback", "comparisons")
base_reward_model_checkpoint = "google/gemma-2-2b"
reward_model_repo_name="reward_model"
reward_model_checkpoint=f"JaishreeramCoder/{reward_model_repo_name}"
output_dir="/content/sample_data"
base_sft_model_checkpoint = "meta-llama/Llama-3.1-8B"
sft_model_repo_name = "sft_model"
sft_model_checkpoint=f"JaishreeramCoder/{sft_model_repo_name}"
rlhf_model_repo_name="ppo_gpt2_summary"
rlhf_model_checkpoint=f"JaishreeramCoder/{rlhf_model_repo_name}"
num_train_epochs_reward_model = 5
num_train_epochs_sft = 5
num_train_epochs_ppo_outer=5
ppo_training_batch_size=8
eval_batch_size = 8

# Supervised fine tuned model

In [None]:
sft_tokenizer = AutoTokenizer.from_pretrained(base_sft_model_checkpoint)
sft_tokenizer.pad_token = sft_tokenizer.eos_token

In [None]:
def get_sft_dataset(data):
    input_ids, attention_mask, label_ids = ([], [], [])
    for i in range(len(data["choice"])):
        input = f"Summarize the following text:\n\n{data['info'][i]['post']}"
        cur = sft_tokenizer(
            input,
            padding="max_length",
            truncation=True,
            max_length=512,
            padding_side="left",
        )
        cur_input_ids = cur.input_ids
        cur_attention_mask = cur.attention_mask
        completion = (
            data["summaries"][i][1]["text"]
            if data["choice"][i] == 1
            else data["summaries"][i][0]["text"]
        )
        cur_label_ids = sft_tokenizer(
            completion,
            padding="max_length",
            truncation=True,
            max_length=512,
            padding_side="left",
        ).input_ids
        input_ids.append(cur_input_ids)
        attention_mask.append(cur_attention_mask)
        label_ids.append(cur_label_ids)

    output = {
        "input_ids": input_ids,
        "attention_masks": attention_mask,
        "labels": label_ids,
    }
    output = Dataset.from_dict(output)
    return output

In [None]:
sft_train_dataset = get_sft_dataset(dataset["train"][1000:2000])
sft_eval_dataset = get_sft_dataset(dataset["validation"][1000:2000])

In [None]:
compute_dtype = getattr(torch, "float16")

quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=compute_dtype,
    bnb_4bit_use_double_quant=False,
)
sft_model = AutoModelForCausalLM.from_pretrained(
    base_sft_model_checkpoint,
    quantization_config=quantization_config,
)

In [None]:
lora_config =  LoraConfig(
    lora_alpha=16,
    lora_dropout=0.1,
    r=64,
    bias="none",
    task_type="CAUSAL_LM",
)
sft_model = get_peft_model(sft_model, lora_config)

In [None]:
print(count_parameters(sft_model))

In [None]:
sft_training_args = TrainingArguments(
    per_device_train_batch_size=1,
    gradient_accumulation_steps=8,
    optim="paged_adamw_32bit",
    logging_steps=1,
    learning_rate=1e-4,
    fp16=True,
    max_grad_norm=0.3,
    num_train_epochs=num_train_epochs_sft,
    evaluation_strategy="epoch",
    eval_steps=0.2,
    warmup_ratio=0.05,
    save_strategy="epoch",
    group_by_length=True,
    output_dir="/content/sample_data",
    save_safetensors=True,
    lr_scheduler_type="cosine",
    seed=42,
    load_best_model_at_end=True,
    push_to_hub=True,
)

param_to_update = []
for param in sft_model.parameters():
    if param.requires_grad == True:
        param_to_update.append(param)

optimizers = AdamW8bit(param_to_update, lr=2e-5)

model_trainer = SFTTrainer(
    model=sft_model,
    tokenizer=sft_tokenizer,
    train_dataset=sft_train_dataset,
    eval_dataset=sft_eval_dataset,
    args=sft_training_args,
    optimizers=(optimizers, None),
)

In [None]:
model_trainer.train()

In [None]:
rouge_metric = evaluate.load("rouge")
def compute_metrics(decoded_preds, decoded_actual_labels):
    result = rouge_metric.compute(
        predictions=decoded_preds, references=decoded_actual_labels
    )
    print(f"SFT Model ROUGE values: {result}")

In [None]:
generation_kwargs = {
    "min_length": -1,  # don't ignore the EOS token
    "top_k": 0.0,  # no top-k sampling
    "top_p": 1.0,  # no nucleus sampling
    "do_sample": True,  # yes, we want to sample
    "eos_token_id": sft_tokenizer.eos_token_id,
    "bos_token_id": sft_tokenizer.bos_token_id,
    "pad_token_id": sft_tokenizer.eos_token_id,  # most decoder models don't have a padding token - use EOS token instead
    "max_new_tokens": 32,  # specify how many tokens you want to generate at most
}

In [None]:
def evaluate_sft_model(sft_model, sft_eval_dataset):
    with torch.no_grad():
        sft_model.eval()
        decoded_preds = []
        decoded_actual_labels = []
        for i in tqdm(range(0, len(sft_eval_dataset["input_ids"]), eval_batch_size)):
            cur_data = torch.tensor(
                sft_eval_dataset["input_ids"][i : i + eval_batch_size]
            )
            cur_preds = sft_model.generate(cur_data, **generation_kwargs)
            cur_preds = cur_preds[:, cur_data.shape[1] :]
            for j in range(eval_batch_size):
                generated_text = sft_tokenizer.decode(
                    cur_preds[j], skip_special_tokens=True
                )
                decoded_preds.append(generated_text)
            cur_actual_label_ids = torch.tensor(
                sft_eval_dataset["labels"][i : i + eval_batch_size]
            )
            for j in range(eval_batch_size):
                decoded_actual_labels.append(
                    sft_tokenizer.decode(
                        cur_actual_label_ids[j], skip_special_tokens=True
                    )
                )
        sft_model_eval_result = compute_metrics(
            decoded_preds=decoded_preds, decoded_actual_labels=decoded_actual_labels
        )


evaluate_sft_model(sft_model, sft_eval_dataset)

# Push to hub

In [None]:
sft_model=sft_model.merge_and_unload()
sft_model.push_to_hub(sft_model_repo_name)
sft_tokenizer.push_to_hub(sft_model_repo_name)