Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SFT for D2L + Pre-Training (rename of the previous SFT) #102

Open
wants to merge 16 commits into
base: main
Choose a base branch
from

Conversation

llauraa23
Copy link
Collaborator

Implement SFT and use D2L as a demo case. Rename previous SFT to Pre-training and modify corresponding scripts/notebooks.

llauraa23 and others added 11 commits December 4, 2023 23:34
execute with "python -m example.rlhf.supervised_finetuning_d2l"
Temporarily use all entries in the dataset as training dataset
(i.e., no eval)
…ction, whether to disable evalution configurable
… Use trl DataCollatorForCompletionOnlyLM instead of customized one. Debug: cannot use ConstantLengthDataset or packing when using DataCollatorForCompletionOnly
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

qq: what is this .ipynb file for?

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

qq: what is this .ipynb file for?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is used to generate synthetic immigration data by rephrasing.

@@ -5,6 +5,7 @@

from accelerate import Accelerator
from peft import LoraConfig, TaskType
# TODO: DH: num_train_epochs=20,
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: what is this comment code for?

@@ -0,0 +1,40 @@
from typing import Any, Dict, List, Union
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

qq: do we still need this customized collator per our discussion,

seq_length=args.max_seq_length,
# chars_per_token=chars_per_token,
)
return {"train": train_dataset, "eval": eval_dataset}
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: need a new line. Make sure you setup your linter properly as we discussed.

f" Answer: {example[self._rlhf_config.answer_title]}")
return text

def prepare_d2l_text(self, example):
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: let's rename this method because it can be used for other things.

@CambioML
Copy link
Owner

Also, please add what you have tested for this PR.

from pykoi.chat.db.constants import (QA_CSV_HEADER_ANSWER, QA_CSV_HEADER_ID,
QA_CSV_HEADER_QUESTION,
QA_CSV_HEADER_VOTE_STATUS)
from pykoi.chat.db.qa_database import QuestionAnswerDatabase
from pykoi.rlhf.config import RLHFConfig
from pykoi.telemetry.events import SFTStartEvent, SFTStopEvent
from pykoi.telemetry.telemetry import Telemetry
from trl import DataCollatorForCompletionOnlyLM
# from pykoi.rlhf.customize_data_collator import DataCollatorForCompletionOnlyLM
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: let's remove non-used code.

# resize the token embeddings to include the added special tokens
self.model.resize_token_embeddings(len(self.tokenizer))
data_collator = None
if self._rlhf_config.data_collator == "DataCollatorForCompletionOnlyLM":
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should consider to set data_collator to DataCollatorForCompletionOnlyLM class instead of a string for SFT training argument.

Then, here you should check None. Also, it looks like a bug for me that if people use SFT without passing in data_collator. Therefore, you should set proper default value in the config.py

Copy link
Collaborator Author

@llauraa23 llauraa23 Feb 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that a class is better than a string in the argument file.
I believe if set to None, the default Datacollator will be used when "None" is passed to trl.SFTTrainer. Since default Datacollator also depends on other parameters such as "pack", setting it to None by default makes more sense than a fixed class.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants