In [2]:
import datasets
from datasets import load_dataset

# we want to use the pytorch dataloader
from torch.utils.data import DataLoader

import torch
from transformers import AutoTokenizer

In [3]:
dataset_name = "Anthropic/hh-rlhf"
dataset = load_dataset(dataset_name)

Found cached dataset json (/home/codespace/.cache/huggingface/datasets/Anthropic___json/Anthropic--hh-rlhf-1cfce6f8a62eee84/0.0.0/e347ab1c932092252e717ff3f949105a4dd28b27e842dd53157d2f72e276c2e4)
100%|██████████| 2/2 [00:04<00:00,  2.31s/it]


In [4]:
model_name = "distilgpt2"

# currently we choose the gpt2 tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name)

tokenizer.pad_token = tokenizer.eos_token

In [5]:
# now we want to transform the dataset pytorch dataloader
# we have to create a custom collate function
max_length = 512

def compute_input_ids(list_sample_tokenized):
    max_length_chosen = max([len(sample["input_ids"]) for sample in list_sample_tokenized])

    # we limit the max_length to 512
    max_length_chosen = min(max_length_chosen, max_length)

    # we take only the last max_length_chosen tokens of the text
    def filter_last_tokens(sample):
        return {"input_ids" : sample["input_ids"][-max_length_chosen:], "attention_mask" : sample["attention_mask"][-max_length_chosen:]}


    list_sample_tokenized = [filter_last_tokens(sample) for sample in list_sample_tokenized]

    tokenized_pad = [tokenizer.pad(sample, max_length=max_length_chosen, padding="max_length") for sample in list_sample_tokenized]

    tokenized_pad_id = [sample["input_ids"] for sample in tokenized_pad]
    tokenized_pad_id = torch.tensor(tokenized_pad_id)

    chosen_attention_masks = [sample["attention_mask"] for sample in tokenized_pad]
    chosen_attention_masks = torch.tensor(chosen_attention_masks)

    return tokenized_pad_id, chosen_attention_masks

def collate_fn(list_of_samples):
    """
    In this function we define how we want to collate (combine) samples from the dataset
    The dataset return a dict with the following keys:
        - "chosen" : the chosen text
        - "rejected" : the rejected text
    """
    # we tokenize the chosen and rejected text for every sample
    chosen_tokenized = [tokenizer(sample["chosen"], padding=False, truncation=False) for sample in list_of_samples]
    rejected_tokenized = [tokenizer(sample["rejected"], padding=False, truncation=False) for sample in list_of_samples]

    # we compute the input_ids and attention_masks for the chosen text
    chosen_input_ids, chosen_attention_masks = compute_input_ids(chosen_tokenized)

    # we compute the input_ids and attention_masks for the rejected text
    rejected_input_ids, rejected_attention_masks = compute_input_ids(rejected_tokenized)

    # we create a new dict with the input_ids
    chosen_input_ids = {
        "input_ids": chosen_input_ids,
        "attention_mask": chosen_attention_masks
    }

    rejected_input_ids = {
        "input_ids": rejected_input_ids,
        "attention_mask": rejected_attention_masks
    }

    # we create a new dict with the tokenized text
    tokenized = {
        "chosen": chosen_input_ids,
        "rejected": rejected_input_ids
    }

    # we return the tokenized text
    return tokenized

# we create the dataloader
dataloader = DataLoader(dataset["train"], batch_size=4, collate_fn=collate_fn)

# we can now iterate over the dataloader
for batch in dataloader:
    print(batch)
    break


You're using a GPT2TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


