In [1]:
import datasets
from datasets import load_dataset

# we want to use the pytorch dataloader
from torch.utils.data import DataLoader

import torch
from transformers import AutoTokenizer

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
dataset_name = "Anthropic/hh-rlhf"
dataset = load_dataset(dataset_name)

Found cached dataset json (/home/codespace/.cache/huggingface/datasets/Anthropic___json/Anthropic--hh-rlhf-1cfce6f8a62eee84/0.0.0/e347ab1c932092252e717ff3f949105a4dd28b27e842dd53157d2f72e276c2e4)
100%|██████████| 2/2 [00:04<00:00,  2.44s/it]


In [3]:
model_name = "distilgpt2"

# currently we choose the gpt2 tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name)

tokenizer.pad_token = tokenizer.eos_token

In [4]:
# now we want to transform the dataset pytorch dataloader
# we have to create a custom collate function
max_length = 512

def compute_input_ids(list_sample_tokenized):
    max_length_chosen = max([len(sample["input_ids"]) for sample in list_sample_tokenized])
    tokenized_pad = [tokenizer.pad(sample, max_length=max_length_chosen, padding="max_length") for sample in list_sample_tokenized]

    tokenized_pad_id = [sample["input_ids"] for sample in tokenized_pad]
    tokenized_pad_id = torch.tensor(tokenized_pad_id)

    chosen_attention_masks = [sample["attention_mask"] for sample in tokenized_pad]
    chosen_attention_masks = torch.tensor(chosen_attention_masks)

    return tokenized_pad_id, chosen_attention_masks

def collate_fn(list_of_samples):
    """
    In this function we define how we want to collate (combine) samples from the dataset
    The dataset return a dict with the following keys:
        - "chosen" : the chosen text
        - "rejected" : the rejected text
    """
    # we tokenize the chosen and rejected text for every sample
    chosen_tokenized = [tokenizer(sample["chosen"], padding=False, truncation=False) for sample in list_of_samples]
    rejected_tokenized = [tokenizer(sample["rejected"], padding=False, truncation=False) for sample in list_of_samples]

    # we compute the input_ids and attention_masks for the chosen text
    chosen_input_ids, chosen_attention_masks = compute_input_ids(chosen_tokenized)

    # we compute the input_ids and attention_masks for the rejected text
    rejected_input_ids, rejected_attention_masks = compute_input_ids(rejected_tokenized)

    # we create a new dict with the input_ids
    chosen_input_ids = {
        "input_ids": chosen_input_ids,
        "attention_mask": chosen_attention_masks
    }

    rejected_input_ids = {
        "input_ids": rejected_input_ids,
        "attention_mask": rejected_attention_masks
    }

    # we create a new dict with the tokenized text
    tokenized = {
        "chosen": chosen_input_ids,
        "rejected": rejected_input_ids
    }

    # we return the tokenized text
    return tokenized

# we create the dataloader
dataloader = DataLoader(dataset["train"], batch_size=4, collate_fn=collate_fn)

# we can now iterate over the dataloader
for batch in dataloader:
    print(batch)
    break


You're using a GPT2TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


{'chosen': {'input_ids': tensor([[  198,   198, 20490,    25,  1867,   389,   617,   269,  1046,  2456,
           287, 46932,    30,   198,   198, 48902,    25,  3423,   447,   247,
            82,   281, 17503,  1351,    13,   198,   198,  8021,    11, 19317,
            11,   809, 26679,    11, 18824,    11,  5089,    11,  7510,    11,
         21551,    11,   256,  2799,    11,  7510,  2256,    11,  7510, 21454,
            11,   629, 10599,   388,    11, 40267,    11, 40107,    11,  5089,
           263,    11,  7510,    12, 30041,    11, 10973,    11,   269,  2178,
         38811,    11,  5089,    77,  1018,  1136,    11,   475,   400,  2305,
            11, 40125,    11, 14509,   562,    11,   269,  3320, 12603,    11,
         29836,    11, 43546,    11, 18314,    11, 19311,    11,  6611,    11,
           266,   962,    11,   474,  1042,    11, 10973,    12,    82, 19296,
            11, 22938,   378,    11,   277,  9460,   313,    11, 24506,    11,
           474,  6457,    1

In [5]:
# now we can use the training loop from the transformers library
from transformers import AutoModelForCausalLM, Trainer, TrainingArguments

# we load the model
model = AutoModelForCausalLM.from_pretrained(model_name)

# we define the training arguments
training_args = TrainingArguments(
    output_dir="./hh-rlhf",
    overwrite_output_dir=True,
    num_train_epochs=1,
    per_device_train_batch_size=2,
    save_steps=10_000,
    save_total_limit=2,
    prediction_loss_only=True,
    logging_dir="./hh-rlhf/logs",
    dataloader_num_workers=4,
    run_name="hh-rlhf"
)

from torch.nn import functional as F

# we need to define a custom Trainer
class DPOTrainer(Trainer):
    """
    Class to train the model with DPO (Direct Preference Optimization)
    """
    # we need to define the compute_loss function
    def compute_loss(self, model, inputs, return_outputs=False):
        """
        This function is called by the Trainer during training
        This is where we compute the (DPO) loss
        """
        # we get the chosen and rejected text
        chosen_text = inputs["chosen"]
        rejected_text = inputs["rejected"]

        print("before model")

        print(chosen_text['input_ids'].shape)
        print(rejected_text['input_ids'].shape)

        logits_chosen = model(**chosen_text).logits
        logits_rejected = model(**rejected_text).logits

        print("after model")

        # try with logits as ref first

        pos_logprob = F.log_softmax(logits_chosen, dim=-1)
        neg_logprob = F.log_softmax(logits_rejected, dim=-1)

        pos_logprob = torch.gather(pos_logprob, 2, chosen_text.unsqueeze(-1))
        neg_logprob = torch.gather(neg_logprob, 2, rejected_text.unsqueeze(-1))

        # we need to compute the logprob of the reference examples
        pos_logits_ref = logits_chosen.detach()
        neg_logits_ref = logits_rejected.detach()

        pos_logprob_ref = F.log_softmax(pos_logits_ref, dim=-1)
        neg_logprob_ref = F.log_softmax(neg_logits_ref, dim=-1)

        pos_logprob_ref = torch.gather(
            pos_logprob_ref, 2, chosen_text.unsqueeze(-1)
        )
        neg_logprob_ref = torch.gather(
            neg_logprob_ref, 2, rejected_text.unsqueeze(-1)
        )

        # compute loss and reward
        pi_logratios = pos_logprob.mean() - neg_logprob.mean()
        ref_logratios = pos_logprob_ref.mean() - neg_logprob_ref.mean()

        print("yolo")
        print(pi_logratios.shape)
        print(ref_logratios.shape)

        loss = -F.logsigmoid(self.beta * (pi_logratios - ref_logratios))

        return loss

dpo_trainer = DPOTrainer(
    model=model,
    args=training_args,
    train_dataset=dataset["train"],
    tokenizer=tokenizer,
)

# try the compute loss function on batch
for batch in dataloader:
    with torch.no_grad():
        loss = dpo_trainer.compute_loss(model, batch)
    print(loss)
    break


before model
torch.Size([4, 202])
torch.Size([4, 196])
