diff --git a/fastdeploy/entrypoints/openai/protocol.py b/fastdeploy/entrypoints/openai/protocol.py index 4a1e4ef647f..339164fca1d 100644 --- a/fastdeploy/entrypoints/openai/protocol.py +++ b/fastdeploy/entrypoints/openai/protocol.py @@ -671,10 +671,7 @@ def to_dict_for_infer(self, request_id=None): if request_id is not None: req_dict["request_id"] = request_id - if "prompt_token_ids" in req_dict: - if "messages" in req_dict: - del req_dict["messages"] - else: + if "prompt_token_ids" not in req_dict or not req_dict["prompt_token_ids"]: # If disable_chat_template is set, then the first message in messages will be used as the prompt. assert ( len(req_dict["messages"]) > 0 diff --git a/fastdeploy/input/ernie4_5_vl_processor/ernie4_5_vl_processor.py b/fastdeploy/input/ernie4_5_vl_processor/ernie4_5_vl_processor.py index 0fe724af53f..3376c846fd8 100644 --- a/fastdeploy/input/ernie4_5_vl_processor/ernie4_5_vl_processor.py +++ b/fastdeploy/input/ernie4_5_vl_processor/ernie4_5_vl_processor.py @@ -219,7 +219,13 @@ def process_request_dict(self, request, max_model_len=None): bad_words_token_ids = self.update_bad_words(bad_words, bad_words_token_ids) request["bad_words_token_ids"] = bad_words_token_ids - if request.get("prompt"): + if request.get("prompt_token_ids"): + messages = request.get("messages") + if messages: + self._check_mm_limits(messages) + request.setdefault("enable_thinking", True) + outputs = self.ernie4_5_processor.prompt_token_ids2outputs(request) + elif request.get("prompt"): multimodal_data = request.get("multimodal_data") if multimodal_data is None: multimodal_data = {} @@ -256,7 +262,9 @@ def process_request_dict(self, request, max_model_len=None): self.append_completion_tokens(outputs, request["completion_token_ids"]) outputs = self.pack_outputs(outputs) - request["prompt_token_ids"] = outputs["input_ids"].tolist() + request["prompt_token_ids"] = ( + outputs["input_ids"].tolist() if "prompt_token_ids" not in request else request["prompt_token_ids"] + ) request["prompt_token_ids_len"] = len(request["prompt_token_ids"]) request["multimodal_inputs"] = outputs diff --git a/fastdeploy/input/ernie4_5_vl_processor/process.py b/fastdeploy/input/ernie4_5_vl_processor/process.py index 4ccdf287f20..3da2bfcb97f 100644 --- a/fastdeploy/input/ernie4_5_vl_processor/process.py +++ b/fastdeploy/input/ernie4_5_vl_processor/process.py @@ -136,7 +136,9 @@ def __init__( self.video_end = self.VID_END self.image_patch_id = self.tokenizer.convert_tokens_to_ids("<|IMAGE_PLACEHOLDER|>") self.image_start_id = self.tokenizer.convert_tokens_to_ids(self.image_start) + self.image_end_id = self.tokenizer.convert_tokens_to_ids(self.image_end) self.video_start_id = self.tokenizer.convert_tokens_to_ids(self.video_start) + self.video_end_id = self.tokenizer.convert_tokens_to_ids(self.video_end) self.sep_token_id = self.tokenizer.convert_tokens_to_ids(self.sep_token) self.eos_token_id = self.tokenizer.convert_tokens_to_ids(self.eos_token) @@ -243,14 +245,7 @@ def text2ids(self, text, images=None, videos=None, image_uuid=None, video_uuid=N return outputs - def request2ids( - self, request: Dict[str, Any], tgts: List[str] = None - ) -> Dict[str, Union[np.ndarray, List[np.ndarray], None]]: - """ - Convert chat messages into model inputs. - Returns a dict with input_ids, token_type_ids, position_ids, images, grid_thw, image_type_ids, labels. - """ - + def extract_mm_items(self, request: Dict[str, Any]): messages = parse_chat_messages(request.get("messages")) mm_items = [] for msg in messages: @@ -273,6 +268,7 @@ def request2ids( if len(missing_hashes) > 0 and not self.enable_processor_cache: raise ValueError("Missing items cannot be retrieved without processor cache.") + dealer = None if self.enable_processor_cache: context = zmq.Context() dealer = context.socket(zmq.DEALER) @@ -295,6 +291,16 @@ def request2ids( video_uuid.append(item["uuid"]) else: raise ValueError(f"Unsupported multimodal type: {item.get('type')}") + return images, videos, image_uuid, video_uuid, dealer, missing_idx, mm_items + + def request2ids( + self, request: Dict[str, Any], tgts: List[str] = None + ) -> Dict[str, Union[np.ndarray, List[np.ndarray], None]]: + """ + Convert chat messages into model inputs. + Returns a dict with input_ids, token_type_ids, position_ids, images, grid_thw, image_type_ids, labels. + """ + images, videos, image_uuid, video_uuid, dealer, missing_idx, mm_items = self.extract_mm_items(request) if self.tokenizer.chat_template is None: raise ValueError("This model does not support chat template.") @@ -329,6 +335,115 @@ def request2ids( return outputs + def prompt_token_ids2outputs( + self, request: Dict[str, Any], tgts: List[str] = None + ) -> Dict[str, Union[np.ndarray, List[np.ndarray], None]]: + outputs = { + "input_ids": [], + "token_type_ids": [], + "position_ids": [], + "images": [], + "grid_thw": [], + "image_type_ids": [], + "labels": [], + "cur_position": 0, + "video_cnt": 0, + "num_input_image_tokens": 0, + "num_input_video_tokens": 0, + "mm_positions": [], + "mm_hashes": [], + } + prompt_token_ids = request.get("prompt_token_ids", []) + prompt_token_ids_len = len(prompt_token_ids) + if not request.get("messages"): + outputs["input_ids"].extend(prompt_token_ids) + outputs["token_type_ids"].extend([IDS_TYPE_FLAG["text"]] * prompt_token_ids_len) + for i in range(prompt_token_ids_len): + outputs["position_ids"].append([i] * 3) + outputs["cur_position"] += prompt_token_ids_len + return outputs + images, videos, image_uuid, video_uuid, dealer, missing_idx, mm_items = self.extract_mm_items(request) + st, image_idx, video_idx = 0, 0, 0 + while st < prompt_token_ids_len: + cur_token_id = prompt_token_ids[st] + if cur_token_id == self.image_start_id: + if image_idx >= len(images): + raise ValueError("prompt token ids has more image placeholder than in messages") + # append image_start_id + outputs["input_ids"].extend([cur_token_id]) + outputs["token_type_ids"].extend([IDS_TYPE_FLAG["text"]]) + outputs["position_ids"].append([outputs["cur_position"]] * 3) + outputs["cur_position"] += 1 + st += 1 + # process placeholder token ids + cur_idx = st + while cur_idx < prompt_token_ids_len and prompt_token_ids[cur_idx] != self.image_end_id: + cur_idx += 1 + if cur_idx >= prompt_token_ids_len: + raise ValueError("image token ids not complete") + image = images[image_idx] + uuid = image_uuid[image_idx] if image_uuid else None + token_len = cur_idx - st + if not isinstance(image, tuple): + self._add_image(image, outputs, uuid, token_len) + else: + self._add_processed_image(image, outputs, uuid, token_len) + image_idx += 1 + st = cur_idx + elif cur_token_id == self.video_start_id: + if video_idx >= len(videos): + raise ValueError("prompt token ids has more video placeholder than in messages") + # append video_start_id + outputs["input_ids"].extend([cur_token_id]) + outputs["token_type_ids"].extend([IDS_TYPE_FLAG["text"]]) + outputs["position_ids"].append([outputs["cur_position"]] * 3) + outputs["cur_position"] += 1 + st += 1 + # process placeholder token ids + cur_idx = st + while cur_idx < prompt_token_ids_len and prompt_token_ids[cur_idx] != self.video_end_id: + cur_idx += 1 + if cur_idx >= prompt_token_ids_len: + raise ValueError("video token ids not complete") + video = videos[video_idx] + uuid = video_uuid[video_idx] if video_uuid else None + token_len = cur_idx - st + if not isinstance(video, tuple): + if isinstance(video, dict): + frames = self._load_and_process_video(video["video"], video) + else: + frames = self._load_and_process_video(video, {}) + self._add_video(frames, outputs, uuid, token_len) + else: + self._add_processed_video(video, outputs, uuid, token_len) + video_idx += 1 + st = cur_idx + else: + outputs["input_ids"].extend([cur_token_id]) + outputs["token_type_ids"].extend([IDS_TYPE_FLAG["text"]]) + outputs["position_ids"].append([outputs["cur_position"]] * 3) + outputs["cur_position"] += 1 + st += 1 + if image_idx != len(images): + raise ValueError("number of images does not match") + if video_idx != len(videos): + raise ValueError("number of videos does not match") + + if self.enable_processor_cache: + missing_idx = set(missing_idx) + hashes_to_cache, items_to_cache = [], [] + for idx in range(len(mm_items)): + if idx in missing_idx: + continue + meta = {} + t, h, w = outputs["grid_thw"][idx][0] + meta["thw"] = (t, h, w) + hashes_to_cache.append(outputs["mm_hashes"][idx]) + items_to_cache.append((outputs["images"][idx], meta)) + self.update_processor_cache(dealer, hashes_to_cache, items_to_cache) + + return outputs + def _add_special_token(self, token: Union[str, int], outputs: Dict) -> None: token_id = token if isinstance(token, int) else self.tokenizer.convert_tokens_to_ids(token) outputs["input_ids"].append(token_id) @@ -348,7 +463,7 @@ def _add_text(self, tokens, outputs: Dict) -> None: outputs["position_ids"].append([start + i] * 3) outputs["cur_position"] += len(tokens) - def _add_image(self, img, outputs: Dict, uuid: Optional[str]) -> None: + def _add_image(self, img, outputs: Dict, uuid: Optional[str], token_len=None) -> None: patches_h, patches_w = self.image_preprocessor.get_smarted_resize( img.height, img.width, @@ -356,6 +471,8 @@ def _add_image(self, img, outputs: Dict, uuid: Optional[str]) -> None: max_pixels=self.image_max_pixels, )[1] num_tokens = (patches_h * patches_w) // (self.spatial_conv_size**2) + if token_len and token_len != num_tokens: + raise ValueError("image tokens num not match the size") outputs["mm_positions"].append(ImagePosition(len(outputs["input_ids"]), num_tokens)) outputs["input_ids"].extend([self.image_patch_id] * num_tokens) @@ -383,9 +500,13 @@ def _add_image(self, img, outputs: Dict, uuid: Optional[str]) -> None: outputs["grid_thw"].append(ret["image_grid_thw"]) outputs["image_type_ids"].append(0) - def _add_processed_image(self, img_cache: Tuple[np.ndarray, dict], outputs: Dict, uuid: str) -> None: + def _add_processed_image( + self, img_cache: Tuple[np.ndarray, dict], outputs: Dict, uuid: str, token_len=None + ) -> None: img, meta = img_cache num_tokens = img.shape[0] // (self.spatial_conv_size**2) + if token_len and num_tokens != token_len: + raise ValueError("image tokens num not match the size") outputs["mm_positions"].append(ImagePosition(len(outputs["input_ids"]), num_tokens)) outputs["input_ids"].extend([self.image_patch_id] * num_tokens) @@ -401,7 +522,7 @@ def _add_processed_image(self, img_cache: Tuple[np.ndarray, dict], outputs: Dict outputs["grid_thw"].append(np.array([[1, h, w]])) outputs["image_type_ids"].append(0) - def _add_video(self, frames, outputs: Dict, uuid: Optional[str]) -> None: + def _add_video(self, frames, outputs: Dict, uuid: Optional[str], token_len=None) -> None: patches_h, patches_w = self.image_preprocessor.get_smarted_resize( frames[0].height, frames[0].width, @@ -410,6 +531,8 @@ def _add_video(self, frames, outputs: Dict, uuid: Optional[str]) -> None: )[1] num_frames = len(frames) num_tokens = (num_frames * patches_h * patches_w) // (self.spatial_conv_size**2 * self.temporal_conv_size) + if token_len and num_tokens != token_len: + raise ValueError("video tokens num not match the size") pixel_stack = np.stack([np.array(f.convert("RGB")) for f in frames], axis=0) ret = self.image_preprocessor.preprocess( @@ -438,9 +561,13 @@ def _add_video(self, frames, outputs: Dict, uuid: Optional[str]) -> None: outputs["position_ids"].extend(pos_ids) outputs["cur_position"] = np.max(pos_ids) + 1 - def _add_processed_video(self, frames_cache: Tuple[np.ndarray, dict], outputs: Dict, uuid: str) -> None: + def _add_processed_video( + self, frames_cache: Tuple[np.ndarray, dict], outputs: Dict, uuid: str, token_len=None + ) -> None: frames, meta = frames_cache num_tokens = frames.shape[0] // (self.spatial_conv_size**2 * self.temporal_conv_size) + if token_len and num_tokens != token_len: + raise ValueError("video tokens num not match the size") t, h, w = meta["thw"] outputs["images"].append(frames) diff --git a/tests/input/test_ernie_vl_processor.py b/tests/input/test_ernie_vl_processor.py index 92d24d5b96f..ee4d0f195f8 100644 --- a/tests/input/test_ernie_vl_processor.py +++ b/tests/input/test_ernie_vl_processor.py @@ -1,7 +1,15 @@ import unittest from unittest.mock import MagicMock, patch +import numpy as np + +from fastdeploy.input.ernie4_5_tokenizer import Ernie4_5Tokenizer from fastdeploy.input.ernie4_5_vl_processor import Ernie4_5_VLProcessor +from fastdeploy.input.ernie4_5_vl_processor.image_preprocessor.image_preprocessor_adaptive import ( + AdaptiveImageProcessor, +) +from fastdeploy.input.ernie4_5_vl_processor.process import DataProcessor +from fastdeploy.input.utils import IDS_TYPE_FLAG class TestErnie4_5_vl_ProcessorProcessResponseDictStreaming(unittest.TestCase): @@ -77,7 +85,7 @@ def test_process_request_dict_with_options(self): "prompt_token_ids": [1, 1, 1], } self.processor.process_request_dict(request_dict, 100) - self.assertEqual(request_dict["enable_thinking"], False) + self.assertEqual(request_dict["enable_thinking"], True) request_dict = { "messages": [{"role": "user", "content": "Hello"}], @@ -93,7 +101,7 @@ def test_process_request_dict_with_options(self): "prompt_token_ids": [1, 1, 1], } self.processor.process_request_dict(request_dict, 100) - self.assertEqual(request_dict["enable_thinking"], False) + self.assertEqual(request_dict["enable_thinking"], True) request_dict = { "messages": [{"role": "user", "content": "Hello"}], @@ -101,7 +109,7 @@ def test_process_request_dict_with_options(self): "prompt_token_ids": [1, 1, 1], } self.processor.process_request_dict(request_dict, 100) - self.assertEqual(request_dict["enable_thinking"], False) + self.assertEqual(request_dict["enable_thinking"], True) request_dict = { "messages": [{"role": "user", "content": "Hello"}], @@ -111,6 +119,446 @@ def test_process_request_dict_with_options(self): self.processor.process_request_dict(request_dict, 100) self.assertEqual(request_dict["enable_thinking"], True) + request_dict = { + "messages": [{"role": "user", "content": "Hello"}], + "chat_template_kwargs": {"options": {"thinking_mode": "close"}}, + } + self.processor.process_request_dict(request_dict, 100) + self.assertEqual(request_dict["enable_thinking"], False) + + request_dict = { + "messages": [{"role": "user", "content": "Hello"}], + "chat_template_kwargs": {"options": {"thinking_mode": "false"}}, + } + self.processor.process_request_dict(request_dict, 100) + self.assertEqual(request_dict["enable_thinking"], False) + + request_dict = { + "messages": [{"role": "user", "content": "Hello"}], + "chat_template_kwargs": {"enable_thinking": False}, + } + self.processor.process_request_dict(request_dict, 100) + self.assertEqual(request_dict["enable_thinking"], False) + + +class TestDataProcessorTargetMethods(unittest.TestCase): + def setUp(self): + self.mock_tokenizer = MagicMock(spec=Ernie4_5Tokenizer) + self.mock_tokenizer.ignored_index = -100 + self.mock_tokenizer.convert_tokens_to_ids.side_effect = self._mock_convert_tokens_to_ids + self.mock_tokenizer.chat_template = "mock_template" + self.mock_tokenizer.apply_chat_template.return_value = "User: Hello<|image@placeholder|>" + + def mock_load_tokenizer(dp_instance): + dp_instance.tokenizer = self.mock_tokenizer + + with patch.object(DataProcessor, "_load_tokenizer", side_effect=mock_load_tokenizer, autospec=True): + with patch.object(AdaptiveImageProcessor, "from_pretrained") as mock_image_preprocessor: + mock_image_preprocessor.return_value = MagicMock() + self.data_processor = DataProcessor( + tokenizer_name="mock_tokenizer", + image_preprocessor_name="mock_image_preprocessor", + enable_processor_cache=False, + ) + self.data_processor.image_patch_id = 1001 + self.data_processor.image_start_id = 1002 + self.data_processor.image_end_id = 1003 + self.data_processor.video_start_id = 1004 + self.data_processor.video_end_id = 1005 + self.data_processor.role_prefixes = {"user": "User: ", "assistant": "Assistant: "} + self.data_processor.enable_processor_cache = False + self.data_processor.extract_mm_items = MagicMock(return_value=([], [], [], [], None, [], [])) + + def _mock_convert_tokens_to_ids(self, token): + token_id_map = { + "<|begin_of_sentence|>": 101, + "<|end_of_sentence|>": 102, + "": 103, + "<|IMAGE_PLACEHOLDER|>": 1001, + "<|IMAGE_START|>": 1002, + "<|IMAGE_END|>": 1003, + "<|VIDEO_START|>": 1004, + "<|VIDEO_END|>": 1005, + } + return token_id_map.get(token, 999) + + def test_prompt_token_ids2outputs_only_prompt_token_ids(self): + test_prompt_token_ids = [101, 999, 998, 997, 102] + request = { + "prompt_token_ids": test_prompt_token_ids, + } + + outputs = self.data_processor.prompt_token_ids2outputs(request) + + prompt_len = len(test_prompt_token_ids) + + self.assertEqual( + outputs["input_ids"], + test_prompt_token_ids, + f"input_ids 不匹配:实际{outputs['input_ids']},预期[{test_prompt_token_ids}]", + ) + + self.assertEqual(outputs["token_type_ids"], [IDS_TYPE_FLAG["text"]] * prompt_len) + + expected_position_ids = [[i] * 3 for i in range(prompt_len)] + self.assertEqual(outputs["position_ids"], expected_position_ids) + + self.assertEqual(outputs["cur_position"], prompt_len) + + self.assertEqual(len(outputs["images"]), 0) + self.assertEqual(len(outputs["grid_thw"]), 0) + self.assertEqual(len(outputs["mm_positions"]), 0) + self.assertEqual(len(outputs["mm_hashes"]), 0) + self.assertEqual(outputs["video_cnt"], 0) + self.assertEqual(outputs["num_input_image_tokens"], 0) + self.assertEqual(outputs["num_input_video_tokens"], 0) + + def test_prompt_token_ids2outputs_with_messages_no_mm(self): + test_prompt_token_ids = [101, 999, 998, 997, 102] + request = { + "prompt_token_ids": test_prompt_token_ids, + "messages": [{"role": "user", "content": "Hello World"}], + } + + self.data_processor.extract_mm_items.return_value = ([], [], [], [], None, [], []) + + outputs = self.data_processor.prompt_token_ids2outputs(request) + + prompt_len = len(test_prompt_token_ids) + + self.assertEqual(outputs["input_ids"], test_prompt_token_ids) + + self.assertEqual(outputs["token_type_ids"], [IDS_TYPE_FLAG["text"]] * prompt_len) + + expected_position_ids = [[i] * 3 for i in range(prompt_len)] + self.assertEqual(outputs["position_ids"], expected_position_ids) + + self.assertEqual(outputs["cur_position"], prompt_len) + + self.assertEqual(len(outputs["images"]), 0) + self.assertEqual(outputs["video_cnt"], 0) + self.assertEqual(outputs["num_input_image_tokens"], 0) + + def test_prompt_token_ids2outputs_add_image(self): + test_prompt_token_ids = [101, 1002, 1001, 1001, 1003, 102] + mock_img = MagicMock() + mock_img.height = 224 + mock_img.width = 224 + mock_img.convert.return_value = mock_img + request = { + "prompt_token_ids": test_prompt_token_ids, + "messages": [ + {"role": "user", "content": [{"type": "image_url", "image_url": mock_img, "uuid": "img_uuid"}]} + ], + } + self.data_processor.extract_mm_items.return_value = ( + [mock_img], + [], + ["img_uuid"], + [], + None, + [], + [{"type": "image", "data": mock_img}], + ) + mock_resize = (None, (2, 4)) + self.data_processor.image_preprocessor.get_smarted_resize.return_value = mock_resize + mock_preprocess = {"pixel_values": np.random.randn(1, 16, 16, 3), "image_grid_thw": np.array([[2, 4]])} + self.data_processor.image_preprocessor.preprocess.return_value = mock_preprocess + # self.data_processor._compute_3d_positions = MagicMock(return_value=[[i]*3 for i in range(4)]) + outputs = self.data_processor.prompt_token_ids2outputs(request) + self.assertEqual(outputs["input_ids"], [101, 1002, 1001, 1001, 1003, 102]) + self.assertEqual( + outputs["token_type_ids"], + [ + IDS_TYPE_FLAG["text"], + IDS_TYPE_FLAG["text"], + IDS_TYPE_FLAG["image"], + IDS_TYPE_FLAG["image"], + IDS_TYPE_FLAG["text"], + IDS_TYPE_FLAG["text"], + ], + ) + self.assertEqual(len(outputs["position_ids"]), 6) + self.assertEqual(outputs["cur_position"], 6) + self.assertEqual(len(outputs["images"]), 1) + self.assertIsNotNone(outputs["images"][0]) + self.assertEqual(outputs["num_input_image_tokens"], 2) + self.assertEqual(len(outputs["mm_positions"]), 1) + self.assertEqual(len(outputs["mm_hashes"]), 1) + self.assertEqual(len(outputs["grid_thw"]), 1) + self.assertEqual(len(outputs["image_type_ids"]), 1) + + def test_prompt_token_ids2outputs_add_processed_image(self): + test_prompt_token_ids = [101, 1002, 1001, 1001, 1003, 102] + mock_img_data = np.random.randn(8, 28, 28) + mock_img_cache = (mock_img_data, {"thw": (1, 8, 8)}) + request = { + "prompt_token_ids": test_prompt_token_ids, + "messages": [ + {"role": "user", "content": [{"type": "image_url", "image_url": mock_img_cache, "uuid": "img_uuid"}]} + ], + } + self.data_processor.extract_mm_items.return_value = ( + [mock_img_cache], + [], + ["img_uuid"], + [], + None, + [], + [{"type": "image", "data": mock_img_cache}], + ) + outputs = self.data_processor.prompt_token_ids2outputs(request) + self.assertEqual(outputs["input_ids"], [101, 1002, 1001, 1001, 1003, 102]) + self.assertEqual( + outputs["token_type_ids"], + [ + IDS_TYPE_FLAG["text"], + IDS_TYPE_FLAG["text"], + IDS_TYPE_FLAG["image"], + IDS_TYPE_FLAG["image"], + IDS_TYPE_FLAG["text"], + IDS_TYPE_FLAG["text"], + ], + ) + self.assertEqual(len(outputs["position_ids"]), 20) + self.assertEqual(outputs["cur_position"], 8) + self.assertEqual(len(outputs["images"]), 1) + self.assertIsNotNone(outputs["images"][0]) + self.assertEqual(len(outputs["mm_positions"]), 1) + self.assertEqual(outputs["mm_hashes"][0], "img_uuid") + self.assertEqual(len(outputs["grid_thw"]), 1) + self.assertEqual(len(outputs["image_type_ids"]), 1) + + def test_prompt_token_ids2outputs_add_video(self): + test_prompt_token_ids = [101, 1004, 1001, 1001, 1001, 1001, 1005, 102] + mock_frame1 = MagicMock() + mock_frame1.height = 224 + mock_frame1.width = 224 + mock_frame1.convert.return_value = mock_frame1 + mock_frame2 = MagicMock() + mock_frame2.height = 224 + mock_frame2.width = 224 + mock_frame2.convert.return_value = mock_frame2 + frames = [mock_frame1, mock_frame2] + request = { + "prompt_token_ids": test_prompt_token_ids, + "messages": [ + {"role": "user", "content": [{"type": "video_url", "video_url": frames, "uuid": "vid_uuid"}]} + ], + } + self.data_processor.extract_mm_items.return_value = ( + [], + [frames], + [], + ["vid_uuid"], + None, + [], + [{"type": "video", "data": frames}], + ) + self.data_processor._load_and_process_video = MagicMock(return_value=frames) + patches_h, patches_w = 4, 4 + self.data_processor.image_preprocessor.get_smarted_resize.return_value = (None, (patches_h, patches_w)) + mock_preprocess = { + "pixel_values_videos": np.random.randn(2, patches_h, patches_w, 3), + "video_grid_thw": np.array([[patches_h, patches_w]] * 2), + } + self.data_processor.image_preprocessor.preprocess.return_value = mock_preprocess + outputs = self.data_processor.prompt_token_ids2outputs(request) + self.assertEqual(outputs["input_ids"], [101, 1004, 1001, 1001, 1001, 1001, 1005, 102]) + self.assertEqual( + outputs["token_type_ids"], + [ + IDS_TYPE_FLAG["text"], + IDS_TYPE_FLAG["text"], + IDS_TYPE_FLAG["video"], + IDS_TYPE_FLAG["video"], + IDS_TYPE_FLAG["video"], + IDS_TYPE_FLAG["video"], + IDS_TYPE_FLAG["text"], + IDS_TYPE_FLAG["text"], + ], + ) + self.assertEqual(len(outputs["position_ids"]), 8) + self.assertEqual(outputs["cur_position"], 6) + self.assertEqual(len(outputs["images"]), 1) + self.assertIsNotNone(outputs["images"][0]) + self.assertEqual(len(outputs["mm_positions"]), 1) + self.assertEqual(outputs["mm_hashes"][0], "vid_uuid") + self.assertEqual(len(outputs["grid_thw"]), 1) + self.assertEqual(len(outputs["image_type_ids"]), 2) + self.assertEqual(outputs["num_input_video_tokens"], 4) + + def test_prompt_token_ids2outputs_add_processed_video(self): + test_prompt_token_ids = [101, 1004, 1001, 1001, 1001, 1001, 1005, 102] + t, h, w = 2, 4, 4 + spatial_conv_size = self.data_processor.spatial_conv_size + temporal_conv_size = self.data_processor.temporal_conv_size + token_per_frame = (h // spatial_conv_size) * (w // spatial_conv_size) + num_tokens = (t // temporal_conv_size) * token_per_frame + mock_frames_data = np.random.randn(num_tokens * spatial_conv_size**2 * temporal_conv_size, 28, 28) + mock_frames_cache = (mock_frames_data, {"thw": (t, h, w)}) + request = { + "prompt_token_ids": test_prompt_token_ids, + "messages": [ + {"role": "user", "content": [{"type": "video", "data": mock_frames_cache, "uuid": "vid_uuid"}]} + ], + } + self.data_processor.extract_mm_items.return_value = ( + [], + [mock_frames_cache], + [], + ["vid_uuid"], + None, + [], + [{"type": "video", "data": mock_frames_cache}], + ) + outputs = self.data_processor.prompt_token_ids2outputs(request) + self.assertEqual(outputs["input_ids"], [101, 1004, 1001, 1001, 1001, 1001, 1005, 102]) + self.assertEqual( + outputs["token_type_ids"], + [ + IDS_TYPE_FLAG["text"], + IDS_TYPE_FLAG["text"], + IDS_TYPE_FLAG["video"], + IDS_TYPE_FLAG["video"], + IDS_TYPE_FLAG["video"], + IDS_TYPE_FLAG["video"], + IDS_TYPE_FLAG["text"], + IDS_TYPE_FLAG["text"], + ], + ) + self.assertEqual(len(outputs["position_ids"]), 8) + self.assertEqual(outputs["cur_position"], 6) + self.assertEqual(len(outputs["images"]), 1) + self.assertIsNotNone(outputs["images"][0]) + self.assertEqual(len(outputs["mm_positions"]), 1) + self.assertEqual(outputs["mm_hashes"][0], "vid_uuid") + self.assertEqual(len(outputs["grid_thw"]), 1) + self.assertEqual(len(outputs["image_type_ids"]), 2) + + def test_prompt_token_ids2outputs_add_image_token_len_mismatch(self): + test_prompt_token_ids = [101, 1002, 1001, 1001, 1001, 1003, 102] + mock_img = MagicMock() + mock_img.height = 224 + mock_img.width = 224 + mock_img.convert.return_value = mock_img + request = { + "prompt_token_ids": test_prompt_token_ids, + "messages": [ + {"role": "user", "content": [{"type": "image_url", "image_url": mock_img, "uuid": "img_uuid"}]} + ], + } + self.data_processor.extract_mm_items.return_value = ( + [mock_img], + [], + ["img_uuid"], + [], + None, + [], + [{"type": "image", "data": mock_img}], + ) + patches_h, patches_w = 8, 8 + self.data_processor.image_preprocessor.get_smarted_resize.return_value = (None, (patches_h, patches_w)) + mock_preprocess = { + "pixel_values": np.random.randn(1, patches_h, patches_w, 3), + "image_grid_thw": np.array([[patches_h, patches_w]]), + } + self.data_processor.image_preprocessor.preprocess.return_value = mock_preprocess + with self.assertRaises(ValueError) as ctx: + self.data_processor.prompt_token_ids2outputs(request) + self.assertIn("image tokens num not match the size", str(ctx.exception)) + + def test_prompt_token_ids2outputs_add_processed_image_token_len_mismatch(self): + test_prompt_token_ids = [101, 1002, 1001, 1001, 1003, 102] + spatial_conv_size = self.data_processor.spatial_conv_size + num_tokens = 4 + mock_img_data = np.random.randn(num_tokens * (spatial_conv_size**2), 28, 28) + mock_img_cache = (mock_img_data, {"thw": (1, 8, 8)}) + request = { + "prompt_token_ids": test_prompt_token_ids, + "messages": [ + {"role": "user", "content": [{"type": "image_url", "image_url": mock_img_cache, "uuid": "img_uuid"}]} + ], + } + self.data_processor.extract_mm_items.return_value = ( + [mock_img_cache], + [], + ["img_uuid"], + [], + None, + [], + [{"type": "image", "data": mock_img_cache}], + ) + with self.assertRaises(ValueError) as ctx: + self.data_processor.prompt_token_ids2outputs(request) + self.assertIn("image tokens num not match the size", str(ctx.exception)) + + def test_prompt_token_ids2outputs_add_video_token_len_mismatch(self): + test_prompt_token_ids = [101, 1004, 1001, 1001, 1005, 102] + mock_frame1 = MagicMock() + mock_frame1.height = 224 + mock_frame1.width = 224 + mock_frame1.convert.return_value = mock_frame1 + mock_frame2 = MagicMock() + mock_frame2.height = 224 + mock_frame2.width = 224 + mock_frame2.convert.return_value = mock_frame2 + frames = [mock_frame1, mock_frame2] + request = { + "prompt_token_ids": test_prompt_token_ids, + "messages": [ + {"role": "user", "content": [{"type": "video_url", "video_url": frames, "uuid": "vid_uuid"}]} + ], + } + self.data_processor.extract_mm_items.return_value = ( + [], + [frames], + [], + ["vid_uuid"], + None, + [], + [{"type": "video", "data": frames}], + ) + self.data_processor._load_and_process_video = MagicMock(return_value=frames) + patches_h, patches_w = 8, 8 + self.data_processor.image_preprocessor.get_smarted_resize.return_value = (None, (patches_h, patches_w)) + mock_preprocess = { + "pixel_values_videos": np.random.randn(2, patches_h, patches_w, 3), + "video_grid_thw": np.array([[patches_h, patches_w]] * 2), + } + self.data_processor.image_preprocessor.preprocess.return_value = mock_preprocess + with self.assertRaises(ValueError) as ctx: + self.data_processor.prompt_token_ids2outputs(request) + self.assertIn("video tokens num not match the size", str(ctx.exception)) + + def test_prompt_token_ids2outputs_add_processed_video_token_len_mismatch(self): + test_prompt_token_ids = [101, 1004, 1001, 1005, 102] + t, h, w = 2, 8, 8 + spatial_conv_size = self.data_processor.spatial_conv_size + temporal_conv_size = self.data_processor.temporal_conv_size + + num_tokens = 4 + mock_frames_data = np.random.randn(num_tokens * spatial_conv_size**2 * temporal_conv_size, 28, 28) + mock_frames_cache = (mock_frames_data, {"thw": (t, h, w)}) + request = { + "prompt_token_ids": test_prompt_token_ids, + "messages": [ + {"role": "user", "content": [{"type": "video", "data": mock_frames_cache, "uuid": "vid_uuid"}]} + ], + } + self.data_processor.extract_mm_items.return_value = ( + [], + [mock_frames_cache], + [], + ["vid_uuid"], + None, + [], + [{"type": "video", "data": mock_frames_cache}], + ) + with self.assertRaises(ValueError) as ctx: + self.data_processor.prompt_token_ids2outputs(request) + self.assertIn("video tokens num not match the size", str(ctx.exception)) + if __name__ == "__main__": unittest.main()