From 07fbb310b0faf436ee7f6463e845a5a6f641060f Mon Sep 17 00:00:00 2001 From: Matt Van Horn <455140+mvanhorn@users.noreply.github.com> Date: Sat, 16 May 2026 13:32:36 -0700 Subject: [PATCH] fix(client): add Qwen35Renderer dispatch in _build_mm_features Qwen3.5 and Qwen3-VL share the Qwen2-VL multimodal payload shape (pixel_values + image_grid_thw with merge_size=2 across all seven Qwen3.5 sizes), so the same feature builder can serve both. Added a _build_qwen_vl_features helper and factored the dispatch so the renderer-class check covers both renderer types. Fixes #39 --- renderers/client.py | 280 +++++++++++++++++++++++++++++++ tests/test_client_mm_features.py | 25 +++ 2 files changed, 305 insertions(+) create mode 100644 tests/test_client_mm_features.py diff --git a/renderers/client.py b/renderers/client.py index b4e91fd..7ddb7da 100644 --- a/renderers/client.py +++ b/renderers/client.py @@ -14,6 +14,9 @@ import asyncio import base64 import logging +from dataclasses import dataclass +from functools import reduce +from operator import mul from typing import Any, cast import numpy as np @@ -24,6 +27,28 @@ _request_logger = logging.getLogger("renderers.client") +@dataclass(frozen=True) +class _FallbackPlaceholderRange: + offset: int + length: int + is_embed: Any = None + + +@dataclass +class _FallbackMultiModalFieldElem: + data: Any + field: Any = None + + +@dataclass +class _FallbackMultiModalFeatureSpec: + data: dict[str, _FallbackMultiModalFieldElem] | None + modality: str + identifier: str + mm_position: _FallbackPlaceholderRange + mm_hash: str | None = None + + async def _run_pooled(pool: RendererPool, fn): def _work(): with pool.checkout() as r: @@ -32,6 +57,261 @@ def _work(): return await asyncio.to_thread(_work) +def _build_mm_features(renderer_cls: type, mm_data: Any) -> list[Any] | None: + """Build vLLM multimodal feature specs for renderer-native payloads.""" + from renderers.qwen3_vl import Qwen3VLRenderer + from renderers.qwen35 import Qwen35Renderer + + if issubclass(renderer_cls, (Qwen3VLRenderer, Qwen35Renderer)): + # Qwen3-VL and Qwen3.5 both emit Qwen2-VL-family image payloads: + # pixel_values plus image_grid_thw. All seven current Qwen3.5 sizes + # use merge_size=2; move this to renderer metadata when that API lands. + return _build_qwen_vl_features(mm_data, spatial_merge_size=2) + + raise NotImplementedError(f"No multimodal feature builder for {renderer_cls!r}") + + +def _build_qwen_vl_features( + mm_data: Any, *, spatial_merge_size: int +) -> list[Any] | None: + image_payloads = _image_payloads(mm_data) + if not image_payloads: + return None + + try: + from vllm.multimodal.inputs import ( + MultiModalFeatureSpec, + MultiModalFieldConfig, + PlaceholderRange, + ) + except Exception: + return _build_fallback_qwen_vl_features( + image_payloads, spatial_merge_size=spatial_merge_size + ) + + features: list[Any] = [] + next_offset = 0 + for payload_idx, payload in enumerate(image_payloads): + pixel_values = _tensor_data(payload["pixel_values"]) + image_grid_thw = _image_grid_tensor(payload["image_grid_thw"]) + grid_rows = _grid_rows(image_grid_thw) + sizes = [_grid_prod(row) for row in grid_rows] + + field_elems = MultiModalFieldConfig.flat_from_sizes( + "image", _tensor(sizes, like=image_grid_thw) + ).field.build_elems("image", "pixel_values", pixel_values) + grid_elems = MultiModalFieldConfig.batched("image").field.build_elems( + "image", "image_grid_thw", image_grid_thw + ) + + for image_idx, (pixel_elem, grid_elem, grid_row) in enumerate( + zip(field_elems, grid_elems, grid_rows, strict=True) + ): + length = _grid_prod(grid_row) // (spatial_merge_size**2) + mm_position = _placeholder_range( + payload, + image_idx, + default_offset=next_offset, + default_length=length, + placeholder_cls=PlaceholderRange, + ) + next_offset = mm_position.offset + mm_position.length + feature_kwargs = { + "data": { + "pixel_values": pixel_elem, + "image_grid_thw": grid_elem, + }, + "modality": "image", + "identifier": _identifier(payload, payload_idx, image_idx), + "mm_position": mm_position, + "mm_hash": _mm_hash(payload), + } + try: + features.append(MultiModalFeatureSpec(**feature_kwargs)) + except TypeError: + feature_kwargs.pop("mm_hash") + features.append(MultiModalFeatureSpec(**feature_kwargs)) + + return features + + +def _build_fallback_qwen_vl_features( + image_payloads: list[dict[str, Any]], *, spatial_merge_size: int +) -> list[_FallbackMultiModalFeatureSpec]: + features: list[_FallbackMultiModalFeatureSpec] = [] + next_offset = 0 + for payload_idx, payload in enumerate(image_payloads): + for image_idx, grid_row in enumerate(_grid_rows(payload["image_grid_thw"])): + length = _grid_prod(grid_row) // (spatial_merge_size**2) + mm_position = _placeholder_range( + payload, + image_idx, + default_offset=next_offset, + default_length=length, + placeholder_cls=_FallbackPlaceholderRange, + ) + next_offset = mm_position.offset + mm_position.length + features.append( + _FallbackMultiModalFeatureSpec( + data={ + "pixel_values": _FallbackMultiModalFieldElem( + payload["pixel_values"] + ), + "image_grid_thw": _FallbackMultiModalFieldElem(grid_row), + }, + modality="image", + identifier=_identifier(payload, payload_idx, image_idx), + mm_position=mm_position, + mm_hash=_mm_hash(payload), + ) + ) + return features + + +def _image_payloads(mm_data: Any) -> list[dict[str, Any]]: + if mm_data is None: + return [] + + image_data = _get(mm_data, "image") + if image_data is None and _get(mm_data, "pixel_values") is not None: + image_data = mm_data + if image_data is None: + return [] + + if _is_pixel_grid_pair(image_data): + image_data = [image_data] + elif isinstance(image_data, dict) and "pixel_values" in image_data: + image_data = [image_data] + + payloads: list[dict[str, Any]] = [] + for item in image_data: + if _is_pixel_grid_pair(item): + pixel_values, image_grid_thw = item + payloads.append( + {"pixel_values": pixel_values, "image_grid_thw": image_grid_thw} + ) + else: + payloads.append( + { + "pixel_values": _get(item, "pixel_values"), + "image_grid_thw": _get(item, "image_grid_thw"), + "mm_position": _get(item, "mm_position"), + "mm_positions": _get(item, "mm_positions"), + "offset": _get(item, "offset"), + "identifier": _get(item, "identifier"), + "mm_hash": _get(item, "mm_hash"), + } + ) + + return [ + payload + for payload in payloads + if payload["pixel_values"] is not None and payload["image_grid_thw"] is not None + ] + + +def _is_pixel_grid_pair(value: Any) -> bool: + return isinstance(value, tuple) and len(value) == 2 + + +def _get(value: Any, key: str) -> Any: + if isinstance(value, dict): + return value.get(key) + return getattr(value, key, None) + + +def _grid_rows(image_grid_thw: Any) -> list[Any]: + rows = _to_list(image_grid_thw) + if not rows: + return [] + if all(isinstance(x, int | float) for x in rows): + return [rows] + return rows + + +def _grid_prod(grid_row: Any) -> int: + return int(reduce(mul, (int(x) for x in _to_list(grid_row)), 1)) + + +def _to_list(value: Any) -> list[Any]: + if hasattr(value, "tolist"): + value = value.tolist() + if isinstance(value, tuple): + return list(value) + if isinstance(value, list): + return value + return [value] + + +def _tensor(value: list[int], *, like: Any) -> Any: + try: + import torch + + device = getattr(like, "device", None) + return torch.as_tensor(value, device=device) + except Exception: + return value + + +def _tensor_data(value: Any) -> Any: + try: + import torch + + return torch.as_tensor(value) + except Exception: + return value + + +def _image_grid_tensor(value: Any) -> Any: + tensor = _tensor_data(value) + if hasattr(tensor, "ndim") and tensor.ndim == 1: + return tensor.unsqueeze(0) + return tensor + + +def _placeholder_range( + payload: dict[str, Any], + image_idx: int, + *, + default_offset: int, + default_length: int, + placeholder_cls: Any, +) -> Any: + mm_position = _indexed(_get(payload, "mm_positions"), image_idx) or _get( + payload, "mm_position" + ) + if mm_position is not None: + offset = _get(mm_position, "offset") + length = _get(mm_position, "length") + if offset is not None and length is not None: + return placeholder_cls(offset=int(offset), length=int(length)) + + offset = _indexed(_get(payload, "offset"), image_idx) + if offset is None: + offset = default_offset + return placeholder_cls(offset=int(offset), length=default_length) + + +def _indexed(value: Any, idx: int) -> Any: + if value is None or isinstance(value, str): + return value + if isinstance(value, list | tuple): + return value[idx] if idx < len(value) else None + return value + + +def _identifier(payload: dict[str, Any], payload_idx: int, image_idx: int) -> str: + identifier = _indexed(_get(payload, "identifier"), image_idx) + if identifier is not None: + return str(identifier) + return f"image-{payload_idx}-{image_idx}" + + +def _mm_hash(payload: dict[str, Any]) -> str | None: + value = _get(payload, "mm_hash") + return str(value) if value is not None else None + + async def generate( *, client: AsyncOpenAI, diff --git a/tests/test_client_mm_features.py b/tests/test_client_mm_features.py new file mode 100644 index 0000000..7ed66bd --- /dev/null +++ b/tests/test_client_mm_features.py @@ -0,0 +1,25 @@ +from renderers.client import _build_mm_features +from renderers.qwen35 import Qwen35Renderer + + +def test_build_mm_features_dispatches_qwen35_renderer(): + features = _build_mm_features( + Qwen35Renderer, + { + "image": { + "pixel_values": [[1.0], [2.0], [3.0], [4.0]], + "image_grid_thw": [1, 4, 4], + "offset": 7, + "identifier": "image-0", + } + }, + ) + + assert features is not None + assert len(features) == 1 + feature = features[0] + assert feature.modality == "image" + assert feature.identifier == "image-0" + assert feature.mm_position.offset == 7 + assert feature.mm_position.length == 4 + assert set(feature.data) == {"pixel_values", "image_grid_thw"}