Skip to content

Commit

Permalink
Fix DPO prompts (#142)
Browse files Browse the repository at this point in the history
  • Loading branch information
natolambert committed Jun 12, 2024
1 parent 6414f35 commit 5bd6a4a
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 15 deletions.
21 changes: 7 additions & 14 deletions rewardbench/dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,12 @@ def tokenize_row(self, feature) -> Dict:
chosen = feature["text_chosen"] # modified from source
rejected = feature["text_rejected"] # modified from source

# remove prompt tokens from chosen + rejected, otherwise they repeat
# see issue 140 https://github.com/allenai/reward-bench/issues/140
# should impact results only slightly thanks to per token dpo reward math!
chosen = chosen.replace(prompt, "")
rejected = rejected.replace(prompt, "")

if not self.is_encoder_decoder:
# Check issues below for more details
# 1. https://github.com/huggingface/trl/issues/907
Expand Down Expand Up @@ -140,20 +146,7 @@ def tokenize_row(self, feature) -> Dict:
batch[f"{k}{type_key}"] = tokens

else:
chosen_tokens = self.tokenizer(
chosen, truncation=True, max_length=self.max_target_length, add_special_tokens=True
)
rejected_tokens = self.tokenizer(
rejected, truncation=True, max_length=self.max_target_length, add_special_tokens=True
)
prompt_tokens = self.tokenizer(
prompt, truncation=True, max_length=self.max_prompt_length, add_special_tokens=True
)

batch["chosen_labels"] = chosen_tokens["input_ids"]
batch["rejected_labels"] = rejected_tokens["input_ids"]
batch["prompt_input_ids"] = prompt_tokens["input_ids"]
batch["prompt_attention_mask"] = prompt_tokens["attention_mask"]
raise ValueError("Encoder-decoder models are not supported yet.")

return batch

Expand Down
3 changes: 2 additions & 1 deletion scripts/run_dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def get_args():
parser.add_argument(
"--trust_remote_code", action="store_true", default=False, help="directly load model instead of pipeline"
)
parser.add_argument("--debug", type=bool, default=False, help="use only 10 examples")
parser.add_argument("--debug", action="store_true", default=False, help="use only 10 examples")
parser.add_argument(
"--disable_beaker_save", action="store_true", help="disable saving the main results in a file for AI2 Beaker"
)
Expand Down Expand Up @@ -182,6 +182,7 @@ def main():
column_names = list(dataset.features)

tokenized_dataset = dataset.map(dpo.tokenize_row, remove_columns=column_names)

dataloader = torch.utils.data.DataLoader(
tokenized_dataset,
batch_size=BATCH_SIZE,
Expand Down

0 comments on commit 5bd6a4a

Please sign in to comment.