diff --git a/docs/en/multi_modal/index.rst b/docs/en/multi_modal/index.rst index a041172edb..ac0e649244 100644 --- a/docs/en/multi_modal/index.rst +++ b/docs/en/multi_modal/index.rst @@ -1,6 +1,12 @@ Vision-Language Models ================================= +.. toctree:: + :maxdepth: 2 + :caption: Guides + + multimodal_inputs.md + .. toctree:: :maxdepth: 2 :caption: Examples diff --git a/docs/en/multi_modal/multimodal_inputs.md b/docs/en/multi_modal/multimodal_inputs.md new file mode 100644 index 0000000000..958888ebbc --- /dev/null +++ b/docs/en/multi_modal/multimodal_inputs.md @@ -0,0 +1,613 @@ +# Multi-Modal Inputs + +LMDeploy uses the OpenAI message format for all modalities. Each content item in a message is a dict with a `type` field that determines how it is decoded. + +**Quick reference:** + +| Modality | `type` key | URL field | +| ----------- | ----------------- | --------------------- | +| Text | `text` | — | +| Image | `image_url` | `image_url.url` | +| Video | `video_url` | `video_url.url` | +| Audio | `audio_url` | `audio_url.url` | +| Time Series | `time_series_url` | `time_series_url.url` | + +All examples below target the lmdeploy OpenAI-compatible API server. Start it with: + +```bash +lmdeploy serve api_server --server-port 23333 +``` + +______________________________________________________________________ + +## Text + +
+Complete example + +```python +from openai import OpenAI + +client = OpenAI(api_key='EMPTY', base_url='http://localhost:23333/v1') +model_name = client.models.list().data[0].id + +response = client.chat.completions.create( + model=model_name, + messages=[{ + 'role': 'user', + 'content': [{ + 'type': 'text', + 'text': 'Who are you?', + }], + }], + temperature=0.8, + top_p=0.8, +) +print(response.choices[0].message.content) +``` + +
+ +______________________________________________________________________ + +## Single Image + +
+Complete example + +```python +from openai import OpenAI + +client = OpenAI(api_key='EMPTY', base_url='http://localhost:23333/v1') +model_name = client.models.list().data[0].id + +response = client.chat.completions.create( + model=model_name, + messages=[{ + 'role': 'user', + 'content': [ + { + 'type': 'image_url', + 'image_url': { + 'url': 'https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg', + }, + }, + { + 'type': 'text', + 'text': 'Describe this image.', + }, + ], + }], + temperature=0.8, + top_p=0.8, +) +print(response.choices[0].message.content) +``` + +
+ +______________________________________________________________________ + +## Multiple Images + +
+Complete example + +```python +from openai import OpenAI + +client = OpenAI(api_key='EMPTY', base_url='http://localhost:23333/v1') +model_name = client.models.list().data[0].id + +response = client.chat.completions.create( + model=model_name, + messages=[{ + 'role': 'user', + 'content': [ + { + 'type': 'image_url', + 'image_url': { + 'url': 'https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg', + }, + }, + { + 'type': 'image_url', + 'image_url': { + 'url': 'https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg', + }, + }, + { + 'type': 'text', + 'text': 'Compare these two images. What are the similarities and differences?', + }, + ], + }], + temperature=0.8, + top_p=0.8, +) +print(response.choices[0].message.content) +``` + +
+ +______________________________________________________________________ + +## Single Video + +> **Note:** Native video input is currently supported for **Qwen3-VL**, **Qwen3.5**, and **InternS1-Pro** models only. + +
+Complete example + +```python +from openai import OpenAI + +client = OpenAI(api_key='EMPTY', base_url='http://localhost:23333/v1') +model_name = client.models.list().data[0].id + +video_url = 'https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2-VL/space_woaudio.mp4' + +response = client.chat.completions.create( + model=model_name, + messages=[{ + 'role': 'user', + 'content': [ + { + 'type': 'video_url', + 'video_url': { + 'url': video_url, + }, + }, + { + 'type': 'text', + 'text': "What's in this video?", + }, + ], + }], + temperature=0.8, + top_p=0.8, + max_completion_tokens=256, +) +print(response.choices[0].message.content) +``` + +
+ +______________________________________________________________________ + +## Multiple Videos + +> **Note:** Native video input is currently supported for **Qwen3-VL**, **Qwen3.5**, and **InternS1-Pro** models only. + +
+Complete example + +```python +from openai import OpenAI + +client = OpenAI(api_key='EMPTY', base_url='http://localhost:23333/v1') +model_name = client.models.list().data[0].id + +video_url_1 = 'https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2-VL/space_woaudio.mp4' +video_url_2 = 'https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2-VL/space_woaudio.mp4' + +response = client.chat.completions.create( + model=model_name, + messages=[{ + 'role': 'user', + 'content': [ + { + 'type': 'video_url', + 'video_url': {'url': video_url_1}, + }, + { + 'type': 'video_url', + 'video_url': {'url': video_url_2}, + }, + { + 'type': 'text', + 'text': 'Compare these two videos. What are the similarities and differences?', + }, + ], + }], + temperature=0.8, + top_p=0.8, + max_completion_tokens=256, +) +print(response.choices[0].message.content) +``` + +
+ +______________________________________________________________________ + +## Single Audio + +
+Complete example + +```python +from openai import OpenAI + +client = OpenAI(api_key='EMPTY', base_url='http://localhost:23333/v1') +model_name = client.models.list().data[0].id + +response = client.chat.completions.create( + model=model_name, + messages=[{ + 'role': 'user', + 'content': [ + { + 'type': 'audio_url', + 'audio_url': { + 'url': 'https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen3-Omni/demo/cough.wav', + }, + }, + { + 'type': 'text', + 'text': 'Describe this audio.', + }, + ], + }], + temperature=0.8, + top_p=0.8, +) +print(response.choices[0].message.content) +``` + +
+ +______________________________________________________________________ + +## Multiple Audios + +
+Complete example + +```python +from openai import OpenAI + +client = OpenAI(api_key='EMPTY', base_url='http://localhost:23333/v1') +model_name = client.models.list().data[0].id + +audio_url_1 = 'https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen3-Omni/demo/cough.wav' +audio_url_2 = 'https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen3-Omni/demo/cough.wav' + +response = client.chat.completions.create( + model=model_name, + messages=[{ + 'role': 'user', + 'content': [ + {'type': 'audio_url', 'audio_url': {'url': audio_url_1}}, + {'type': 'audio_url', 'audio_url': {'url': audio_url_2}}, + { + 'type': 'text', + 'text': 'Compare these two audios. What are the similarities and differences?', + }, + ], + }], + temperature=0.8, + top_p=0.8, +) +print(response.choices[0].message.content) +``` + +
+ +______________________________________________________________________ + +## Mixed Image and Video + +> **Note:** Native video input is currently supported for **Qwen3-VL**, **Qwen3.5**, and **InternS1-Pro** models only. + +
+Complete example + +```python +from openai import OpenAI + +client = OpenAI(api_key='EMPTY', base_url='http://localhost:23333/v1') +model_name = client.models.list().data[0].id + +image_url = 'https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg' +video_url = 'https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2-VL/space_woaudio.mp4' + +response = client.chat.completions.create( + model=model_name, + messages=[{ + 'role': 'user', + 'content': [ + { + 'type': 'image_url', + 'image_url': {'url': image_url}, + }, + { + 'type': 'video_url', + 'video_url': {'url': video_url}, + }, + { + 'type': 'text', + 'text': 'Describe both the image and the video.', + }, + ], + }], + temperature=0.8, + top_p=0.8, + max_completion_tokens=256, +) +print(response.choices[0].message.content) +``` + +
+ +______________________________________________________________________ + +## Time Series + +> **Note:** Time series input is currently supported for the **InternS1-Pro** model only. + +The `time_series_url` content item requires a `sampling_rate` field (in Hz) alongside the URL. + +
+Complete example + +```python +from openai import OpenAI + +client = OpenAI(api_key='EMPTY', base_url='http://localhost:23333/v1') +model_name = client.models.list().data[0].id + +response = client.chat.completions.create( + model=model_name, + messages=[{ + 'role': 'user', + 'content': [ + { + 'type': 'text', + 'text': ('Please determine whether an Earthquake event has occurred. ' + 'If so, specify P-wave and S-wave starting indices.'), + }, + { + 'type': 'time_series_url', + 'time_series_url': { + 'url': 'https://raw.githubusercontent.com/CUHKSZzxy/Online-Data/main/0092638_seism.npy', + 'sampling_rate': 100, + }, + }, + ], + }], + temperature=0.8, + top_p=0.8, + max_completion_tokens=256, +) +print(response.choices[0].message.content) +``` + +
+ +______________________________________________________________________ + +## Local Files and Base64 + +In addition to HTTP URLs, lmdeploy accepts: + +- **Local file paths** via `file://` scheme: `file:///absolute/path/to/file.jpg` +- **Base64-encoded data** via data URLs: `data:;base64,` + +Use the helpers in `lmdeploy.vl.utils` to encode local files: + +
+Local file path example + +```python +from openai import OpenAI + +client = OpenAI(api_key='EMPTY', base_url='http://localhost:23333/v1') +model_name = client.models.list().data[0].id + +response = client.chat.completions.create( + model=model_name, + messages=[{ + 'role': 'user', + 'content': [ + { + 'type': 'image_url', + 'image_url': { + 'url': 'file:///path/to/your/image.jpg', + }, + }, + {'type': 'text', 'text': 'Describe this image.'}, + ], + }], +) +print(response.choices[0].message.content) +``` + +
+ +
+Base64 encoding example (image) + +```python +from openai import OpenAI +from lmdeploy.vl.utils import encode_image_base64 + +client = OpenAI(api_key='EMPTY', base_url='http://localhost:23333/v1') +model_name = client.models.list().data[0].id + +b64 = encode_image_base64('/path/to/your/image.jpg') +image_url = f'data:image/jpeg;base64,{b64}' + +response = client.chat.completions.create( + model=model_name, + messages=[{ + 'role': 'user', + 'content': [ + { + 'type': 'image_url', + 'image_url': {'url': image_url}, + }, + {'type': 'text', 'text': 'Describe this image.'}, + ], + }], +) +print(response.choices[0].message.content) +``` + +
+ +
+Base64 encoding example (video) + +```python +from openai import OpenAI +from lmdeploy.vl.utils import encode_video_base64 + +client = OpenAI(api_key='EMPTY', base_url='http://localhost:23333/v1') +model_name = client.models.list().data[0].id + +# num_frames controls how many frames to sample before encoding +b64 = encode_video_base64('/path/to/your/video.mp4', num_frames=16) +video_url = f'data:video/mp4;base64,{b64}' + +response = client.chat.completions.create( + model=model_name, + messages=[{ + 'role': 'user', + 'content': [ + { + 'type': 'video_url', + 'video_url': {'url': video_url}, + }, + {'type': 'text', 'text': 'Describe this video.'}, + ], + }], +) +print(response.choices[0].message.content) +``` + +
+ +
+Base64 encoding example (time series) + +```python +from openai import OpenAI +from lmdeploy.vl.utils import encode_time_series_base64 + +client = OpenAI(api_key='EMPTY', base_url='http://localhost:23333/v1') +model_name = client.models.list().data[0].id + +b64 = encode_time_series_base64('/path/to/your/data.npy') +ts_url = f'data:application/octet-stream;base64,{b64}' + +response = client.chat.completions.create( + model=model_name, + messages=[{ + 'role': 'user', + 'content': [ + {'type': 'text', 'text': 'Analyze this time series.'}, + { + 'type': 'time_series_url', + 'time_series_url': { + 'url': ts_url, + 'sampling_rate': 100, + }, + }, + ], + }], +) +print(response.choices[0].message.content) +``` + +
+ +______________________________________________________________________ + +## Processor and IO kwargs + +Two optional parameters let you control media processing: + +- **`mm_processor_kwargs`**: controls vision token resolution (min/max pixels per image or video frame) +- **`media_io_kwargs`**: controls how media is loaded (e.g. video frame sampling rate and count) + +Both are passed as extra fields in the API request body via `extra_body`, or directly to `pipe()` when using the pipeline API. + +
+API server example (extra_body) + +```python +from openai import OpenAI + +client = OpenAI(api_key='EMPTY', base_url='http://localhost:23333/v1') +model_name = client.models.list().data[0].id + +video_url = 'https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2-VL/space_woaudio.mp4' + +response = client.chat.completions.create( + model=model_name, + messages=[{ + 'role': 'user', + 'content': [ + {'type': 'video_url', 'video_url': {'url': video_url}}, + {'type': 'text', 'text': 'Describe this video.'}, + ], + }], + max_completion_tokens=256, + extra_body={ + 'mm_processor_kwargs': { + 'video': { + 'min_pixels': 4 * 32 * 32, + 'max_pixels': 256 * 32 * 32, + }, + }, + 'media_io_kwargs': { + 'video': { + 'num_frames': 16, + 'fps': 2, + }, + }, + }, +) +print(response.choices[0].message.content) +``` + +
+ +
+Pipeline API equivalent + +```python +from lmdeploy import pipeline, PytorchEngineConfig + +pipe = pipeline('', backend_config=PytorchEngineConfig(tp=1)) + +video_url = 'https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2-VL/space_woaudio.mp4' + +messages = [{ + 'role': 'user', + 'content': [ + {'type': 'video_url', 'video_url': {'url': video_url}}, + {'type': 'text', 'text': 'Describe this video.'}, + ], +}] + +response = pipe( + messages, + mm_processor_kwargs={ + 'video': { + 'min_pixels': 4 * 32 * 32, + 'max_pixels': 256 * 32 * 32, + }, + }, + media_io_kwargs={ + 'video': { + 'num_frames': 16, + 'fps': 2, + }, + }, +) +print(response) +``` + +
diff --git a/docs/en/multi_modal/vl_pipeline.md b/docs/en/multi_modal/vl_pipeline.md index 1e73fb2e06..4972ba91d5 100644 --- a/docs/en/multi_modal/vl_pipeline.md +++ b/docs/en/multi_modal/vl_pipeline.md @@ -10,6 +10,8 @@ Moreover, we will provide practical inference examples tailored to scenarios wit Using the pipeline interface to infer other VLM models is similar, with the main difference being the configuration and installation dependencies of the models. You can read [here](https://lmdeploy.readthedocs.io/en/latest/multi_modal/index.html) for environment installation and configuration methods for different models. +> **See also:** [Multi-Modal Inputs](multimodal_inputs.md) — message format reference for all modalities (image, video, audio, time series) with OpenAI-style examples. + ## A 'Hello, world' example ```python diff --git a/docs/zh_cn/multi_modal/index.rst b/docs/zh_cn/multi_modal/index.rst index 9a61f6efdb..ed33bba8d3 100644 --- a/docs/zh_cn/multi_modal/index.rst +++ b/docs/zh_cn/multi_modal/index.rst @@ -1,6 +1,12 @@ 视觉语言模型 ================================= +.. toctree:: + :maxdepth: 2 + :caption: 指南 + + multimodal_inputs.md + .. toctree:: :maxdepth: 2 :caption: 示例 diff --git a/docs/zh_cn/multi_modal/multimodal_inputs.md b/docs/zh_cn/multi_modal/multimodal_inputs.md new file mode 100644 index 0000000000..35a2769f14 --- /dev/null +++ b/docs/zh_cn/multi_modal/multimodal_inputs.md @@ -0,0 +1,612 @@ +# 多模态输入 + +LMDeploy 使用 OpenAI 消息格式处理所有模态。消息中的每个内容项都是一个包含 `type` 字段的字典,该字段决定了数据的解码方式。 + +**快速参考:** + +| 模态 | `type` 字段 | URL 字段 | +| -------- | ----------------- | --------------------- | +| 文本 | `text` | — | +| 图像 | `image_url` | `image_url.url` | +| 视频 | `video_url` | `video_url.url` | +| 音频 | `audio_url` | `audio_url.url` | +| 时序数据 | `time_series_url` | `time_series_url.url` | + +以下示例均面向 lmdeploy 兼容 OpenAI 的 API 服务。启动服务: + +```bash +lmdeploy serve api_server --server-port 23333 +``` + +______________________________________________________________________ + +## 纯文本 + +
+完整示例 + +```python +from openai import OpenAI + +client = OpenAI(api_key='EMPTY', base_url='http://localhost:23333/v1') +model_name = client.models.list().data[0].id + +response = client.chat.completions.create( + model=model_name, + messages=[{ + 'role': 'user', + 'content': [{ + 'type': 'text', + 'text': '你是谁?', + }], + }], + temperature=0.8, + top_p=0.8, +) +print(response.choices[0].message.content) +``` + +
+ +______________________________________________________________________ + +## 单张图像 + +
+完整示例 + +```python +from openai import OpenAI + +client = OpenAI(api_key='EMPTY', base_url='http://localhost:23333/v1') +model_name = client.models.list().data[0].id + +response = client.chat.completions.create( + model=model_name, + messages=[{ + 'role': 'user', + 'content': [ + { + 'type': 'image_url', + 'image_url': { + 'url': 'https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg', + }, + }, + { + 'type': 'text', + 'text': '描述这张图片。', + }, + ], + }], + temperature=0.8, + top_p=0.8, +) +print(response.choices[0].message.content) +``` + +
+ +______________________________________________________________________ + +## 多张图像 + +
+完整示例 + +```python +from openai import OpenAI + +client = OpenAI(api_key='EMPTY', base_url='http://localhost:23333/v1') +model_name = client.models.list().data[0].id + +response = client.chat.completions.create( + model=model_name, + messages=[{ + 'role': 'user', + 'content': [ + { + 'type': 'image_url', + 'image_url': { + 'url': 'https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg', + }, + }, + { + 'type': 'image_url', + 'image_url': { + 'url': 'https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg', + }, + }, + { + 'type': 'text', + 'text': '比较这两张图片,有哪些相似点和不同点?', + }, + ], + }], + temperature=0.8, + top_p=0.8, +) +print(response.choices[0].message.content) +``` + +
+ +______________________________________________________________________ + +## 单个视频 + +> **注意:** 原生视频输入目前仅支持 **Qwen3-VL**、**Qwen3.5** 和 **InternS1-Pro** 模型。 + +
+完整示例 + +```python +from openai import OpenAI + +client = OpenAI(api_key='EMPTY', base_url='http://localhost:23333/v1') +model_name = client.models.list().data[0].id + +video_url = 'https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2-VL/space_woaudio.mp4' + +response = client.chat.completions.create( + model=model_name, + messages=[{ + 'role': 'user', + 'content': [ + { + 'type': 'video_url', + 'video_url': { + 'url': video_url, + }, + }, + { + 'type': 'text', + 'text': '这个视频里有什么?', + }, + ], + }], + temperature=0.8, + top_p=0.8, + max_completion_tokens=256, +) +print(response.choices[0].message.content) +``` + +
+ +______________________________________________________________________ + +## 多个视频 + +> **注意:** 原生视频输入目前仅支持 **Qwen3-VL**、**Qwen3.5** 和 **InternS1-Pro** 模型。 + +
+完整示例 + +```python +from openai import OpenAI + +client = OpenAI(api_key='EMPTY', base_url='http://localhost:23333/v1') +model_name = client.models.list().data[0].id + +video_url_1 = 'https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2-VL/space_woaudio.mp4' +video_url_2 = 'https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2-VL/space_woaudio.mp4' + +response = client.chat.completions.create( + model=model_name, + messages=[{ + 'role': 'user', + 'content': [ + { + 'type': 'video_url', + 'video_url': {'url': video_url_1}, + }, + { + 'type': 'video_url', + 'video_url': {'url': video_url_2}, + }, + { + 'type': 'text', + 'text': '比较这两个视频,有哪些相似点和不同点?', + }, + ], + }], + temperature=0.8, + top_p=0.8, + max_completion_tokens=256, +) +print(response.choices[0].message.content) +``` + +
+ +______________________________________________________________________ + +## 单个音频 + +
+完整示例 + +```python +from openai import OpenAI + +client = OpenAI(api_key='EMPTY', base_url='http://localhost:23333/v1') +model_name = client.models.list().data[0].id + +response = client.chat.completions.create( + model=model_name, + messages=[{ + 'role': 'user', + 'content': [ + { + 'type': 'audio_url', + 'audio_url': { + 'url': 'https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen3-Omni/demo/cough.wav', + }, + }, + { + 'type': 'text', + 'text': '描述这段音频。', + }, + ], + }], + temperature=0.8, + top_p=0.8, +) +print(response.choices[0].message.content) +``` + +
+ +______________________________________________________________________ + +## 多个音频 + +
+完整示例 + +```python +from openai import OpenAI + +client = OpenAI(api_key='EMPTY', base_url='http://localhost:23333/v1') +model_name = client.models.list().data[0].id + +audio_url_1 = 'https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen3-Omni/demo/cough.wav' +audio_url_2 = 'https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen3-Omni/demo/cough.wav' + +response = client.chat.completions.create( + model=model_name, + messages=[{ + 'role': 'user', + 'content': [ + {'type': 'audio_url', 'audio_url': {'url': audio_url_1}}, + {'type': 'audio_url', 'audio_url': {'url': audio_url_2}}, + { + 'type': 'text', + 'text': '比较这两段音频,有哪些相似点和不同点?', + }, + ], + }], + temperature=0.8, + top_p=0.8, +) +print(response.choices[0].message.content) +``` + +
+ +______________________________________________________________________ + +## 图像与视频混合 + +> **注意:** 原生视频输入目前仅支持 **Qwen3-VL**、**Qwen3.5** 和 **InternS1-Pro** 模型。 + +
+完整示例 + +```python +from openai import OpenAI + +client = OpenAI(api_key='EMPTY', base_url='http://localhost:23333/v1') +model_name = client.models.list().data[0].id + +image_url = 'https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg' +video_url = 'https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2-VL/space_woaudio.mp4' + +response = client.chat.completions.create( + model=model_name, + messages=[{ + 'role': 'user', + 'content': [ + { + 'type': 'image_url', + 'image_url': {'url': image_url}, + }, + { + 'type': 'video_url', + 'video_url': {'url': video_url}, + }, + { + 'type': 'text', + 'text': '描述这张图片和这个视频。', + }, + ], + }], + temperature=0.8, + top_p=0.8, + max_completion_tokens=256, +) +print(response.choices[0].message.content) +``` + +
+ +______________________________________________________________________ + +## 时序数据 + +> **注意:** 时序数据输入目前仅支持 **InternS1-Pro** 模型。 + +`time_series_url` 内容项需要在 URL 之外额外提供 `sampling_rate` 字段(单位:Hz)。 + +
+完整示例 + +```python +from openai import OpenAI + +client = OpenAI(api_key='EMPTY', base_url='http://localhost:23333/v1') +model_name = client.models.list().data[0].id + +response = client.chat.completions.create( + model=model_name, + messages=[{ + 'role': 'user', + 'content': [ + { + 'type': 'text', + 'text': '请判断是否发生了地震事件,若有请指出P波和S波的起始索引。', + }, + { + 'type': 'time_series_url', + 'time_series_url': { + 'url': 'https://raw.githubusercontent.com/CUHKSZzxy/Online-Data/main/0092638_seism.npy', + 'sampling_rate': 100, + }, + }, + ], + }], + temperature=0.8, + top_p=0.8, + max_completion_tokens=256, +) +print(response.choices[0].message.content) +``` + +
+ +______________________________________________________________________ + +## 本地文件与 Base64 + +除 HTTP URL 外,lmdeploy 还支持: + +- **本地文件路径**,使用 `file://` 协议:`file:///absolute/path/to/file.jpg` +- **Base64 编码数据**,使用 data URL:`data:;base64,` + +可使用 `lmdeploy.vl.utils` 中的工具函数对本地文件进行编码: + +
+本地文件路径示例 + +```python +from openai import OpenAI + +client = OpenAI(api_key='EMPTY', base_url='http://localhost:23333/v1') +model_name = client.models.list().data[0].id + +response = client.chat.completions.create( + model=model_name, + messages=[{ + 'role': 'user', + 'content': [ + { + 'type': 'image_url', + 'image_url': { + 'url': 'file:///path/to/your/image.jpg', + }, + }, + {'type': 'text', 'text': '描述这张图片。'}, + ], + }], +) +print(response.choices[0].message.content) +``` + +
+ +
+Base64 编码示例(图像) + +```python +from openai import OpenAI +from lmdeploy.vl.utils import encode_image_base64 + +client = OpenAI(api_key='EMPTY', base_url='http://localhost:23333/v1') +model_name = client.models.list().data[0].id + +b64 = encode_image_base64('/path/to/your/image.jpg') +image_url = f'data:image/jpeg;base64,{b64}' + +response = client.chat.completions.create( + model=model_name, + messages=[{ + 'role': 'user', + 'content': [ + { + 'type': 'image_url', + 'image_url': {'url': image_url}, + }, + {'type': 'text', 'text': '描述这张图片。'}, + ], + }], +) +print(response.choices[0].message.content) +``` + +
+ +
+Base64 编码示例(视频) + +```python +from openai import OpenAI +from lmdeploy.vl.utils import encode_video_base64 + +client = OpenAI(api_key='EMPTY', base_url='http://localhost:23333/v1') +model_name = client.models.list().data[0].id + +# num_frames 控制编码前采样的帧数 +b64 = encode_video_base64('/path/to/your/video.mp4', num_frames=16) +video_url = f'data:video/mp4;base64,{b64}' + +response = client.chat.completions.create( + model=model_name, + messages=[{ + 'role': 'user', + 'content': [ + { + 'type': 'video_url', + 'video_url': {'url': video_url}, + }, + {'type': 'text', 'text': '描述这个视频。'}, + ], + }], +) +print(response.choices[0].message.content) +``` + +
+ +
+Base64 编码示例(时序数据) + +```python +from openai import OpenAI +from lmdeploy.vl.utils import encode_time_series_base64 + +client = OpenAI(api_key='EMPTY', base_url='http://localhost:23333/v1') +model_name = client.models.list().data[0].id + +b64 = encode_time_series_base64('/path/to/your/data.npy') +ts_url = f'data:application/octet-stream;base64,{b64}' + +response = client.chat.completions.create( + model=model_name, + messages=[{ + 'role': 'user', + 'content': [ + {'type': 'text', 'text': '分析这段时序数据。'}, + { + 'type': 'time_series_url', + 'time_series_url': { + 'url': ts_url, + 'sampling_rate': 100, + }, + }, + ], + }], +) +print(response.choices[0].message.content) +``` + +
+ +______________________________________________________________________ + +## 处理器与 IO 参数 + +两个可选参数用于控制媒体处理行为: + +- **`mm_processor_kwargs`**:控制视觉 token 的分辨率(每张图片或视频帧的最小/最大像素数) +- **`media_io_kwargs`**:控制媒体加载方式(如视频帧采样率和帧数) + +两者均通过 `extra_body` 作为请求体中的额外字段传入 API,或在使用 pipeline API 时直接传给 `pipe()`。 + +
+API 服务示例(extra_body) + +```python +from openai import OpenAI + +client = OpenAI(api_key='EMPTY', base_url='http://localhost:23333/v1') +model_name = client.models.list().data[0].id + +video_url = 'https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2-VL/space_woaudio.mp4' + +response = client.chat.completions.create( + model=model_name, + messages=[{ + 'role': 'user', + 'content': [ + {'type': 'video_url', 'video_url': {'url': video_url}}, + {'type': 'text', 'text': '描述这个视频。'}, + ], + }], + max_completion_tokens=256, + extra_body={ + 'mm_processor_kwargs': { + 'video': { + 'min_pixels': 4 * 32 * 32, + 'max_pixels': 256 * 32 * 32, + }, + }, + 'media_io_kwargs': { + 'video': { + 'num_frames': 16, + 'fps': 2, + }, + }, + }, +) +print(response.choices[0].message.content) +``` + +
+ +
+Pipeline API 等价写法 + +```python +from lmdeploy import pipeline, PytorchEngineConfig + +pipe = pipeline('', backend_config=PytorchEngineConfig(tp=1)) + +video_url = 'https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2-VL/space_woaudio.mp4' + +messages = [{ + 'role': 'user', + 'content': [ + {'type': 'video_url', 'video_url': {'url': video_url}}, + {'type': 'text', 'text': '描述这个视频。'}, + ], +}] + +response = pipe( + messages, + mm_processor_kwargs={ + 'video': { + 'min_pixels': 4 * 32 * 32, + 'max_pixels': 256 * 32 * 32, + }, + }, + media_io_kwargs={ + 'video': { + 'num_frames': 16, + 'fps': 2, + }, + }, +) +print(response) +``` + +
diff --git a/docs/zh_cn/multi_modal/vl_pipeline.md b/docs/zh_cn/multi_modal/vl_pipeline.md index 67ff082964..9662bcc569 100644 --- a/docs/zh_cn/multi_modal/vl_pipeline.md +++ b/docs/zh_cn/multi_modal/vl_pipeline.md @@ -10,6 +10,8 @@ LMDeploy 把视觉-语言模型(VLM)复杂的推理过程,抽象为简单 使用 pipeline 接口推理其他 VLM 模型,大同小异,主要区别在于模型依赖的配置和安装。你可以阅读[此处](https://lmdeploy.readthedocs.io/zh-cn/latest/multi_modal/),查看不同模型的环境安装和配置方式 +> **另请参阅:** [多模态输入](multimodal_inputs.md) — 涵盖所有模态(图像、视频、音频、时序数据)的消息格式参考,包含 OpenAI 风格示例。 + ## "Hello, world" 示例 ```python diff --git a/lmdeploy/pytorch/configurations/glm4_1v.py b/lmdeploy/pytorch/configurations/glm4_1v.py new file mode 100644 index 0000000000..fbad2616ed --- /dev/null +++ b/lmdeploy/pytorch/configurations/glm4_1v.py @@ -0,0 +1,19 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .builder import AutoModelConfigBuilder +from .default import DefaultModelConfigBuilder + + +class Glm4vModelConfigBuilder(AutoModelConfigBuilder): + + @classmethod + def condition(cls, hf_config): + return hf_config.model_type == 'glm4v' + + @classmethod + def build(cls, hf_config, model_path: str = None, **kwargs): + """build.""" + bos_token_id = getattr(hf_config, 'bos_token_id', None) + hf_config.text_config.bos_token_id = bos_token_id + cfg = DefaultModelConfigBuilder.build(hf_config.text_config, model_path, **kwargs) + cfg.hf_config = hf_config + return cfg diff --git a/lmdeploy/pytorch/messages.py b/lmdeploy/pytorch/messages.py index 1ef5caba83..966bd430cd 100644 --- a/lmdeploy/pytorch/messages.py +++ b/lmdeploy/pytorch/messages.py @@ -582,7 +582,7 @@ def get_datas(self, start=0, end=-1): for modal_type, modal_datas in self.multimodals.items(): data = [] for modal_data in modal_datas: - if (modal_data.start not in test_range and modal_data.end - 1 not in test_range): + if (modal_data.start not in test_range or modal_data.end - 1 not in test_range): continue data.append(modal_data) if len(data) > 0: diff --git a/lmdeploy/pytorch/models/glm4_1v.py b/lmdeploy/pytorch/models/glm4_1v.py index bc7be1a07b..7ab240208b 100644 --- a/lmdeploy/pytorch/models/glm4_1v.py +++ b/lmdeploy/pytorch/models/glm4_1v.py @@ -2,22 +2,23 @@ # adapted from: # https://github.com/huggingface/transformers/blob/main/src/transformers/models/glm4v/modeling_glm4v.py -from collections.abc import Callable, Iterable +from collections.abc import Callable, Iterable, Sequence from typing import Any +import numpy as np import torch import torch.nn.functional as F from torch import nn from transformers.configuration_utils import PretrainedConfig -from lmdeploy.pytorch.engine.input_process import BaseModelInputProcessor +from lmdeploy.pytorch.engine.input_process import BaseModelInputProcessor, PreprocessInputResult from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager +from lmdeploy.pytorch.multimodal.data_type import MultiModalData from lmdeploy.pytorch.nn import ApplyRotaryEmb, FlashAttention, RMSNorm, SiluAndMul, build_rotary_embedding_from_config from lmdeploy.pytorch.nn.linear import build_merged_colwise_linear, build_qkv_proj, build_rowwise_linear from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight from .glm4 import Glm4DecoderLayer -from .qwen2_vl import Qwen2VLInputProcessor as Glm4vInputProcessor from .utils.cudagraph import CudaGraphMixin from .utils.model import DeployModelMixin, vlm_model @@ -629,7 +630,7 @@ def prepare_inputs_for_generation( pixel_values = torch.cat([data.data for data in image_data]) image_token_id = image_data[0].meta['image_token_id'] image_mask = input_ids == image_token_id - grid_thw = torch.cat([data.meta['grid_thw'] for data in image_data]).cpu() + grid_thw = torch.stack([data.meta['grid_thw'] for data in image_data]).cpu() vis_pos_emb, image_type_ids = self.visual.rot_pos_emb(grid_thw) vis_cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).to(pixel_values.device) @@ -722,3 +723,57 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): def get_input_processor(self) -> BaseModelInputProcessor: """Get input processor.""" return self.input_processor + +class Glm4vInputProcessor(BaseModelInputProcessor): + """Glm4v input processor.""" + + def __init__(self, config: PretrainedConfig) -> None: + self.config = config + + @classmethod + def _get_multimodal_pos_ids(cls, grid_thw: Sequence[int]) -> np.ndarray: + """Get mrope ids.""" + t, h, w = grid_thw + h = h // 2 + w = w // 2 + stride = np.array([h * w, w, 1])[None] + size = np.array([t, h, w])[None] + pos_ids = np.arange(t * h * w)[:, None].repeat(3, axis=1) + pos_ids = pos_ids // stride % size + return pos_ids + + @classmethod + def make_mrope(cls, grid_thw: torch.Tensor): + grid_thw = grid_thw.tolist() if grid_thw.dim() == 1 else grid_thw[0].tolist() + img_pos_ids = cls._get_multimodal_pos_ids(grid_thw) + return img_pos_ids + + def preprocess_input(self, + input_ids: list[int], + input_multimodals: list[dict[str, Any]] = None, + **kwargs) -> PreprocessInputResult: + """Prepare multimodal input.""" + if input_multimodals is None or len(input_multimodals) == 0: + return input_ids, input_multimodals + + input_imgs = [] + for input_mm in input_multimodals: + pixel_values = input_mm['pixel_values'] + image_grid_thw = input_mm['image_grid_thw'] + offset = input_mm['offset'] + image_token_id = input_mm['image_token_id'] + + mrope_pos_ids = self.make_mrope(image_grid_thw) + + mm_data = MultiModalData(data=pixel_values, + start=offset[0], + end=offset[1], + mrope_pos_ids=mrope_pos_ids, + meta=dict(grid_thw=image_grid_thw, image_token_id=image_token_id)) + input_imgs.append(mm_data) + + result = PreprocessInputResult( + input_ids=input_ids, + input_multimodals=dict(image=input_imgs), + ) + return result diff --git a/lmdeploy/pytorch/models/interns1_pro.py b/lmdeploy/pytorch/models/interns1_pro.py index a78b9803f5..734a10f480 100644 --- a/lmdeploy/pytorch/models/interns1_pro.py +++ b/lmdeploy/pytorch/models/interns1_pro.py @@ -12,7 +12,7 @@ from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight from lmdeploy.vl.constants import Modality -from .interns1_pro_ts import InternS1ProTimeSeriesModel +from .interns1_pro_time_series import InternS1ProTimeSeriesModel from .patch import add_prefix, get_build_model_context from .qwen3_moe import Qwen3MoeModel from .qwen3_vl import Qwen3VLVisionModel @@ -88,14 +88,13 @@ def forward( pixel_values: torch.Tensor = None, vis_cu_seqlens: torch.Tensor = None, vis_pos_emb: torch.Tensor = None, - image_mask: torch.Tensor = None, + multimodal_mask: torch.Tensor = None, pos_embeds: torch.Tensor = None, grid_thw: torch.Tensor = None, # for time series ts_values: torch.Tensor = None, ts_lens: torch.Tensor = None, ts_sr: torch.Tensor = None, - ts_mask: torch.Tensor = None, **kwargs, ): """Model forward, return logits.""" @@ -121,12 +120,11 @@ def forward( image_embeds = torch.cat(image_embeds, dim=0).to(inputs_embeds.device, dtype) # mask and scatter to create final input embeddings - expanded_image_mask = image_mask.unsqueeze(-1).expand_as(inputs_embeds) - inputs_embeds = inputs_embeds.masked_scatter(expanded_image_mask, image_embeds) - + multimodal_mask = multimodal_mask.unsqueeze(-1).expand_as(inputs_embeds) + inputs_embeds = inputs_embeds.masked_scatter(multimodal_mask, image_embeds) elif ts_values is not None: ts_embeds = self.time_series(ts_values, ts_lens, ts_sr) # [B, T, C] - inputs_embeds = inputs_embeds.masked_scatter_(ts_mask[..., None], ts_embeds) + inputs_embeds = inputs_embeds.masked_scatter(multimodal_mask[..., None], ts_embeds) # router replay all_routed_experts = None @@ -166,14 +164,13 @@ def prepare_inputs_for_generation( pixel_values = None vis_cu_seqlens = None vis_pos_emb = None - image_mask = None + multimodal_mask = None grid_thw = None pos_embeds = None # for time series ts_values = None ts_lens = None ts_sr = None - ts_mask = None if context.input_multimodals is not None: mm_inputs = [input_mm.get('mm_data', []) for input_mm in context.input_multimodals] # flatten batch @@ -181,22 +178,15 @@ def prepare_inputs_for_generation( if len(mm_inputs) > 0: modality = mm_inputs[0].modality - image_token_id = mm_inputs[0].meta.get('image_token_id') - video_token_id = mm_inputs[0].meta.get('video_token_id') - ts_token_id = mm_inputs[0].meta.get('ts_token_id') + multimodal_mask = self.get_multimodal_mask(input_ids, mm_inputs) if modality == Modality.TIME_SERIES: ts_values = torch.cat([inp.data for inp in mm_inputs]) - ts_mask = input_ids == ts_token_id - ts_lens = mm_inputs[0].meta['ts_lens'] ts_sr = mm_inputs[0].meta['ts_sr'] else: pixel_values = torch.cat([inp.data for inp in mm_inputs]) - mm_token_id = image_token_id if modality == Modality.IMAGE else video_token_id - image_mask = (input_ids == mm_token_id) - - grid_thw = torch.cat([data.meta['grid_thw'] for data in mm_inputs]).cpu() + grid_thw = torch.stack([data.meta['grid_thw'] for data in mm_inputs]).cpu() vis_pos_emb = self.visual.rot_pos_emb(grid_thw) pos_embeds = self.visual.fast_pos_embed_interpolate(grid_thw) vis_cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], @@ -223,14 +213,13 @@ def prepare_inputs_for_generation( pixel_values=pixel_values, vis_cu_seqlens=vis_cu_seqlens, vis_pos_emb=vis_pos_emb, - image_mask=image_mask, + multimodal_mask=multimodal_mask, grid_thw=grid_thw, pos_embeds=pos_embeds, # for time series ts_values=ts_values, ts_lens=ts_lens, ts_sr=ts_sr, - ts_mask=ts_mask, ) @classmethod @@ -375,16 +364,12 @@ def _make_image_mm_data(self, input_mm: dict[str, Any]) -> MultiModalData: pixel_values = input_mm['pixel_values'].to(self.dtype) image_grid_thw = input_mm['image_grid_thw'] offset = input_mm['offset'] - start = offset image_token_id = input_mm['image_token_id'] - num_pad = input_mm['image_tokens'] - if isinstance(num_pad, torch.Tensor): - num_pad = num_pad.item() mm_data = MultiModalData(modality=Modality.IMAGE, data=pixel_values, - start=start, - end=start + num_pad, + start=offset[0], + end=offset[1], meta=dict(grid_thw=image_grid_thw, image_token_id=image_token_id)) return mm_data @@ -393,16 +378,12 @@ def _make_video_mm_data(self, input_mm: dict[str, Any]) -> MultiModalData: pixel_values_videos = input_mm['pixel_values_videos'].to(self.dtype) video_grid_thw = input_mm['video_grid_thw'] offset = input_mm['offset'] - start = offset video_token_id = input_mm['video_token_id'] - num_pad = input_mm['video_tokens'] - if isinstance(num_pad, torch.Tensor): - num_pad = num_pad.item() mm_data = MultiModalData(modality=Modality.VIDEO, data=pixel_values_videos, - start=start, - end=start + num_pad, + start=offset[0], + end=offset[1], meta=dict( grid_thw=video_grid_thw, video_token_id=video_token_id, @@ -416,14 +397,11 @@ def _make_time_series_mm_data(self, input_mm: dict[str, Any]) -> MultiModalData: ts_token_id = input_mm['ts_token_id'] ts_lens = input_mm['ts_lens'] ts_sr = input_mm['ts_sr'] - num_pad = input_mm['ts_tokens'] - if isinstance(num_pad, torch.Tensor): - num_pad = num_pad.item() mm_data = MultiModalData(modality=Modality.TIME_SERIES, data=ts_values, - start=offset, - end=offset + num_pad, + start=offset[0], + end=offset[1], meta=dict(ts_lens=ts_lens, ts_sr=ts_sr, ts_token_id=ts_token_id)) return mm_data diff --git a/lmdeploy/pytorch/models/interns1_pro_ts.py b/lmdeploy/pytorch/models/interns1_pro_time_series.py similarity index 100% rename from lmdeploy/pytorch/models/interns1_pro_ts.py rename to lmdeploy/pytorch/models/interns1_pro_time_series.py diff --git a/lmdeploy/pytorch/models/qwen3_5.py b/lmdeploy/pytorch/models/qwen3_5.py index aae94cb6d7..92135321bb 100644 --- a/lmdeploy/pytorch/models/qwen3_5.py +++ b/lmdeploy/pytorch/models/qwen3_5.py @@ -26,7 +26,6 @@ ) from lmdeploy.pytorch.nn.rotary_embedding import get_rope_parameters from lmdeploy.pytorch.weight_loader.model_weight_loader import default_weight_loader, load_weight -from lmdeploy.vl.constants import Modality from .patch import add_prefix, get_build_model_context from .qwen2_5_vl import Qwen2_5_VisionRotaryEmbedding as Qwen3_5VisionRotaryEmbedding @@ -1022,7 +1021,7 @@ def forward( pixel_values: torch.Tensor | None = None, vis_cu_seqlens: torch.Tensor | None = None, vis_pos_emb: torch.Tensor | None = None, - image_mask: torch.Tensor | None = None, + multimodal_mask: torch.Tensor | None = None, pos_embeds: torch.Tensor | None = None, grid_thw: torch.Tensor | None = None, all_routed_experts: torch.Tensor | None = None, @@ -1051,8 +1050,8 @@ def forward( image_embeds = torch.cat(image_embeds, dim=0).to(inputs_embeds.device, dtype) # mask and scatter to create final input embeddings - expanded_image_mask = image_mask.unsqueeze(-1).expand_as(inputs_embeds) - inputs_embeds = inputs_embeds.masked_scatter(expanded_image_mask, image_embeds) + multimodal_mask = multimodal_mask.unsqueeze(-1).expand_as(inputs_embeds) + inputs_embeds = inputs_embeds.masked_scatter(multimodal_mask, image_embeds) output_inputs_embeds = inputs_embeds if return_input_embeds else None @@ -1126,7 +1125,7 @@ def forward( pixel_values: torch.Tensor | None = None, vis_cu_seqlens: torch.Tensor | None = None, vis_pos_emb: torch.Tensor | None = None, - image_mask: torch.Tensor | None = None, + multimodal_mask: torch.Tensor | None = None, pos_embeds: torch.Tensor | None = None, grid_thw: torch.Tensor | None = None, return_input_embeds: bool = False, @@ -1151,7 +1150,7 @@ def forward( pixel_values=pixel_values, vis_cu_seqlens=vis_cu_seqlens, vis_pos_emb=vis_pos_emb, - image_mask=image_mask, + multimodal_mask=multimodal_mask, pos_embeds=pos_embeds, grid_thw=grid_thw, all_routed_experts=all_routed_experts, @@ -1192,7 +1191,7 @@ def prepare_inputs_for_generation( pixel_values = None vis_cu_seqlens = None vis_pos_emb = None - image_mask = None + multimodal_mask = None grid_thw = None pos_embeds = None if context.input_multimodals is not None: @@ -1201,15 +1200,10 @@ def prepare_inputs_for_generation( mm_inputs = [item for sublist in mm_inputs for item in sublist] if len(mm_inputs) > 0: - modality = mm_inputs[0].modality pixel_values = torch.cat([inp.data for inp in mm_inputs]) - image_token_id = mm_inputs[0].meta.get('image_token_id') - video_token_id = mm_inputs[0].meta.get('video_token_id') - mm_token_id = image_token_id if modality == Modality.IMAGE else video_token_id - image_mask = (input_ids == mm_token_id) - - grid_thw = torch.cat([data.meta['grid_thw'] for data in mm_inputs]).cpu() + multimodal_mask = self.get_multimodal_mask(input_ids, mm_inputs) + grid_thw = torch.stack([data.meta['grid_thw'] for data in mm_inputs]).cpu() vis_pos_emb = self.model.visual.rot_pos_emb(grid_thw) pos_embeds = self.model.visual.fast_pos_embed_interpolate(grid_thw) vis_cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], @@ -1244,7 +1238,7 @@ def prepare_inputs_for_generation( pixel_values=pixel_values, vis_cu_seqlens=vis_cu_seqlens, vis_pos_emb=vis_pos_emb, - image_mask=image_mask, + multimodal_mask=multimodal_mask, grid_thw=grid_thw, pos_embeds=pos_embeds, return_input_embeds=return_input_embeds, diff --git a/lmdeploy/pytorch/models/qwen3_vl.py b/lmdeploy/pytorch/models/qwen3_vl.py index 8a12c878e5..ca8e359ded 100644 --- a/lmdeploy/pytorch/models/qwen3_vl.py +++ b/lmdeploy/pytorch/models/qwen3_vl.py @@ -551,7 +551,7 @@ def forward( pixel_values: torch.Tensor = None, vis_cu_seqlens: torch.Tensor = None, vis_pos_emb: torch.Tensor = None, - image_mask: torch.Tensor = None, + multimodal_mask: torch.Tensor = None, pos_embeds: torch.Tensor = None, grid_thw: torch.Tensor = None, **kwargs, @@ -580,10 +580,9 @@ def forward( image_embeds = torch.cat(image_embeds, dim=0).to(inputs_embeds.device, dtype) # mask and scatter to create final input embeddings - expanded_image_mask = image_mask.unsqueeze(-1).expand_as(inputs_embeds) - inputs_embeds = inputs_embeds.masked_scatter(expanded_image_mask, image_embeds) - - visual_pos_masks = expanded_image_mask + multimodal_mask = multimodal_mask.unsqueeze(-1).expand_as(inputs_embeds) + inputs_embeds = inputs_embeds.masked_scatter(multimodal_mask, image_embeds) + visual_pos_masks = multimodal_mask hidden_states = self.language_model( input_ids=input_ids, @@ -618,7 +617,7 @@ def prepare_inputs_for_generation( pixel_values = None vis_cu_seqlens = None vis_pos_emb = None - image_mask = None + multimodal_mask = None grid_thw = None pos_embeds = None if context.input_multimodals is not None: @@ -627,15 +626,9 @@ def prepare_inputs_for_generation( mm_inputs = [item for sublist in mm_inputs for item in sublist] if len(mm_inputs) > 0: - modality = mm_inputs[0].modality pixel_values = torch.cat([inp.data for inp in mm_inputs]) - - image_token_id = mm_inputs[0].meta.get('image_token_id') - video_token_id = mm_inputs[0].meta.get('video_token_id') - mm_token_id = image_token_id if modality == Modality.IMAGE else video_token_id - image_mask = (input_ids == mm_token_id) - - grid_thw = torch.cat([data.meta['grid_thw'] for data in mm_inputs]).cpu() + multimodal_mask = self.get_multimodal_mask(input_ids, mm_inputs) + grid_thw = torch.stack([data.meta['grid_thw'] for data in mm_inputs]).cpu() vis_pos_emb = self.visual.rot_pos_emb(grid_thw) pos_embeds = self.visual.fast_pos_embed_interpolate(grid_thw) vis_cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], @@ -665,7 +658,7 @@ def prepare_inputs_for_generation( pixel_values=pixel_values, vis_cu_seqlens=vis_cu_seqlens, vis_pos_emb=vis_pos_emb, - image_mask=image_mask, + multimodal_mask=multimodal_mask, grid_thw=grid_thw, pos_embeds=pos_embeds, ) @@ -744,7 +737,8 @@ def _get_multimodal_pos_ids(cls, grid_thw: Sequence[int]) -> np.ndarray: @classmethod def make_mrope(cls, grid_thw: torch.Tensor): - img_pos_ids = cls._get_multimodal_pos_ids(grid_thw[0].tolist()) + grid_thw = grid_thw.tolist() if grid_thw.dim() == 1 else grid_thw[0].tolist() + img_pos_ids = cls._get_multimodal_pos_ids(grid_thw) return img_pos_ids def _make_image_mm_data(self, input_mm: dict[str, Any]) -> MultiModalData: @@ -752,18 +746,14 @@ def _make_image_mm_data(self, input_mm: dict[str, Any]) -> MultiModalData: pixel_values = input_mm['pixel_values'] image_grid_thw = input_mm['image_grid_thw'] offset = input_mm['offset'] - start = offset image_token_id = input_mm['image_token_id'] - num_pad = input_mm['image_tokens'] - if isinstance(num_pad, torch.Tensor): - num_pad = num_pad.item() mrope_pos_ids = self.make_mrope(image_grid_thw) mm_data = MultiModalData(modality=Modality.IMAGE, data=pixel_values, - start=start, - end=start + num_pad, + start=offset[0], + end=offset[1], mrope_pos_ids=mrope_pos_ids, meta=dict(grid_thw=image_grid_thw, image_token_id=image_token_id)) return mm_data @@ -773,18 +763,14 @@ def _make_video_mm_data(self, input_mm: dict[str, Any]) -> MultiModalData: pixel_values_videos = input_mm['pixel_values_videos'] video_grid_thw = input_mm['video_grid_thw'] offset = input_mm['offset'] - start = offset video_token_id = input_mm['video_token_id'] - num_pad = input_mm['video_tokens'] - if isinstance(num_pad, torch.Tensor): - num_pad = num_pad.item() mrope_pos_ids = self.make_mrope(video_grid_thw) mm_data = MultiModalData(modality=Modality.VIDEO, data=pixel_values_videos, - start=start, - end=start + num_pad, + start=offset[0], + end=offset[1], mrope_pos_ids=mrope_pos_ids, meta=dict( grid_thw=video_grid_thw, diff --git a/lmdeploy/pytorch/models/qwen3_vl_moe.py b/lmdeploy/pytorch/models/qwen3_vl_moe.py index 8de27b1b3b..9dd8263c4a 100644 --- a/lmdeploy/pytorch/models/qwen3_vl_moe.py +++ b/lmdeploy/pytorch/models/qwen3_vl_moe.py @@ -146,7 +146,7 @@ def forward( pixel_values: torch.Tensor = None, vis_cu_seqlens: torch.Tensor = None, vis_pos_emb: torch.Tensor = None, - image_mask: torch.Tensor = None, + multimodal_mask: torch.Tensor = None, pos_embeds: torch.Tensor = None, grid_thw: torch.Tensor = None, **kwargs, @@ -175,10 +175,9 @@ def forward( image_embeds = torch.cat(image_embeds, dim=0).to(inputs_embeds.device, dtype) # mask and scatter to create final input embeddings - expanded_image_mask = image_mask.unsqueeze(-1).expand_as(inputs_embeds) - inputs_embeds = inputs_embeds.masked_scatter(expanded_image_mask, image_embeds) - - visual_pos_masks = expanded_image_mask + multimodal_mask = multimodal_mask.unsqueeze(-1).expand_as(inputs_embeds) + inputs_embeds = inputs_embeds.masked_scatter(multimodal_mask, image_embeds) + visual_pos_masks = multimodal_mask # router replay all_routed_experts = None diff --git a/lmdeploy/pytorch/models/utils/model.py b/lmdeploy/pytorch/models/utils/model.py index 3c99240f07..ffa9b546e5 100644 --- a/lmdeploy/pytorch/models/utils/model.py +++ b/lmdeploy/pytorch/models/utils/model.py @@ -8,8 +8,10 @@ from lmdeploy.pytorch.engine.input_process import BaseModelInputProcessor from lmdeploy.pytorch.model_inputs import ModelInputs, ModelInputsDelta, StepContext from lmdeploy.pytorch.models.patch import get_build_model_context +from lmdeploy.pytorch.multimodal.data_type import MultiModalData from lmdeploy.pytorch.nn.embedding import ParallelEmbedding from lmdeploy.pytorch.nn.linear import build_rowwise_linear +from lmdeploy.vl.constants import Modality class BaseModelMetaProcessor: @@ -150,6 +152,31 @@ def build_lm_head(self, ) return lm_head + def get_multimodal_mask(self, input_ids: torch.Tensor, mm_inputs: list[MultiModalData]) -> torch.Tensor: + """Get position masks for vision tokens.""" + image_token_id = next((m.meta.get('image_token_id') for m in mm_inputs if m.modality == Modality.IMAGE), None) + video_token_id = next((m.meta.get('video_token_id') for m in mm_inputs if m.modality == Modality.VIDEO), None) + ts_token_id = next((m.meta.get('ts_token_id') for m in mm_inputs if m.modality == Modality.TIME_SERIES), None) + + image_mask, video_mask, ts_mask = None, None, None + if image_token_id is not None: + image_mask = (input_ids == image_token_id) + if video_token_id is not None: + video_mask = (input_ids == video_token_id) + if ts_token_id is not None: + ts_mask = (input_ids == ts_token_id) + + multimodal_mask = None + if image_mask is not None and video_mask is not None: + multimodal_mask = image_mask | video_mask + elif image_mask is not None: + multimodal_mask = image_mask + elif video_mask is not None: + multimodal_mask = video_mask + elif ts_mask is not None: + multimodal_mask = ts_mask + + return multimodal_mask def vlm_model(vlm_cls): if not issubclass(vlm_cls, torch.nn.Module): diff --git a/lmdeploy/serve/core/async_engine.py b/lmdeploy/serve/core/async_engine.py index a259fbdd90..f8bb9ccda2 100644 --- a/lmdeploy/serve/core/async_engine.py +++ b/lmdeploy/serve/core/async_engine.py @@ -396,8 +396,8 @@ async def generate( media_io_kwargs=media_io_kwargs, mm_processor_kwargs=mm_processor_kwargs, **kwargs) - prompt = prompt_input['prompt'] - input_ids = prompt_input['input_ids'] + prompt = prompt_input.get('prompt') + input_ids = prompt_input.get('input_ids') self.request_logger.log_inputs(session, prompt=prompt, prompt_token_ids=input_ids, diff --git a/lmdeploy/serve/processors/multimodal.py b/lmdeploy/serve/processors/multimodal.py index 8847c6f2c1..5be6b91241 100644 --- a/lmdeploy/serve/processors/multimodal.py +++ b/lmdeploy/serve/processors/multimodal.py @@ -1,5 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. import asyncio +import inspect from typing import Any, Literal import PIL @@ -37,6 +38,8 @@ def __init__(self, self.chat_template = chat_template self.vl_encoder = vl_encoder self.backend = backend + _sig = inspect.signature(vl_encoder.model.preprocess).parameters if vl_encoder else {} + self._uses_new_preprocess = 'input_prompt' in _sig and 'mm_processor_kwargs' in _sig @staticmethod def merge_message_content(msg: dict) -> dict: @@ -131,7 +134,7 @@ def _parse_multimodal_item(i: int, in_messages: list[dict], out_messages: list[d else: raise NotImplementedError(f'unknown type: {item_type}') - out_message['content'].append({'type': modality, 'data': data, **item_params}) + out_message['content'].append({'type': modality.value, 'data': data, **item_params}) out_messages[i] = out_message @@ -356,14 +359,9 @@ async def _get_multimodal_prompt_input(self, engines.""" chat_template = self.chat_template if do_preprocess else BaseChatTemplate() messages = await self.async_parse_multimodal_item(messages, media_io_kwargs) - results = await self.vl_encoder.preprocess(messages, mm_processor_kwargs) if self.backend == 'turbomind': - # for tm engine, this module perform vision embedding after image - # preprocessing. It utilizes the hf model's vision embeddings - # functions and returns the input_ids, input_embeddings, - # embedding_ranges and so on. All the returned values are passed - # to tm engine for token generation + results = await self.vl_encoder.preprocess(messages, mm_processor_kwargs) results = await self.vl_encoder.async_infer(results) results = await self.vl_encoder.wrap_for_turbomind(messages=results, chat_template=chat_template, @@ -372,12 +370,19 @@ async def _get_multimodal_prompt_input(self, tools=tools, chat_template_kwargs=chat_template_kwargs) elif self.backend == 'pytorch': - # for pt engine, this module only conduct the image preprocessing - # It leaves the vision embedding to the pt engine - results = await self.vl_encoder.wrap_for_pytorch(messages=results, - chat_template=chat_template, - tokenizer=self.tokenizer, - sequence_start=sequence_start, - tools=tools, - chat_template_kwargs=chat_template_kwargs) + if self._uses_new_preprocess: + input_prompt = self.vl_encoder.apply_chat_template(messages=messages, + chat_template=chat_template, + sequence_start=sequence_start, + chat_template_kwargs=chat_template_kwargs) + results = await self.vl_encoder.preprocess(messages, input_prompt, mm_processor_kwargs) + else: + results = await self.vl_encoder.preprocess(messages, mm_processor_kwargs) + results = await self.vl_encoder.wrap_for_pytorch(messages=results, + chat_template=chat_template, + tokenizer=self.tokenizer, + sequence_start=sequence_start, + tools=tools, + chat_template_kwargs=chat_template_kwargs) + return results diff --git a/lmdeploy/vl/constants.py b/lmdeploy/vl/constants.py index e0cd744f15..4a9b486ecf 100644 --- a/lmdeploy/vl/constants.py +++ b/lmdeploy/vl/constants.py @@ -4,8 +4,18 @@ IMAGE_TOKEN = '' -class Modality(str, Enum): +class Modality(Enum): IMAGE = 'image' VIDEO = 'video' AUDIO = 'audio' TIME_SERIES = 'time_series' + + def __eq__(self, other): + if isinstance(other, Modality): + return self.value == other.value + if isinstance(other, str): + return self.value == other + return NotImplemented + + def __hash__(self): + return hash(self.value) diff --git a/lmdeploy/vl/engine.py b/lmdeploy/vl/engine.py index 8cd179df8a..b284b6caac 100644 --- a/lmdeploy/vl/engine.py +++ b/lmdeploy/vl/engine.py @@ -24,11 +24,6 @@ def _raise_exception_on_finish(task: asyncio.Task) -> None: raise e -def _accepts_arg(func, arg_name: str) -> bool: - """Check if a function accepts a specific keyword argument.""" - return arg_name in inspect.signature(func).parameters - - class ImageEncoder: """Image encoder.""" @@ -47,15 +42,23 @@ def __init__( self.executor = ThreadPoolExecutor(max_workers=1) torch.cuda.empty_cache() + def apply_chat_template(self, messages, chat_template, sequence_start, chat_template_kwargs=None): + if self.model.has_input_ids(messages): + return messages[0]['content'][0]['text'] + return self.model.apply_chat_template( + messages, chat_template, sequence_start, chat_template_kwargs + ) + async def preprocess(self, messages: list[dict], + input_prompt: str | list[int] | None = None, mm_processor_kwargs: dict[str, Any] | None = None) -> list[dict]: """Preprocess multimodal data in the messages.""" - if _accepts_arg(self.model.preprocess, 'mm_processor_kwargs'): - future = asyncio.get_event_loop().run_in_executor(self.executor, self.model.preprocess, messages, - mm_processor_kwargs) - else: - future = asyncio.get_event_loop().run_in_executor(self.executor, self.model.preprocess, messages) + sig_params = inspect.signature(self.model.preprocess).parameters + kwargs = {k: v for k, v in [('input_prompt', input_prompt), ('mm_processor_kwargs', mm_processor_kwargs)] + if k in sig_params} + future = asyncio.get_event_loop().run_in_executor( + self.executor, lambda: self.model.preprocess(messages, **kwargs)) future.add_done_callback(_raise_exception_on_finish) outputs = await future return outputs diff --git a/lmdeploy/vl/model/base.py b/lmdeploy/vl/model/base.py index 51ebb44419..8c3ab4aee6 100644 --- a/lmdeploy/vl/model/base.py +++ b/lmdeploy/vl/model/base.py @@ -1,15 +1,22 @@ # Copyright (c) OpenMMLab. All rights reserved. +import dataclasses from abc import ABC, abstractmethod from itertools import groupby +from typing import Any import numpy as np +import torch from mmengine import Registry from transformers import AutoConfig, AutoTokenizer from lmdeploy.archs import get_model_arch +from lmdeploy.utils import get_logger +from lmdeploy.vl.constants import Modality VISION_MODELS = Registry('vision_model') +logger = get_logger('lmdeploy') + class VisionModel(ABC): """Visual model which extract image feature.""" @@ -31,6 +38,51 @@ def __init__(self, self.hf_config = hf_config self.image_token_id = self.get_pad_token_id(model_path, hf_config) or 0 + # mapping from attribute names to modality types + self.ATTR_NAME_TO_MODALITY = { + # image-related attributes + 'pixel_values': Modality.IMAGE, + 'image_sizes': Modality.IMAGE, + 'image_grid_thw': Modality.IMAGE, + 'image_attention_mask': Modality.IMAGE, + 'image_emb_mask': Modality.IMAGE, + 'images_spatial_crop': Modality.IMAGE, + 'images_crop': Modality.IMAGE, + 'has_local_crops': Modality.IMAGE, + 'has_images': Modality.IMAGE, + 'tgt_size': Modality.IMAGE, + 'image_grid_hws': Modality.IMAGE, + 'aspect_ratio_ids': Modality.IMAGE, + 'aspect_ratio_mask': Modality.IMAGE, + 'num_patches': Modality.IMAGE, + 'patch_pixel_values': Modality.IMAGE, + 'block_sizes': Modality.IMAGE, + # audio-related attributes + 'audio_features': Modality.AUDIO, + 'audio_feature_lens': Modality.AUDIO, + 'input_features': Modality.AUDIO, + 'input_features_mask': Modality.AUDIO, + 'audio_attention_mask': Modality.AUDIO, + 'feature_attention_mask': Modality.AUDIO, + # video-related attributes + 'pixel_values_videos': Modality.VIDEO, + 'second_per_grid_ts': Modality.VIDEO, + 'video_grid_thw': Modality.VIDEO, + # time series-related attributes + 'ts_values': Modality.TIME_SERIES, + 'ts_sr': Modality.TIME_SERIES, + 'ts_lens': Modality.TIME_SERIES, + } + + # name of the feature filed + self.FEATURE_NAMES = [ + 'pixel_values', + 'pixel_values_videos', + 'audio_features', + 'input_features', + 'ts_values', + ] + def get_pad_token_id(self, model_path, hf_config): """Get pad_token_id from hf_config or tokenizer.""" pad_token_id = getattr(hf_config, 'pad_token_id', None) @@ -60,51 +112,313 @@ def build_model(self, ): if self.backend == 'turbomind' or self.with_llm: raise NotImplementedError() - @abstractmethod - def preprocess(self, messages: list[dict]) -> list[dict]: - """Preprocess multimodal data in the messages. - - The derived class, - i.e., a specific vision model, takes the charge of image preprocessing - and the result management. - It can integrate the result into the messages list, or insert it to - the individual image item. - Args: - message(dict): multimodal data in a dict, which is as follows: - [ - {'role': 'user', 'content': 'user prompt'}, - {'role': 'assisant', 'content': 'AI reponse'}, - { - 'role': 'user', - 'content': [ - { - 'type': 'text', - 'text': 'string', - }, - { - 'type': 'image', - 'image': pillow.Image, - 'key1': value1, - ... - }, - { - 'type': 'image', - 'image': pillow.Image, - 'key1': value1, - ... - }, - ... - ] - } - {....} - ] - Returns: - the message list with preprocessing results included, which is - determined by the derived classes - """ # noqa - raise NotImplementedError() + @staticmethod + def get_mm_items_offset( + input_ids: torch.Tensor, mm_token_id: int + ) -> list[tuple[int, int]]: + """ + Get a set of range for mm_items from input_ids + Example: + input_ids = [1, 2, 3, 3, 3, 4, 3, 3] + mm_token_id = 3 + return result = [(2,4),(6,7)] + """ + mask = input_ids == mm_token_id + start_positions = (mask & ~torch.roll(mask, 1)).nonzero(as_tuple=True)[0] + end_positions = (mask & ~torch.roll(mask, -1)).nonzero(as_tuple=True)[0] + end_positions += 1 # convert to exclusive end index, compatible with legacy pytorch implementation + return list(zip(start_positions.tolist(), end_positions.tolist())) + + def get_override_size(self, processor, mm_processor_kwargs: dict[str, Any] | None = None, modality: str = ''): + if not mm_processor_kwargs: + return None + try: + default_min = processor.size['shortest_edge'] + default_max = processor.size['longest_edge'] + except (AttributeError, KeyError, TypeError): + tag = f'[{modality}] ' if modality else '' + logger.warning(f'{tag}processor does not expose size[shortest_edge/longest_edge], ' + f'mm_processor_kwargs size override will be skipped.') + return None + override_min = mm_processor_kwargs.get('min_pixels', default_min) + override_max = mm_processor_kwargs.get('max_pixels', default_max) + tag = f'[{modality}] ' if modality else '' + if override_min > override_max: + logger.warning( + f'{tag}Overriding min_pixels {override_min} > max_pixels {override_max}, ' \ + f'falling back to defaults, min_pixels={default_min} and max_pixels={default_max}.' + ) + return None + logger.info(f'{tag}Overriding processor size with min_pixels={override_min} and max_pixels={override_max}.') + return {'shortest_edge': override_min, 'longest_edge': override_max} + + def get_expanded_input_ids(self, input_prompt, collected_mm_items) -> torch.Tensor: + """Get input_ids with multimodal tokens expanded.""" + image_grid_thw = collected_mm_items.get(Modality.IMAGE, {}).get('image_grid_thw', None) + merge_length = self.processor.image_processor.merge_size ** 2 + image_index = 0 + input_ids = [] + for token in input_prompt: + if token == self.image_token_id: + image_tokens = image_grid_thw[image_index].prod() // merge_length + input_ids.extend([self.image_token_id] * image_tokens) + image_index += 1 + else: + input_ids.append(token) + input_ids = torch.tensor(input_ids) + return input_ids + + # adapted from https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/managers/mm_utils.py + def get_expanded_mm_items(self, collected_mm_items): + """Hf processor outputs produced bundled data for multiple + images/videos we need to expand them into per-image/video entries for + better cache locality and fine-grained scheduling.""" + expanded_mm_items = [] + for modality, item in collected_mm_items.items(): + is_bundled = item.get('offset', None) is not None and len(item['offset']) > 1 + + # non-bundled case + if not is_bundled: + if modality == Modality.IMAGE: + expanded_mm_items.append( + dict( + modality=modality, + pixel_values=item['feature'], + image_grid_thw=item['image_grid_thw'][0], + offset=item['offset'][0], + image_token_id=self.image_token_id + ) + ) + elif modality == Modality.TIME_SERIES: + expanded_mm_items.append( + dict( + modality=modality, + ts_values=item['feature'], + ts_sr=item['ts_sr'], + ts_lens=item['ts_lens'], + offset=item['offset'][0], + ts_token_id=self.ts_token_id + ) + ) + continue + + # bundled case + num_items = len(item['offset']) + if modality == Modality.IMAGE: + image_grid_thw = item['image_grid_thw'] + grid_len = image_grid_thw.shape[0] + + patches_per_item = [] + for grid in image_grid_thw: + grid_tensor = torch.as_tensor(grid, dtype=torch.long) + patches_per_item.append(int(torch.prod(grid_tensor).item())) + + cumulative = torch.cumsum( + torch.tensor(patches_per_item, dtype=torch.long), dim=0 + ) + slice_indices = [0] + cumulative.tolist() + + # expand each image into a separate item + for i in range(num_items): + start_idx, end_idx = slice_indices[i], slice_indices[i + 1] + # TODO: zhouxinyu, compute mask and avoid passing token id + expanded_mm_items.append( + dict( + modality=modality, + pixel_values=item['feature'][start_idx:end_idx], + image_grid_thw=image_grid_thw[i], + offset=item['offset'][i], + image_token_id=self.image_token_id, + ) + ) + elif modality == Modality.VIDEO: + video_grid_thw = item['video_grid_thw'] + + # video_grid_thw shape: [num_videos, 3] where each row is [T, H, W] + # When T > 1, item.offsets contains frames (num_items = total frames) + # grid_len = num_videos, num_items = sum(T for each video) = total frames + grid_len = video_grid_thw.shape[0] + num_videos = grid_len + + # calculate total frames and frames per video + frames_per_video = [] + total_frames = 0 + for i in range(num_videos): + grid = video_grid_thw[i] + if isinstance(grid, torch.Tensor): + T = int(grid[0].item()) # T is the first element [T, H, W] + else: + grid_tensor = torch.as_tensor(grid, dtype=torch.long) + T = int(grid_tensor[0].item()) + frames_per_video.append(T) + total_frames += T + + # num_items should equal total_frames when T > 1 + if num_items != total_frames: + expanded_mm_items.append(item) + continue - def has_input_ids(self, messages: list[dict]) -> bool: + # calculate patches per video: T * H * W for each video + patches_per_video = [] + for i in range(num_videos): + grid = video_grid_thw[i] + if isinstance(grid, torch.Tensor): + patches_per_video.append(int(torch.prod(grid).item())) + else: + grid_tensor = torch.as_tensor(grid, dtype=torch.long) + patches_per_video.append(int(torch.prod(grid_tensor).item())) + + # calculate cumulative patches to get slice indices for each video + cumulative = torch.cumsum( + torch.tensor(patches_per_video, dtype=torch.long), dim=0 + ) + slice_indices = [0] + cumulative.tolist() + + # group frames by video, calculate frame indices for each video + frame_start_indices = [0] + for i in range(num_videos): + frame_start_indices.append( + frame_start_indices[-1] + frames_per_video[i] + ) + + # expand each video into a separate item + for video_idx in range(num_videos): + start, end = ( + slice_indices[video_idx], + slice_indices[video_idx + 1], + ) + frame_start, frame_end = ( + frame_start_indices[video_idx], + frame_start_indices[video_idx + 1], + ) + + # expand each frame into a separate item + # TODO: zhouxinyu, not sure per-frame split is good or not + # TODO: zhouxinyu, grid_thw [1, h, w] is only for qwen3vl + t, h, w = video_grid_thw[video_idx].tolist() + for frame_idx in range(t): + video_feature = item['feature'][start:end] + expanded_mm_items.append( + dict( + modality=modality, + pixel_values_videos=video_feature[frame_idx * h * w:(frame_idx + 1) * h * w], + video_grid_thw=torch.tensor([1, h, w]), + offset=item['offset'][frame_start:frame_end][frame_idx], + video_token_id=self.video_token_id, + ) + ) + + return expanded_mm_items + + def preprocess(self, + messages: list[dict], + input_prompt: str | list[int], + mm_processor_kwargs: dict[str, Any] | None = None) -> dict[str, Any]: + """Preprocess multimodal data and return a dict with ``input_ids`` and + multimodal features. + + New-style models inherit this implementation. Legacy models override with `def preprocess(self, messages)`. + """ + + mm_items = self.collect_multimodal_items(messages) + + raw_images, raw_videos, video_metadatas = [], [], [] + raw_time_series, sampling_rates = [], [] + for modality, data, params in mm_items: + if modality == Modality.IMAGE: + raw_images.append(data) + elif modality == Modality.VIDEO: + raw_videos.append(data) + video_metadatas.append(params.get('video_metadata', None)) + elif modality == Modality.TIME_SERIES: + raw_time_series.append(data) + sampling_rates.append(params.get('sampling_rate', None)) + else: + raise ValueError(f'unsupported modality {modality}') + + # get kwargs for processor + kwargs = {} + images_kwargs = {} + videos_kwargs = {} + mm_processor_kwargs = mm_processor_kwargs or {} + if raw_images: + kwargs['images'] = raw_images + image_size = self.get_override_size(self.processor.image_processor, + mm_processor_kwargs.get('image'), + modality='image') + if image_size is not None: + images_kwargs['size'] = image_size + if raw_videos: + kwargs['videos'] = raw_videos + videos_kwargs['video_metadata'] = video_metadatas + # perform resize in hf processor, while sample frames has been done in video loader + videos_kwargs['do_resize'] = True + videos_kwargs['do_sample_frames'] = False + video_size = self.get_override_size(self.processor.video_processor, + mm_processor_kwargs.get('video'), + modality='video') + if video_size is not None: + videos_kwargs['size'] = video_size + if images_kwargs: + kwargs['images_kwargs'] = images_kwargs + if videos_kwargs: + kwargs['videos_kwargs'] = videos_kwargs + if raw_time_series: + assert hasattr(self, 'time_series_processor'), \ + 'time series processor is not defined for time series input' + assert not raw_images and not raw_videos, \ + 'time series is not compatible with image/video input' + self.tokenizer = self.processor.tokenizer + time_series_processor = self.time_series_processor + kwargs['time_series'] = raw_time_series + kwargs['sampling_rate'] = sampling_rates + + # process raw items with hf processor + input_text = input_prompt if isinstance(input_prompt, str) else '' + processor_outputs = (time_series_processor if raw_time_series else self.processor)( + text=[input_text], + padding=True, + return_tensors='pt', + **kwargs, + ) + + # collect from processor outputs and categorized by modality + collected_mm_items: dict[Modality, dict[str, Any]] = {} + for attr_name, value in processor_outputs.items(): + if attr_name == 'input_ids': + continue + + current_modality = self.ATTR_NAME_TO_MODALITY.get(attr_name) + if current_modality: + if current_modality not in collected_mm_items: + collected_mm_items[current_modality] = {} + + if attr_name in self.FEATURE_NAMES: + attr_name = 'feature' + + collected_mm_items[current_modality][attr_name] = value + + # get input_ids + if isinstance(input_prompt, str): + input_ids = processor_outputs['input_ids'].flatten() + else: + input_ids = self.get_expanded_input_ids(input_prompt, collected_mm_items) + + # compute offsets for all items + for modality, item in collected_mm_items.items(): + mm_token_id = self.mm_tokens.get_token_id_by_modality(modality) + item['offset'] = self.get_mm_items_offset( + input_ids=input_ids, + mm_token_id=mm_token_id, + ) + + # expand bundled hf processor outputs into per-image/video entry + expanded_mm_items = self.get_expanded_mm_items(collected_mm_items) + + return dict(input_ids=input_ids.tolist(), multimodal=expanded_mm_items) + + @staticmethod + def has_input_ids(messages: list[dict]) -> bool: """Check whether the messages contain input_ids directly. Args: @@ -329,3 +643,25 @@ def match(cls, config: AutoConfig): if arch and (arch == cls._arch or arch in cls._arch): return True return False + + +@dataclasses.dataclass +class MultimodalSpecialTokens: + image_token: str | list[str] | None = None + video_token: str | list[str] | None = None + audio_token: str | list[str] | None = None + ts_token: str | list[str] | None = None + + image_token_id: int | None = None + video_token_id: int | None = None + audio_token_id: int | None = None + ts_token_id: int | None = None + + def get_token_id_by_modality(self, modality: Modality) -> int | None: + """Get token ID for a given modality.""" + return { + Modality.IMAGE: self.image_token_id, + Modality.VIDEO: self.video_token_id, + Modality.AUDIO: self.audio_token_id, + Modality.TIME_SERIES: self.ts_token_id, + }.get(modality) diff --git a/lmdeploy/vl/model/glm4_1v.py b/lmdeploy/vl/model/glm4_1v.py index 6a796105fc..feeed84145 100644 --- a/lmdeploy/vl/model/glm4_1v.py +++ b/lmdeploy/vl/model/glm4_1v.py @@ -3,7 +3,7 @@ from transformers import AutoConfig from lmdeploy.utils import get_logger -from lmdeploy.vl.model.base import VISION_MODELS, VisionModel +from lmdeploy.vl.model.base import VISION_MODELS, MultimodalSpecialTokens, VisionModel logger = get_logger('lmdeploy') @@ -25,52 +25,15 @@ def match(cls, config: AutoConfig): def build_preprocessor(self): from transformers import AutoProcessor self.processor = AutoProcessor.from_pretrained(self.model_path) - tokenizer = self.processor.tokenizer - image_token = self.processor.image_token - self.image_token_id = tokenizer.encode(image_token)[-1] - def build_model(self): - raise NotImplementedError('turbomind has not supported glm4v yet') + self.image_token = self.processor.image_token + self.image_token_id = self.processor.image_token_id - def preprocess(self, messages: list[dict]) -> list[dict]: - """Refer to `super().preprocess()` for spec.""" - images = self.collect_multimodal_items(messages) - optional_keys = {'resized_height', 'resized_width', 'min_pixels', 'max_pixels'} - outputs = [] - for modality, image, params in images: - item = dict(type='image', image=image) - item.update({key: params[key] for key in params.keys() if key in optional_keys}) - result = self.processor.image_processor(images=image, videos=None, return_tensors='pt') - merge_length = self.processor.image_processor.merge_size**2 - image_tokens = result['image_grid_thw'].prod(dim=1) // merge_length - result.update(dict(image_size=image.size, image_tokens=image_tokens, image_token_id=self.image_token_id)) - outputs.append(result) - messages.append(dict(role='preprocess', content=outputs)) - return messages + self.mm_tokens = MultimodalSpecialTokens( + image_token=self.image_token, + image_token_id=self.image_token_id + ) - @staticmethod - def proc_messages(messages, chat_template, sequence_start): + def apply_chat_template(self, messages, chat_template, sequence_start, chat_template_kwargs=None): """Apply chat template to get the prompt.""" - prompt_messages = [] - IMAGE_TOKEN = '' - for message in messages: - if isinstance(message['content'], str): - prompt_messages.append(message) - continue - elif message['role'] in ['images', 'preprocess', 'forward']: - continue - n_images = len([1 for x in message['content'] if x['type'] == 'image']) - content = [item['text'] for item in message['content'] if item['type'] == 'text'] - prompt = content[0] - if IMAGE_TOKEN in prompt and '<|begin_of_image|>' not in prompt: - prompt = prompt.replace(IMAGE_TOKEN, f'<|begin_of_image|>{IMAGE_TOKEN}<|end_of_image|>') - else: - prompt = f'<|begin_of_image|>{IMAGE_TOKEN}<|end_of_image|>' * \ - n_images + prompt - prompt_messages.append(dict(role=message['role'], content=prompt)) - prompt = chat_template.messages2prompt(prompt_messages, sequence_start) - return prompt, IMAGE_TOKEN - - def to_pytorch(self, messages, chat_template, tokenizer, sequence_start, **kwargs): - prompt, IMAGE_TOKEN = self.proc_messages(messages, chat_template, sequence_start) - return self.to_pytorch_aux(messages, prompt, IMAGE_TOKEN, tokenizer, sequence_start) + return chat_template.messages2prompt(messages, sequence_start, **(chat_template_kwargs or {})) diff --git a/lmdeploy/vl/model/interns1_pro.py b/lmdeploy/vl/model/interns1_pro.py index 6534886017..66b21a233b 100644 --- a/lmdeploy/vl/model/interns1_pro.py +++ b/lmdeploy/vl/model/interns1_pro.py @@ -6,7 +6,7 @@ from lmdeploy.utils import get_logger from lmdeploy.vl.constants import Modality -from lmdeploy.vl.model.base import VISION_MODELS, VisionModel +from lmdeploy.vl.model.base import VISION_MODELS, MultimodalSpecialTokens, VisionModel from lmdeploy.vl.model.qwen3 import Qwen3VLModel logger = get_logger('lmdeploy') @@ -28,14 +28,27 @@ def build_preprocessor(self): # time series tokens self.ts_token = getattr(self.processor, 'ts_token', None) self.ts_token_id = getattr(self.processor, 'ts_token_id', None) - - def _preprocess_time_series(self, - data: list[Any], - params: dict[str, Any], - mm_processor_kwargs: dict[str, Any] | None = None) -> list[dict]: - - ts_input = data - sr = params.get('sampling_rate') if params is not None else None + self.ts_start_token = getattr(self.processor, 'ts_start_token', None) + self.ts_end_token = getattr(self.processor, 'ts_end_token', None) + + # special tokens + self.mm_tokens = MultimodalSpecialTokens( + image_token=self.image_token, + video_token=self.video_token, + ts_token=self.ts_token, + image_token_id=self.image_token_id, + video_token_id=self.video_token_id, + ts_token_id=self.ts_token_id + ) + + def time_series_processor(self, + text: str, + time_series: list[Any], + sampling_rate: float | None = None, + **kwargs): + + ts_input = time_series[0] if isinstance(time_series, list) else time_series + sampling_rate = sampling_rate[0] if isinstance(sampling_rate, list) else sampling_rate if not isinstance(ts_input, np.ndarray): ts_input = np.array(ts_input, dtype=np.float32) @@ -55,46 +68,37 @@ def _preprocess_time_series(self, ts_len = ts_input.shape[0] # set the default value to ts_len / 4 if sr is not provided or invalid - if sr is None or sr <= 0: - sr = max(ts_len / 4, 1.0) + if sampling_rate is None or sampling_rate <= 0: + sampling_rate = max(ts_len / 4, 1.0) # compute num ts tokens - stride = np.floor(160 / ((1 + np.exp(-sr / 100))**6)) + stride = np.floor(160 / ((1 + np.exp(-sampling_rate / 100))**6)) patch_size = stride * 2 embed_length = (np.ceil((ts_len - patch_size) / stride) + 1) ts_tokens = int((embed_length // 2 + 1) // 2) - return dict(ts_values=[ts_input], - ts_sr=[sr], - ts_lens=[ts_len], - ts_tokens=[ts_tokens], + # generate text with ts tokens + for i in range(len(text)): + if f'{self.ts_start_token}{self.ts_token}{self.ts_end_token}' in text[i]: + ts_placeholder = self.ts_start_token + self.ts_token * ts_tokens + self.ts_end_token + text[i] = text[i].replace( + f'{self.ts_start_token}{self.ts_token}{self.ts_end_token}', ts_placeholder, 1 + ) + elif self.ts_token in text[i]: + text[i] = text[i].replace(self.ts_token, self.ts_token * ts_tokens) + + input_ids = self.tokenizer(text, add_special_tokens=False, **kwargs)['input_ids'] + + ts_input = torch.from_numpy(np.array([ts_input])).to(dtype=torch.bfloat16) + ts_sr = torch.tensor([sampling_rate]) + ts_lens = torch.tensor([ts_len]) + return dict(input_ids=input_ids, + ts_values=ts_input, + ts_sr=ts_sr, + ts_lens=ts_lens, ts_token_id=self.ts_token_id) - def preprocess(self, messages: list[dict], mm_processor_kwargs: dict[str, Any] | None = None) -> list[dict]: - """Refer to `super().preprocess()` for spec.""" - outputs = [] - self.contains_video_input = False - self.contains_ts_input = False - - mm_items = self.collect_multimodal_items(messages) - for modality, data, params in mm_items: - result = {} - if modality == Modality.IMAGE: - result = self._preprocess_image(data, params, mm_processor_kwargs) - elif modality == Modality.VIDEO: - self.contains_video_input = True - result = self._preprocess_video(data, params, mm_processor_kwargs) - elif modality == Modality.TIME_SERIES: - self.contains_ts_input = True - result = self._preprocess_time_series(data, params, mm_processor_kwargs) - - result.update(modality=modality) - outputs.append(result) - - messages.append(dict(role='preprocess', content=outputs)) - return messages - - def proc_messages(self, + def apply_chat_template(self, messages, chat_template, sequence_start, @@ -119,65 +123,9 @@ def proc_messages(self, else: prompt_messages = messages - # time series input requires enabling_thinking = False - if self.contains_ts_input: + # time series requires enabling_thinking = False + if any(m == Modality.TIME_SERIES for m, _, _ in self.collect_multimodal_items(messages)): chat_template_kwargs['enable_thinking'] = False prompt = chat_template.messages2prompt(prompt_messages, sequence_start, tools=tools, **chat_template_kwargs) - return prompt, None - - def to_pytorch_aux_ts(self, messages, prompt, TS_TOKEN, tokenizer, sequence_start): - """Pack the time series input to the compatible format with pytorch - engine.""" - # collect all preprocessing result from messages - preps = [x['content'] for x in messages if x['role'] == 'preprocess'] - assert len(preps) == 1 - preps = preps[0] - - # split prompt into segments and validate data - segs = prompt.split(TS_TOKEN) - assert len(segs) == len(preps) + 1, (f'the number of {TS_TOKEN} is not equal ' - f'to input time series data, {len(segs) - 1} vs {len(preps)}') - - input_ids = [] - for i, seg in enumerate(segs): - if i > 0 and i <= len(preps): - preps[i - 1].update(offset=len(input_ids)) - ts_tokens = preps[i - 1]['ts_tokens'] - - ts_tokens = ts_tokens[0] - ts_array = np.array(preps[i - 1]['ts_values']) - - preps[i - 1].update(ts_tokens=ts_tokens) - preps[i - 1].update(ts_values=torch.from_numpy(ts_array).to(dtype=torch.bfloat16)) - preps[i - 1].update(ts_lens=torch.tensor(preps[i - 1]['ts_lens'])) - preps[i - 1].update(ts_sr=torch.tensor(preps[i - 1]['ts_sr'])) - - assert self.ts_token_id == preps[i - 1]['ts_token_id'] - input_ids.extend([self.ts_token_id] * ts_tokens) - token_ids = tokenizer.encode(seg, add_bos=((i == 0) and sequence_start)) - input_ids.extend(token_ids) - - return dict(prompt=prompt, input_ids=input_ids, multimodal=preps) - - def to_pytorch(self, - messages, - chat_template, - tokenizer, - sequence_start, - tools: list[object] | None = None, - chat_template_kwargs: dict | None = None, - **kwargs): - """Return to the information needed by pytorch engine.""" - prompt, _ = self.proc_messages(messages, - chat_template, - sequence_start, - tools=tools, - chat_template_kwargs=chat_template_kwargs) - - if self.contains_video_input: - return self.to_pytorch_aux_video(messages, prompt, self.video_token, tokenizer, sequence_start) - elif self.contains_ts_input: - return self.to_pytorch_aux_ts(messages, prompt, self.ts_token, tokenizer, sequence_start) - else: - return self.to_pytorch_aux(messages, prompt, self.image_token, tokenizer, sequence_start) + return prompt diff --git a/lmdeploy/vl/model/qwen3.py b/lmdeploy/vl/model/qwen3.py index e43dad838c..2efa35df34 100644 --- a/lmdeploy/vl/model/qwen3.py +++ b/lmdeploy/vl/model/qwen3.py @@ -1,12 +1,9 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import Any -import torch from transformers import AutoProcessor from lmdeploy.utils import get_logger -from lmdeploy.vl.constants import Modality -from lmdeploy.vl.model.base import VISION_MODELS, VisionModel +from lmdeploy.vl.model.base import VISION_MODELS, MultimodalSpecialTokens, VisionModel logger = get_logger('lmdeploy') @@ -41,88 +38,15 @@ def build_preprocessor(self): self.vision_start_token = self.processor.vision_start_token self.vision_end_token = self.processor.vision_end_token - def resolve_size_params(self, processor, mm_processor_kwargs: dict[str, Any] | None = None): - default_min = processor.size['shortest_edge'] - default_max = processor.size['longest_edge'] - - if not mm_processor_kwargs: - return {'shortest_edge': default_min, 'longest_edge': default_max} - - min_pixels = mm_processor_kwargs.get('min_pixels', default_min) - max_pixels = mm_processor_kwargs.get('max_pixels', default_max) - - if min_pixels > max_pixels: - logger.warning(f'min_pixels {min_pixels} > max_pixels {max_pixels}, falling back to defaults.') - return {'shortest_edge': default_min, 'longest_edge': default_max} - - return {'shortest_edge': min_pixels, 'longest_edge': max_pixels} - - def _preprocess_image(self, - data: list[Any], - params: dict[str, Any], - mm_processor_kwargs: dict[str, Any] | None = None) -> list[dict]: - - size = self.resolve_size_params(self.processor.image_processor, mm_processor_kwargs) - result = self.processor.image_processor(images=data, size=size, return_tensors='pt') - merge_length = self.processor.image_processor.merge_size**2 - image_tokens = result['image_grid_thw'].prod(dim=1) // merge_length - result.update(dict(image_size=data.size, image_tokens=image_tokens, image_token_id=self.image_token_id)) - return result - - def _preprocess_video(self, - data: list[Any], - params: dict[str, Any], - mm_processor_kwargs: dict[str, Any] | None = None) -> list[dict]: - - metadata = params['video_metadata'] - if metadata.get('fps') is None or metadata['fps'] <= 0: - logger.warning('fps not found or invalid, fallback to 24.') - metadata['fps'] = 24 - size = self.resolve_size_params(self.processor.video_processor, mm_processor_kwargs) - - # do_resize = True, we leave resize to hf processor - # do_sample_frames = False, we already sample frames in video loader, avoid duplicates in hf processor - result = self.processor.video_processor(videos=data, - size=size, - return_metadata=True, - do_resize=True, - do_sample_frames=False, - video_metadata=metadata, - return_tensors='pt') - - merge_length = self.processor.video_processor.merge_size**2 - video_grid_thw = result['video_grid_thw'] - frame_seqlen = video_grid_thw[0][1:].prod() // merge_length - curr_timestamp = self.processor._calculate_timestamps( - metadata['frames_indices'], - metadata['fps'], - self.processor.video_processor.merge_size, + # special tokens + self.mm_tokens = MultimodalSpecialTokens( + image_token=self.image_token, + video_token=self.video_token, + image_token_id=self.image_token_id, + video_token_id=self.video_token_id, ) - result.update(curr_timestamp=curr_timestamp, frame_seqlen=frame_seqlen, video_token_id=self.video_token_id) - return result - - def preprocess(self, messages: list[dict], mm_processor_kwargs: dict[str, Any] | None = None) -> list[dict]: - """Refer to `super().preprocess()` for spec.""" - outputs = [] - self.contains_video_input = False - - mm_items = self.collect_multimodal_items(messages) - for modality, data, params in mm_items: - result = {} - if modality == Modality.IMAGE: - result = self._preprocess_image(data, params, mm_processor_kwargs) - elif modality == Modality.VIDEO: - self.contains_video_input = True - result = self._preprocess_video(data, params, mm_processor_kwargs) - - result.update(modality=modality) - outputs.append(result) - - messages.append(dict(role='preprocess', content=outputs)) - return messages - - def proc_messages(self, messages, chat_template, sequence_start, chat_template_kwargs=None): + def apply_chat_template(self, messages, chat_template, sequence_start, chat_template_kwargs=None): """Apply chat template to get the prompt.""" chat_template_kwargs = chat_template_kwargs or {} prompt_messages = [] @@ -141,107 +65,6 @@ def proc_messages(self, messages, chat_template, sequence_start, chat_template_k prompt_messages.append(dict(role='user', content=prompt)) else: prompt_messages = messages - prompt = chat_template.messages2prompt(prompt_messages, sequence_start, **chat_template_kwargs) - return prompt, None - - def to_pytorch_aux_video(self, messages, prompt, VIDEO_TOKEN, tokenizer, sequence_start): - """Pack the video input to the compatible format with pytorch engine. - - Each video is split into per-frame (temporal step) entries so that the timestamp text tokens between frames get - sequential mrope positions and each frame's video-pad tokens get independent 3D spatial positions. - """ - # collect all preprocessing result from messages - preps = [x['content'] for x in messages if x['role'] == 'preprocess'] - assert len(preps) == 1 - preps = preps[0] - - # split prompt into segments and validate data - segs = prompt.split(self.vision_start_token + self.video_token + self.vision_end_token) - assert len(segs) == len(preps) + 1, (f'the number of {self.video_token} is not equal ' - f'to input videos, {len(segs) - 1} vs {len(preps)}') - - # calculate the video token offset for each frame - input_ids = [] - frame_preps = [] - - for i, seg in enumerate(segs): - if i > 0 and i <= len(preps): - video_prep = preps[i - 1] - frame_seqlen = video_prep['frame_seqlen'] - curr_timestamp = video_prep['curr_timestamp'] - video_grid_thw = video_prep['video_grid_thw'] - pixel_values_videos = video_prep['pixel_values_videos'] - assert self.video_token_id == video_prep['video_token_id'] - - t, h, w = video_grid_thw[0].tolist() - - # each temporal step becomes an independent multimodal entry - for frame_idx in range(t): - curr_time = curr_timestamp[frame_idx] - - # timestamp text + vision_start (regular text tokens) - prefix = f'<{curr_time:.1f} seconds>' + self.vision_start_token - prefix_ids = tokenizer.encode(prefix, add_bos=False) - input_ids.extend(prefix_ids) - - # video pad tokens for this frame - frame_offset = len(input_ids) - input_ids.extend([self.video_token_id] * frame_seqlen) - - # vision_end (regular text token) - suffix_ids = tokenizer.encode(self.vision_end_token, add_bos=False) - input_ids.extend(suffix_ids) - - # since we use timestamps to separate videos - # like - # the video_grid_thw should also be split, becomes [1, h, w] for each frame - frame_preps.append( - dict( - offset=frame_offset, - video_tokens=frame_seqlen, - pixel_values_videos=pixel_values_videos[frame_idx * h * w:(frame_idx + 1) * h * w], - video_grid_thw=torch.tensor([[1, h, w]]), - video_token_id=self.video_token_id, - modality=video_prep['modality'], - ) - ) - - token_ids = tokenizer.encode(seg, add_bos=((i == 0) and sequence_start)) - input_ids.extend(token_ids) - - return dict(prompt=prompt, input_ids=input_ids, multimodal=frame_preps) - - def to_pytorch(self, - messages, - chat_template, - tokenizer, - sequence_start, - chat_template_kwargs: dict | None = None, - **kwargs): - """Return to the information needed by pytorch engine.""" - prompt, _ = self.proc_messages(messages, chat_template, sequence_start, chat_template_kwargs) - - if self.contains_video_input: - return self.to_pytorch_aux_video(messages, prompt, self.video_token, tokenizer, sequence_start) - else: - return self.to_pytorch_aux(messages, prompt, self.image_token, tokenizer, sequence_start) - - def build_model(self): - # TODO: implement for turbomind - pass - - @torch.no_grad() - def forward(self, messages: list[dict], max_batch_size: int = 1) -> list[dict]: - # TODO: implement for turbomind - pass - - def to_turbomind(self, - messages, - chat_template, - tokenizer, - sequence_start, - chat_template_kwargs: dict | None = None, - **kwargs): - # TODO: implement for turbomind - pass + prompt = chat_template.messages2prompt(prompt_messages, sequence_start, **chat_template_kwargs) + return prompt diff --git a/lmdeploy/vl/model/qwen3_5.py b/lmdeploy/vl/model/qwen3_5.py index f030de5e5e..3f46f0a38d 100644 --- a/lmdeploy/vl/model/qwen3_5.py +++ b/lmdeploy/vl/model/qwen3_5.py @@ -1,5 +1,4 @@ # Copyright (c) OpenMMLab. All rights reserved. -from transformers import AutoProcessor from lmdeploy.utils import get_logger from lmdeploy.vl.model.base import VISION_MODELS @@ -25,17 +24,4 @@ class Qwen3_5Model(Qwen3VLModel): def build_preprocessor(self): check_transformers() - - self.processor = AutoProcessor.from_pretrained(self.model_path) - - # image tokens - self.image_token = self.processor.image_token - self.image_token_id = self.processor.image_token_id - - # video tokens - self.video_token = self.processor.video_token - self.video_token_id = self.processor.video_token_id - - # vision start and end tokens - self.vision_start_token = self.processor.vision_start_token - self.vision_end_token = self.processor.vision_end_token + super().build_preprocessor() diff --git a/tests/test_lmdeploy/test_vl/test_hf_chat_template.py b/tests/test_lmdeploy/test_vl/test_hf_chat_template.py index fdd3d193af..9a7a488081 100644 --- a/tests/test_lmdeploy/test_vl/test_hf_chat_template.py +++ b/tests/test_lmdeploy/test_vl/test_hf_chat_template.py @@ -117,6 +117,46 @@ def models(self): 'Qwen/Qwen2.5-VL-7B-Instruct', 'Qwen/Qwen2.5-VL-32B-Instruct', 'Qwen/Qwen2.5-VL-72B-Instruct', + ] + models = [get_model_and_chat_template(model_path) for model_path in model_list] + return models + + def test_proc_messages(self, models, mock_messages): + for model, chat_template in models: + model.build_preprocessor() + reference = model.processor.apply_chat_template(mock_messages, + add_generation_prompt=True, + tokenize=False, + return_dict=True) + prompt, _ = model.proc_messages(mock_messages, chat_template, sequence_start=True) + assert prompt == reference + + def test_pure_img_messages(self, models, mock_pure_img_messages): + for model, chat_template in models: + model.build_preprocessor() + reference = model.processor.apply_chat_template(mock_pure_img_messages, + add_generation_prompt=True, + tokenize=False, + return_dict=True) + prompt, _ = model.proc_messages(mock_pure_img_messages, chat_template, sequence_start=True) + assert prompt == reference + + def test_pure_text_messages(self, models, mock_pure_text_messages): + for model, chat_template in models: + model.build_preprocessor() + reference = model.processor.apply_chat_template(mock_pure_text_messages, + add_generation_prompt=True, + tokenize=False, + return_dict=True) + prompt, _ = model.proc_messages(mock_pure_text_messages, chat_template, sequence_start=True) + assert prompt == reference + + +class TestQwen3VLChatTemplate: + + @pytest.fixture(scope='module') + def models(self): + model_list = [ 'Qwen/Qwen3-VL-2B-Instruct', 'Qwen/Qwen3-VL-2B-Thinking', 'Qwen/Qwen3-VL-4B-Instruct', @@ -133,14 +173,14 @@ def models(self): models = [get_model_and_chat_template(model_path) for model_path in model_list] return models - def test_proc_messages(self, models, mock_messages): + def test_apply_chat_template(self, models, mock_messages): for model, chat_template in models: model.build_preprocessor() reference = model.processor.apply_chat_template(mock_messages, add_generation_prompt=True, tokenize=False, return_dict=True) - prompt, _ = model.proc_messages(mock_messages, chat_template, sequence_start=True) + prompt = model.apply_chat_template(mock_messages, chat_template, sequence_start=True) assert prompt == reference def test_pure_img_messages(self, models, mock_pure_img_messages): @@ -150,7 +190,7 @@ def test_pure_img_messages(self, models, mock_pure_img_messages): add_generation_prompt=True, tokenize=False, return_dict=True) - prompt, _ = model.proc_messages(mock_pure_img_messages, chat_template, sequence_start=True) + prompt = model.apply_chat_template(mock_pure_img_messages, chat_template, sequence_start=True) assert prompt == reference def test_pure_text_messages(self, models, mock_pure_text_messages): @@ -160,5 +200,5 @@ def test_pure_text_messages(self, models, mock_pure_text_messages): add_generation_prompt=True, tokenize=False, return_dict=True) - prompt, _ = model.proc_messages(mock_pure_text_messages, chat_template, sequence_start=True) + prompt = model.apply_chat_template(mock_pure_text_messages, chat_template, sequence_start=True) assert prompt == reference diff --git a/tests/test_lmdeploy/test_vl/test_qwen3vl_processor.py b/tests/test_lmdeploy/test_vl/test_qwen3vl_processor.py index 3d7b0d4458..1ed9a8f790 100644 --- a/tests/test_lmdeploy/test_vl/test_qwen3vl_processor.py +++ b/tests/test_lmdeploy/test_vl/test_qwen3vl_processor.py @@ -1,6 +1,7 @@ import pytest from lmdeploy.vl import load_image, load_video +from lmdeploy.vl.constants import Modality from lmdeploy.vl.model.qwen3 import Qwen3VLModel QWEN3VL_MODELS = [ @@ -39,9 +40,17 @@ def sample_video_messages(video_data): return [{'role': 'user', 'content': [{'type': 'video', 'data': frames, 'video_metadata': metadata}]}] -def _preprocess(model, messages, **kwargs): - result = model.preprocess(messages=list(messages), **kwargs) - return result[-1]['content'][0] +def _preprocess(model, messages, mm_processor_kwargs=None): + """Call model.preprocess following the same flow as the engine: + + apply_chat_template → input_prompt → preprocess. + """ + from lmdeploy.model import MODELS + chat_template = MODELS.module_dict['hf'](model_path=model.model_path) + input_prompt = model.apply_chat_template(messages, chat_template, sequence_start=True) + result = model.preprocess(messages=list(messages), input_prompt=input_prompt, + mm_processor_kwargs=mm_processor_kwargs) + return result['multimodal'][0] def test_image_with_custom_pixels(qwen3vl_model, sample_messages): @@ -56,14 +65,14 @@ def test_image_with_custom_pixels(qwen3vl_model, sample_messages): default_shape = _preprocess(qwen3vl_model, sample_messages)['pixel_values'].shape # [60, 1536] + small_kwargs = {'image': {'min_pixels': 10 * 32 * 32, 'max_pixels': 20 * 32 * 32}} small_shape = _preprocess(qwen3vl_model, sample_messages, - mm_processor_kwargs={'min_pixels': 10 * 32 * 32, - 'max_pixels': 20 * 32 * 32})['pixel_values'].shape + mm_processor_kwargs=small_kwargs)['pixel_values'].shape # [468, 1536] + large_kwargs = {'image': {'min_pixels': 100 * 32 * 32, 'max_pixels': 20000 * 32 * 32}} large_shape = _preprocess(qwen3vl_model, sample_messages, - mm_processor_kwargs={'min_pixels': 100 * 32 * 32, - 'max_pixels': 20000 * 32 * 32})['pixel_values'].shape + mm_processor_kwargs=large_kwargs)['pixel_values'].shape assert small_shape[0] < default_shape[0] < large_shape[0] @@ -72,20 +81,82 @@ def test_video_with_custom_pixels(qwen3vl_model, sample_video_messages): """Test that mm_processor_kwargs min/max pixels affect video preprocessing. Videos process at native resolution by default, so we compare two constrained ranges rather than comparing against - the default. + the default. Per-frame shapes are compared (each multimodal item is one frame). """ # [28160, 1536] default_shape = _preprocess(qwen3vl_model, sample_video_messages)['pixel_values_videos'].shape - # [32, 1536] + # [4, 1536] + small_kwargs = {'video': {'min_pixels': 10 * 32 * 32, 'max_pixels': 20 * 32 * 32}} small_shape = _preprocess(qwen3vl_model, sample_video_messages, - mm_processor_kwargs={'min_pixels': 10 * 32 * 32, - 'max_pixels': 20 * 32 * 32})['pixel_values_videos'].shape + mm_processor_kwargs=small_kwargs)['pixel_values_videos'].shape - # [256, 1536] + # [32, 1536] + medium_kwargs = {'video': {'min_pixels': 50 * 32 * 32, 'max_pixels': 200 * 32 * 32}} medium_shape = _preprocess(qwen3vl_model, sample_video_messages, - mm_processor_kwargs={'min_pixels': 50 * 32 * 32, - 'max_pixels': 200 * 32 * 32})['pixel_values_videos'].shape + mm_processor_kwargs=medium_kwargs)['pixel_values_videos'].shape assert small_shape[0] < medium_shape[0] <= default_shape[0] + + +@pytest.fixture +def sample_mixed_messages(pil_image, video_data): + frames, metadata = video_data + return [{ + 'role': 'user', + 'content': [ + {'type': 'image', 'data': pil_image}, + {'type': 'video', 'data': frames, 'video_metadata': metadata}, + ] + }] + + +def _preprocess_by_modality(model, messages, mm_processor_kwargs=None): + """Like _preprocess but returns all multimodal items grouped by + modality.""" + from lmdeploy.model import MODELS + chat_template = MODELS.module_dict['hf'](model_path=model.model_path) + input_prompt = model.apply_chat_template(messages, chat_template, sequence_start=True) + result = model.preprocess(messages=list(messages), input_prompt=input_prompt, + mm_processor_kwargs=mm_processor_kwargs) + by_modality = {} + for item in result['multimodal']: + by_modality.setdefault(item['modality'], []).append(item) + return by_modality + + +def test_mixed_image_video_independent_size(qwen3vl_model, sample_mixed_messages): + """Per-modality mm_processor_kwargs must not bleed across image and video. + + Shrinking image budget must not change video token count, and vice versa. + """ + default = _preprocess_by_modality(qwen3vl_model, sample_mixed_messages) + default_image_patches = default[Modality.IMAGE][0]['pixel_values'].shape[0] + default_video_patches = sum(item['pixel_values_videos'].shape[0] for item in default[Modality.VIDEO]) + + # shrink image only — video must be unchanged + small_image = _preprocess_by_modality(qwen3vl_model, sample_mixed_messages, + mm_processor_kwargs={'image': {'min_pixels': 10 * 32 * 32, + 'max_pixels': 20 * 32 * 32}}) + assert small_image[Modality.IMAGE][0]['pixel_values'].shape[0] < default_image_patches + assert sum(item['pixel_values_videos'].shape[0] + for item in small_image[Modality.VIDEO]) == default_video_patches + + # shrink video only — image must be unchanged + small_video = _preprocess_by_modality(qwen3vl_model, sample_mixed_messages, + mm_processor_kwargs={'video': {'min_pixels': 10 * 32 * 32, + 'max_pixels': 20 * 32 * 32}}) + assert small_video[Modality.IMAGE][0]['pixel_values'].shape[0] == default_image_patches + assert sum(item['pixel_values_videos'].shape[0] + for item in small_video[Modality.VIDEO]) < default_video_patches + + # shrink both simultaneously — both must decrease independently + small_both = _preprocess_by_modality(qwen3vl_model, sample_mixed_messages, + mm_processor_kwargs={ + 'image': {'min_pixels': 10 * 32 * 32, 'max_pixels': 20 * 32 * 32}, + 'video': {'min_pixels': 10 * 32 * 32, 'max_pixels': 20 * 32 * 32}, + }) + assert small_both[Modality.IMAGE][0]['pixel_values'].shape[0] < default_image_patches + assert sum(item['pixel_values_videos'].shape[0] + for item in small_both[Modality.VIDEO]) < default_video_patches