In [1]:
import datasets
from datasets import load_dataset

# we want to use the pytorch dataloader
from torch.utils.data import DataLoader

import torch
from transformers import AutoTokenizer

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
dataset_name = "Anthropic/hh-rlhf"
dataset = load_dataset(dataset_name)

Found cached dataset json (/home/codespace/.cache/huggingface/datasets/Anthropic___json/Anthropic--hh-rlhf-1cfce6f8a62eee84/0.0.0/e347ab1c932092252e717ff3f949105a4dd28b27e842dd53157d2f72e276c2e4)
100%|██████████| 2/2 [00:04<00:00,  2.23s/it]


In [3]:
model_name = "distilgpt2"

# currently we choose the gpt2 tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name)

tokenizer.pad_token = tokenizer.eos_token

In [7]:
# now we want to transform the dataset pytorch dataloader
# we have to create a custom collate function
max_length = 512

def compute_input_ids(list_sample_tokenized):
    max_length_chosen = max([len(sample["input_ids"]) for sample in list_sample_tokenized])

    # we limit the max_length to 512
    max_length_chosen = min(max_length_chosen, max_length)

    # we take only the last max_length_chosen tokens of the text
    def filter_last_tokens(sample):
        return {"input_ids" : sample["input_ids"][-max_length_chosen:], "attention_mask" : sample["attention_mask"][-max_length_chosen:]}
    
    print(list_sample_tokenized[0])

    list_sample_tokenized = [filter_last_tokens(sample) for sample in list_sample_tokenized]

    tokenized_pad = [tokenizer.pad(sample, max_length=max_length_chosen, padding="max_length") for sample in list_sample_tokenized]

    tokenized_pad_id = [sample["input_ids"] for sample in tokenized_pad]
    tokenized_pad_id = torch.tensor(tokenized_pad_id)

    chosen_attention_masks = [sample["attention_mask"] for sample in tokenized_pad]
    chosen_attention_masks = torch.tensor(chosen_attention_masks)

    return tokenized_pad_id, chosen_attention_masks

def collate_fn(list_of_samples):
    """
    In this function we define how we want to collate (combine) samples from the dataset
    The dataset return a dict with the following keys:
        - "chosen" : the chosen text
        - "rejected" : the rejected text
    """
    # we tokenize the chosen and rejected text for every sample
    chosen_tokenized = [tokenizer(sample["chosen"], padding=False, truncation=False) for sample in list_of_samples]
    rejected_tokenized = [tokenizer(sample["rejected"], padding=False, truncation=False) for sample in list_of_samples]

    # we compute the input_ids and attention_masks for the chosen text
    chosen_input_ids, chosen_attention_masks = compute_input_ids(chosen_tokenized)

    # we compute the input_ids and attention_masks for the rejected text
    rejected_input_ids, rejected_attention_masks = compute_input_ids(rejected_tokenized)

    # we create a new dict with the input_ids
    chosen_input_ids = {
        "input_ids": chosen_input_ids,
        "attention_mask": chosen_attention_masks
    }

    rejected_input_ids = {
        "input_ids": rejected_input_ids,
        "attention_mask": rejected_attention_masks
    }

    # we create a new dict with the tokenized text
    tokenized = {
        "chosen": chosen_input_ids,
        "rejected": rejected_input_ids
    }

    # we return the tokenized text
    return tokenized

# we create the dataloader
dataloader = DataLoader(dataset["train"], batch_size=4, collate_fn=collate_fn)

# we can now iterate over the dataloader
for batch in dataloader:
    print(batch)
    break


