Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support sft training on d2l #100

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4,698 changes: 4,698 additions & 0 deletions example/autorate/auto-rater.ipynb
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you please help describe why this is called auto-rater.ipynb, it looks like this file is used for generating QA dataset.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This script for auto evaluation is incomplete. It also contains redundant code from previous QA generation code. Since the latest version of pykoi/uniflow already have auto rater, shall we remove my related commits?

Large diffs are not rendered by default.

Binary file added example/autorate/data/Chapter 5 Rome.docx
Binary file not shown.
497 changes: 497 additions & 0 deletions example/autorate/data/rome.txt

Large diffs are not rendered by default.

45 changes: 45 additions & 0 deletions example/rlhf/supervised_finetuning_d2l.py
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's rename this file to supervised_finetuning_demo_d2l.py to indicate this is a demo file.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

okay

Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
"""Demo for the supervised fine tuning.

python -m example.rlhf.supervised_finetuning_demo
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: it should be python -m example.rlhf.supervised_finetuning_demo_d2l.py.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

corrected

"""

from pykoi.chat import QuestionAnswerDatabase
from pykoi.chat.db.constants import (QA_CSV_HEADER_ANSWER, QA_CSV_HEADER_ID,
QA_CSV_HEADER_QUESTION,
QA_CSV_HEADER_VOTE_STATUS)
from pykoi.rlhf import RLHFConfig, SupervisedFinetuning

# get data from local database
qa_database = QuestionAnswerDatabase()
my_data_pd = qa_database.retrieve_all_question_answers_as_pandas()
my_data_pd = my_data_pd[
[
QA_CSV_HEADER_ID,
QA_CSV_HEADER_QUESTION,
QA_CSV_HEADER_ANSWER,
QA_CSV_HEADER_VOTE_STATUS,
]
]

# analyze the data
print(my_data_pd)
print("My local database has {} samples in total".format(my_data_pd.shape[0]))

# run supervised finetuning
from peft import LoraConfig
config = RLHFConfig(base_model_path="mistralai/Mistral-7B-Instruct-v0.1",
dataset_type="local_csv", dataset_name="data/chapter22_trnvalfromseed_data_processed.csv",
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

qq: I do not see this file in the data folder? Is it missing?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure if I should add these data files.

train_test_split_ratio=0.1,
max_seq_length=896,
per_device_eval_batch_size = 1,
lora_config_rl = LoraConfig(
r=512,
lora_alpha=1024,
lora_dropout=0.05,
target_modules=["q_proj","k_proj","v_proj","o_proj",], # "gate_proj","up_proj","down_proj",], #"lm_head",],
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

qq: what is the target_modules parameter used for?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It specifies the modules (e.g., which components in the attention layers) to be updated when we using peft training.

bias="none",
task_type="CAUSAL_LM"
),
)
rlhf_step1_sft = SupervisedFinetuning(config)
rlhf_step1_sft.train_and_save("./models/rlhf_step1_sft")
1 change: 1 addition & 0 deletions pykoi/rlhf/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from accelerate import Accelerator
from peft import LoraConfig, TaskType
# TODO: DH: num_train_epochs=20,


@dataclass
Expand Down
35 changes: 35 additions & 0 deletions pykoi/rlhf/customize_data_collator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from typing import Any, Dict, List, Tuple, Union
from transformers import DataCollatorForLanguageModeling
import numpy as np
class DataCollatorForCompletionOnlyLM(DataCollatorForLanguageModeling):
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

qq: in this example, https://huggingface.co/docs/trl/sft_trainer#advanced-usage, it looks like it directly imports from trl import SFTTrainer, DataCollatorForCompletionOnlyLM.

it looks like DataCollatorForCompletionOnlyLM is doing what we want to mask out question from the training objective and trl already have an implementation. I am curious regarding why we wrote our version of DataCollatorForCompletionOnlyLM here.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is Rachel and Yunfan's customized implementation

def torch_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
batch = super().torch_call(examples)