{'chosen': {'input_ids': tensor([[  198,   198, 20490,    25,  1867,   389,   617,   269,  1046,  2456,
           287, 46932,    30,   198,   198, 48902,    25,  3423,   447,   247,
            82,   281, 17503,  1351,    13,   198,   198,  8021,    11, 19317,
            11,   809, 26679,    11, 18824,    11,  5089,    11,  7510,    11,
         21551,    11,   256,  2799,    11,  7510,  2256,    11,  7510, 21454,
            11,   629, 10599,   388,    11, 40267,    11, 40107,    11,  5089,
           263,    11,  7510,    12, 30041,    11, 10973,    11,   269,  2178,
         38811,    11,  5089,    77,  1018,  1136,    11,   475,   400,  2305,
            11, 40125,    11, 14509,   562,    11,   269,  3320, 12603,    11,
         29836,    11, 43546,    11, 18314,    11, 19311,    11,  6611,    11,
           266,   962,    11,   474,  1042,    11, 10973,    12,    82, 19296,
            11, 22938,   378,    11,   277,  9460,   313,    11, 24506,    11,
           474,  6457,    1

In [6]:
# now we can use the training loop from the transformers library
from torch.utils.data import DataLoader
from transformers import AutoModelForCausalLM, Trainer, TrainingArguments

# we load the model
model = AutoModelForCausalLM.from_pretrained(model_name)

model_ref = AutoModelForCausalLM.from_pretrained(model_name)


from torch.nn import functional as F
import copy

# we need to define a custom Trainer
class DPOTrainer(Trainer):
    """
    Class to train the model with DPO (Direct Preference Optimization)
    """
    beta = 0.5

    def __init__(self, **kwargs):
        self.model_ref = kwargs.pop("model_ref")
        super().__init__(**kwargs)


        self.model_ref = model_ref
        self.model_ref.eval()

        self.device_ref = torch.device("cuda" if torch.cuda.is_available() else "cpu") 
        self.model_ref.to(self.device_ref)



    # we need to define the compute_loss function
    def compute_loss(self, model, inputs, return_outputs=False):
        """
        This function is called by the Trainer during training
        This is where we compute the (DPO) loss
        """
        """
        This function is called by the Trainer during training
        This is where we compute the (DPO) loss
        """
        # we get the chosen and rejected text
        chosen_text = inputs["chosen"]
        rejected_text = inputs["rejected"]

        # labels chosen is just the chosen text shifted by one
        labels_chosen = torch.zeros_like(chosen_text["input_ids"]).long()
        labels_chosen[:, :-1] = chosen_text["input_ids"][:, 1:]
        
        # labels rejected is just the rejected text shifted by one
        labels_rejected = torch.zeros_like(rejected_text["input_ids"]).long()
        labels_rejected[:, :-1] = rejected_text["input_ids"][:, 1:]

        logits_chosen = model(**chosen_text).logits
        logits_rejected = model(**rejected_text).logits

        # try with logits as ref first
        pos_logprob = F.log_softmax(logits_chosen, dim=-1)
        neg_logprob = F.log_softmax(logits_rejected, dim=-1)


        pos_logprob = torch.gather(pos_logprob, 2, labels_chosen.unsqueeze(-1))
        neg_logprob = torch.gather(neg_logprob, 2, labels_rejected.unsqueeze(-1))

        # we need to compute the logprob of the reference examples
        with torch.no_grad():
          pos_logits_ref = self.model_ref(**chosen_text).logits
          neg_logits_ref = self.model_ref(**rejected_text).logits

        pos_logprob_ref = F.log_softmax(pos_logits_ref, dim=-1)
        neg_logprob_ref = F.log_softmax(neg_logits_ref, dim=-1)

        pos_logprob_ref = torch.gather(
            pos_logprob_ref, 2, labels_chosen.unsqueeze(-1)
        )
        neg_logprob_ref = torch.gather(
            neg_logprob_ref, 2, labels_rejected.unsqueeze(-1)
        )

        # TODO COMPUTE MASK FOR PADDING VALUES


        # compute loss and reward
        pi_logratios = pos_logprob.sum(-1) - neg_logprob.sum(-1)
        ref_logratios = pos_logprob_ref.sum(-1) - neg_logprob_ref.sum(-1)
        ref_logratios = 0

        loss = -F.logsigmoid(self.beta * (pi_logratios - ref_logratios.detach()))
        loss = loss.mean()

        return loss
    
    def get_train_dataloader(self) -> DataLoader:
        collate_fn = self.data_collator
        train_dataset = self.train_dataset

        return DataLoader(
            train_dataset,
            batch_size=self.args.train_batch_size,
            collate_fn=collate_fn,
            drop_last=self.args.dataloader_drop_last,
            num_workers=self.args.dataloader_num_workers,
            pin_memory=self.args.dataloader_pin_memory,
            shuffle=True,
        )

from transformers import TrainingArguments

# we define the training arguments
training_args = TrainingArguments(
    output_dir="./hh-rlhf",
    overwrite_output_dir=True,
    num_train_epochs=1,
    per_device_train_batch_size=4,
    save_steps=10_000,
    save_total_limit=2,
    prediction_loss_only=True,
    logging_dir="./hh-rlhf/logs_4",
    dataloader_num_workers=4,
    run_name="hh-rlhf_3",
    logging_steps=100,
    #bf16=True,
)

#training_args.set_logging(strategy="steps", steps=100, report_to="tensorboard")

dpo_trainer = DPOTrainer(
    model=model,
    args=training_args,
    train_dataset=dataset["train"],
    data_collator=collate_fn,
    model_ref=model_ref.eval(),
)

# we can now train the model
dpo_trainer.train()



huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Av

Could not estimate the number of tokens of the input, floating-point operations will not be computed


Step,Training Loss


: 

: 