Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/maxtext/input_pipeline/hf_data_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,7 @@ def preprocessing_pipeline(
)
operations = []
if use_sft:
input_pipeline_utils.verify_chat_template_generation_prompt_logic(tokenizer)
operations.append(
input_pipeline_utils.SFTPromptMasking(
text_column_name=data_column_names[0],
Expand Down
50 changes: 50 additions & 0 deletions src/maxtext/input_pipeline/input_pipeline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,56 @@ def is_conversational(features, data_columns):
return False


def verify_chat_template_generation_prompt_logic(tokenizer_model):
"""Verifies the tokenizer's chat template for correct SFT loss masking.

This function ensures that the tokens added by `add_generation_prompt=True`
are identical to the tokens that begin an assistant's turn in a complete
conversation, which is critical for masking prompt tokens during SFT loss
calculation.

Example of a mismatch:
A `ValueError` is raised if the generation prompt and the actual
assistant prefix do not match. For example:

- `add_generation_prompt=True` on a user message produces a prompt ending in:
`...<|im_start|>generation\n`
- A full turn with an assistant message starts the reply with:
`...<|im_start|>assistant\n...`

This function would fail because the tokens for "generation" do not
match the tokens for "assistant".

Args:
tokenizer_model: The Hugging Face tokenizer instance to verify.

Raises:
ValueError: If the `add_generation_prompt` tokens do not exactly
match the beginning of an assistant message in the template.
"""
dummy_msgs = [{"role": "user", "content": "Test message"}]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you also include system prompt in verification?


prompt_wo_gen = tokenizer_model.apply_chat_template(dummy_msgs, add_generation_prompt=False, tokenize=True)
prompt_with_gen = tokenizer_model.apply_chat_template(dummy_msgs, add_generation_prompt=True, tokenize=True)
# Extract the tokenized generation prompt (the expected assistant prefix)
assistant_prefix_tokens = prompt_with_gen[len(prompt_wo_gen) :]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a check before this:

if prompt_with_gen[:len(prompt_wo_gen)] != prompt_wo_gen:
    raise ValueError("Unable to extract generation prompt tokens.")

full_turn_tokens = tokenizer_model.apply_chat_template(
dummy_msgs + [{"role": "assistant", "content": "Dummy response"}], add_generation_prompt=False, tokenize=True
)
# Extract the actual tokens that appear right after the user message in the full turn
actual_prefix_in_full_turn = full_turn_tokens[len(prompt_wo_gen) : len(prompt_wo_gen) + len(assistant_prefix_tokens)]

if actual_prefix_in_full_turn != assistant_prefix_tokens:
expected_str = tokenizer_model.decode(assistant_prefix_tokens)
actual_str = tokenizer_model.decode(actual_prefix_in_full_turn)
raise ValueError(
"Chat template generation prompt mismatch!\n"
f"Expected assistant prefix tokens: {assistant_prefix_tokens} ('{expected_str}')\n"
f"Actual prefix tokens found: {actual_prefix_in_full_turn} ('{actual_str}')\n"
"This means the tokenizer's chat template will break the sft masking logic."
)


def _get_completion_in_chat_template(tokenizer_model, round_msgs):
"""
Calculates the completion part of a conversation turn when formatted with a chat template.
Expand Down
50 changes: 49 additions & 1 deletion tests/unit/sft_data_processing_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,19 @@
import pytest
import numpy as np
import jax
import re
from jax.sharding import Mesh
from jax.experimental import mesh_utils
from datasets import Dataset
import transformers
from parameterized import parameterized_class

from unittest.mock import patch
from maxtext.configs import pyconfig
from maxtext.utils.globals import MAXTEXT_PKG_DIR, MAXTEXT_CONFIGS_DIR, MAXTEXT_ASSETS_ROOT
from maxtext.input_pipeline import hf_data_processing
from maxtext.input_pipeline import input_pipeline_interface
from maxtext.input_pipeline.hf_data_processing import _get_pad_id
from maxtext.input_pipeline.input_pipeline_utils import verify_chat_template_generation_prompt_logic

PROMPT_DATA = [
[
Expand Down Expand Up @@ -480,5 +482,51 @@ def test_system_message_not_at_beginning(self):
self.get_data_iterator(dataset, ["messages"])


@pytest.mark.external_training
class SFTChatTemplateLogicTest(unittest.TestCase):
LLAMA_TOKENIZER_PATH = os.path.join(MAXTEXT_ASSETS_ROOT, "llama2-chat-tokenizer")

@classmethod
def setUpClass(cls):
super().setUpClass()
if not os.path.exists(cls.LLAMA_TOKENIZER_PATH):
exit_code = subprocess.call(
[
"gsutil",
"cp",
"-r",
"gs://maxtext-dataset/hf/llama2-chat-tokenizer",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it okay to use a private gs: location ?

os.path.join(MAXTEXT_ASSETS_ROOT, ""),
]
)
if exit_code != 0:
raise ValueError("Failed to download llama tokenizer")

def setUp(self):
super().setUp()
self.qwen3_tokenizer = transformers.AutoTokenizer.from_pretrained("Qwen/Qwen3-4B")
self.llama2_tokenizer = transformers.AutoTokenizer.from_pretrained(self.LLAMA_TOKENIZER_PATH)

def test_tokenizer_w_generation_prompt(self):
verify_chat_template_generation_prompt_logic(self.qwen3_tokenizer)

def test_tokenizer_wo_generation_promt(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: test_tokenizer_wo_generation_prompt

verify_chat_template_generation_prompt_logic(self.llama2_tokenizer)

def test_failure_path_with_modified_template(self):
"""Verifies the function correctly raises a ValueError on a bad template."""
# Replace the role within the existing add_generation_prompt block with a deliberately faulty one.
fault_chat_template = re.sub(
r"(\{%-?\s*if add_generation_prompt\s*%\}.*?<\|im_start\|>)assistant(.*?\{%-?\s*endif\s*%\})",
r"\1wrong_role\2",
self.qwen3_tokenizer.chat_template,
flags=re.DOTALL,
)
with patch.object(self.qwen3_tokenizer, "chat_template", fault_chat_template):
# Verify that our function catches the mismatch and raises the expected error
with self.assertRaisesRegex(ValueError, "Chat template generation prompt mismatch!"):
verify_chat_template_generation_prompt_logic(self.qwen3_tokenizer)


if __name__ == "__main__":
unittest.main()
Loading