# The prompt ends with the response key plus a newline. We encode this and then try to find it in the
# sequence of tokens. This should just be a single token.
RESPONSE_KEY = "### Response:"
RESPONSE_KEY_NL = f"{RESPONSE_KEY}\n"
response_token_ids = self.tokenizer.encode(RESPONSE_KEY_NL)

labels = batch["labels"].clone()

for i in range(len(examples)):

response_token_ids_start_idx = None
for idx in np.where(batch["labels"][i] == response_token_ids[0])[0]:
response_token_ids_start_idx = idx
break

if response_token_ids_start_idx is None:
raise RuntimeError(
f'Could not find response key {response_token_ids} in token IDs {batch["labels"][i]}'
)

response_token_ids_end_idx = response_token_ids_start_idx + 1

# Make pytorch loss function ignore all tokens up through the end of the response key
labels[i, :response_token_ids_end_idx] = -100

batch["labels"] = labels

return batch
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: add a new line at the end of the file. If you have not done so, please setup your dev environment following https://www.notion.so/goldpiggy/Python-Linter-and-formatter-Setup-30fb3b81f0904af889832e4c697c5ec9?pvs=4

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! I resolved it and ran pylint on other files as well.

87 changes: 83 additions & 4 deletions pykoi/rlhf/supervised_finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from pykoi.rlhf.config import RLHFConfig
from pykoi.telemetry.events import SFTStartEvent, SFTStopEvent
from pykoi.telemetry.telemetry import Telemetry

from pykoi.rlhf.customize_data_collator import DataCollatorForCompletionOnlyLM

class SupervisedFinetuning:
"""
Expand Down Expand Up @@ -48,6 +48,13 @@ def __init__(self, rlhf_config: RLHFConfig, enable_telemetry: bool = True) -> No
self._telemetry = Telemetry(enable_telemetry)
self._rlhf_config = rlhf_config
self.tokenizer = AutoTokenizer.from_pretrained(rlhf_config.base_model_path)
# dh: add special tokens to tokenizer
self.tokenizer.pad_token = self.tokenizer.eos_token
END_KEY = "### End"
INSTRUCTION_KEY = "### Instruction:"
RESPONSE_KEY = "### Response:"
RESPONSE_KEY_NL = f"{RESPONSE_KEY}\n"
self.tokenizer.add_special_tokens({"additional_special_tokens": [END_KEY, INSTRUCTION_KEY, RESPONSE_KEY_NL]})
self.num_proc = (
self._rlhf_config.num_workers if not self._rlhf_config.streaming else None
)
Expand Down Expand Up @@ -83,13 +90,23 @@ def __init__(self, rlhf_config: RLHFConfig, enable_telemetry: bool = True) -> No
load_in_8bit=self._rlhf_config.load_in_8bit,
device_map=self._rlhf_config.device_map,
)
# resize the token embeddings to include the added special tokens
self.model.resize_token_embeddings(len(self.tokenizer))

# dh: try the customized data collator that only predicts the answer part
data_collator = DataCollatorForCompletionOnlyLM(
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

qq: shall we make this configurable to avoid breaking running the code in the old way.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

tokenizer=self.tokenizer, mlm=False, return_tensors="pt", pad_to_multiple_of=8
)

self.trainer = SFTTrainer(
model=self.model,
args=self.training_args,
train_dataset=self.dataset["train"],
eval_dataset=self.dataset["eval"],
peft_config=self._rlhf_config.lora_config_rl,
peft_config=self._rlhf_config.lora_config_rl, ## TODO: DH: LoraConfig MAY BE IGNORED IF USING FROM_PRETRAINED
packing=True,
data_collator=data_collator,
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

qq: could you please help explain in the PR description why we added this data_collator while we do not need this before.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I got this that this is for training the instruction following objective by masking out the query instead of the casual language model objective for only the next token.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

okay

dataset_text_field="text",
)

def train(self):
Expand All @@ -103,6 +120,8 @@ def load_lora(
base_model_path: Optional[str] = None,
lora_model_path: Optional[str] = None,
):
#import pdb; pdb.set_trace()
# dh: not used
if base_model_path is None:
base_model_path = self._rlhf_config.base_model_path

Expand Down Expand Up @@ -163,6 +182,65 @@ def prepare_sample_text(self, example):
f" Answer: {example[self._rlhf_config.answer_title]}"
)
return text


def prepare_d2l_text(self, example):
"""Prepare the text from a sample of the d2l dataset ."""
INTRO_BLURB = (
"Below is an instruction that describes a task. Write a response that appropriately completes the request."
)
Comment on lines +189 to +191
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

qq that does INTRO_BLURB needed for SFT.

Unless user always put this as a part of their system prompt, I am wondering if user forget to include this as a part of their system prompt. It might hurt the inference performance.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is Yunfei's prompt and I kept it in order to reproduce his result in pykoi. I agree with the case you mentioned.

INSTRUCTION_KEY = "### Instruction:"
INPUT_KEY = "Input:"
RESPONSE_KEY = "### Response:"
END_KEY = "### End"
RESPONSE_KEY_NL = f"{RESPONSE_KEY}\n"
DEFAULT_SEED = 42

# This is a training prompt that does not contain an input string. The instruction by itself has enough information
# to respond. For example, the instruction might ask for the year a historic figure was born.
PROMPT_NO_INPUT_FORMAT = """{intro}
{instruction_key}
{instruction}
{response_key}
{response}
{end_key}""".format(
intro=INTRO_BLURB,
instruction_key=INSTRUCTION_KEY,
instruction="{instruction}",
response_key=RESPONSE_KEY,
response="{response}",
end_key=END_KEY,
)

# This is a training prompt that contains an input string that serves as context for the instruction. For example,
# the input might be a passage from Wikipedia and the intruction is to extract some information from it.
PROMPT_WITH_INPUT_FORMAT = """{intro}
{instruction_key}
{instruction}
{input_key}
{input}
{response_key}
{response}
{end_key}""".format(
intro=INTRO_BLURB,
instruction_key=INSTRUCTION_KEY,
instruction="{instruction}",
input_key=INPUT_KEY,
input="{input}",
response_key=RESPONSE_KEY,
response="{response}",
end_key=END_KEY,
)

context = example.get("context")
if context:
text = PROMPT_WITH_INPUT_FORMAT.format(instruction=example["instruction"], response=example["response"], input=context)
else:
text = PROMPT_NO_INPUT_FORMAT.format(instruction=example["instruction"], response=example["instruction"])



return text

def create_datasets(self, tokenizer, args):
if args.dataset_type == "local_db":
Expand All @@ -181,6 +259,7 @@ def create_datasets(self, tokenizer, args):
elif args.dataset_type == "local_csv":
dataset = load_dataset("csv", data_files=args.dataset_name)
dataset = dataset[args.split] # Convert DatasetDict to Dataset
dataset2 = load_dataset("csv", data_files=args.dataset_name, split='train[:10%]')
elif args.dataset_type == "huggingface":
dataset = load_dataset(
args.dataset_name,
Expand Down Expand Up @@ -208,15 +287,15 @@ def create_datasets(self, tokenizer, args):
train_dataset = ConstantLengthDataset(
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One caveat here is that ConstantLengthDataset always prepare your dataset into seq_length by breaking one coherent c + q + a into multiple data point if len(c + q + a) > seq_length

While DataCollatorForCompletionOnlyLM implementation for SFTTrainer is to train mask query and train an object for response only.

However, I am a bit confused that your dataset is not prepared to train on response only (mask out query) but still casual langauge model object for next token.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does the collator mask out the query?

tokenizer,
dataset["train"],
formatting_func=self.prepare_sample_text,
formatting_func=self.prepare_d2l_text,
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

qq: same as my comments above, we should make this configurable to maintain the old functionality.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

infinite=True,
seq_length=args.max_seq_length,
# chars_per_token=chars_per_token,
)
eval_dataset = ConstantLengthDataset(
tokenizer,
dataset["test"],
formatting_func=self.prepare_sample_text,
formatting_func=self.prepare_d2l_text,
infinite=False,
seq_length=args.max_seq_length,
# chars_per_token=chars_per_token,
Expand Down