-
Notifications
You must be signed in to change notification settings - Fork 483
Add chat template check for sft #3350
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -172,6 +172,56 @@ def is_conversational(features, data_columns): | |
| return False | ||
|
|
||
|
|
||
| def verify_chat_template_generation_prompt_logic(tokenizer_model): | ||
| """Verifies the tokenizer's chat template for correct SFT loss masking. | ||
|
|
||
| This function ensures that the tokens added by `add_generation_prompt=True` | ||
| are identical to the tokens that begin an assistant's turn in a complete | ||
| conversation, which is critical for masking prompt tokens during SFT loss | ||
| calculation. | ||
|
|
||
| Example of a mismatch: | ||
| A `ValueError` is raised if the generation prompt and the actual | ||
| assistant prefix do not match. For example: | ||
|
|
||
| - `add_generation_prompt=True` on a user message produces a prompt ending in: | ||
| `...<|im_start|>generation\n` | ||
| - A full turn with an assistant message starts the reply with: | ||
| `...<|im_start|>assistant\n...` | ||
|
|
||
| This function would fail because the tokens for "generation" do not | ||
| match the tokens for "assistant". | ||
|
|
||
| Args: | ||
| tokenizer_model: The Hugging Face tokenizer instance to verify. | ||
|
|
||
| Raises: | ||
| ValueError: If the `add_generation_prompt` tokens do not exactly | ||
| match the beginning of an assistant message in the template. | ||
| """ | ||
| dummy_msgs = [{"role": "user", "content": "Test message"}] | ||
|
|
||
| prompt_wo_gen = tokenizer_model.apply_chat_template(dummy_msgs, add_generation_prompt=False, tokenize=True) | ||
| prompt_with_gen = tokenizer_model.apply_chat_template(dummy_msgs, add_generation_prompt=True, tokenize=True) | ||
| # Extract the tokenized generation prompt (the expected assistant prefix) | ||
| assistant_prefix_tokens = prompt_with_gen[len(prompt_wo_gen) :] | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add a check before this: |
||
| full_turn_tokens = tokenizer_model.apply_chat_template( | ||
| dummy_msgs + [{"role": "assistant", "content": "Dummy response"}], add_generation_prompt=False, tokenize=True | ||
| ) | ||
| # Extract the actual tokens that appear right after the user message in the full turn | ||
| actual_prefix_in_full_turn = full_turn_tokens[len(prompt_wo_gen) : len(prompt_wo_gen) + len(assistant_prefix_tokens)] | ||
|
|
||
| if actual_prefix_in_full_turn != assistant_prefix_tokens: | ||
| expected_str = tokenizer_model.decode(assistant_prefix_tokens) | ||
| actual_str = tokenizer_model.decode(actual_prefix_in_full_turn) | ||
| raise ValueError( | ||
| "Chat template generation prompt mismatch!\n" | ||
| f"Expected assistant prefix tokens: {assistant_prefix_tokens} ('{expected_str}')\n" | ||
| f"Actual prefix tokens found: {actual_prefix_in_full_turn} ('{actual_str}')\n" | ||
| "This means the tokenizer's chat template will break the sft masking logic." | ||
| ) | ||
|
|
||
|
|
||
| def _get_completion_in_chat_template(tokenizer_model, round_msgs): | ||
| """ | ||
| Calculates the completion part of a conversation turn when formatted with a chat template. | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -19,17 +19,19 @@ | |
| import pytest | ||
| import numpy as np | ||
| import jax | ||
| import re | ||
| from jax.sharding import Mesh | ||
| from jax.experimental import mesh_utils | ||
| from datasets import Dataset | ||
| import transformers | ||
| from parameterized import parameterized_class | ||
|
|
||
| from unittest.mock import patch | ||
| from maxtext.configs import pyconfig | ||
| from maxtext.utils.globals import MAXTEXT_PKG_DIR, MAXTEXT_CONFIGS_DIR, MAXTEXT_ASSETS_ROOT | ||
| from maxtext.input_pipeline import hf_data_processing | ||
| from maxtext.input_pipeline import input_pipeline_interface | ||
| from maxtext.input_pipeline.hf_data_processing import _get_pad_id | ||
| from maxtext.input_pipeline.input_pipeline_utils import verify_chat_template_generation_prompt_logic | ||
|
|
||
| PROMPT_DATA = [ | ||
| [ | ||
|
|
@@ -480,5 +482,51 @@ def test_system_message_not_at_beginning(self): | |
| self.get_data_iterator(dataset, ["messages"]) | ||
|
|
||
|
|
||
| @pytest.mark.external_training | ||
| class SFTChatTemplateLogicTest(unittest.TestCase): | ||
| LLAMA_TOKENIZER_PATH = os.path.join(MAXTEXT_ASSETS_ROOT, "llama2-chat-tokenizer") | ||
|
|
||
| @classmethod | ||
| def setUpClass(cls): | ||
| super().setUpClass() | ||
| if not os.path.exists(cls.LLAMA_TOKENIZER_PATH): | ||
| exit_code = subprocess.call( | ||
| [ | ||
| "gsutil", | ||
| "cp", | ||
| "-r", | ||
| "gs://maxtext-dataset/hf/llama2-chat-tokenizer", | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is it okay to use a private gs: location ? |
||
| os.path.join(MAXTEXT_ASSETS_ROOT, ""), | ||
| ] | ||
| ) | ||
| if exit_code != 0: | ||
| raise ValueError("Failed to download llama tokenizer") | ||
|
|
||
| def setUp(self): | ||
| super().setUp() | ||
| self.qwen3_tokenizer = transformers.AutoTokenizer.from_pretrained("Qwen/Qwen3-4B") | ||
| self.llama2_tokenizer = transformers.AutoTokenizer.from_pretrained(self.LLAMA_TOKENIZER_PATH) | ||
|
|
||
| def test_tokenizer_w_generation_prompt(self): | ||
| verify_chat_template_generation_prompt_logic(self.qwen3_tokenizer) | ||
|
|
||
| def test_tokenizer_wo_generation_promt(self): | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: |
||
| verify_chat_template_generation_prompt_logic(self.llama2_tokenizer) | ||
|
|
||
| def test_failure_path_with_modified_template(self): | ||
| """Verifies the function correctly raises a ValueError on a bad template.""" | ||
| # Replace the role within the existing add_generation_prompt block with a deliberately faulty one. | ||
| fault_chat_template = re.sub( | ||
| r"(\{%-?\s*if add_generation_prompt\s*%\}.*?<\|im_start\|>)assistant(.*?\{%-?\s*endif\s*%\})", | ||
| r"\1wrong_role\2", | ||
| self.qwen3_tokenizer.chat_template, | ||
| flags=re.DOTALL, | ||
| ) | ||
| with patch.object(self.qwen3_tokenizer, "chat_template", fault_chat_template): | ||
| # Verify that our function catches the mismatch and raises the expected error | ||
| with self.assertRaisesRegex(ValueError, "Chat template generation prompt mismatch!"): | ||
| verify_chat_template_generation_prompt_logic(self.qwen3_tokenizer) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| unittest.main() | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you also include
systemprompt in verification?