From 01372bcc156aab59d466b7ef3caf3ba9998e0dcf Mon Sep 17 00:00:00 2001 From: WintersMontagne10335 Date: Mon, 17 Nov 2025 17:05:38 +0000 Subject: [PATCH 1/3] Add unit tests for ernie4_5_processor --- tests/input/test_ernie4_5_processor.py | 485 +++++++++++++++++++++++++ 1 file changed, 485 insertions(+) create mode 100644 tests/input/test_ernie4_5_processor.py diff --git a/tests/input/test_ernie4_5_processor.py b/tests/input/test_ernie4_5_processor.py new file mode 100644 index 00000000000..01c4f53c203 --- /dev/null +++ b/tests/input/test_ernie4_5_processor.py @@ -0,0 +1,485 @@ +import sys +import types +import unittest +from unittest.mock import patch + +import numpy as np + +# Fake opentelemetry to avoid import errors when importing fastdeploy +opentelemetry_mod = types.ModuleType("opentelemetry") +instrumentation_mod = types.ModuleType("opentelemetry.instrumentation") +logging_mod = types.ModuleType("opentelemetry.instrumentation.logging") + + +class DummyLoggingInstrumentor: + @staticmethod + def instrument(*args, **kwargs): + pass + + +logging_mod.LoggingInstrumentor = DummyLoggingInstrumentor +sys.modules["opentelemetry"] = opentelemetry_mod +sys.modules["opentelemetry.instrumentation"] = instrumentation_mod +sys.modules["opentelemetry.instrumentation.logging"] = logging_mod + +# Now import the module under test +from fastdeploy.input.ernie4_5_processor import ( # noqa: E402 + _SAMPLING_EPS, + Ernie4_5Processor, +) + +MODULE_PATH = "fastdeploy.input.ernie4_5_processor" + + +class DummyTokenizer: + """Simple fake tokenizer used for unit tests.""" + + def __init__(self): + self.bos_token = "" + self.bos_token_id = 1 + self.eos_token = "" + self.eos_token_id = 2 + self.pad_token_id = 0 + self.vocab_size = 1000 + self.chat_template = "dummy_template" + + def tokenize(self, text): + # Treat the whole string as a single token + if text is None: + return [] + return [text] + + def convert_tokens_to_ids(self, tokens): + # Map token -> length-based id; special BIG token to be > vocab_size + ids = [] + for t in tokens: + if t == "BIG": + ids.append(self.vocab_size + 1) + else: + ids.append(len(t) % self.vocab_size) + return ids + + def decode(self, token_ids, **kwargs): + # Join ids into a string so we can assert easily + return "|".join(str(i) for i in token_ids) + + def decode_token(self, token_ids, prefix_offset, read_offset): + # Streaming decode: return new part from read_offset onward + new_text = "|".join(str(i) for i in token_ids[read_offset:]) + new_read = len(token_ids) + return new_text, prefix_offset, new_read + + def apply_chat_template(self, request_or_messages, **kwargs): + # Join all message contents into a single text + if isinstance(request_or_messages, dict): + messages = request_or_messages.get("messages", []) + else: + messages = request_or_messages + contents = [] + for m in messages: + if isinstance(m, dict): + contents.append(m.get("content", "")) + else: + contents.append(str(m)) + return "\n".join(contents) + + +class DummyRequest: + """Simple request-like object with get/set methods.""" + + def __init__(self, **kwargs): + for k, v in kwargs.items(): + setattr(self, k, v) + + def get(self, key, default=None): + return getattr(self, key, default) + + def set(self, key, value): + setattr(self, key, value) + + def to_dict(self): + return self.__dict__ + + +class DummyOutputs: + """Simple outputs-like container for process_response.""" + + def __init__(self, token_ids, index=0): + self.token_ids = token_ids + self.index = index + self.text = "" + self.reasoning_content = "" + + +class DummyReasoningParser: + """Fake reasoning parser.""" + + def __init__(self, tokenizer): + self.tokenizer = tokenizer + + def extract_reasoning_content(self, full_text, response_dict): + # Return dummy reasoning + plain text + return "reasoning_part", "final_text" + + def extract_reasoning_content_streaming( + self, + previous_texts, + full_text, + delta_text, + previous_token_ids, + full_token_ids, + delta_token_ids, + ): + # Return an object with reasoning_content attribute + msg = types.SimpleNamespace() + msg.reasoning_content = "stream_reasoning" + return msg + + +class DummyToolParser: + """Fake tool parser.""" + + def __init__(self, tokenizer): + self.tokenizer = tokenizer + self.called_times = 0 + + def extract_tool_calls(self, full_text, response_dict): + # When called first time, pretend we found a tool call + self.called_times += 1 + if self.called_times == 1: + msg = types.SimpleNamespace() + msg.tools_called = True + msg.tool_calls = [{"name": "tool1"}] + msg.content = "tool_content" + return msg + msg = types.SimpleNamespace() + msg.tools_called = False + msg.tool_calls = [] + msg.content = full_text + return msg + + def extract_tool_calls_streaming( + self, + previous_texts, + full_text, + delta_text, + previous_token_ids, + full_token_ids, + delta_token_ids, + response_dict, + ): + msg = types.SimpleNamespace() + msg.tools_called = True + msg.tool_calls = [{"name": "stream_tool"}] + msg.content = full_text + return msg + + +class TestErnie45ProcessorInit(unittest.TestCase): + """Tests for __init__ and tokenizer loading.""" + + @patch(f"{MODULE_PATH}.GenerationConfig.from_pretrained") + @patch(f"{MODULE_PATH}.Ernie4_5Tokenizer.from_pretrained") + @patch(f"{MODULE_PATH}.get_eos_token_id") + def test_init_with_generation_config( + self, + mock_get_eos_token_id, + mock_from_pretrained_tokenizer, + mock_from_pretrained_config, + ): + dummy_tokenizer = DummyTokenizer() + mock_from_pretrained_tokenizer.return_value = dummy_tokenizer + mock_get_eos_token_id.return_value = [2, 3] + mock_from_pretrained_config.return_value = object() + + processor = Ernie4_5Processor("dummy_model_path") + + self.assertIsNotNone(processor.generation_config) + self.assertEqual(processor.tokenizer, dummy_tokenizer) + self.assertEqual(processor.eos_token_ids, [2, 3]) + self.assertEqual(processor.pad_token_id, dummy_tokenizer.pad_token_id) + + @patch(f"{MODULE_PATH}.GenerationConfig.from_pretrained") + @patch(f"{MODULE_PATH}.Ernie4_5Tokenizer.from_pretrained") + @patch(f"{MODULE_PATH}.get_eos_token_id") + def test_init_without_generation_config( + self, + mock_get_eos_token_id, + mock_from_pretrained_tokenizer, + mock_from_pretrained_config, + ): + dummy_tokenizer = DummyTokenizer() + mock_from_pretrained_tokenizer.return_value = dummy_tokenizer + mock_get_eos_token_id.return_value = [2] + mock_from_pretrained_config.side_effect = Exception("no config") + + processor = Ernie4_5Processor("dummy_model_path") + + self.assertIsNone(processor.generation_config) + self.assertEqual(processor.eos_token_ids, [2]) + self.assertEqual(processor.pad_token_id, dummy_tokenizer.pad_token_id) + + +class TestErnie45ProcessorRequest(unittest.TestCase): + """Tests for process_request and process_request_dict.""" + + def setUp(self): + patcher_tok = patch(f"{MODULE_PATH}.Ernie4_5Tokenizer.from_pretrained", return_value=DummyTokenizer()) + patcher_cfg = patch(f"{MODULE_PATH}.GenerationConfig.from_pretrained", return_value=object()) + patcher_eos = patch(f"{MODULE_PATH}.get_eos_token_id", return_value=[2]) + + self.addCleanup(patcher_tok.stop) + self.addCleanup(patcher_cfg.stop) + self.addCleanup(patcher_eos.stop) + + self.mock_tok = patcher_tok.start() + self.mock_cfg = patcher_cfg.start() + self.mock_eos = patcher_eos.start() + + self.processor = Ernie4_5Processor("dummy_model_path") + + @patch.object(Ernie4_5Processor, "_apply_default_parameters", side_effect=lambda self, r: r) + def test_process_request_with_prompt_string(self, _mock_apply_default): + req = DummyRequest( + request_id="req1", + prompt="hello", + prompt_token_ids=[], + messages=None, + eos_token_ids=None, + stop=[], + bad_words=None, + bad_words_token_ids=None, + max_tokens=None, + temperature=0.0, + top_p=0.0, + ) + + processed = self.processor.process_request(req, max_model_len=10) + + self.assertEqual(processed.eos_token_ids, [2]) + self.assertTrue(len(processed.prompt_token_ids) > 0) + self.assertEqual(processed.temperature, 1) + self.assertEqual(processed.top_p, _SAMPLING_EPS) + self.assertEqual(processed.max_tokens, 10 - len(processed.prompt_token_ids)) + + @patch.object(Ernie4_5Processor, "_apply_default_parameters", side_effect=lambda self, r: r) + def test_process_request_with_messages(self, _mock_apply_default): + messages = [{"role": "user", "content": "hi"}] + req = DummyRequest( + request_id="req2", + prompt=None, + prompt_token_ids=[], + messages=messages, + eos_token_ids=None, + stop=[], + bad_words=None, + bad_words_token_ids=None, + max_tokens=None, + temperature=1.0, + top_p=1.0, + ) + + processed = self.processor.process_request( + req, + max_model_len=50, + chat_template_kwargs={"extra": "value"}, + ) + + self.assertTrue(len(processed.prompt_token_ids) > 0) + self.assertEqual(processed.max_tokens, 50 - len(processed.prompt_token_ids)) + + @patch.object(Ernie4_5Processor, "_apply_default_parameters", side_effect=lambda self, r: r) + def test_process_request_missing_all_prompt_sources_raises(self, _mock_apply_default): + req = DummyRequest( + request_id="req3", + prompt=None, + prompt_token_ids=[], + messages=None, + eos_token_ids=None, + stop=[], + bad_words=None, + bad_words_token_ids=None, + max_tokens=None, + temperature=1.0, + top_p=1.0, + ) + + with self.assertRaises(ValueError): + self.processor.process_request(req, max_model_len=10) + + @patch.object(Ernie4_5Processor, "_apply_default_parameters", side_effect=lambda self, r: r) + def test_process_request_empty_prompt_token_ids_raises(self, _mock_apply_default): + req = DummyRequest( + request_id="req4", + prompt="", + prompt_token_ids=[], + messages=None, + eos_token_ids=None, + stop=[], + bad_words=None, + bad_words_token_ids=None, + max_tokens=None, + temperature=1.0, + top_p=1.0, + ) + + with self.assertRaises(ValueError): + self.processor.process_request(req, max_model_len=10) + + @patch.object(Ernie4_5Processor, "_apply_default_parameters", side_effect=lambda self, r: r) + def test_process_request_dict_with_prompt_list(self, _mock_apply_default): + req = { + "request_id": "req5", + "prompt_token_ids": [], + "prompt": [1, 2, 3], + "messages": None, + "eos_token_ids": None, + "stop": [], + "bad_words": None, + "bad_words_token_ids": None, + "max_tokens": None, + "temperature": 0.0, + "top_p": 0.0, + } + + processed = self.processor.process_request_dict(req, max_model_len=20) + + self.assertEqual(processed["prompt_token_ids"], [1, 2, 3]) + self.assertEqual(processed["temperature"], 1) + self.assertEqual(processed["top_p"], _SAMPLING_EPS) + self.assertEqual(processed["max_tokens"], 20 - len(processed["prompt_token_ids"])) + + +class TestErnie45ProcessorResponse(unittest.TestCase): + """Tests for process_response and response_dict helpers.""" + + def setUp(self): + patcher_tok = patch(f"{MODULE_PATH}.Ernie4_5Tokenizer.from_pretrained", return_value=DummyTokenizer()) + patcher_cfg = patch(f"{MODULE_PATH}.GenerationConfig.from_pretrained", return_value=object()) + patcher_eos = patch(f"{MODULE_PATH}.get_eos_token_id", return_value=[2]) + + self.addCleanup(patcher_tok.stop) + self.addCleanup(patcher_cfg.stop) + self.addCleanup(patcher_eos.stop) + + self.mock_tok = patcher_tok.start() + self.mock_cfg = patcher_cfg.start() + self.mock_eos = patcher_eos.start() + + def test_process_response_basic(self): + processor = Ernie4_5Processor("dummy_model_path") + outputs = DummyOutputs(token_ids=[10, 11, 2], index=2) + response = types.SimpleNamespace(request_id="req1", outputs=outputs) + + result = processor.process_response(response) + + self.assertIsNotNone(result) + # eos token should be stripped + self.assertEqual(result.outputs.text, "10|11") + self.assertEqual(result.usage["completion_tokens"], 3) + + def test_process_response_with_reasoning_and_tool(self): + processor = Ernie4_5Processor( + "dummy_model_path", + reasoning_parser_obj=DummyReasoningParser, + tool_parser_obj=DummyToolParser, + ) + outputs = DummyOutputs(token_ids=[10, 11, 2], index=1) + response = types.SimpleNamespace(request_id="req2", outputs=outputs) + + result = processor.process_response(response) + + self.assertEqual(result.outputs.reasoning_content, "reasoning_part") + self.assertEqual(result.outputs.text, "tool_content") + self.assertTrue(hasattr(result.outputs, "tool_calls")) + + def test_process_response_dict_normal_end(self): + processor = Ernie4_5Processor("dummy_model_path") + resp = { + "request_id": "req3", + "finished": True, + "outputs": {"token_ids": [5, 6, 7]}, + } + + result = processor.process_response_dict_normal(resp) + + self.assertEqual(result["outputs"]["text"], "5|6|7") + self.assertNotIn("req3", processor.decode_status) + + def test_process_response_dict_streaming_end(self): + processor = Ernie4_5Processor("dummy_model_path") + resp = { + "request_id": "req4", + "finished": True, + "outputs": {"token_ids": [8, 9]}, + } + + result = processor.process_response_dict_streaming(resp) + + self.assertEqual(result["outputs"]["text"], "8|9") + self.assertNotIn("req4", processor.decode_status) + + +class TestErnie45ProcessorUtilities(unittest.TestCase): + """Tests for pad_batch_data, update_stop_seq, process_logprob_response, update_bad_words.""" + + def setUp(self): + patcher_tok = patch(f"{MODULE_PATH}.Ernie4_5Tokenizer.from_pretrained", return_value=DummyTokenizer()) + patcher_cfg = patch(f"{MODULE_PATH}.GenerationConfig.from_pretrained", return_value=object()) + patcher_eos = patch(f"{MODULE_PATH}.get_eos_token_id", return_value=[2]) + + self.addCleanup(patcher_tok.stop) + self.addCleanup(patcher_cfg.stop) + self.addCleanup(patcher_eos.stop) + + self.mock_tok = patcher_tok.start() + self.mock_cfg = patcher_cfg.start() + self.mock_eos = patcher_eos.start() + + self.processor = Ernie4_5Processor("dummy_model_path") + + def test_pad_batch_data_right_and_left(self): + insts = [[1, 2], [3]] + padded_right, seq_len = self.processor.pad_batch_data( + insts, + pad_id=0, + return_seq_len=True, + return_array=True, + pad_style="right", + ) + padded_left = self.processor.pad_batch_data( + insts, + pad_id=0, + return_seq_len=False, + return_array=True, + pad_style="left", + ) + + self.assertEqual(padded_right.shape, (2, 2)) + self.assertTrue(np.array_equal(seq_len.reshape(-1).tolist(), [2, 1])) + self.assertEqual(padded_left.shape, (2, 2)) + self.assertTrue((padded_left[1, 0] == 0) and (padded_left[1, 1] == 3)) + + def test_update_stop_seq(self): + stop_seqs, stop_lens = self.processor.update_stop_seq(["stop1", "stop2"]) + + self.assertEqual(len(stop_seqs), 2) + self.assertEqual(len(stop_lens), 2) + self.assertTrue(all(l > 0 for l in stop_lens)) + + def test_process_logprob_response(self): + token_ids = [1, 2, 3] + text = self.processor.process_logprob_response(token_ids) + + self.assertEqual(text, "1|2|3") + + def test_update_bad_words_valid_and_invalid(self): + token_ids = self.processor.update_bad_words(["bad", "BIG"], bad_words_token_ids=None) + + self.assertTrue(len(token_ids) > 0) + # "BIG" maps to id > vocab_size and should be skipped, so only "bad" is counted once + self.assertEqual(len(token_ids), 1) + + +if __name__ == "__main__": + unittest.main() From 97130ce9540639f585c79813e0d93dd1b54dcd12 Mon Sep 17 00:00:00 2001 From: WintersMontagne10335 Date: Tue, 18 Nov 2025 02:28:15 +0000 Subject: [PATCH 2/3] update --- tests/input/test_ernie4_5_processor.py | 553 +++++++------------------ 1 file changed, 156 insertions(+), 397 deletions(-) diff --git a/tests/input/test_ernie4_5_processor.py b/tests/input/test_ernie4_5_processor.py index 01c4f53c203..c802ebd5b79 100644 --- a/tests/input/test_ernie4_5_processor.py +++ b/tests/input/test_ernie4_5_processor.py @@ -1,162 +1,100 @@ -import sys -import types import unittest -from unittest.mock import patch +from unittest.mock import MagicMock, patch import numpy as np -# Fake opentelemetry to avoid import errors when importing fastdeploy -opentelemetry_mod = types.ModuleType("opentelemetry") -instrumentation_mod = types.ModuleType("opentelemetry.instrumentation") -logging_mod = types.ModuleType("opentelemetry.instrumentation.logging") - - -class DummyLoggingInstrumentor: - @staticmethod - def instrument(*args, **kwargs): - pass - - -logging_mod.LoggingInstrumentor = DummyLoggingInstrumentor -sys.modules["opentelemetry"] = opentelemetry_mod -sys.modules["opentelemetry.instrumentation"] = instrumentation_mod -sys.modules["opentelemetry.instrumentation.logging"] = logging_mod - -# Now import the module under test -from fastdeploy.input.ernie4_5_processor import ( # noqa: E402 - _SAMPLING_EPS, - Ernie4_5Processor, -) - MODULE_PATH = "fastdeploy.input.ernie4_5_processor" +from fastdeploy.input.ernie4_5_processor import _SAMPLING_EPS, Ernie4_5Processor -class DummyTokenizer: - """Simple fake tokenizer used for unit tests.""" + +class MockTokenizer: + """Simple fake tokenizer for unit tests.""" def __init__(self): self.bos_token = "" - self.bos_token_id = 1 + self.bos_token_id = 101 self.eos_token = "" - self.eos_token_id = 2 + self.eos_token_id = 102 self.pad_token_id = 0 - self.vocab_size = 1000 - self.chat_template = "dummy_template" + self.vocab_size = 200 + # Any non-None value means chat_template is supported + self.chat_template = "dummy" def tokenize(self, text): - # Treat the whole string as a single token - if text is None: - return [] + """ + Make “multi” return multiple tokens to cover multi-token branch + All other texts return single-token + """ + if text.startswith("multi"): + return ["multi", "word"] return [text] def convert_tokens_to_ids(self, tokens): - # Map token -> length-based id; special BIG token to be > vocab_size - ids = [] - for t in tokens: - if t == "BIG": - ids.append(self.vocab_size + 1) - else: - ids.append(len(t) % self.vocab_size) - return ids + """Token → ID mapping used for specific branch coverage.""" + mapping = { + "bad": 5, + " bad": 6, + "multi": 7, + "word": 8, + "oov": 250, # > vocab_size → out-of-range branch + " oov": 251, + "hello": 9, + "REASON": 42, + } + return [mapping.get(t, 1) for t in tokens] def decode(self, token_ids, **kwargs): - # Join ids into a string so we can assert easily - return "|".join(str(i) for i in token_ids) + """Simple decode implementation.""" + return " ".join(str(t) for t in token_ids) def decode_token(self, token_ids, prefix_offset, read_offset): - # Streaming decode: return new part from read_offset onward - new_text = "|".join(str(i) for i in token_ids[read_offset:]) - new_read = len(token_ids) - return new_text, prefix_offset, new_read - - def apply_chat_template(self, request_or_messages, **kwargs): - # Join all message contents into a single text - if isinstance(request_or_messages, dict): - messages = request_or_messages.get("messages", []) - else: - messages = request_or_messages - contents = [] - for m in messages: - if isinstance(m, dict): - contents.append(m.get("content", "")) - else: - contents.append(str(m)) - return "\n".join(contents) - - -class DummyRequest: - """Simple request-like object with get/set methods.""" - - def __init__(self, **kwargs): - for k, v in kwargs.items(): - setattr(self, k, v) - - def get(self, key, default=None): - return getattr(self, key, default) - - def set(self, key, value): - setattr(self, key, value) - - def to_dict(self): - return self.__dict__ - - -class DummyOutputs: - """Simple outputs-like container for process_response.""" - - def __init__(self, token_ids, index=0): - self.token_ids = token_ids - self.index = index - self.text = "" - self.reasoning_content = "" - - -class DummyReasoningParser: - """Fake reasoning parser.""" + """ + Incremental decode: + - Use read_offset to get new tokens + - Return new string and updated read_offset + """ + new_tokens = token_ids[read_offset:] + decode_str = " ".join(str(t) for t in new_tokens) + new_read_offset = len(token_ids) + return decode_str, prefix_offset, new_read_offset + + def apply_chat_template(self, request_or_messages, tokenize, split_special_tokens, add_special_tokens, **kwargs): + """Minimal chat template behavior.""" + if isinstance(request_or_messages, dict) and "messages" in request_or_messages: + return " | ".join(m["content"] for m in request_or_messages["messages"]) + return str(request_or_messages) + + +class ErnieX1ReasoningParser: + """Fake reasoning parser used to trigger reasoning branch in streaming mode.""" def __init__(self, tokenizer): self.tokenizer = tokenizer - def extract_reasoning_content(self, full_text, response_dict): - # Return dummy reasoning + plain text - return "reasoning_part", "final_text" - def extract_reasoning_content_streaming( self, previous_texts, full_text, delta_text, previous_token_ids, - full_token_ids, + all_token_ids, delta_token_ids, ): - # Return an object with reasoning_content attribute - msg = types.SimpleNamespace() - msg.reasoning_content = "stream_reasoning" - return msg + """Return a minimal reasoning object.""" + + class ReasoningDelta: + def __init__(self, content): + self.reasoning_content = content + return ReasoningDelta("REASON") -class DummyToolParser: - """Fake tool parser.""" + +class MockToolParser: + """Fake tool parser used to trigger tool-calling branch.""" def __init__(self, tokenizer): self.tokenizer = tokenizer - self.called_times = 0 - - def extract_tool_calls(self, full_text, response_dict): - # When called first time, pretend we found a tool call - self.called_times += 1 - if self.called_times == 1: - msg = types.SimpleNamespace() - msg.tools_called = True - msg.tool_calls = [{"name": "tool1"}] - msg.content = "tool_content" - return msg - msg = types.SimpleNamespace() - msg.tools_called = False - msg.tool_calls = [] - msg.content = full_text - return msg def extract_tool_calls_streaming( self, @@ -164,321 +102,142 @@ def extract_tool_calls_streaming( full_text, delta_text, previous_token_ids, - full_token_ids, + all_token_ids, delta_token_ids, response_dict, ): - msg = types.SimpleNamespace() - msg.tools_called = True - msg.tool_calls = [{"name": "stream_tool"}] - msg.content = full_text - return msg - - -class TestErnie45ProcessorInit(unittest.TestCase): - """Tests for __init__ and tokenizer loading.""" - - @patch(f"{MODULE_PATH}.GenerationConfig.from_pretrained") - @patch(f"{MODULE_PATH}.Ernie4_5Tokenizer.from_pretrained") - @patch(f"{MODULE_PATH}.get_eos_token_id") - def test_init_with_generation_config( - self, - mock_get_eos_token_id, - mock_from_pretrained_tokenizer, - mock_from_pretrained_config, - ): - dummy_tokenizer = DummyTokenizer() - mock_from_pretrained_tokenizer.return_value = dummy_tokenizer - mock_get_eos_token_id.return_value = [2, 3] - mock_from_pretrained_config.return_value = object() - - processor = Ernie4_5Processor("dummy_model_path") - - self.assertIsNotNone(processor.generation_config) - self.assertEqual(processor.tokenizer, dummy_tokenizer) - self.assertEqual(processor.eos_token_ids, [2, 3]) - self.assertEqual(processor.pad_token_id, dummy_tokenizer.pad_token_id) - - @patch(f"{MODULE_PATH}.GenerationConfig.from_pretrained") - @patch(f"{MODULE_PATH}.Ernie4_5Tokenizer.from_pretrained") - @patch(f"{MODULE_PATH}.get_eos_token_id") - def test_init_without_generation_config( - self, - mock_get_eos_token_id, - mock_from_pretrained_tokenizer, - mock_from_pretrained_config, - ): - dummy_tokenizer = DummyTokenizer() - mock_from_pretrained_tokenizer.return_value = dummy_tokenizer - mock_get_eos_token_id.return_value = [2] - mock_from_pretrained_config.side_effect = Exception("no config") - - processor = Ernie4_5Processor("dummy_model_path") + """Return minimal tool-calling object.""" - self.assertIsNone(processor.generation_config) - self.assertEqual(processor.eos_token_ids, [2]) - self.assertEqual(processor.pad_token_id, dummy_tokenizer.pad_token_id) + class ToolDelta: + def __init__(self): + self.tool_calls = [{"name": "fake_tool"}] + return ToolDelta() -class TestErnie45ProcessorRequest(unittest.TestCase): - """Tests for process_request and process_request_dict.""" +class TestErnie4_5Processor(unittest.TestCase): def setUp(self): - patcher_tok = patch(f"{MODULE_PATH}.Ernie4_5Tokenizer.from_pretrained", return_value=DummyTokenizer()) - patcher_cfg = patch(f"{MODULE_PATH}.GenerationConfig.from_pretrained", return_value=object()) - patcher_eos = patch(f"{MODULE_PATH}.get_eos_token_id", return_value=[2]) - - self.addCleanup(patcher_tok.stop) - self.addCleanup(patcher_cfg.stop) - self.addCleanup(patcher_eos.stop) - - self.mock_tok = patcher_tok.start() - self.mock_cfg = patcher_cfg.start() - self.mock_eos = patcher_eos.start() - - self.processor = Ernie4_5Processor("dummy_model_path") - - @patch.object(Ernie4_5Processor, "_apply_default_parameters", side_effect=lambda self, r: r) - def test_process_request_with_prompt_string(self, _mock_apply_default): - req = DummyRequest( - request_id="req1", - prompt="hello", - prompt_token_ids=[], - messages=None, - eos_token_ids=None, - stop=[], - bad_words=None, - bad_words_token_ids=None, - max_tokens=None, - temperature=0.0, - top_p=0.0, + """Patch GenerationConfig, Tokenizer, and get_eos_token_id.""" + self.gen_patcher = patch(f"{MODULE_PATH}.GenerationConfig.from_pretrained", return_value=MagicMock()) + self.tokenizer_patcher = patch( + f"{MODULE_PATH}.Ernie4_5Tokenizer.from_pretrained", side_effect=lambda path: MockTokenizer() ) - - processed = self.processor.process_request(req, max_model_len=10) - - self.assertEqual(processed.eos_token_ids, [2]) - self.assertTrue(len(processed.prompt_token_ids) > 0) - self.assertEqual(processed.temperature, 1) - self.assertEqual(processed.top_p, _SAMPLING_EPS) - self.assertEqual(processed.max_tokens, 10 - len(processed.prompt_token_ids)) - - @patch.object(Ernie4_5Processor, "_apply_default_parameters", side_effect=lambda self, r: r) - def test_process_request_with_messages(self, _mock_apply_default): - messages = [{"role": "user", "content": "hi"}] - req = DummyRequest( - request_id="req2", - prompt=None, - prompt_token_ids=[], - messages=messages, - eos_token_ids=None, - stop=[], - bad_words=None, - bad_words_token_ids=None, - max_tokens=None, - temperature=1.0, - top_p=1.0, + self.eos_patcher = patch( + "paddleformers.trl.llm_utils.get_eos_token_id", + side_effect=lambda tokenizer, cfg: [tokenizer.eos_token_id], ) - processed = self.processor.process_request( - req, - max_model_len=50, - chat_template_kwargs={"extra": "value"}, - ) + self.gen_patcher.start() + self.tokenizer_patcher.start() + self.eos_patcher.start() - self.assertTrue(len(processed.prompt_token_ids) > 0) - self.assertEqual(processed.max_tokens, 50 - len(processed.prompt_token_ids)) - - @patch.object(Ernie4_5Processor, "_apply_default_parameters", side_effect=lambda self, r: r) - def test_process_request_missing_all_prompt_sources_raises(self, _mock_apply_default): - req = DummyRequest( - request_id="req3", - prompt=None, - prompt_token_ids=[], - messages=None, - eos_token_ids=None, - stop=[], - bad_words=None, - bad_words_token_ids=None, - max_tokens=None, - temperature=1.0, - top_p=1.0, - ) + def tearDown(self): + self.gen_patcher.stop() + self.tokenizer_patcher.stop() + self.eos_patcher.stop() - with self.assertRaises(ValueError): - self.processor.process_request(req, max_model_len=10) - - @patch.object(Ernie4_5Processor, "_apply_default_parameters", side_effect=lambda self, r: r) - def test_process_request_empty_prompt_token_ids_raises(self, _mock_apply_default): - req = DummyRequest( - request_id="req4", - prompt="", - prompt_token_ids=[], - messages=None, - eos_token_ids=None, - stop=[], - bad_words=None, - bad_words_token_ids=None, - max_tokens=None, - temperature=1.0, - top_p=1.0, - ) + def _make_processor(self, reasoning=False, tool=False): + """Helper to construct Ernie4_5Processor with mocked tokenizer.""" + reasoning_cls = ErnieX1ReasoningParser if reasoning else None + tool_cls = MockToolParser if tool else None + proc = Ernie4_5Processor("dummy-model", reasoning_parser_obj=reasoning_cls, tool_parser_obj=tool_cls) + proc._apply_default_parameters = lambda req: req # avoid dependency on parent class + return proc - with self.assertRaises(ValueError): - self.processor.process_request(req, max_model_len=10) + # 1) update_bad_words + def test_update_bad_words(self): + proc = self._make_processor() - @patch.object(Ernie4_5Processor, "_apply_default_parameters", side_effect=lambda self, r: r) - def test_process_request_dict_with_prompt_list(self, _mock_apply_default): + bad_words = ["bad", "multi", "oov"] # single → multi → OOV + token_ids = proc.update_bad_words(bad_words, bad_words_token_ids=None) + + # Only “bad” and its prefixed-space version should remain + self.assertEqual(token_ids, [5, 6, 1]) + + # 2) process_request_dict → prompt branch + def test_process_request_dict_with_prompt_string(self): + proc = self._make_processor() req = { - "request_id": "req5", - "prompt_token_ids": [], - "prompt": [1, 2, 3], - "messages": None, - "eos_token_ids": None, - "stop": [], - "bad_words": None, - "bad_words_token_ids": None, - "max_tokens": None, + "prompt": "hello", "temperature": 0.0, "top_p": 0.0, } - processed = self.processor.process_request_dict(req, max_model_len=20) + processed = proc.process_request_dict(req, max_model_len=10) - self.assertEqual(processed["prompt_token_ids"], [1, 2, 3]) - self.assertEqual(processed["temperature"], 1) - self.assertEqual(processed["top_p"], _SAMPLING_EPS) - self.assertEqual(processed["max_tokens"], 20 - len(processed["prompt_token_ids"])) + self.assertIn("eos_token_ids", processed) + self.assertEqual(processed["eos_token_ids"], [proc.tokenizer.eos_token_id]) + expected_ids = proc.tokenizer.convert_tokens_to_ids(proc.tokenizer.tokenize("hello")) + self.assertEqual(processed["prompt_token_ids"], expected_ids) -class TestErnie45ProcessorResponse(unittest.TestCase): - """Tests for process_response and response_dict helpers.""" + self.assertEqual(processed["max_tokens"], max(1, 10 - len(expected_ids))) + self.assertEqual(processed["temperature"], 1) + self.assertAlmostEqual(processed["top_p"], _SAMPLING_EPS) - def setUp(self): - patcher_tok = patch(f"{MODULE_PATH}.Ernie4_5Tokenizer.from_pretrained", return_value=DummyTokenizer()) - patcher_cfg = patch(f"{MODULE_PATH}.GenerationConfig.from_pretrained", return_value=object()) - patcher_eos = patch(f"{MODULE_PATH}.get_eos_token_id", return_value=[2]) - - self.addCleanup(patcher_tok.stop) - self.addCleanup(patcher_cfg.stop) - self.addCleanup(patcher_eos.stop) - - self.mock_tok = patcher_tok.start() - self.mock_cfg = patcher_cfg.start() - self.mock_eos = patcher_eos.start() - - def test_process_response_basic(self): - processor = Ernie4_5Processor("dummy_model_path") - outputs = DummyOutputs(token_ids=[10, 11, 2], index=2) - response = types.SimpleNamespace(request_id="req1", outputs=outputs) - - result = processor.process_response(response) - - self.assertIsNotNone(result) - # eos token should be stripped - self.assertEqual(result.outputs.text, "10|11") - self.assertEqual(result.usage["completion_tokens"], 3) - - def test_process_response_with_reasoning_and_tool(self): - processor = Ernie4_5Processor( - "dummy_model_path", - reasoning_parser_obj=DummyReasoningParser, - tool_parser_obj=DummyToolParser, - ) - outputs = DummyOutputs(token_ids=[10, 11, 2], index=1) - response = types.SimpleNamespace(request_id="req2", outputs=outputs) + self.assertEqual(processed["prompt_tokens"], "hello") - result = processor.process_response(response) + # 3) pad_batch_data + def test_pad_batch_data_right_and_left_and_empty(self): + proc = self._make_processor() - self.assertEqual(result.outputs.reasoning_content, "reasoning_part") - self.assertEqual(result.outputs.text, "tool_content") - self.assertTrue(hasattr(result.outputs, "tool_calls")) + insts = [[1, 2], [3]] - def test_process_response_dict_normal_end(self): - processor = Ernie4_5Processor("dummy_model_path") - resp = { - "request_id": "req3", - "finished": True, - "outputs": {"token_ids": [5, 6, 7]}, - } + # right pad + padded, seq_len = proc.pad_batch_data( + insts, pad_id=0, return_seq_len=True, return_array=True, pad_style="right" + ) + np.testing.assert_array_equal(padded, np.array([[1, 2], [3, 0]], dtype=np.int64)) + np.testing.assert_array_equal(seq_len, np.array([[2], [1]], dtype=np.int64)) + + # left pad + padded_left, seq_len_left = proc.pad_batch_data( + insts, pad_id=0, return_seq_len=True, return_array=True, pad_style="left" + ) + np.testing.assert_array_equal(padded_left, np.array([[1, 2], [0, 3]], dtype=np.int64)) + np.testing.assert_array_equal(seq_len_left, np.array([[2], [1]], dtype=np.int64)) - result = processor.process_response_dict_normal(resp) + # empty + padded_empty, seq_len_empty = proc.pad_batch_data( + [], pad_id=0, return_seq_len=True, return_array=True, pad_style="right" + ) + np.testing.assert_array_equal(padded_empty, np.array([[]], dtype=np.int64)) + np.testing.assert_array_equal(seq_len_empty, np.array([], dtype=np.int64)) - self.assertEqual(result["outputs"]["text"], "5|6|7") - self.assertNotIn("req3", processor.decode_status) + # 4) process_response_dict_streaming reasoning + tool branches + def test_process_response_dict_streaming_with_reasoning_and_tool(self): + proc = self._make_processor(reasoning=True, tool=True) - def test_process_response_dict_streaming_end(self): - processor = Ernie4_5Processor("dummy_model_path") - resp = { - "request_id": "req4", + response = { "finished": True, - "outputs": {"token_ids": [8, 9]}, + "request_id": "req-1", + "outputs": { + "token_ids": [10, 11], + }, } - result = processor.process_response_dict_streaming(resp) - - self.assertEqual(result["outputs"]["text"], "8|9") - self.assertNotIn("req4", processor.decode_status) - - -class TestErnie45ProcessorUtilities(unittest.TestCase): - """Tests for pad_batch_data, update_stop_seq, process_logprob_response, update_bad_words.""" - - def setUp(self): - patcher_tok = patch(f"{MODULE_PATH}.Ernie4_5Tokenizer.from_pretrained", return_value=DummyTokenizer()) - patcher_cfg = patch(f"{MODULE_PATH}.GenerationConfig.from_pretrained", return_value=object()) - patcher_eos = patch(f"{MODULE_PATH}.get_eos_token_id", return_value=[2]) - - self.addCleanup(patcher_tok.stop) - self.addCleanup(patcher_cfg.stop) - self.addCleanup(patcher_eos.stop) - - self.mock_tok = patcher_tok.start() - self.mock_cfg = patcher_cfg.start() - self.mock_eos = patcher_eos.start() - - self.processor = Ernie4_5Processor("dummy_model_path") - - def test_pad_batch_data_right_and_left(self): - insts = [[1, 2], [3]] - padded_right, seq_len = self.processor.pad_batch_data( - insts, - pad_id=0, - return_seq_len=True, - return_array=True, - pad_style="right", + result = proc.process_response_dict_streaming( + response, + enable_thinking=False, # reasoning still enabled because class name matches + include_stop_str_in_output=False, ) - padded_left = self.processor.pad_batch_data( - insts, - pad_id=0, - return_seq_len=False, - return_array=True, - pad_style="left", - ) - - self.assertEqual(padded_right.shape, (2, 2)) - self.assertTrue(np.array_equal(seq_len.reshape(-1).tolist(), [2, 1])) - self.assertEqual(padded_left.shape, (2, 2)) - self.assertTrue((padded_left[1, 0] == 0) and (padded_left[1, 1] == 3)) - - def test_update_stop_seq(self): - stop_seqs, stop_lens = self.processor.update_stop_seq(["stop1", "stop2"]) - self.assertEqual(len(stop_seqs), 2) - self.assertEqual(len(stop_lens), 2) - self.assertTrue(all(l > 0 for l in stop_lens)) + outputs = result["outputs"] - def test_process_logprob_response(self): - token_ids = [1, 2, 3] - text = self.processor.process_logprob_response(token_ids) + self.assertIn("completion_tokens", outputs) + self.assertIn("text", outputs) + self.assertEqual(outputs["completion_tokens"], outputs["text"]) - self.assertEqual(text, "1|2|3") + self.assertIn("reasoning_token_num", outputs) + self.assertGreaterEqual(outputs["reasoning_token_num"], 0) - def test_update_bad_words_valid_and_invalid(self): - token_ids = self.processor.update_bad_words(["bad", "BIG"], bad_words_token_ids=None) + self.assertIn("delta_message", outputs) + delta_msg = outputs["delta_message"] + self.assertTrue(hasattr(delta_msg, "tool_calls")) + self.assertEqual(delta_msg.tool_calls[0]["name"], "fake_tool") - self.assertTrue(len(token_ids) > 0) - # "BIG" maps to id > vocab_size and should be skipped, so only "bad" is counted once - self.assertEqual(len(token_ids), 1) + self.assertNotIn("req-1", proc.decode_status) + self.assertNotIn("req-1", proc.tool_parser_dict) if __name__ == "__main__": From c962f6386fddca54046d8ba59ff4fd55b95411f3 Mon Sep 17 00:00:00 2001 From: WintersMontagne10335 Date: Tue, 18 Nov 2025 03:04:24 +0000 Subject: [PATCH 3/3] update --- tests/input/test_ernie4_5_processor.py | 187 +++++++++++++++++++------ 1 file changed, 142 insertions(+), 45 deletions(-) diff --git a/tests/input/test_ernie4_5_processor.py b/tests/input/test_ernie4_5_processor.py index c802ebd5b79..8c7386fef85 100644 --- a/tests/input/test_ernie4_5_processor.py +++ b/tests/input/test_ernie4_5_processor.py @@ -9,7 +9,7 @@ class MockTokenizer: - """Simple fake tokenizer for unit tests.""" + """A simple mock tokenizer used to simulate tokenization behavior in unit tests.""" def __init__(self): self.bos_token = "" @@ -18,26 +18,23 @@ def __init__(self): self.eos_token_id = 102 self.pad_token_id = 0 self.vocab_size = 200 - # Any non-None value means chat_template is supported + # Non-None value indicates chat_template support self.chat_template = "dummy" def tokenize(self, text): - """ - Make “multi” return multiple tokens to cover multi-token branch - All other texts return single-token - """ + """Return multi-token output for 'multi*' to test branching; otherwise return single-token.""" if text.startswith("multi"): return ["multi", "word"] return [text] def convert_tokens_to_ids(self, tokens): - """Token → ID mapping used for specific branch coverage.""" + """Map tokens to synthetic IDs for branch coverage.""" mapping = { "bad": 5, " bad": 6, "multi": 7, "word": 8, - "oov": 250, # > vocab_size → out-of-range branch + "oov": 250, " oov": 251, "hello": 9, "REASON": 42, @@ -45,29 +42,25 @@ def convert_tokens_to_ids(self, tokens): return [mapping.get(t, 1) for t in tokens] def decode(self, token_ids, **kwargs): - """Simple decode implementation.""" + """Simple decode implementation returning a space-separated string.""" return " ".join(str(t) for t in token_ids) def decode_token(self, token_ids, prefix_offset, read_offset): - """ - Incremental decode: - - Use read_offset to get new tokens - - Return new string and updated read_offset - """ + """Incremental decode used to test streaming behavior.""" new_tokens = token_ids[read_offset:] decode_str = " ".join(str(t) for t in new_tokens) new_read_offset = len(token_ids) return decode_str, prefix_offset, new_read_offset def apply_chat_template(self, request_or_messages, tokenize, split_special_tokens, add_special_tokens, **kwargs): - """Minimal chat template behavior.""" + """Minimal chat template implementation used by messages2ids.""" if isinstance(request_or_messages, dict) and "messages" in request_or_messages: return " | ".join(m["content"] for m in request_or_messages["messages"]) return str(request_or_messages) class ErnieX1ReasoningParser: - """Fake reasoning parser used to trigger reasoning branch in streaming mode.""" + """Mock reasoning parser to trigger reasoning-related branches during streaming.""" def __init__(self, tokenizer): self.tokenizer = tokenizer @@ -81,7 +74,7 @@ def extract_reasoning_content_streaming( all_token_ids, delta_token_ids, ): - """Return a minimal reasoning object.""" + """Return a simple object with reasoning_content to cover reasoning branch.""" class ReasoningDelta: def __init__(self, content): @@ -91,11 +84,23 @@ def __init__(self, content): class MockToolParser: - """Fake tool parser used to trigger tool-calling branch.""" + """Mock tool parser to cover tool-related branches in both normal and streaming responses.""" def __init__(self, tokenizer): self.tokenizer = tokenizer + class ToolDelta: + """Simple container representing detected tool calls.""" + + def __init__(self): + self.tool_calls = [{"name": "fake_tool"}] + self.tools_called = True + self.content = "tool_content" + + def extract_tool_calls(self, full_text, response_dict): + """Used in process_response and process_response_dict_normal.""" + return MockToolParser.ToolDelta() + def extract_tool_calls_streaming( self, previous_texts, @@ -106,18 +111,15 @@ def extract_tool_calls_streaming( delta_token_ids, response_dict, ): - """Return minimal tool-calling object.""" - - class ToolDelta: - def __init__(self): - self.tool_calls = [{"name": "fake_tool"}] - - return ToolDelta() + """Used in process_response_dict_streaming.""" + return MockToolParser.ToolDelta() class TestErnie4_5Processor(unittest.TestCase): + """Unit tests for Ernie4_5Processor focusing on preprocessing and postprocessing logic.""" + def setUp(self): - """Patch GenerationConfig, Tokenizer, and get_eos_token_id.""" + """Patch external dependencies: tokenizer, generation config, eos token resolution.""" self.gen_patcher = patch(f"{MODULE_PATH}.GenerationConfig.from_pretrained", return_value=MagicMock()) self.tokenizer_patcher = patch( f"{MODULE_PATH}.Ernie4_5Tokenizer.from_pretrained", side_effect=lambda path: MockTokenizer() @@ -132,30 +134,30 @@ def setUp(self): self.eos_patcher.start() def tearDown(self): + """Stop all patches after each test.""" self.gen_patcher.stop() self.tokenizer_patcher.stop() self.eos_patcher.stop() def _make_processor(self, reasoning=False, tool=False): - """Helper to construct Ernie4_5Processor with mocked tokenizer.""" + """Helper for creating a processor with optional reasoning/tool parser support.""" reasoning_cls = ErnieX1ReasoningParser if reasoning else None tool_cls = MockToolParser if tool else None proc = Ernie4_5Processor("dummy-model", reasoning_parser_obj=reasoning_cls, tool_parser_obj=tool_cls) - proc._apply_default_parameters = lambda req: req # avoid dependency on parent class + proc._apply_default_parameters = lambda req: req return proc - # 1) update_bad_words def test_update_bad_words(self): + """Verify filtering, multi-token skipping, and OOV behavior in update_bad_words.""" proc = self._make_processor() - bad_words = ["bad", "multi", "oov"] # single → multi → OOV + bad_words = ["bad", "multi", "oov"] token_ids = proc.update_bad_words(bad_words, bad_words_token_ids=None) - # Only “bad” and its prefixed-space version should remain self.assertEqual(token_ids, [5, 6, 1]) - # 2) process_request_dict → prompt branch def test_process_request_dict_with_prompt_string(self): + """Test prompt-based tokenization, truncation, and temperature/top_p correction.""" proc = self._make_processor() req = { "prompt": "hello", @@ -174,52 +176,44 @@ def test_process_request_dict_with_prompt_string(self): self.assertEqual(processed["max_tokens"], max(1, 10 - len(expected_ids))) self.assertEqual(processed["temperature"], 1) self.assertAlmostEqual(processed["top_p"], _SAMPLING_EPS) - self.assertEqual(processed["prompt_tokens"], "hello") - # 3) pad_batch_data def test_pad_batch_data_right_and_left_and_empty(self): + """Test left/right padding and empty input behavior.""" proc = self._make_processor() insts = [[1, 2], [3]] - # right pad padded, seq_len = proc.pad_batch_data( insts, pad_id=0, return_seq_len=True, return_array=True, pad_style="right" ) np.testing.assert_array_equal(padded, np.array([[1, 2], [3, 0]], dtype=np.int64)) np.testing.assert_array_equal(seq_len, np.array([[2], [1]], dtype=np.int64)) - # left pad padded_left, seq_len_left = proc.pad_batch_data( insts, pad_id=0, return_seq_len=True, return_array=True, pad_style="left" ) np.testing.assert_array_equal(padded_left, np.array([[1, 2], [0, 3]], dtype=np.int64)) np.testing.assert_array_equal(seq_len_left, np.array([[2], [1]], dtype=np.int64)) - # empty padded_empty, seq_len_empty = proc.pad_batch_data( [], pad_id=0, return_seq_len=True, return_array=True, pad_style="right" ) np.testing.assert_array_equal(padded_empty, np.array([[]], dtype=np.int64)) np.testing.assert_array_equal(seq_len_empty, np.array([], dtype=np.int64)) - # 4) process_response_dict_streaming reasoning + tool branches def test_process_response_dict_streaming_with_reasoning_and_tool(self): + """Ensure streaming mode handles reasoning and tool-call parsing correctly.""" proc = self._make_processor(reasoning=True, tool=True) response = { "finished": True, "request_id": "req-1", - "outputs": { - "token_ids": [10, 11], - }, + "outputs": {"token_ids": [10, 11]}, } result = proc.process_response_dict_streaming( - response, - enable_thinking=False, # reasoning still enabled because class name matches - include_stop_str_in_output=False, + response, enable_thinking=False, include_stop_str_in_output=False ) outputs = result["outputs"] @@ -234,11 +228,114 @@ def test_process_response_dict_streaming_with_reasoning_and_tool(self): self.assertIn("delta_message", outputs) delta_msg = outputs["delta_message"] self.assertTrue(hasattr(delta_msg, "tool_calls")) - self.assertEqual(delta_msg.tool_calls[0]["name"], "fake_tool") self.assertNotIn("req-1", proc.decode_status) self.assertNotIn("req-1", proc.tool_parser_dict) + def test_update_stop_seq(self): + """Test stop sequence tokenization and padding.""" + proc = self._make_processor() + + stop_seqs, stop_lens = proc.update_stop_seq("stop") + self.assertIsInstance(stop_seqs, list) + self.assertIsInstance(stop_lens, list) + + stop_seqs2, stop_lens2 = proc.update_stop_seq(["stop", "hello"]) + self.assertEqual(len(stop_seqs2), 2) + self.assertEqual(len(stop_lens2), 2) + + def test_process_request_chat_template_kwargs(self): + """Test chat_template_kwargs application inside process_request.""" + + proc = self._make_processor() + + class ReqObj(dict): + """Mock request object supporting attributes, set(), and to_dict().""" + + def set(self, k, v): + self[k] = v + + def __getattr__(self, item): + return self.get(item, None) + + def to_dict(self): + return dict(self) + + request = ReqObj( + { + "messages": [{"role": "user", "content": "hello"}], + "temperature": 0.5, + "top_p": 0.5, + } + ) + + processed = proc.process_request(request, max_model_len=20, chat_template_kwargs={"extra": "VALUE"}) + + self.assertEqual(processed.eos_token_ids, [proc.tokenizer.eos_token_id]) + + expected_ids = proc.tokenizer.convert_tokens_to_ids(proc.tokenizer.tokenize("hello")) + self.assertIsNotNone(processed.prompt_token_ids) + self.assertEqual(processed.prompt_token_ids, expected_ids) + + self.assertIn("max_tokens", processed) + self.assertEqual(processed["max_tokens"], max(1, 20 - len(expected_ids))) + + def test_process_request_dict_chat_template_kwargs(self): + """Test chat_template_kwargs insertion in process_request_dict.""" + proc = self._make_processor() + + req = { + "messages": [{"role": "user", "content": "hey"}], + "chat_template_kwargs": {"A": "B"}, + "temperature": 0.5, + "top_p": 0.5, + } + + result = proc.process_request_dict(req, max_model_len=30) + + self.assertIn("prompt_token_ids", result) + self.assertEqual(result["A"], "B") + + def test_init_generation_config_exception(self): + """Test fallback behavior when GenerationConfig loading fails.""" + with patch(f"{MODULE_PATH}.GenerationConfig.from_pretrained", side_effect=Exception("fail")): + proc = self._make_processor() + self.assertIsNone(proc.generation_config) + + def test_process_response_with_tool_parser(self): + """Verify tool_call extraction in process_response.""" + proc = self._make_processor(tool=True) + + class RespObj: + """Mock response carrying token_ids and index for testing.""" + + def __init__(self): + self.request_id = "reqx" + self.outputs = MagicMock() + self.outputs.token_ids = [9, proc.tokenizer.eos_token_id] + self.outputs.index = 0 + + resp = RespObj() + result = proc.process_response(resp) + + self.assertTrue(hasattr(result.outputs, "tool_calls")) + self.assertEqual(result.outputs.tool_calls[0]["name"], "fake_tool") + + def test_process_response_dict_normal_with_tool(self): + """Verify tool_call extraction in normal (non-streaming) response mode.""" + proc = self._make_processor(tool=True) + + resp = { + "finished": True, + "request_id": "task-99", + "outputs": {"token_ids": [10, 11], "text": ""}, + } + + result = proc.process_response_dict_normal(resp, enable_thinking=False, include_stop_str_in_output=False) + + self.assertIn("tool_call", result["outputs"]) + self.assertEqual(result["outputs"]["tool_call"][0]["name"], "fake_tool") + if __name__ == "__main__": unittest.main()