{'input_ids': [198, 198, 20490, 25, 1867, 389, 617, 269, 1046, 2456, 287, 46932, 30, 198, 198, 48902, 25, 3423, 447, 247, 82, 281, 17503, 1351, 13, 198, 198, 8021, 11, 19317, 11, 809, 26679, 11, 18824, 11, 5089, 11, 7510, 11, 21551, 11, 256, 2799, 11, 7510, 2256, 11, 7510, 21454, 11, 629, 10599, 388, 11, 40267, 11, 40107, 11, 5089, 263, 11, 7510, 12, 30041, 11, 10973, 11, 269, 2178, 38811, 11, 5089, 77, 1018, 1136, 11, 475, 400, 2305, 11, 40125, 11, 14509, 562, 11, 269, 3320, 12603, 11, 29836, 11, 43546, 11, 18314, 11, 19311, 11, 6611, 11, 266, 962, 11, 474, 1042, 11, 10973, 12, 82, 19296, 11, 22938, 378, 11, 277, 9460, 313, 11, 24506, 11, 474, 6457, 11, 474, 6457, 12, 75, 7958, 11, 37833, 11, 33526, 11, 1125, 729, 11, 329, 6988, 1352, 11, 781, 2238, 7357, 11, 9583, 1891, 11, 10816, 11, 16949, 11, 32581, 296, 578, 11, 3095, 1136, 11, 285, 1689, 447, 247, 82, 2933, 11, 277, 9460, 313, 11, 583, 1851, 11, 24506, 11, 629, 2178, 363, 11, 21551, 11, 198, 198, 20490, 25, 1867, 338, 534, 4004,

In [8]:
# now we can use the training loop from the transformers library
from torch.utils.data import DataLoader
from transformers import AutoModelForCausalLM, Trainer, TrainingArguments

# we load the model
model = AutoModelForCausalLM.from_pretrained(model_name)

# we define the training arguments
training_args = TrainingArguments(
    output_dir="./hh-rlhf",
    overwrite_output_dir=True,
    num_train_epochs=1,
    per_device_train_batch_size=2,
    save_steps=10_000,
    save_total_limit=2,
    prediction_loss_only=True,
    logging_dir="./hh-rlhf/logs",
    dataloader_num_workers=4,
    run_name="hh-rlhf"
)

from torch.nn import functional as F

# we need to define a custom Trainer
class DPOTrainer(Trainer):
    """
    Class to train the model with DPO (Direct Preference Optimization)
    """
    beta = 0.001
    # we need to define the compute_loss function
    def compute_loss(self, model, inputs, return_outputs=False):
        """
        This function is called by the Trainer during training
        This is where we compute the (DPO) loss
        """
        # we get the chosen and rejected text
        chosen_text = inputs["chosen"]
        rejected_text = inputs["rejected"]

        print(chosen_text)

        logits_chosen = model(**chosen_text).logits
        logits_rejected = model(**rejected_text).logits

        # try with logits as ref first
        pos_logprob = F.log_softmax(logits_chosen, dim=-1)
        neg_logprob = F.log_softmax(logits_rejected, dim=-1)


        pos_logprob = torch.gather(pos_logprob, 2, chosen_text["input_ids"].unsqueeze(-1))
        neg_logprob = torch.gather(neg_logprob, 2, rejected_text["input_ids"].unsqueeze(-1))

        # we need to compute the logprob of the reference examples
        pos_logits_ref = logits_chosen.detach()
        neg_logits_ref = logits_rejected.detach()

        pos_logprob_ref = F.log_softmax(pos_logits_ref, dim=-1)
        neg_logprob_ref = F.log_softmax(neg_logits_ref, dim=-1)

        pos_logprob_ref = torch.gather(
            pos_logprob_ref, 2, chosen_text["input_ids"].unsqueeze(-1)
        )
        neg_logprob_ref = torch.gather(
            neg_logprob_ref, 2, rejected_text["input_ids"].unsqueeze(-1)
        )

        # compute loss and reward
        pi_logratios = pos_logprob.mean() - neg_logprob.mean()
        ref_logratios = pos_logprob_ref.mean() - neg_logprob_ref.mean()

        loss = -F.logsigmoid(self.beta * (pi_logratios - ref_logratios))

        return loss
    
    def get_train_dataloader(self) -> DataLoader:
        collate_fn = self.data_collator
        train_dataset = self.train_dataset

        return DataLoader(
            train_dataset,
            batch_size=self.args.train_batch_size,
            collate_fn=collate_fn,
            drop_last=self.args.dataloader_drop_last,
            num_workers=self.args.dataloader_num_workers,
            pin_memory=self.args.dataloader_pin_memory,
            shuffle=True,
        )

dpo_trainer = DPOTrainer(
    model=model,
    args=training_args,
    train_dataset=dataset["train"],
    data_collator=collate_fn
)

# we can now train the model
dpo_trainer.train()




huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
{'input_ids': [198, 198, 20490, 25, 314, 423, 257, 10211, 287, 9592, 13, 220, 1867, 318, 262, 1266, 12840, 284, 779, 284, 4929, 340, 30, 198, 198, 4890

Could not estimate the number of tokens of the input, floating-point operations will not be computed


Step,Training Loss


{'input_ids': [198, 198, 20490, 25, 7731, 2035, 286, 345, 651, 10785, 938, 1755, 11, 968, 6280, 338, 28001, 30, 198, 198, 48902, 25, 314, 447, 247, 76, 7787, 326, 1692, 3842, 1422, 447, 247, 83, 8749, 2035, 286, 514, 284, 15000, 1997, 26016, 13, 198, 198, 20490, 25, 14690, 347, 11, 644, 750, 345, 2342, 30, 198, 198, 48902, 25, 18689, 11, 1692, 25, 314, 373, 8179, 1804, 1223, 11, 290, 314, 635, 2492, 470, 4964, 5581, 13, 220, 5524, 25, 1406, 644, 466, 345, 588, 284, 466, 319, 17122, 30, 198, 198, 48902, 25, 220, 314, 2883, 17499, 17122, 351, 616, 1641, 13, 220, 775, 477, 651, 1978, 11, 423, 257, 1256, 286, 1257, 11, 290, 3360, 198, 198, 20490, 25, 11161, 345, 1775, 340, 878, 30, 198, 198, 48902, 25, 3363, 11, 314, 7342, 340, 2961, 13, 220, 314, 1053, 1239, 7342, 340, 757, 706, 326, 13, 220, 554, 1109, 11, 611, 314, 10014, 9380, 11, 262, 3807, 373, 2089, 13, 198, 198, 20490, 25, 314, 12546, 11, 340, 373, 262, 1266, 286, 1115, 13, 632, 4966, 1165, 890, 996, 13, 198, 198, 48902, 25, 1867, 

: 

: 