Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement KTO into OpenRLHF #201

Merged
merged 3 commits into from
Jan 27, 2024
Merged

Implement KTO into OpenRLHF #201

merged 3 commits into from
Jan 27, 2024

Conversation

Dylancer1998
Copy link
Contributor

Referenced the implementation of HALOs, the KTO algorithm has been integrated into this branch. It supports both balanced (referred to as the vanilla version) and unbalanced (referred to as the non-vanilla version) scenarios for handling positive and negative samples in a batch. The vanilla version ensures that the number of positive and negative samples is consistent within each batch, while the non-vanilla version does not require this consistency.

A lightweight dataset was selected for algorithm validation, where the effects of DPO, vanilla KTO, non-vanilla KTO, and the baseline were compared. The dataset and the results are as follows:

  • dataset
--dataset Anthropic/hh-rlhf,tasksource/oasst1_pairwise_rlhf_reward,openai/webgpt_comparisons
--dataset_probs 0.72,0.14,0.14
  • performance
model Writing Roleplay Reasoning Math Coding Extraction STEM Humanities Average
baseline 7.125 7.425 4.05 2.6 2.85 4.475 7.475 8.475 5.559
DPO 7.4 7.39 3.9 3.05 2.475 4.875 7.2 9.075 5.670
KTO_with_vanilla_loss 7.225 7.325 4.025 2.3 3.475 5.525 7.184 9.075 5.715
KTO 7.145 7.273 4.112 2.666 2.790 5.212 8.315 8.479 5.799

MTBench

* baseline model is "OpenLLMAI/Llama-2-7b-sft-model-ocra-500k"

@hijkzzz
Copy link
Collaborator

hijkzzz commented Jan 27, 2024

Thank you for your contribution and we will review it as soon as possible

labels = np.array(self.dataset.labels)
unique_labels = np.unique(labels)
self.label_to_indices = {label: np.where(labels == label)[0] for label in unique_labels}
for label in self.label_to_indices:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like we should add a condition here

if self.shuffle:
    xxxxx

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good suggestion, I'll include it in the next PR submission.

@hijkzzz hijkzzz mentioned this pull request Jan 27, 2024
@hijkzzz hijkzzz merged commit a581794 into OpenLLMAI:main Jan 27, 2024
1 check passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants