diff --git a/renderers/__init__.py b/renderers/__init__.py index 60cc271..7c82510 100644 --- a/renderers/__init__.py +++ b/renderers/__init__.py @@ -28,6 +28,7 @@ ToolCallParseStatus, ToolSpec, VideoPart, + attribute_text_segments, build_training_sample, build_trajectory_step, create_renderer, @@ -90,6 +91,7 @@ "ToolSpec", "VideoPart", "__version__", + "attribute_text_segments", "build_training_sample", "build_trajectory_step", "create_renderer", diff --git a/renderers/base.py b/renderers/base.py index 2da9dbf..defb267 100644 --- a/renderers/base.py +++ b/renderers/base.py @@ -169,6 +169,32 @@ class RenderedTokens: masking. ``DefaultRenderer`` leaves it empty because the Jinja template is opaque; hand-coded renderers populate it. + ``is_content`` is a per-token signal generalizing the "scaffold vs + body" distinction across all roles: ``True`` iff the token was + produced from message-body bytes (caller-provided ``content`` / + ``tool_calls`` / ``reasoning_content``, or the model's sampled + emission for the assistant role), ``False`` iff it is template + scaffolding the renderer added around message bodies — role-tag + openers, closers when not model-sampled, inter-turn separators, + tool-response wraps, the tools-header block, the generation prompt. + Generalises ``sampled_mask``: where ``sampled_mask`` answers "would + the model emit this?" (useful for assistant tokens; uniformly + ``False`` elsewhere), ``is_content`` answers "is this from caller + or model data?" (meaningful on every role). By construction + ``is_content[k] == sampled_mask[k]`` over every token attributed to + an assistant message; on other roles ``is_content`` carries new + information that ``sampled_mask`` does not. + + The use case: SFT on tool response bodies while applying RL only to + assistant tokens. The trainer wants the model to anticipate tool + outputs but never to emit ``<|tool_response>`` itself (that would + interrupt the rollout), so the SFT loss mask is + ``message_role == "tool" AND is_content``. + + Empty ``is_content`` (``[]``) — like ``sampled_mask`` — means the + renderer doesn't provide the signal. ``DefaultRenderer`` leaves it + empty for the same reason. + ``multi_modal_data`` is populated by multimodal renderers (e.g. ``Qwen3VLRenderer``) when image / video content parts are present; text-only renderers leave it as ``None``. @@ -177,6 +203,7 @@ class RenderedTokens: token_ids: list[int] = field(default_factory=list) message_indices: list[int] = field(default_factory=list) sampled_mask: list[bool] = field(default_factory=list) + is_content: list[bool] = field(default_factory=list) message_roles: list[str] = field(default_factory=list) multi_modal_data: "MultiModalData | None" = None @@ -333,6 +360,94 @@ def tokens_by_role(self, *, sampled_only: bool = False) -> dict[str, int]: out[role] = out.get(role, 0) + n return out + def content_token_spans_by_role(self) -> dict[str, list[tuple[int, int]]]: + """Per-role spans of contiguous body-only tokens (``is_content=True``). + + Maps each role appearing in :attr:`message_roles` to a list of + half-open ``[start, end)`` slices into :attr:`token_ids` over + which every token satisfies ``is_content=True`` AND belongs to + a message of that role. Spans never cross message boundaries: + a tool message contributes its own runs; an immediately + adjacent assistant message contributes separate runs even when + the bodies abut on the token axis. + + Returns an empty dict when :attr:`is_content` or + :attr:`message_roles` is empty (renderer didn't populate the + signal — e.g. ``DefaultRenderer``). + + Intended for selective loss masking: SFT on tool response + bodies while RL acts only on assistant turns is the canonical + case:: + + spans = rendered.content_token_spans_by_role() + tool_sft_mask = [False] * len(rendered.token_ids) + for s, e in spans.get("tool", []): + for k in range(s, e): + tool_sft_mask[k] = True + + See also :meth:`content_mask_for_roles` for the same + computation returned as a per-token bool list. + """ + out: dict[str, list[tuple[int, int]]] = {} + if not self.is_content or not self.message_roles: + return out + n = len(self.token_ids) + if len(self.is_content) != n or len(self.message_indices) != n: + return out + + msg_spans = self.message_token_spans() + for role, span in zip(self.message_roles, msg_spans): + bucket = out.setdefault(role, []) + if span is None: + continue + start, end = span + run_start: int | None = None + for k in range(start, end): + if self.is_content[k]: + if run_start is None: + run_start = k + else: + if run_start is not None: + bucket.append((run_start, k)) + run_start = None + if run_start is not None: + bucket.append((run_start, end)) + return out + + def content_mask_for_roles(self, roles: "set[str] | frozenset[str]") -> list[bool]: + """Per-token bool list: ``True`` iff the token is body of a + message whose role is in ``roles``. + + Length matches :attr:`token_ids`. Returns an all-``False`` + list of that length when :attr:`is_content` or + :attr:`message_roles` is empty — consumers can AND this with + their own attribution masks without length checks. + + ``role_to_mask`` style helpers in :func:`build_training_sample` + cover the trainable-role question; this one covers the + complementary "body-only" question. The two compose: SFT mask + on tool body is + ``rendered.content_mask_for_roles({"tool"})``; RL mask on + assistant tokens stays + ``[s and (mi >= 0 and rendered.message_roles[mi] == "assistant") + for s, mi in zip(rendered.sampled_mask, rendered.message_indices)]``. + """ + n = len(self.token_ids) + mask = [False] * n + if not self.is_content or not self.message_roles: + return mask + if len(self.is_content) != n or len(self.message_indices) != n: + return mask + + for k, msg_idx in enumerate(self.message_indices): + if msg_idx < 0: + continue + if msg_idx >= len(self.message_roles): + continue + if self.message_roles[msg_idx] in roles and self.is_content[k]: + mask[k] = True + return mask + class ToolCallParseStatus(str, enum.Enum): """Per-attempt outcome of parsing a single ```` block. @@ -530,6 +645,15 @@ def bridge_to_next_turn( caller needs that distinction for the prior portion, they have it directly: every token in ``prev_completion_ids`` was sampled; every token in ``prev_prompt_ids`` was not. + - ``is_content`` mirrors ``sampled_mask``'s scheme for the + prior portion (uniformly ``False`` — body-vs-wrap + attribution can't be recovered from raw token ids), and on + the bridge-added portion the renderer populates it the same + way as in :meth:`render`: ``True`` over the body bytes of + each new message, ``False`` over the surrounding scaffold. + Consumers walk the trajectory and read each step's own + ``is_content`` for full-conversation body masks; the bridge + output covers only the *new* tokens this turn adds. Text-only renderers return :class:`RenderedTokens` with ``multi_modal_data=None``. Multimodal renderers (see @@ -1208,6 +1332,7 @@ def build_training_sample( *, role_to_mask: Callable[[Message], bool], tools: list[ToolSpec] | None = None, + content_sft_roles: "set[str] | frozenset[str] | None" = None, ) -> tuple[list[int], list[bool]]: """Build (token_ids, loss_mask) for supervised training. @@ -1223,17 +1348,53 @@ def build_training_sample( back to attribution-only masking — every token attributed to a trainable role is trained on, including template-injected ``<|im_start|>role\\n`` openers. + + ``content_sft_roles`` opts in additional roles for "body-only" + supervision: for every message whose role is in this set, tokens + with ``is_content=True`` are marked trainable even though the + ``sampled_mask`` gate excludes them (the model never samples + tool / user / system tokens). Template scaffolding around those + messages — ``<|im_start|>role\\n`` openers, ``<|im_end|>`` + closers, ``<|tool_response>`` wraps, inter-turn ``\\n`` — stays + masked out, so the model learns to anticipate the body text + without producing the surrounding special tokens (which would + interrupt a real rollout). The canonical use case is RL on + assistant tokens (``role_to_mask=lambda m: m["role"] == + "assistant"``) plus SFT on tool response bodies + (``content_sft_roles={"tool"}``). + + Requires the renderer to populate ``is_content`` for the body-only + path to fire. Renderers that leave it empty (``DefaultRenderer``, + or hand-coded renderers that haven't been wired up yet) ignore + ``content_sft_roles`` silently — falling back to the original + ``role_to_mask`` + ``sampled_mask`` behaviour. """ rendered = renderer.render(messages, tools=tools) has_sampled_info = len(rendered.sampled_mask) == len(rendered.token_ids) + has_content_info = len(rendered.is_content) == len(rendered.token_ids) + body_roles: "frozenset[str]" + if content_sft_roles and has_content_info: + body_roles = frozenset(content_sft_roles) + else: + body_roles = frozenset() + loss_mask: list[bool] = [] for k, msg_idx in enumerate(rendered.message_indices): if msg_idx < 0: loss_mask.append(False) - elif has_sampled_info and not rendered.sampled_mask[k]: + continue + msg = messages[msg_idx] + # Body-only path for opt-in roles. Fires only on tokens whose + # is_content bit is set; never adds the scaffolding around the + # message, so the model isn't supervised on emitting the role + # tags / wraps that would derail a rollout. + if body_roles and msg.get("role") in body_roles: + loss_mask.append(rendered.is_content[k]) + continue + if has_sampled_info and not rendered.sampled_mask[k]: loss_mask.append(False) else: - loss_mask.append(role_to_mask(messages[msg_idx])) + loss_mask.append(role_to_mask(msg)) return rendered.token_ids, loss_mask @@ -1280,6 +1441,157 @@ def trim_to_turn_close( return previous_ids +# Per-model offset-aware tokenizer cache. ``attribute_text_segments`` +# uses the fast HuggingFace tokenizer's ``offset_mapping`` to attribute +# each token to its source text segment under one BPE pass. Fastokens +# (the Rust BPE we patch in by default for ~10x faster encode) does not +# track character offsets — the patched tokenizer's +# ``return_offsets_mapping=True`` raises ``NotImplementedError``. So we +# keep a parallel vanilla tokenizer per model purely for offset queries. +# Memory cost is one extra tokenizer per *unique* model name across all +# pools / renderers (the cache is process-global), independent of pool +# size. +_offset_tokenizers: dict[str, Any] = {} +_offset_tokenizers_lock = threading.Lock() + + +def _get_offset_tokenizer(tokenizer): + """Return a tokenizer that supports ``return_offsets_mapping=True``. + + If ``tokenizer`` itself supports offsets, returns it unchanged. + Otherwise loads a vanilla (non-fastokens) tokenizer from + ``tokenizer.name_or_path`` and caches it. Raises if the tokenizer + has no usable ``name_or_path`` — hand-coded renderers always pass + a tokenizer loaded via ``load_tokenizer`` which does set it. + """ + # Cheap probe: does this tokenizer already provide offsets? + try: + tokenizer("a", add_special_tokens=False, return_offsets_mapping=True) + return tokenizer + except (NotImplementedError, ValueError, TypeError): + pass + + name_or_path = getattr(tokenizer, "name_or_path", "") + if not name_or_path: + raise RuntimeError( + "Cannot construct an offset-aware tokenizer: the supplied " + "tokenizer has no ``name_or_path`` to fall back on. Pass a " + "tokenizer loaded via ``renderers.base.load_tokenizer``." + ) + + with _offset_tokenizers_lock: + cached = _offset_tokenizers.get(name_or_path) + if cached is not None: + return cached + from transformers import AutoTokenizer + + kwargs: dict[str, Any] = {} + revision = TRUSTED_REVISIONS.get(name_or_path) + if revision is not None: + kwargs = {"trust_remote_code": True, "revision": revision} + else: + kwargs = {"trust_remote_code": False} + # Explicitly vanilla — we want HF's Rust tokenizer with offset + # tracking, not the fastokens shim. ``load_tokenizer`` would + # patch fastokens in by default; calling + # ``AutoTokenizer.from_pretrained`` directly here keeps the + # fastokens patch out of this code path entirely. + offset_tok = AutoTokenizer.from_pretrained(name_or_path, **kwargs) + if not getattr(offset_tok, "is_fast", False): + raise RuntimeError( + f"Vanilla tokenizer for {name_or_path!r} is not a fast " + "tokenizer; offset_mapping is unavailable. Hand-coded " + "renderers require a fast tokenizer for body/scaffold " + "attribution." + ) + _offset_tokenizers[name_or_path] = offset_tok + return offset_tok + + +def attribute_text_segments( + tokenizer, + segments: "list[tuple[str, bool]]", +) -> "list[tuple[int, bool]]": + """Tokenize concatenated segments as a single BPE pass and return + ``(token_id, is_content)`` pairs. + + ``segments`` is a list of ``(text, is_content)`` chunks the renderer + wants to emit contiguously — for example ``[("user\\n", False), + (content, True)]`` for a user message. Concatenation is done before + encoding to preserve BPE merges across the wrap/body boundary; the + resulting tokens are then attributed back to their source segment + via the fast tokenizer's ``offset_mapping``. + + A token is attributed to the segment containing its first source + character (``offset_mapping[k][0]``). Tokens whose first character + falls exactly on a segment boundary are attributed to the segment + that *starts* at that offset (the "later" segment). Zero-length + tokens (rare; usually pre-tokenizer artefacts) are attributed to + the most recently entered segment. + + Requires a HuggingFace fast tokenizer with offset tracking. The + ``fastokens`` patch ``load_tokenizer`` applies by default does + **not** track offsets — when that's the case we transparently load + a vanilla offset-capable tokenizer for the same model and cache it + (see :func:`_get_offset_tokenizer`). Hand-coded renderers are only + registered for model families that ship a fast tokenizer, so a + silent slow-tokenizer fallback isn't supported — BPE drift at the + wrap/body boundary would defeat the whole point. + + Empty input or empty joined text returns an empty list. + """ + if not segments: + return [] + full_text = "".join(text for text, _ in segments) + if not full_text: + return [] + + offset_tokenizer = _get_offset_tokenizer(tokenizer) + encoding = offset_tokenizer( + full_text, + add_special_tokens=False, + return_offsets_mapping=True, + ) + token_ids = list(encoding["input_ids"]) + offsets = list(encoding["offset_mapping"]) + + # Build segment char-span lookup. Track the half-open span + # [seg_start, seg_end) of each segment and its is_content bit. + spans: list[tuple[int, int, bool]] = [] + pos = 0 + for text, is_content in segments: + spans.append((pos, pos + len(text), is_content)) + pos += len(text) + total_len = pos + + out: list[tuple[int, bool]] = [] + last_is_content = spans[-1][2] if spans else False + for tok_id, (start, _end) in zip(token_ids, offsets): + if start >= total_len: + # Token's character offset is past every segment (shouldn't + # normally happen for add_special_tokens=False, but defensive + # against tokenizer-specific edge cases). + out.append((tok_id, last_is_content)) + continue + # Find the segment that contains `start`. Segments are + # contiguous and ordered, so a linear scan is fine — the inner + # loop runs at most len(segments) times per token and segments + # is typically 2-3 in practice. + is_content = last_is_content + for seg_start, seg_end, seg_is_content in spans: + if seg_start <= start < seg_end: + is_content = seg_is_content + break + else: + # start == total_len handled above; the remaining case is + # an empty segment in the middle. Empty segments emit no + # characters, so no token can land in them; fall through to + # the last non-empty segment's bit. + pass + out.append((tok_id, is_content)) + return out + + def reject_assistant_in_extension(new_messages: list[Message]) -> bool: """Return True if any message in ``new_messages`` is an assistant turn. diff --git a/renderers/client.py b/renderers/client.py index ae1722f..576e5c2 100644 --- a/renderers/client.py +++ b/renderers/client.py @@ -23,6 +23,7 @@ from renderers.base import ( Message, MultiModalData, + RenderedTokens, Renderer, RendererPool, ToolCallParseStatus, @@ -127,6 +128,7 @@ async def generate( model: str, prompt_ids: list[int] | None = None, multi_modal_data: MultiModalData | None = None, + prompt_attribution: RenderedTokens | None = None, tools: list[ToolSpec] | None = None, sampling_params: dict[str, Any] | None = None, cache_salt: str | None = None, @@ -141,7 +143,11 @@ async def generate( renderer) and ``logprobs=1`` (we always emit completion_logprobs). Pass ``prompt_ids`` to skip rendering and use a prebuilt token sequence — pair it with ``multi_modal_data`` when the prebuilt prompt has image / - video placeholders that need engine-side mm payload. + video placeholders that need engine-side mm payload, and with + ``prompt_attribution`` (a :class:`RenderedTokens` whose ``token_ids`` + match the passed-in ``prompt_ids``) to carry the renderer's per-token + attribution (``is_content`` / ``sampled_mask`` / ``message_indices`` / + ``message_roles``) into the result without re-rendering. For multimodal renderers (e.g. ``Qwen3VLRenderer``), the call goes through ``renderer.render(...)`` to recover the ``multi_modal_data`` @@ -161,7 +167,19 @@ async def generate( Returns a dict with: request_id, prompt_ids, completion_ids, completion_logprobs, content, reasoning_content, tool_calls, - finish_reason, routed_experts. + finish_reason, routed_experts, multi_modal_data, prompt_attribution. + + ``prompt_attribution`` is the renderer's :class:`RenderedTokens` for + the prompt — either the one this call computed via + ``renderer.render(...)`` or the one the caller threaded in alongside + ``prompt_ids``. Carries ``token_ids``, ``message_indices``, + ``sampled_mask``, ``is_content``, ``message_roles``, and + ``multi_modal_data``, so downstream consumers (verifiers + ``RendererClient`` → prime-rl) can build per-token loss masks + (``content_mask_for_roles({"tool"})`` for SFT-on-tool-body, + ``sampled_mask`` for RL trainable spans) without a second render + pass. ``None`` when the caller passed pre-built ``prompt_ids`` + without attribution. """ if tools and not getattr(renderer, "supports_tools", True): raise ValueError( @@ -171,15 +189,26 @@ async def generate( def _prepare(): if prompt_ids is not None: - return list(prompt_ids), renderer.get_stop_token_ids(), multi_modal_data + # Caller-supplied prompt; if they also gave us pre-computed + # attribution (e.g. the bridge path in verifiers), thread it + # through unchanged. + return ( + list(prompt_ids), + renderer.get_stop_token_ids(), + multi_modal_data, + prompt_attribution, + ) rendered = renderer.render(messages, tools=tools, add_generation_prompt=True) return ( rendered.token_ids, renderer.get_stop_token_ids(), rendered.multi_modal_data, + rendered, ) - prompt_ids, stop_token_ids, mm_data = await _maybe_offload(renderer, _prepare) + prompt_ids, stop_token_ids, mm_data, prompt_attr = await _maybe_offload( + renderer, _prepare + ) if max_prompt_len is None: max_prompt_len = await _resolve_max_prompt_len(client, model) @@ -279,6 +308,14 @@ def _prepare(): # callers can persist it on the trajectory step for downstream # multi-turn bridging and training-sample construction. "multi_modal_data": mm_data, + # The renderer's per-token attribution for the prompt — either + # the RenderedTokens computed here via renderer.render(...) or + # the one threaded in by the caller alongside prompt_ids (the + # bridge path). Lets downstream consumers (verifiers + # RendererClient → prime-rl) build SFT-on-tool-body and other + # selective loss masks without a second render pass. ``None`` + # when the caller passed prompt_ids without attribution. + "prompt_attribution": prompt_attr, } diff --git a/renderers/deepseek_v3.py b/renderers/deepseek_v3.py index b3c843b..507d81d 100644 --- a/renderers/deepseek_v3.py +++ b/renderers/deepseek_v3.py @@ -21,6 +21,7 @@ ParsedResponse, RenderedTokens, ToolSpec, + attribute_text_segments, reject_assistant_in_extension, trim_to_turn_close, ) @@ -114,22 +115,47 @@ def render( tokens: list[int] = [] indices: list[int] = [] sampled: list[bool] = [] + content_mask: list[bool] = [] - def emit_ids(ids: list[int], msg_idx: int, *, is_sampled: bool) -> None: + def emit_ids( + ids: list[int], msg_idx: int, *, is_sampled: bool, is_content: bool + ) -> None: tokens.extend(ids) indices.extend([msg_idx] * len(ids)) sampled.extend([is_sampled] * len(ids)) + content_mask.extend([is_content] * len(ids)) - def emit_special(token_id: int, msg_idx: int, *, is_sampled: bool) -> None: + def emit_special( + token_id: int, msg_idx: int, *, is_sampled: bool, is_content: bool + ) -> None: tokens.append(token_id) indices.append(msg_idx) sampled.append(is_sampled) + content_mask.append(is_content) - def emit_text(text: str, msg_idx: int, *, is_sampled: bool) -> None: - emit_ids(self._encode(text), msg_idx, is_sampled=is_sampled) + def emit_text( + text: str, msg_idx: int, *, is_sampled: bool, is_content: bool + ) -> None: + emit_ids( + self._encode(text), + msg_idx, + is_sampled=is_sampled, + is_content=is_content, + ) + + def emit_text_segments( + segments: list[tuple[str, bool]], msg_idx: int, *, is_sampled: bool + ) -> None: + for tok_id, is_content in attribute_text_segments( + self._tokenizer, segments + ): + tokens.append(tok_id) + indices.append(msg_idx) + sampled.append(is_sampled) + content_mask.append(is_content) # ── 1. BOS token ───────────────────────────────────────────── - emit_special(self._bos, -1, is_sampled=False) + emit_special(self._bos, -1, is_sampled=False, is_content=False) # ── 2. Collect system messages at the start ─────────────────── # All leading system messages are concatenated with "\n\n" and emitted @@ -151,7 +177,8 @@ def emit_text(text: str, msg_idx: int, *, is_sampled: bool) -> None: if sys_parts: # Attribute the concatenated system text to the first system message (index 0). - emit_text("\n\n".join(sys_parts), 0, is_sampled=False) + # The system content is the caller's body — mark is_content=True. + emit_text("\n\n".join(sys_parts), 0, is_sampled=False, is_content=True) # ── 3. Render non-system messages ───────────────────────────── num_messages = len(messages) @@ -166,8 +193,8 @@ def emit_text(text: str, msg_idx: int, *, is_sampled: bool) -> None: content = "".join( p.get("text", "") for p in content if isinstance(p, dict) ) - emit_special(self._user_token, i, is_sampled=False) - emit_text(str(content), i, is_sampled=False) + emit_special(self._user_token, i, is_sampled=False, is_content=False) + emit_text(str(content), i, is_sampled=False, is_content=True) elif role == "user": content = msg.get("content") or "" @@ -180,8 +207,8 @@ def emit_text(text: str, msg_idx: int, *, is_sampled: bool) -> None: else "" for p in content ) - emit_special(self._user_token, i, is_sampled=False) - emit_text(str(content), i, is_sampled=False) + emit_special(self._user_token, i, is_sampled=False, is_content=False) + emit_text(str(content), i, is_sampled=False, is_content=True) elif role == "assistant": self._render_assistant( @@ -190,6 +217,7 @@ def emit_text(text: str, msg_idx: int, *, is_sampled: bool) -> None: messages, emit_special=emit_special, emit_text=emit_text, + emit_text_segments=emit_text_segments, ) elif role == "tool": @@ -198,6 +226,7 @@ def emit_text(text: str, msg_idx: int, *, is_sampled: bool) -> None: i, emit_special=emit_special, emit_text=emit_text, + emit_text_segments=emit_text_segments, ) # ── 4. Generation prompt ────────────────────────────────────── @@ -205,14 +234,17 @@ def emit_text(text: str, msg_idx: int, *, is_sampled: bool) -> None: # Don't add <|Assistant|> after tool outputs — content flows directly. last_role = messages[-1]["role"] if messages else None if last_role != "tool": - emit_special(self._assistant_token, -1, is_sampled=False) + emit_special( + self._assistant_token, -1, is_sampled=False, is_content=False + ) if self._enable_thinking: - emit_text("\n", -1, is_sampled=False) + emit_text("\n", -1, is_sampled=False, is_content=False) return RenderedTokens( token_ids=tokens, message_indices=indices, sampled_mask=sampled, + is_content=content_mask, message_roles=[m.get("role") or "" for m in messages], ) @@ -276,27 +308,38 @@ def bridge_to_next_turn( ext: list[int] = [] ext_indices: list[int] = [] ext_sampled: list[bool] = [] + ext_content: list[bool] = [] # Bridge populates ``message_indices`` (relative to ``new_messages``) # and ``sampled_mask`` (uniformly ``False`` — every token the # bridge emits is template scaffolding for the next prompt, not - # something the model sampled). Downstream consumers can run - # :meth:`RenderedTokens.tokens_per_message` on the bridge output - # to get per-new-message token counts without re-rendering. + # something the model sampled). ``is_content`` follows the same + # rules as in :meth:`render` so consumers can walk the trajectory + # and read each step's own body mask. def emit_special( - token_id: int, msg_idx: int = -1, *, is_sampled: bool = False + token_id: int, + msg_idx: int = -1, + *, + is_sampled: bool = False, + is_content: bool = False, ) -> None: ext.append(token_id) ext_indices.append(msg_idx) ext_sampled.append(is_sampled) + ext_content.append(is_content) def emit_text( - text: str, msg_idx: int = -1, *, is_sampled: bool = False + text: str, + msg_idx: int = -1, + *, + is_sampled: bool = False, + is_content: bool = False, ) -> None: ids = self._encode(text) ext.extend(ids) ext_indices.extend([msg_idx] * len(ids)) ext_sampled.extend([is_sampled] * len(ids)) + ext_content.extend([is_content] * len(ids)) for i, msg in enumerate(new_messages): role = msg.get("role") @@ -309,11 +352,11 @@ def emit_text( if role == "user": emit_special(self._user_token, i) - emit_text(content, i) + emit_text(content, i, is_content=True) elif role == "system": # Post-initial system messages render as user turns. emit_special(self._user_token, i) - emit_text(content, i) + emit_text(content, i, is_content=True) elif role == "tool": prev_is_tool = i > 0 and new_messages[i - 1].get("role") == "tool" next_is_tool = ( @@ -323,7 +366,7 @@ def emit_text( if not prev_is_tool: emit_special(self._tool_outputs_begin, i) emit_special(self._tool_output_begin, i) - emit_text(content, i) + emit_text(content, i, is_content=True) emit_special(self._tool_output_end, i) if not next_is_tool: emit_special(self._tool_outputs_end, i) @@ -344,6 +387,7 @@ def emit_text( token_ids=previous_ids + ext, message_indices=[-1] * len(previous_ids) + ext_indices, sampled_mask=[False] * total_len, + is_content=[False] * len(previous_ids) + ext_content, message_roles=[m.get("role") or "" for m in new_messages], ) @@ -359,6 +403,7 @@ def _render_assistant( *, emit_special, emit_text, + emit_text_segments, ) -> None: # Determine whether this message follows a tool output sequence. # The HF template emits <|tool▁outputs▁end|> before the assistant content @@ -391,18 +436,23 @@ def _render_assistant( # keeps the SFT loss mask aligned with what the model would # actually have produced. When the previous message is a tool # response, the template skips this token entirely (content - # flows directly out of ``<|tool▁outputs▁end|>``). + # flows directly out of ``<|tool▁outputs▁end|>``). On assistant + # the invariant ``is_content == sampled_mask`` holds. if not prev_is_tool: - emit_special(self._assistant_token, msg_idx, is_sampled=False) + emit_special( + self._assistant_token, msg_idx, is_sampled=False, is_content=False + ) if not tool_calls: - emit_text(content, msg_idx, is_sampled=True) + emit_text(content, msg_idx, is_sampled=True, is_content=True) else: # Emit any pre-tool-call content first. - emit_text(content, msg_idx, is_sampled=True) + emit_text(content, msg_idx, is_sampled=True, is_content=True) # Tool call section. - emit_special(self._tool_calls_begin, msg_idx, is_sampled=True) + emit_special( + self._tool_calls_begin, msg_idx, is_sampled=True, is_content=True + ) for tc in tool_calls: func = tc.get("function") or tc name = func.get("name", "") @@ -414,17 +464,28 @@ def _render_assistant( ) # Format: <|tool▁call▁begin|>function<|tool▁sep|>{name}\n```json\n{args}\n```<|tool▁call▁end|> # tool_sep is a special token; type ("function") and name+args are plain text. - emit_special(self._tool_call_begin, msg_idx, is_sampled=True) - emit_text("function", msg_idx, is_sampled=True) - emit_special(self._tool_sep, msg_idx, is_sampled=True) - emit_text(f"{name}\n```json\n{args_str}\n```", msg_idx, is_sampled=True) - emit_special(self._tool_call_end, msg_idx, is_sampled=True) - emit_special(self._tool_calls_end, msg_idx, is_sampled=True) + emit_special( + self._tool_call_begin, msg_idx, is_sampled=True, is_content=True + ) + emit_text("function", msg_idx, is_sampled=True, is_content=True) + emit_special(self._tool_sep, msg_idx, is_sampled=True, is_content=True) + emit_text( + f"{name}\n```json\n{args_str}\n```", + msg_idx, + is_sampled=True, + is_content=True, + ) + emit_special( + self._tool_call_end, msg_idx, is_sampled=True, is_content=True + ) + emit_special( + self._tool_calls_end, msg_idx, is_sampled=True, is_content=True + ) # ``<|end▁of▁sentence|>`` is the model's stop signal — it # samples this to end its turn, so it is part of the sampled # stream. - emit_special(self._eos, msg_idx, is_sampled=True) + emit_special(self._eos, msg_idx, is_sampled=True, is_content=True) # ------------------------------------------------------------------ # Tool (tool-response) rendering @@ -437,10 +498,13 @@ def _render_tool( *, emit_special, emit_text, + emit_text_segments, ) -> None: # Tool messages are conversation history injected by the runtime # between assistant turns — the model never samples any of these - # tokens, so every emission is is_sampled=False. + # tokens, so every emission is is_sampled=False. The ``content`` + # body bytes get ``is_content=True``; the surrounding section + # specials are scaffold. prev_is_tool = msg_idx > 0 and messages[msg_idx - 1]["role"] == "tool" next_is_tool = ( msg_idx + 1 < len(messages) and messages[msg_idx + 1]["role"] == "tool" @@ -451,11 +515,17 @@ def _render_tool( content = "".join(p.get("text", "") for p in content if isinstance(p, dict)) if not prev_is_tool: - emit_special(self._tool_outputs_begin, msg_idx, is_sampled=False) + emit_special( + self._tool_outputs_begin, msg_idx, is_sampled=False, is_content=False + ) - emit_special(self._tool_output_begin, msg_idx, is_sampled=False) - emit_text(str(content), msg_idx, is_sampled=False) - emit_special(self._tool_output_end, msg_idx, is_sampled=False) + emit_special( + self._tool_output_begin, msg_idx, is_sampled=False, is_content=False + ) + emit_text(str(content), msg_idx, is_sampled=False, is_content=True) + emit_special(self._tool_output_end, msg_idx, is_sampled=False, is_content=False) if not next_is_tool: - emit_special(self._tool_outputs_end, msg_idx, is_sampled=False) + emit_special( + self._tool_outputs_end, msg_idx, is_sampled=False, is_content=False + ) diff --git a/renderers/glm45.py b/renderers/glm45.py index 4a80ab3..206f366 100644 --- a/renderers/glm45.py +++ b/renderers/glm45.py @@ -20,6 +20,7 @@ ParsedResponse, RenderedTokens, ToolSpec, + attribute_text_segments, reject_assistant_in_extension, should_preserve_past_thinking, ) @@ -128,30 +129,58 @@ def render( tokens: list[int] = [] indices: list[int] = [] sampled: list[bool] = [] + content_mask: list[bool] = [] - def emit_special(token_id: int, msg_idx: int, *, is_sampled: bool) -> None: + def emit_special( + token_id: int, msg_idx: int, *, is_sampled: bool, is_content: bool + ) -> None: tokens.append(token_id) indices.append(msg_idx) sampled.append(is_sampled) + content_mask.append(is_content) - def emit_text(text: str, msg_idx: int, *, is_sampled: bool) -> None: + def emit_text( + text: str, msg_idx: int, *, is_sampled: bool, is_content: bool + ) -> None: ids = self._encode(text) tokens.extend(ids) indices.extend([msg_idx] * len(ids)) sampled.extend([is_sampled] * len(ids)) + content_mask.extend([is_content] * len(ids)) + + def emit_text_segments( + segments: list[tuple[str, bool]], msg_idx: int, *, is_sampled: bool + ) -> None: + """Tokenize concatenated segments as one BPE pass; per-token + ``is_content`` follows each token's source segment. + + Lets call sites express "this wrap + this body, joined the + same way as the chat template, but attributed separately" + without splitting the encode call (which could shift BPE + merges at the boundary).""" + for tok_id, is_content in attribute_text_segments( + self._tokenizer, segments + ): + tokens.append(tok_id) + indices.append(msg_idx) + sampled.append(is_sampled) + content_mask.append(is_content) # ── Prefix ────────────────────────────────────────────────── - emit_special(self._gmask, -1, is_sampled=False) - emit_special(self._sop, -1, is_sampled=False) + emit_special(self._gmask, -1, is_sampled=False, is_content=False) + emit_special(self._sop, -1, is_sampled=False, is_content=False) # ── Tools in system prompt ────────────────────────────────── + # The tools-header block is all scaffold by design — the tools + # dict is recoverable from the ``tools`` argument; don't + # re-attribute the embedded JSON specs as message body. if tools: - emit_special(self._system, -1, is_sampled=False) + emit_special(self._system, -1, is_sampled=False, is_content=False) tool_text = _TOOLS_HEADER for tool in tools: tool_text += json.dumps(tool, ensure_ascii=False) + "\n" tool_text += _TOOLS_FOOTER - emit_text(tool_text, -1, is_sampled=False) + emit_text(tool_text, -1, is_sampled=False, is_content=False) # ── Compute last_user_index ───────────────────────────────── last_ui = self._last_user_index(messages) @@ -162,15 +191,22 @@ def emit_text(text: str, msg_idx: int, *, is_sampled: bool) -> None: content = self._visible_text(msg.get("content")) if role == "system": - emit_special(self._system, i, is_sampled=False) - emit_text("\n" + content, i, is_sampled=False) + emit_special(self._system, i, is_sampled=False, is_content=False) + # ``\n`` is the scaffold separator after the role tag; + # the body proper is the caller-provided content. + emit_text_segments( + [("\n", False), (content, True)], i, is_sampled=False + ) elif role == "user": - emit_special(self._user, i, is_sampled=False) - user_text = "\n" + content + emit_special(self._user, i, is_sampled=False, is_content=False) + # ``\n`` is scaffold; ``content`` is body; the optional + # ``/nothink`` suffix is scaffold the renderer injects + # when ``enable_thinking=False``. + user_segments: list[tuple[str, bool]] = [("\n", False), (content, True)] if not self._enable_thinking and not content.endswith("/nothink"): - user_text += "/nothink" - emit_text(user_text, i, is_sampled=False) + user_segments.append(("/nothink", False)) + emit_text_segments(user_segments, i, is_sampled=False) elif role == "assistant": preserve_thinking = should_preserve_past_thinking( @@ -187,25 +223,32 @@ def emit_text(text: str, msg_idx: int, *, is_sampled: bool) -> None: preserve_thinking=preserve_thinking, emit_special=emit_special, emit_text=emit_text, + emit_text_segments=emit_text_segments, ) elif role == "tool": self._render_tool( - messages, i, content, emit_special=emit_special, emit_text=emit_text + messages, + i, + content, + emit_special=emit_special, + emit_text=emit_text, + emit_text_segments=emit_text_segments, ) # ── Generation prompt ─────────────────────────────────────── if add_generation_prompt: - emit_special(self._assistant, -1, is_sampled=False) + emit_special(self._assistant, -1, is_sampled=False, is_content=False) if not self._enable_thinking: - emit_text("\n", -1, is_sampled=False) - emit_special(self._think, -1, is_sampled=False) - emit_special(self._think_end, -1, is_sampled=False) + emit_text("\n", -1, is_sampled=False, is_content=False) + emit_special(self._think, -1, is_sampled=False, is_content=False) + emit_special(self._think_end, -1, is_sampled=False, is_content=False) return RenderedTokens( token_ids=tokens, message_indices=indices, sampled_mask=sampled, + is_content=content_mask, message_roles=[m.get("role") or "" for m in messages], ) @@ -276,27 +319,54 @@ def bridge_to_next_turn( ext: list[int] = [] ext_indices: list[int] = [] ext_sampled: list[bool] = [] + ext_content: list[bool] = [] # Bridge populates ``message_indices`` (relative to ``new_messages``) # and ``sampled_mask`` (uniformly ``False`` — every token the # bridge emits is template scaffolding for the next prompt, not - # something the model sampled). Downstream consumers can run - # :meth:`RenderedTokens.tokens_per_message` on the bridge output - # to get per-new-message token counts without re-rendering. + # something the model sampled). ``is_content`` follows the same + # rules as in :meth:`render` so consumers can walk the trajectory + # and read each step's own body mask. Downstream consumers can + # run :meth:`RenderedTokens.tokens_per_message` on the bridge + # output to get per-new-message token counts without re-rendering. def emit_special( - token_id: int, msg_idx: int = -1, *, is_sampled: bool = False + token_id: int, + msg_idx: int = -1, + *, + is_sampled: bool = False, + is_content: bool = False, ) -> None: ext.append(token_id) ext_indices.append(msg_idx) ext_sampled.append(is_sampled) + ext_content.append(is_content) def emit_text( - text: str, msg_idx: int = -1, *, is_sampled: bool = False + text: str, + msg_idx: int = -1, + *, + is_sampled: bool = False, + is_content: bool = False, ) -> None: ids = self._encode(text) ext.extend(ids) ext_indices.extend([msg_idx] * len(ids)) ext_sampled.extend([is_sampled] * len(ids)) + ext_content.extend([is_content] * len(ids)) + + def emit_text_segments( + segments: list[tuple[str, bool]], + msg_idx: int = -1, + *, + is_sampled: bool = False, + ) -> None: + for tok_id, is_content in attribute_text_segments( + self._tokenizer, segments + ): + ext.append(tok_id) + ext_indices.append(msg_idx) + ext_sampled.append(is_sampled) + ext_content.append(is_content) for i, msg in enumerate(new_messages): role = msg.get("role") @@ -304,20 +374,30 @@ def emit_text( if role == "user": if not (i == 0 and last_prev == self._user): emit_special(self._user, i) - user_text = "\n" + content + user_segments: list[tuple[str, bool]] = [ + ("\n", False), + (content, True), + ] if not self._enable_thinking and not content.endswith("/nothink"): - user_text += "/nothink" - emit_text(user_text, i) + user_segments.append(("/nothink", False)) + emit_text_segments(user_segments, i) elif role == "system": emit_special(self._system, i) - emit_text("\n" + content, i) + emit_text_segments([("\n", False), (content, True)], i) elif role == "tool": prev_is_tool = i > 0 and new_messages[i - 1].get("role") == "tool" if i == 0 and last_prev == self._observation: pass elif not prev_is_tool: emit_special(self._observation, i) - emit_text("\n\n" + content + "\n", i) + emit_text_segments( + [ + ("\n\n", False), + (content, True), + ("\n", False), + ], + i, + ) else: return None @@ -333,6 +413,7 @@ def emit_text( token_ids=previous_ids + ext, message_indices=[-1] * len(previous_ids) + ext_indices, sampled_mask=[False] * total_len, + is_content=[False] * len(previous_ids) + ext_content, message_roles=[m.get("role") or "" for m in new_messages], ) @@ -346,6 +427,7 @@ def _render_assistant( preserve_thinking: bool = False, emit_special, emit_text, + emit_text_segments, ): reasoning_content = "" if isinstance(msg.get("reasoning_content"), str): @@ -373,23 +455,31 @@ def _render_assistant( # turn). So no sampled stop-signal token lives inside this # assistant span — content / think / tool_calls carry the # is_sampled=True signal. - emit_special(self._assistant, msg_idx, is_sampled=False) - emit_text("\n", msg_idx, is_sampled=False) + # + # Invariant on assistant tokens: ``is_content == sampled_mask``. + # Every scaffold token here gets ``is_sampled=False/is_content=False``; + # every model-sampled emit gets both True. + emit_special(self._assistant, msg_idx, is_sampled=False, is_content=False) + emit_text("\n", msg_idx, is_sampled=False, is_content=False) if (msg_idx > last_user_index or preserve_thinking) and reasoning_content: - emit_special(self._think, msg_idx, is_sampled=True) - emit_text(reasoning_content.strip(), msg_idx, is_sampled=True) - emit_special(self._think_end, msg_idx, is_sampled=True) + emit_special(self._think, msg_idx, is_sampled=True, is_content=True) + emit_text( + reasoning_content.strip(), msg_idx, is_sampled=True, is_content=True + ) + emit_special(self._think_end, msg_idx, is_sampled=True, is_content=True) else: - emit_special(self._think, msg_idx, is_sampled=True) - emit_special(self._think_end, msg_idx, is_sampled=True) + emit_special(self._think, msg_idx, is_sampled=True, is_content=True) + emit_special(self._think_end, msg_idx, is_sampled=True, is_content=True) # Tool calls — keep content + \n contiguous to preserve BPE merges tool_calls = msg.get("tool_calls") or [] if content.strip() and tool_calls: - emit_text("\n" + content.strip() + "\n", msg_idx, is_sampled=True) + emit_text( + "\n" + content.strip() + "\n", msg_idx, is_sampled=True, is_content=True + ) elif content.strip(): - emit_text("\n" + content.strip(), msg_idx, is_sampled=True) + emit_text("\n" + content.strip(), msg_idx, is_sampled=True, is_content=True) for tc in tool_calls: func = tc.get("function") or tc @@ -397,9 +487,9 @@ def _render_assistant( arguments = func.get("arguments", {}) if not content.strip(): - emit_text("\n", msg_idx, is_sampled=True) - emit_special(self._tool_call_tok, msg_idx, is_sampled=True) - emit_text(name + "\n", msg_idx, is_sampled=True) + emit_text("\n", msg_idx, is_sampled=True, is_content=True) + emit_special(self._tool_call_tok, msg_idx, is_sampled=True, is_content=True) + emit_text(name + "\n", msg_idx, is_sampled=True, is_content=True) # OpenAI canonical form: arguments is a JSON string. Parse it so the # per-argument rendering below still works. if isinstance(arguments, str): @@ -409,22 +499,33 @@ def _render_assistant( arguments = {} if isinstance(arguments, dict): for arg_name, arg_value in arguments.items(): - emit_special(self._arg_key, msg_idx, is_sampled=True) - emit_text(arg_name, msg_idx, is_sampled=True) - emit_special(self._arg_key_end, msg_idx, is_sampled=True) - emit_text("\n", msg_idx, is_sampled=True) - emit_special(self._arg_value, msg_idx, is_sampled=True) + emit_special( + self._arg_key, msg_idx, is_sampled=True, is_content=True + ) + emit_text(arg_name, msg_idx, is_sampled=True, is_content=True) + emit_special( + self._arg_key_end, msg_idx, is_sampled=True, is_content=True + ) + emit_text("\n", msg_idx, is_sampled=True, is_content=True) + emit_special( + self._arg_value, msg_idx, is_sampled=True, is_content=True + ) if isinstance(arg_value, str): - emit_text(arg_value, msg_idx, is_sampled=True) + emit_text(arg_value, msg_idx, is_sampled=True, is_content=True) else: emit_text( json.dumps(arg_value, ensure_ascii=False), msg_idx, is_sampled=True, + is_content=True, ) - emit_special(self._arg_value_end, msg_idx, is_sampled=True) - emit_text("\n", msg_idx, is_sampled=True) - emit_special(self._tool_call_end_tok, msg_idx, is_sampled=True) + emit_special( + self._arg_value_end, msg_idx, is_sampled=True, is_content=True + ) + emit_text("\n", msg_idx, is_sampled=True, is_content=True) + emit_special( + self._tool_call_end_tok, msg_idx, is_sampled=True, is_content=True + ) def _render_tool( self, @@ -434,17 +535,30 @@ def _render_tool( *, emit_special, emit_text, + emit_text_segments, ) -> None: # Tool messages are conversation history injected by the runtime # between assistant turns — the model never samples any of these - # tokens, so every emission is is_sampled=False. + # tokens, so every emission is is_sampled=False. The body bytes + # get ``is_content=True``; the ``\n\n`` / + # ``\n`` wraps and the ``<|observation|>`` role + # tag are scaffold so the SFT mask for tool body never trains + # the model to emit them. Single BPE pass over the joined text + # preserves boundary merges (the tool body's leading/trailing + # chars can merge with the wrap's ``\n``s if the tokenizer would + # do so; we route through ``emit_text_segments`` so the + # attribution is offset-driven and tokenizer-agnostic). prev_is_tool = msg_idx > 0 and messages[msg_idx - 1]["role"] == "tool" if not prev_is_tool: - emit_special(self._observation, msg_idx, is_sampled=False) - - emit_text( - "\n\n" + content + "\n", + emit_special(self._observation, msg_idx, is_sampled=False, is_content=False) + + emit_text_segments( + [ + ("\n\n", False), + (content, True), + ("\n", False), + ], msg_idx, is_sampled=False, ) diff --git a/renderers/glm5.py b/renderers/glm5.py index 63599f2..6de6ba3 100644 --- a/renderers/glm5.py +++ b/renderers/glm5.py @@ -21,6 +21,7 @@ ParsedResponse, RenderedTokens, ToolSpec, + attribute_text_segments, reject_assistant_in_extension, should_preserve_past_thinking, ) @@ -143,32 +144,61 @@ def render( tokens: list[int] = [] indices: list[int] = [] sampled: list[bool] = [] + content_mask: list[bool] = [] - def emit_special(token_id: int, msg_idx: int, *, is_sampled: bool) -> None: + def emit_special( + token_id: int, msg_idx: int, *, is_sampled: bool, is_content: bool + ) -> None: tokens.append(token_id) indices.append(msg_idx) sampled.append(is_sampled) + content_mask.append(is_content) - def emit_text(text: str, msg_idx: int, *, is_sampled: bool) -> None: + def emit_text( + text: str, msg_idx: int, *, is_sampled: bool, is_content: bool + ) -> None: ids = self._encode(text) tokens.extend(ids) indices.extend([msg_idx] * len(ids)) sampled.extend([is_sampled] * len(ids)) + content_mask.extend([is_content] * len(ids)) + + def emit_text_segments( + segments: list[tuple[str, bool]], msg_idx: int, *, is_sampled: bool + ) -> None: + """Tokenize concatenated segments as one BPE pass; per-token + ``is_content`` follows each token's source segment. + + Lets call sites express "this wrap + this body, joined the + same way as the chat template, but attributed separately" + without splitting the encode call (which could shift BPE + merges at the boundary).""" + for tok_id, is_content in attribute_text_segments( + self._tokenizer, segments + ): + tokens.append(tok_id) + indices.append(msg_idx) + sampled.append(is_sampled) + content_mask.append(is_content) # ── Prefix ────────────────────────────────────────────────── # ``[gMASK]`` is unconditional template scaffolding at the - # very start of the stream — the model never samples these. - emit_special(self._gmask, -1, is_sampled=False) - emit_special(self._sop, -1, is_sampled=False) + # very start of the stream — the model never samples these and + # they are not part of any message body. + emit_special(self._gmask, -1, is_sampled=False, is_content=False) + emit_special(self._sop, -1, is_sampled=False, is_content=False) # ── Tools in system prompt ────────────────────────────────── + # The tools-header block is all scaffold by design — the tools + # dict is recoverable from the ``tools`` argument; don't + # re-attribute the embedded JSON specs as message body. if tools: - emit_special(self._system, -1, is_sampled=False) + emit_special(self._system, -1, is_sampled=False, is_content=False) tool_text = _TOOLS_HEADER for tool in tools: tool_text += self._format_tool_spec(tool) + "\n" tool_text += _TOOLS_FOOTER - emit_text(tool_text, -1, is_sampled=False) + emit_text(tool_text, -1, is_sampled=False, is_content=False) # ── Compute last_user_index ───────────────────────────────── last_ui = self._last_user_index(messages) @@ -179,12 +209,12 @@ def emit_text(text: str, msg_idx: int, *, is_sampled: bool) -> None: content = self._visible_text(msg.get("content")) if role == "system": - emit_special(self._system, i, is_sampled=False) - emit_text(content, i, is_sampled=False) + emit_special(self._system, i, is_sampled=False, is_content=False) + emit_text(content, i, is_sampled=False, is_content=True) elif role == "user": - emit_special(self._user, i, is_sampled=False) - emit_text(content, i, is_sampled=False) + emit_special(self._user, i, is_sampled=False, is_content=False) + emit_text(content, i, is_sampled=False, is_content=True) elif role == "assistant": preserve_thinking = should_preserve_past_thinking( @@ -201,28 +231,35 @@ def emit_text(text: str, msg_idx: int, *, is_sampled: bool) -> None: preserve_thinking=preserve_thinking, emit_special=emit_special, emit_text=emit_text, + emit_text_segments=emit_text_segments, ) elif role == "tool": self._render_tool( - messages, i, content, emit_special=emit_special, emit_text=emit_text + messages, + i, + content, + emit_special=emit_special, + emit_text=emit_text, + emit_text_segments=emit_text_segments, ) # ── Generation prompt ─────────────────────────────────────── # Gen prompt tokens are what the chat template prepends before # sampling starts — the model continues from these, never emits - # them. Always is_sampled=False. + # them. Always is_sampled=False / is_content=False. if add_generation_prompt: - emit_special(self._assistant, -1, is_sampled=False) + emit_special(self._assistant, -1, is_sampled=False, is_content=False) if self._enable_thinking: - emit_special(self._think, -1, is_sampled=False) + emit_special(self._think, -1, is_sampled=False, is_content=False) else: - emit_special(self._think_end, -1, is_sampled=False) + emit_special(self._think_end, -1, is_sampled=False, is_content=False) return RenderedTokens( token_ids=tokens, message_indices=indices, sampled_mask=sampled, + is_content=content_mask, message_roles=[m.get("role") or "" for m in messages], ) @@ -297,27 +334,54 @@ def bridge_to_next_turn( ext: list[int] = [] ext_indices: list[int] = [] ext_sampled: list[bool] = [] + ext_content: list[bool] = [] # Bridge populates ``message_indices`` (relative to ``new_messages``) # and ``sampled_mask`` (uniformly ``False`` — every token the # bridge emits is template scaffolding for the next prompt, not - # something the model sampled). Downstream consumers can run - # :meth:`RenderedTokens.tokens_per_message` on the bridge output - # to get per-new-message token counts without re-rendering. + # something the model sampled). ``is_content`` follows the same + # rules as in :meth:`render` so consumers can walk the trajectory + # and read each step's own body mask. Downstream consumers can + # run :meth:`RenderedTokens.tokens_per_message` on the bridge + # output to get per-new-message token counts without re-rendering. def emit_special( - token_id: int, msg_idx: int = -1, *, is_sampled: bool = False + token_id: int, + msg_idx: int = -1, + *, + is_sampled: bool = False, + is_content: bool = False, ) -> None: ext.append(token_id) ext_indices.append(msg_idx) ext_sampled.append(is_sampled) + ext_content.append(is_content) def emit_text( - text: str, msg_idx: int = -1, *, is_sampled: bool = False + text: str, + msg_idx: int = -1, + *, + is_sampled: bool = False, + is_content: bool = False, ) -> None: ids = self._encode(text) ext.extend(ids) ext_indices.extend([msg_idx] * len(ids)) ext_sampled.extend([is_sampled] * len(ids)) + ext_content.extend([is_content] * len(ids)) + + def emit_text_segments( + segments: list[tuple[str, bool]], + msg_idx: int = -1, + *, + is_sampled: bool = False, + ) -> None: + for tok_id, is_content in attribute_text_segments( + self._tokenizer, segments + ): + ext.append(tok_id) + ext_indices.append(msg_idx) + ext_sampled.append(is_sampled) + ext_content.append(is_content) for i, msg in enumerate(new_messages): role = msg.get("role") @@ -326,10 +390,10 @@ def emit_text( # Dedup: model already emitted <|user|> as its stop token. if not (i == 0 and last_prev == self._user): emit_special(self._user, i) - emit_text(content, i) + emit_text(content, i, is_content=True) elif role == "system": emit_special(self._system, i) - emit_text(content, i) + emit_text(content, i, is_content=True) elif role == "tool": prev_is_tool = i > 0 and new_messages[i - 1].get("role") == "tool" if i == 0 and last_prev == self._observation: @@ -338,7 +402,7 @@ def emit_text( elif not prev_is_tool: emit_special(self._observation, i) emit_special(self._tool_response_tok, i) - emit_text(content, i) + emit_text(content, i, is_content=True) emit_special(self._tool_response_end_tok, i) else: return None @@ -355,6 +419,7 @@ def emit_text( token_ids=previous_ids + ext, message_indices=[-1] * len(previous_ids) + ext_indices, sampled_mask=[False] * total_len, + is_content=[False] * len(previous_ids) + ext_content, message_roles=[m.get("role") or "" for m in new_messages], ) @@ -368,6 +433,7 @@ def _render_assistant( preserve_thinking: bool = False, emit_special, emit_text, + emit_text_segments, ): reasoning_content = "" if isinstance(msg.get("reasoning_content"), str): @@ -386,7 +452,11 @@ def _render_assistant( # samples it. Same for the ```` open / standalone # ```` separator that the template wraps around the # assistant body — see the per-branch comments below. - emit_special(self._assistant, msg_idx, is_sampled=False) + # + # Invariant on assistant tokens: ``is_content == sampled_mask``. + # Every scaffold token here gets ``is_sampled=False/is_content=False``; + # every model-sampled emit gets both True. + emit_special(self._assistant, msg_idx, is_sampled=False, is_content=False) # Chat-template default: keep ```` only on the in-flight cycle # (post-last-user). Past-cycle assistants drop their reasoning. @@ -403,9 +473,11 @@ def _render_assistant( # inference (gen prompt = ``<|assistant|>``), so it's # template-injected scaffolding. The reasoning text and the # closing ```` are what the model actually samples. - emit_special(self._think, msg_idx, is_sampled=False) - emit_text(reasoning_content.strip(), msg_idx, is_sampled=True) - emit_special(self._think_end, msg_idx, is_sampled=True) + emit_special(self._think, msg_idx, is_sampled=False, is_content=False) + emit_text( + reasoning_content.strip(), msg_idx, is_sampled=True, is_content=True + ) + emit_special(self._think_end, msg_idx, is_sampled=True, is_content=True) elif self.empty_think_on_last_assistant and msg_idx > last_user_index: # GLM-5.1: wrap the last assistant with an empty # even without reasoning, matching the Jinja template. With @@ -413,26 +485,28 @@ def _render_assistant( # ````; the model then samples ```` to close an # empty think block. So ```` is scaffolding, # ```` is sampled. - emit_special(self._think, msg_idx, is_sampled=False) - emit_special(self._think_end, msg_idx, is_sampled=True) + emit_special(self._think, msg_idx, is_sampled=False, is_content=False) + emit_special(self._think_end, msg_idx, is_sampled=True, is_content=True) else: # Lone ```` separator the template injects when no # reasoning is rendered (historical assistants, GLM-5 default # with no thinking). Not sampled. - emit_special(self._think_end, msg_idx, is_sampled=False) + emit_special(self._think_end, msg_idx, is_sampled=False, is_content=False) if content.strip(): - emit_text(content.strip(), msg_idx, is_sampled=True) + emit_text(content.strip(), msg_idx, is_sampled=True, is_content=True) - # Tool calls (directly after content, no newlines) + # Tool calls (directly after content, no newlines). All of these + # are the model's sampled output — both is_sampled and is_content + # are True across the entire tool-call span. tool_calls = msg.get("tool_calls") or [] for tc in tool_calls: func = tc.get("function") or tc name = func.get("name", "") arguments = func.get("arguments", {}) - emit_special(self._tool_call_tok, msg_idx, is_sampled=True) - emit_text(name, msg_idx, is_sampled=True) + emit_special(self._tool_call_tok, msg_idx, is_sampled=True, is_content=True) + emit_text(name, msg_idx, is_sampled=True, is_content=True) # OpenAI canonical form: arguments is a JSON string. Parse it so the # per-argument rendering below still works. if isinstance(arguments, str): @@ -442,20 +516,31 @@ def _render_assistant( arguments = {} if isinstance(arguments, dict): for arg_name, arg_value in arguments.items(): - emit_special(self._arg_key, msg_idx, is_sampled=True) - emit_text(arg_name, msg_idx, is_sampled=True) - emit_special(self._arg_key_end, msg_idx, is_sampled=True) - emit_special(self._arg_value, msg_idx, is_sampled=True) + emit_special( + self._arg_key, msg_idx, is_sampled=True, is_content=True + ) + emit_text(arg_name, msg_idx, is_sampled=True, is_content=True) + emit_special( + self._arg_key_end, msg_idx, is_sampled=True, is_content=True + ) + emit_special( + self._arg_value, msg_idx, is_sampled=True, is_content=True + ) if isinstance(arg_value, str): - emit_text(arg_value, msg_idx, is_sampled=True) + emit_text(arg_value, msg_idx, is_sampled=True, is_content=True) else: emit_text( json.dumps(arg_value, ensure_ascii=False), msg_idx, is_sampled=True, + is_content=True, ) - emit_special(self._arg_value_end, msg_idx, is_sampled=True) - emit_special(self._tool_call_end_tok, msg_idx, is_sampled=True) + emit_special( + self._arg_value_end, msg_idx, is_sampled=True, is_content=True + ) + emit_special( + self._tool_call_end_tok, msg_idx, is_sampled=True, is_content=True + ) def _render_tool( self, @@ -465,18 +550,26 @@ def _render_tool( *, emit_special, emit_text, + emit_text_segments, ) -> None: # Tool messages are conversation history injected by the runtime # between assistant turns — the model never samples any of these - # tokens, so every emission is is_sampled=False. + # tokens, so every emission is is_sampled=False. The tool body + # bytes get ``is_content=True``; the ``<|observation|>`` / + # ```` wraps are scaffold so the SFT mask for + # tool body never trains the model to emit them. prev_is_tool = msg_idx > 0 and messages[msg_idx - 1]["role"] == "tool" if not prev_is_tool: - emit_special(self._observation, msg_idx, is_sampled=False) + emit_special(self._observation, msg_idx, is_sampled=False, is_content=False) - emit_special(self._tool_response_tok, msg_idx, is_sampled=False) - emit_text(content, msg_idx, is_sampled=False) - emit_special(self._tool_response_end_tok, msg_idx, is_sampled=False) + emit_special( + self._tool_response_tok, msg_idx, is_sampled=False, is_content=False + ) + emit_text(content, msg_idx, is_sampled=False, is_content=True) + emit_special( + self._tool_response_end_tok, msg_idx, is_sampled=False, is_content=False + ) class GLM51Renderer(GLM5Renderer): diff --git a/renderers/gpt_oss.py b/renderers/gpt_oss.py index 7a5b26e..9939de1 100644 --- a/renderers/gpt_oss.py +++ b/renderers/gpt_oss.py @@ -184,6 +184,86 @@ def _encode(self, text: str) -> list[int]: return [] return self._tokenizer.encode(text, add_special_tokens=False) + def _prefix_content_mask( + self, + prefix_tokens: list[int], + first_system_idx: int | None, + messages: list[Message], + tools: list[ToolSpec] | None, + ) -> list[bool]: + """Per-token is_content mask over the rendered system+developer prefix. + + Harmony's prefix is one opaque block. The caller's system content + lands inside the developer message as ``# Instructions\\n\\n{content}``. + To attribute body bytes back, we render the same prefix with empty + instructions and diff: the unique-to-with-instructions span is the + body region. Falls back to all-False (whole prefix scaffold) if + the caller didn't supply a system message — there's no body to + attribute in that case. + """ + n = len(prefix_tokens) + mask = [False] * n + if first_system_idx is None: + return mask + instructions = _content_text(messages[first_system_idx].get("content")) + if not instructions: + return mask + + # Build the same prefix with empty instructions. + empty_prefix_msgs: list[HarmonyMessage] = [] + if self._use_system_prompt: + sys_content = SystemContent.new().with_reasoning_effort( + self._reasoning_effort + ) + sys_content = sys_content.with_conversation_start_date( + self._conversation_start_date + ) + if self._knowledge_cutoff is not None: + sys_content = sys_content.with_knowledge_cutoff(self._knowledge_cutoff) + if self._model_identity is not None: + sys_content = sys_content.with_model_identity(self._model_identity) + empty_prefix_msgs.append( + HarmonyMessage.from_role_and_content(Role.SYSTEM, sys_content) + ) + dev = DeveloperContent.new() + if tools: + dev = dev.with_function_tools([_tool_to_description(t) for t in tools]) + empty_prefix_msgs.append( + HarmonyMessage.from_role_and_content(Role.DEVELOPER, dev) + ) + try: + empty_tokens = self._enc.render_conversation( + Conversation.from_messages(empty_prefix_msgs) + ) + except Exception: + return mask + + # Longest common prefix. + i_start = 0 + n_empty = len(empty_tokens) + while ( + i_start < min(n, n_empty) + and prefix_tokens[i_start] == empty_tokens[i_start] + ): + i_start += 1 + # Longest common suffix. + j_full = n + j_empty = n_empty + while ( + j_full > i_start + and j_empty > i_start + and prefix_tokens[j_full - 1] == empty_tokens[j_empty - 1] + ): + j_full -= 1 + j_empty -= 1 + # Tokens [i_start:j_full] in prefix_tokens are unique to the + # with-instructions render — that's the body span (includes the + # ``# Instructions\n\n`` scaffolding header, which the substring + # match in the body-decode test ignores). + for k in range(i_start, j_full): + mask[k] = True + return mask + # ── public interface ───────────────────────────────────────────────────── def render( @@ -199,11 +279,15 @@ def render( tokens: list[int] = [] indices: list[int] = [] sampled: list[bool] = [] + content_mask: list[bool] = [] - def emit(ids: list[int], msg_idx: int, *, is_sampled: bool) -> None: + def emit( + ids: list[int], msg_idx: int, *, is_sampled: bool, is_content: bool + ) -> None: tokens.extend(ids) indices.extend([msg_idx] * len(ids)) sampled.extend([is_sampled] * len(ids)) + content_mask.extend([is_content] * len(ids)) def emit_harmony_message( hm_ids: list[int], msg_idx: int, *, is_assistant: bool @@ -225,18 +309,53 @@ def emit_harmony_message( developer, tool) the whole message is conversation history the model never samples — every token is ``is_sampled=False``. + + ``is_content`` further splits the body: the trailing + terminator (``<|end|>`` / ``<|return|>`` / ``<|call|>``) is + scaffold on non-assistant turns; on assistant turns it's the + model's stop signal, so ``is_content=True`` mirrors + ``sampled_mask`` as the invariant on assistant requires. + The body bytes between ``<|message|>`` and the terminator + are body (``is_content=True``) on every role — that's the + caller-provided content (or, for assistant, the model's + sampled emission). The header (``<|start|>`` ... ``<|message|>``) + — including any ``functions.{name}`` recipient text on tool + results, which comes from a prior assistant's tool_calls + rather than this tool message's own content — is scaffold. """ try: msg_marker = hm_ids.index(self._message) except ValueError: # Defensive: a harmony message without <|message|> is # malformed. Treat the whole thing as scaffolding. - emit(hm_ids, msg_idx, is_sampled=False) + emit(hm_ids, msg_idx, is_sampled=False, is_content=False) return header = hm_ids[: msg_marker + 1] body = hm_ids[msg_marker + 1 :] - emit(header, msg_idx, is_sampled=False) - emit(body, msg_idx, is_sampled=is_assistant) + emit(header, msg_idx, is_sampled=False, is_content=False) + # Split body into content + terminator. The terminator (if + # present) is the last token of the body and is one of the + # three harmony stop tokens. + terminator_ids = {self._end, self._return, self._call} + if body and body[-1] in terminator_ids: + body_content = body[:-1] + terminator = body[-1:] + else: + body_content = body + terminator = [] + emit( + body_content, + msg_idx, + is_sampled=is_assistant, + is_content=True, + ) + if terminator: + emit( + terminator, + msg_idx, + is_sampled=is_assistant, + is_content=is_assistant, + ) # ── Build harmony prefix (system + developer) ─────────────────── # When tools are present, harmony's conversation-level renderer @@ -285,7 +404,23 @@ def emit_harmony_message( # caller-relative attribution); otherwise to -1 (pure scaffolding). # The whole prefix is pure template scaffolding — never sampled. prefix_origin = first_system_idx if first_system_idx is not None else -1 - emit(prefix_tokens, prefix_origin, is_sampled=False) + # Compute the body-token span inside the prefix by diffing + # against the same render with empty developer instructions. + # Tokens unique to the with-instructions render are the body + # span (``# Instructions\n\n{caller_system_content}``). Marking + # those is_content=True so the caller's system text is + # recoverable from ``content_token_spans_by_role()["system"]``. + # The scaffolding ``# Instructions\n\n`` prefix bleeds into + # the body run; consumers reading the body do a substring + # check rather than expecting an exact match. + prefix_content_mask = self._prefix_content_mask( + prefix_tokens, first_system_idx, messages, tools + ) + for tid, is_content in zip(prefix_tokens, prefix_content_mask): + tokens.append(tid) + indices.append(prefix_origin) + sampled.append(False) + content_mask.append(is_content) # ── Iterate the rest of the messages ──────────────────────────── last_idx = len(messages) - 1 @@ -326,16 +461,17 @@ def emit_harmony_message( # ── Generation prompt: <|start|>assistant<|channel|>analysis<|message|> # Pure template scaffolding the model continues from — never sampled. if add_generation_prompt: - emit([self._start], -1, is_sampled=False) - emit(self._encode("assistant"), -1, is_sampled=False) - emit([self._channel], -1, is_sampled=False) - emit(self._encode("analysis"), -1, is_sampled=False) - emit([self._message], -1, is_sampled=False) + emit([self._start], -1, is_sampled=False, is_content=False) + emit(self._encode("assistant"), -1, is_sampled=False, is_content=False) + emit([self._channel], -1, is_sampled=False, is_content=False) + emit(self._encode("analysis"), -1, is_sampled=False, is_content=False) + emit([self._message], -1, is_sampled=False, is_content=False) return RenderedTokens( token_ids=tokens, message_indices=indices, sampled_mask=sampled, + is_content=content_mask, message_roles=[m.get("role") or "" for m in messages], ) @@ -407,17 +543,47 @@ def bridge_to_next_turn( # and ``sampled_mask`` (uniformly ``False``). The harmony encoder # renders each ``new_messages[i]`` as a single block, so every # token in that block carries index ``i``; the trailing - # generation prompt uses ``-1``. + # generation prompt uses ``-1``. ``is_content`` follows the same + # rules as :meth:`render`'s ``emit_harmony_message``: header is + # scaffold, body bytes are body, terminator scaffold (the bridge + # never carries assistant turns, so terminators are always + # scaffold on the non-assistant roles the bridge accepts). + terminator_ids = {self._end, self._return, self._call} ext: list[int] = [] ext_indices: list[int] = [] + ext_content: list[bool] = [] for i, msg in enumerate(new_messages): role = msg.get("role") if role not in ("tool", "user", "system", "developer"): return None for hm in self._to_harmony_messages(msg): ids = self._enc.render(hm) - ext.extend(ids) - ext_indices.extend([i] * len(ids)) + try: + msg_marker = ids.index(self._message) + except ValueError: + # Defensive: treat as scaffolding. + ext.extend(ids) + ext_indices.extend([i] * len(ids)) + ext_content.extend([False] * len(ids)) + continue + header = ids[: msg_marker + 1] + body = ids[msg_marker + 1 :] + ext.extend(header) + ext_indices.extend([i] * len(header)) + ext_content.extend([False] * len(header)) + if body and body[-1] in terminator_ids: + body_content = body[:-1] + terminator = body[-1:] + else: + body_content = body + terminator = [] + ext.extend(body_content) + ext_indices.extend([i] * len(body_content)) + ext_content.extend([True] * len(body_content)) + if terminator: + ext.extend(terminator) + ext_indices.extend([i] * len(terminator)) + ext_content.extend([False] * len(terminator)) # Generation prompt: <|start|>assistant<|channel|>analysis<|message|> gen_before = len(ext) @@ -427,12 +593,14 @@ def bridge_to_next_turn( ext.extend(self._encode("analysis")) ext.append(self._message) ext_indices.extend([-1] * (len(ext) - gen_before)) + ext_content.extend([False] * (len(ext) - gen_before)) total_len = len(previous_ids) + len(ext) return RenderedTokens( token_ids=previous_ids + ext, message_indices=[-1] * len(previous_ids) + ext_indices, sampled_mask=[False] * total_len, + is_content=[False] * len(previous_ids) + ext_content, message_roles=[m.get("role") or "" for m in new_messages], ) diff --git a/renderers/kimi_k2.py b/renderers/kimi_k2.py index b0106e4..9e08141 100644 --- a/renderers/kimi_k2.py +++ b/renderers/kimi_k2.py @@ -118,6 +118,11 @@ def render( if not messages: raise ValueError("No messages provided.") + # Preserve the caller's list — ``message_roles`` and per-token + # attribution refer to this frame (not the post-normalisation + # list that includes auto-injected system / tool_declare). + caller_messages = list(messages) + # Inject tools as tool_declare message + ensure system message. # The Jinja template emits the tools list directly (no # ``{"type":"function","function":...}`` wrapper) using @@ -166,19 +171,33 @@ def orig_idx(i: int) -> int: token_ids: list[int] = [] indices: list[int] = [] sampled: list[bool] = [] + content_mask: list[bool] = [] - def emit_ids(ids: list[int], msg_idx: int, *, is_sampled: bool) -> None: + def emit_ids( + ids: list[int], msg_idx: int, *, is_sampled: bool, is_content: bool + ) -> None: token_ids.extend(ids) indices.extend([msg_idx] * len(ids)) sampled.extend([is_sampled] * len(ids)) + content_mask.extend([is_content] * len(ids)) - def emit_special(token_id: int, msg_idx: int, *, is_sampled: bool) -> None: + def emit_special( + token_id: int, msg_idx: int, *, is_sampled: bool, is_content: bool + ) -> None: token_ids.append(token_id) indices.append(msg_idx) sampled.append(is_sampled) + content_mask.append(is_content) - def emit_text(text: str, msg_idx: int, *, is_sampled: bool) -> None: - emit_ids(self._encode(text), msg_idx, is_sampled=is_sampled) + def emit_text( + text: str, msg_idx: int, *, is_sampled: bool, is_content: bool + ) -> None: + emit_ids( + self._encode(text), + msg_idx, + is_sampled=is_sampled, + is_content=is_content, + ) # Compute last non-tool-call assistant index to determine thinking preservation last_plain_assistant_idx = -1 @@ -206,31 +225,40 @@ def emit_text(text: str, msg_idx: int, *, is_sampled: bool) -> None: content = "".join(parts) oi = orig_idx(i) + # Auto-injected system / tool_declare messages have ``oi == -1``. + # Their text isn't from the caller's input, so we treat the + # whole emission as scaffold (``is_content=False`` everywhere). + # The test contract is that ``msg_idx == -1`` runs are + # template-only and ``is_content=False``. + body_is_content = oi >= 0 if role == "system": - emit_special(self._im_system, oi, is_sampled=False) - emit_text("system", oi, is_sampled=False) - emit_special(self._im_middle, oi, is_sampled=False) - emit_text(content, oi, is_sampled=False) - emit_special(self._im_end, oi, is_sampled=False) + emit_special(self._im_system, oi, is_sampled=False, is_content=False) + emit_text("system", oi, is_sampled=False, is_content=False) + emit_special(self._im_middle, oi, is_sampled=False, is_content=False) + emit_text(content, oi, is_sampled=False, is_content=body_is_content) + emit_special(self._im_end, oi, is_sampled=False, is_content=False) # Jinja emits a literal newline only after the auto-injected # system's <|im_end|> (see _ensure_system_message's contract). if i == auto_system_idx: - emit_text("\n", oi, is_sampled=False) + emit_text("\n", oi, is_sampled=False, is_content=False) elif role == "tool_declare": - emit_special(self._im_system, oi, is_sampled=False) - emit_text("tool_declare", oi, is_sampled=False) - emit_special(self._im_middle, oi, is_sampled=False) - emit_text(content, oi, is_sampled=False) - emit_special(self._im_end, oi, is_sampled=False) + # The tool_declare body is the tools JSON — recoverable + # from the caller's ``tools`` argument, so we treat it as + # scaffold (consistent with Qwen3's tools-header block). + emit_special(self._im_system, oi, is_sampled=False, is_content=False) + emit_text("tool_declare", oi, is_sampled=False, is_content=False) + emit_special(self._im_middle, oi, is_sampled=False, is_content=False) + emit_text(content, oi, is_sampled=False, is_content=False) + emit_special(self._im_end, oi, is_sampled=False, is_content=False) elif role == "user": - emit_special(self._im_user, oi, is_sampled=False) - emit_text("user", oi, is_sampled=False) - emit_special(self._im_middle, oi, is_sampled=False) - emit_text(content, oi, is_sampled=False) - emit_special(self._im_end, oi, is_sampled=False) + emit_special(self._im_user, oi, is_sampled=False, is_content=False) + emit_text("user", oi, is_sampled=False, is_content=False) + emit_special(self._im_middle, oi, is_sampled=False, is_content=False) + emit_text(content, oi, is_sampled=False, is_content=body_is_content) + emit_special(self._im_end, oi, is_sampled=False, is_content=False) elif role == "assistant": # Kimi strips reasoning from historical assistant turns and @@ -250,30 +278,35 @@ def emit_text(text: str, msg_idx: int, *, is_sampled: bool) -> None: elif role == "tool": self._render_tool( - msg, oi, content, emit_special=emit_special, emit_text=emit_text + msg, + oi, + content, + emit_special=emit_special, + emit_text=emit_text, ) else: # Unknown role: use system-style formatting. Not a sampled # assistant turn — every token is template-injected from the # caller's POV, so is_sampled=False across the whole emission. - emit_special(self._im_system, oi, is_sampled=False) - emit_text(role, oi, is_sampled=False) - emit_special(self._im_middle, oi, is_sampled=False) - emit_text(content, oi, is_sampled=False) - emit_special(self._im_end, oi, is_sampled=False) + emit_special(self._im_system, oi, is_sampled=False, is_content=False) + emit_text(role, oi, is_sampled=False, is_content=False) + emit_special(self._im_middle, oi, is_sampled=False, is_content=False) + emit_text(content, oi, is_sampled=False, is_content=body_is_content) + emit_special(self._im_end, oi, is_sampled=False, is_content=False) # Generation prompt if add_generation_prompt: - emit_special(self._im_assistant, -1, is_sampled=False) - emit_text("assistant", -1, is_sampled=False) - emit_special(self._im_middle, -1, is_sampled=False) + emit_special(self._im_assistant, -1, is_sampled=False, is_content=False) + emit_text("assistant", -1, is_sampled=False, is_content=False) + emit_special(self._im_middle, -1, is_sampled=False, is_content=False) return RenderedTokens( token_ids=token_ids, message_indices=indices, sampled_mask=sampled, - message_roles=[m.get("role") or "" for m in messages], + is_content=content_mask, + message_roles=[m.get("role") or "" for m in caller_messages], ) def render_ids( @@ -336,27 +369,40 @@ def bridge_to_next_turn( ext: list[int] = [] ext_indices: list[int] = [] ext_sampled: list[bool] = [] + ext_content: list[bool] = [] # Bridge populates ``message_indices`` (relative to ``new_messages``) # and ``sampled_mask`` (uniformly ``False`` — every token the # bridge emits is template scaffolding for the next prompt, not - # something the model sampled). Downstream consumers can run - # :meth:`RenderedTokens.tokens_per_message` on the bridge output - # to get per-new-message token counts without re-rendering. + # something the model sampled). ``is_content`` follows the same + # rules as in :meth:`render` so consumers can walk the trajectory + # and read each step's own body mask. Downstream consumers can + # run :meth:`RenderedTokens.tokens_per_message` on the bridge + # output to get per-new-message token counts without re-rendering. def emit_special( - token_id: int, msg_idx: int = -1, *, is_sampled: bool = False + token_id: int, + msg_idx: int = -1, + *, + is_sampled: bool = False, + is_content: bool = False, ) -> None: ext.append(token_id) ext_indices.append(msg_idx) ext_sampled.append(is_sampled) + ext_content.append(is_content) def emit_text( - text: str, msg_idx: int = -1, *, is_sampled: bool = False + text: str, + msg_idx: int = -1, + *, + is_sampled: bool = False, + is_content: bool = False, ) -> None: ids = self._encode(text) ext.extend(ids) ext_indices.extend([msg_idx] * len(ids)) ext_sampled.extend([is_sampled] * len(ids)) + ext_content.extend([is_content] * len(ids)) for i, msg in enumerate(new_messages): role = msg.get("role") @@ -376,34 +422,39 @@ def emit_text( content = "".join(parts) if role == "user": - emit_special(self._im_user, i, is_sampled=False) - emit_text("user", i, is_sampled=False) - emit_special(self._im_middle, i, is_sampled=False) - emit_text(content, i, is_sampled=False) - emit_special(self._im_end, i, is_sampled=False) + emit_special(self._im_user, i) + emit_text("user", i) + emit_special(self._im_middle, i) + emit_text(content, i, is_content=True) + emit_special(self._im_end, i) elif role == "system": - emit_special(self._im_system, i, is_sampled=False) - emit_text("system", i, is_sampled=False) - emit_special(self._im_middle, i, is_sampled=False) - emit_text(content, i, is_sampled=False) - emit_special(self._im_end, i, is_sampled=False) + emit_special(self._im_system, i) + emit_text("system", i) + emit_special(self._im_middle, i) + emit_text(content, i, is_content=True) + emit_special(self._im_end, i) elif role == "tool": self._render_tool( - msg, i, content, emit_special=emit_special, emit_text=emit_text + msg, + i, + content, + emit_special=emit_special, + emit_text=emit_text, ) else: return None # Generation prompt. - emit_special(self._im_assistant, -1, is_sampled=False) - emit_text("assistant", -1, is_sampled=False) - emit_special(self._im_middle, -1, is_sampled=False) + emit_special(self._im_assistant, -1) + emit_text("assistant", -1) + emit_special(self._im_middle, -1) total_len = len(previous_ids) + len(ext) return RenderedTokens( token_ids=previous_ids + ext, message_indices=[-1] * len(previous_ids) + ext_indices, sampled_mask=[False] * total_len, + is_content=[False] * len(previous_ids) + ext_content, message_roles=[m.get("role") or "" for m in new_messages], ) @@ -421,23 +472,32 @@ def _render_assistant( # scaffolding — at inference the chat template emits these as the # generation prompt and the model never samples them. Marking the # role tag as ``is_sampled=False`` keeps the SFT loss mask aligned - # with what the model would actually have produced. - emit_special(self._im_assistant, msg_idx, is_sampled=False) - emit_text("assistant", msg_idx, is_sampled=False) - emit_special(self._im_middle, msg_idx, is_sampled=False) + # with what the model would actually have produced. ``is_content`` + # is also False here — the role tag isn't part of any message's + # body, on any role. + emit_special(self._im_assistant, msg_idx, is_sampled=False, is_content=False) + emit_text("assistant", msg_idx, is_sampled=False, is_content=False) + emit_special(self._im_middle, msg_idx, is_sampled=False, is_content=False) # Kimi K2's Jinja template has no reasoning-content support: the # assistant turn renders its ``content`` verbatim, including any # inline ``...`` tags. The separate # ``reasoning_content`` field is dropped (the template never reads # it). ``is_last_turn`` is unused here for the same reason. + # On assistant tokens, ``is_content == sampled_mask`` by construction + # — every sampled token is body, every scaffold token isn't. _ = is_last_turn - emit_text(content, msg_idx, is_sampled=True) + emit_text(content, msg_idx, is_sampled=True, is_content=True) - # Tool calls + # Tool calls — model-sampled markup carrying caller / model body. tool_calls = msg.get("tool_calls") or [] if tool_calls: - emit_special(self._tool_calls_section_begin, msg_idx, is_sampled=True) + emit_special( + self._tool_calls_section_begin, + msg_idx, + is_sampled=True, + is_content=True, + ) for tc in tool_calls: func = tc.get("function") or tc arguments = func.get("arguments", {}) @@ -451,16 +511,37 @@ def _render_assistant( # caller to provide an id in ``functions.{name}:{idx}`` form # (that's where the Kimi parser recovers the function name). tc_id = tc.get("id") or "" - emit_special(self._tool_call_begin, msg_idx, is_sampled=True) - emit_text(tc_id, msg_idx, is_sampled=True) - emit_special(self._tool_call_argument_begin, msg_idx, is_sampled=True) - emit_text(args_str, msg_idx, is_sampled=True) - emit_special(self._tool_call_end, msg_idx, is_sampled=True) - emit_special(self._tool_calls_section_end, msg_idx, is_sampled=True) + emit_special( + self._tool_call_begin, + msg_idx, + is_sampled=True, + is_content=True, + ) + emit_text(tc_id, msg_idx, is_sampled=True, is_content=True) + emit_special( + self._tool_call_argument_begin, + msg_idx, + is_sampled=True, + is_content=True, + ) + emit_text(args_str, msg_idx, is_sampled=True, is_content=True) + emit_special( + self._tool_call_end, + msg_idx, + is_sampled=True, + is_content=True, + ) + emit_special( + self._tool_calls_section_end, + msg_idx, + is_sampled=True, + is_content=True, + ) # ``<|im_end|>`` is the model's stop signal — it samples this to - # end its turn, so it is part of the sampled stream. - emit_special(self._im_end, msg_idx, is_sampled=True) + # end its turn, so it is part of the sampled stream (and the + # assistant's body). + emit_special(self._im_end, msg_idx, is_sampled=True, is_content=True) def _render_tool( self, @@ -473,13 +554,28 @@ def _render_tool( ) -> None: # Tool messages are conversation history injected by the runtime # between assistant turns — the model never samples any of these - # tokens, so every emission is is_sampled=False. + # tokens, so every emission is is_sampled=False. The ``content`` + # field's body bytes get ``is_content=True``; everything else — + # the ``<|im_system|>name<|im_middle|>`` wrap, the ``## Return of + # …\n`` header (template-synthesised, not part of the body) — + # is scaffold so the SFT mask for tool body never trains the + # model to emit them. + # + # We keep the original kimi_k2 emit boundaries — the header and + # the content are encoded separately, which preserves the + # template's byte-identity since the original code also emitted + # them as separate ``encode`` calls. name = msg.get("name") or "tool" tool_call_id = msg.get("tool_call_id") or "" - emit_special(self._im_system, msg_idx, is_sampled=False) - emit_text(name, msg_idx, is_sampled=False) - emit_special(self._im_middle, msg_idx, is_sampled=False) - emit_text(f"## Return of {tool_call_id}\n", msg_idx, is_sampled=False) - emit_text(content, msg_idx, is_sampled=False) - emit_special(self._im_end, msg_idx, is_sampled=False) + emit_special(self._im_system, msg_idx, is_sampled=False, is_content=False) + emit_text(name, msg_idx, is_sampled=False, is_content=False) + emit_special(self._im_middle, msg_idx, is_sampled=False, is_content=False) + emit_text( + f"## Return of {tool_call_id}\n", + msg_idx, + is_sampled=False, + is_content=False, + ) + emit_text(content, msg_idx, is_sampled=False, is_content=True) + emit_special(self._im_end, msg_idx, is_sampled=False, is_content=False) diff --git a/renderers/kimi_k25.py b/renderers/kimi_k25.py index 31ee759..d3729ac 100644 --- a/renderers/kimi_k25.py +++ b/renderers/kimi_k25.py @@ -749,24 +749,44 @@ def render( tokens: list[int] = [] indices: list[int] = [] sampled: list[bool] = [] + content_mask: list[bool] = [] mm_hashes: dict[str, list[str]] = {} mm_placeholders: dict[str, list[PlaceholderRange]] = {} mm_items: dict[str, list[dict[str, Any]]] = {} - def emit_ids(ids: list[int], msg_idx: int, *, is_sampled: bool) -> None: + def emit_ids( + ids: list[int], msg_idx: int, *, is_sampled: bool, is_content: bool + ) -> None: tokens.extend(ids) indices.extend([msg_idx] * len(ids)) sampled.extend([is_sampled] * len(ids)) + content_mask.extend([is_content] * len(ids)) - def emit_special(token_id: int, msg_idx: int, *, is_sampled: bool) -> None: + def emit_special( + token_id: int, msg_idx: int, *, is_sampled: bool, is_content: bool + ) -> None: tokens.append(token_id) indices.append(msg_idx) sampled.append(is_sampled) + content_mask.append(is_content) - def emit_text(text: str, msg_idx: int, *, is_sampled: bool) -> None: - emit_ids(self._encode(text), msg_idx, is_sampled=is_sampled) + def emit_text( + text: str, msg_idx: int, *, is_sampled: bool, is_content: bool + ) -> None: + emit_ids( + self._encode(text), + msg_idx, + is_sampled=is_sampled, + is_content=is_content, + ) - def emit_image(part: dict[str, Any], msg_idx: int, *, is_sampled: bool) -> None: + def emit_image( + part: dict[str, Any], + msg_idx: int, + *, + is_sampled: bool, + is_content: bool, + ) -> None: """Emit Kimi K2.5's image wrap and accumulate ``mm_data``. Template-equivalent expansion per image: @@ -779,15 +799,32 @@ def emit_image(part: dict[str, Any], msg_idx: int, *, is_sampled: bool) -> None: Kimi's chat template after every image — kept here verbatim for byte-parity, regardless of what follows (more images, text, or the ``<|im_end|>`` close). + + ``is_content`` attribution: ``<|media_pad|>`` represents + caller-provided image data — body. The wrap tokens + (``<|media_begin|>``, the literal ``"image"`` prose, + ``<|media_content|>``, ``<|media_end|>``, the trailing + ``\\n``) are template-injected scaffold. """ _, out, _num_patches, h = self._process_image(part) - emit_special(self._media_begin, msg_idx, is_sampled=is_sampled) - emit_text("image", msg_idx, is_sampled=is_sampled) - emit_special(self._media_content, msg_idx, is_sampled=is_sampled) + emit_special( + self._media_begin, msg_idx, is_sampled=is_sampled, is_content=False + ) + emit_text("image", msg_idx, is_sampled=is_sampled, is_content=False) + emit_special( + self._media_content, msg_idx, is_sampled=is_sampled, is_content=False + ) offset = len(tokens) - emit_special(self._media_pad, msg_idx, is_sampled=is_sampled) - emit_special(self._media_end, msg_idx, is_sampled=is_sampled) - emit_text("\n", msg_idx, is_sampled=is_sampled) + emit_special( + self._media_pad, + msg_idx, + is_sampled=is_sampled, + is_content=is_content, + ) + emit_special( + self._media_end, msg_idx, is_sampled=is_sampled, is_content=False + ) + emit_text("\n", msg_idx, is_sampled=is_sampled, is_content=False) mm_hashes.setdefault("image", []).append(h) mm_placeholders.setdefault("image", []).append( PlaceholderRange(offset=offset, length=1) @@ -807,13 +844,16 @@ def emit_image(part: dict[str, Any], msg_idx: int, *, is_sampled: bool) -> None: # K2.5/K2.6's tokenizer auto-computes ``tools_ts_str`` and threads # it into apply_chat_template, so the template's TS branch always # fires when tools are present. Match that with our own TS encoder. + # The tools-TS body is recoverable from the ``tools`` argument, so + # we attribute the entire tool_declare emission as scaffold — + # consistent with Qwen3's tools-header treatment. if tools: tools_ts = _encode_tools_typescript(tools) - emit_special(self._im_system, -1, is_sampled=False) - emit_text("tool_declare", -1, is_sampled=False) - emit_special(self._im_middle, -1, is_sampled=False) - emit_text(tools_ts, -1, is_sampled=False) - emit_special(self._im_end, -1, is_sampled=False) + emit_special(self._im_system, -1, is_sampled=False, is_content=False) + emit_text("tool_declare", -1, is_sampled=False, is_content=False) + emit_special(self._im_middle, -1, is_sampled=False, is_content=False) + emit_text(tools_ts, -1, is_sampled=False, is_content=False) + emit_special(self._im_end, -1, is_sampled=False, is_content=False) # ── Iterate messages ───────────────────────────────────────── for i, msg in enumerate(messages): @@ -824,14 +864,14 @@ def emit_image(part: dict[str, Any], msg_idx: int, *, is_sampled: bool) -> None: # generation prompt emits them at inference for assistants; # user / system / tool roles are conversation history). if role == "user": - emit_special(self._im_user, i, is_sampled=False) + emit_special(self._im_user, i, is_sampled=False, is_content=False) elif role == "assistant": - emit_special(self._im_assistant, i, is_sampled=False) + emit_special(self._im_assistant, i, is_sampled=False, is_content=False) else: - emit_special(self._im_system, i, is_sampled=False) + emit_special(self._im_system, i, is_sampled=False, is_content=False) role_name = msg.get("name") or role - emit_text(role_name, i, is_sampled=False) - emit_special(self._im_middle, i, is_sampled=False) + emit_text(role_name, i, is_sampled=False, is_content=False) + emit_special(self._im_middle, i, is_sampled=False, is_content=False) # Body if role == "assistant": @@ -852,10 +892,10 @@ def emit_image(part: dict[str, Any], msg_idx: int, *, is_sampled: bool) -> None: ) # ``<|im_end|>`` is the model's stop signal — it samples # this to end its turn, so it is part of the sampled - # stream. Kimi K2.5 has no inter-turn ``\n`` separator - # (unlike Qwen3), so the turn-close token is the last - # sampled token. - emit_special(self._im_end, i, is_sampled=True) + # stream (and the assistant's body). Kimi K2.5 has no + # inter-turn ``\n`` separator (unlike Qwen3), so the + # turn-close token is the last sampled token. + emit_special(self._im_end, i, is_sampled=True, is_content=True) continue elif role == "tool": self._render_tool_body( @@ -869,7 +909,8 @@ def emit_image(part: dict[str, Any], msg_idx: int, *, is_sampled: bool) -> None: elif msg.get("content") is not None: # User / other content branches — images allowed. All # non-assistant content is conversation history, never - # sampled by the model. + # sampled by the model. The body is caller-provided, so + # ``is_content=True`` over the content emit. self._emit_content( msg.get("content"), i, @@ -878,21 +919,22 @@ def emit_image(part: dict[str, Any], msg_idx: int, *, is_sampled: bool) -> None: emit_ids, emit_image=emit_image, is_sampled=False, + is_content=True, ) - emit_special(self._im_end, i, is_sampled=False) + emit_special(self._im_end, i, is_sampled=False, is_content=False) # ── Generation prompt ──────────────────────────────────────── if add_generation_prompt: - emit_special(self._im_assistant, -1, is_sampled=False) - emit_text("assistant", -1, is_sampled=False) - emit_special(self._im_middle, -1, is_sampled=False) + emit_special(self._im_assistant, -1, is_sampled=False, is_content=False) + emit_text("assistant", -1, is_sampled=False, is_content=False) + emit_special(self._im_middle, -1, is_sampled=False, is_content=False) if self._enable_thinking: # Prefill open tag to trigger thinking mode - emit_text("", -1, is_sampled=False) + emit_text("", -1, is_sampled=False, is_content=False) else: # Empty to disable thinking - emit_text("", -1, is_sampled=False) + emit_text("", -1, is_sampled=False, is_content=False) mm_data: MultiModalData | None = None if mm_hashes or mm_placeholders or mm_items: @@ -906,6 +948,7 @@ def emit_image(part: dict[str, Any], msg_idx: int, *, is_sampled: bool) -> None: token_ids=tokens, message_indices=indices, sampled_mask=sampled, + is_content=content_mask, message_roles=[m.get("role") or "" for m in messages], multi_modal_data=mm_data, ) @@ -997,49 +1040,74 @@ def bridge_to_next_turn( # Seed combined-token list with prior turn so placeholder offsets # are absolute in the bridged sequence. Parallel - # ``indices``/``sampled`` are seeded with ``-1``/``False`` for the - # prior portion — the bridge has no attribution info for - # ``previous_ids``. Bridge-added tokens get proper ``msg_idx`` - # (relative to ``new_messages``) and uniformly ``False`` - # ``sampled``: nothing the bridge emits was model-sampled. + # ``indices``/``sampled``/``content_mask`` are seeded with + # ``-1``/``False``/``False`` for the prior portion — the bridge + # has no attribution info for ``previous_ids``. Bridge-added + # tokens get proper ``msg_idx`` (relative to ``new_messages``) + # and uniformly ``False`` ``sampled``: nothing the bridge emits + # was model-sampled. ``is_content`` follows the same rules as in + # :meth:`render` so consumers can walk the trajectory and read + # each step's own body mask. tokens: list[int] = list(previous_ids) indices: list[int] = [-1] * len(previous_ids) sampled: list[bool] = [False] * len(previous_ids) + content_mask: list[bool] = [False] * len(previous_ids) new_hashes: dict[str, list[str]] = {} new_placeholders: dict[str, list[PlaceholderRange]] = {} new_items: dict[str, list[dict[str, Any]]] = {} def emit_special( - token_id: int, msg_idx: int = -1, *, is_sampled: bool = False + token_id: int, + msg_idx: int = -1, + *, + is_sampled: bool = False, + is_content: bool = False, ) -> None: tokens.append(token_id) indices.append(msg_idx) sampled.append(is_sampled) + content_mask.append(is_content) def emit_text( - text: str, msg_idx: int = -1, *, is_sampled: bool = False + text: str, + msg_idx: int = -1, + *, + is_sampled: bool = False, + is_content: bool = False, ) -> None: ids = self._encode(text) tokens.extend(ids) indices.extend([msg_idx] * len(ids)) sampled.extend([is_sampled] * len(ids)) + content_mask.extend([is_content] * len(ids)) def emit_ids( - ids: list[int], msg_idx: int = -1, *, is_sampled: bool = False + ids: list[int], + msg_idx: int = -1, + *, + is_sampled: bool = False, + is_content: bool = False, ) -> None: tokens.extend(ids) indices.extend([msg_idx] * len(ids)) sampled.extend([is_sampled] * len(ids)) + content_mask.extend([is_content] * len(ids)) def emit_image( - part: dict[str, Any], msg_idx: int = -1, *, is_sampled: bool = False + part: dict[str, Any], + msg_idx: int = -1, + *, + is_sampled: bool = False, + is_content: bool = False, ) -> None: _, out, _num_patches, h = self._process_image(part) emit_special(self._media_begin, msg_idx) emit_text("image", msg_idx) emit_special(self._media_content, msg_idx) offset = len(tokens) - emit_special(self._media_pad, msg_idx) + # ``<|media_pad|>`` stands in for caller-provided image data — + # mark it as body when the surrounding content is body. + emit_special(self._media_pad, msg_idx, is_content=is_content) emit_special(self._media_end, msg_idx) emit_text("\n", msg_idx) new_hashes.setdefault("image", []).append(h) @@ -1086,6 +1154,7 @@ def emit_image( emit_ids, emit_image=emit_image, is_sampled=False, + is_content=True, ) emit_special(self._im_end, i) @@ -1128,6 +1197,7 @@ def emit_image( token_ids=tokens, message_indices=indices, sampled_mask=sampled, + is_content=content_mask, message_roles=bridge_roles, ) @@ -1140,6 +1210,7 @@ def emit_image( token_ids=tokens, message_indices=indices, sampled_mask=sampled, + is_content=content_mask, message_roles=bridge_roles, multi_modal_data=mm_data, ) @@ -1158,6 +1229,7 @@ def _emit_content( *, emit_image=None, is_sampled: bool, + is_content: bool = True, ) -> None: """Emit message content, handling strings, multipart lists, and (when ``emit_image`` is supplied) image parts. @@ -1172,11 +1244,15 @@ def _emit_content( ``bridge_to_next_turn``), so consecutive images naturally produce the template's ``...<|media_end|>\\n<|media_begin|>...`` pattern without an inter-image separator here. + + ``is_content`` propagates to text emits and to the + ``<|media_pad|>`` body token of each image; the surrounding + media wrap tokens are always scaffold (handled by ``emit_image``). """ if content is None: return if isinstance(content, str): - emit_text(content, msg_idx, is_sampled=is_sampled) + emit_text(content, msg_idx, is_sampled=is_sampled, is_content=is_content) return if isinstance(content, list): for part in content: @@ -1189,20 +1265,30 @@ def _emit_content( if emit_image is None: # Silently drop — caller didn't opt into multimodal. continue - emit_image(part, msg_idx, is_sampled=is_sampled) + emit_image( + part, msg_idx, is_sampled=is_sampled, is_content=is_content + ) continue if is_video: raise NotImplementedError( "Video parts are not yet supported by KimiK25Renderer." ) if ptype == "text": - emit_text(part.get("text", ""), msg_idx, is_sampled=is_sampled) + emit_text( + part.get("text", ""), + msg_idx, + is_sampled=is_sampled, + is_content=is_content, + ) elif ptype == "thinking": # Thinking parts in non-assistant roles are rendered as text thinking = part.get("thinking", "") if thinking: emit_text( - f"{thinking}", msg_idx, is_sampled=is_sampled + f"{thinking}", + msg_idx, + is_sampled=is_sampled, + is_content=is_content, ) # Other part types are silently skipped @@ -1265,16 +1351,27 @@ def _render_assistant_body( # block, text content, and any tool_calls section all live in # the sampled stream. The closing ``<|im_end|>`` is emitted by # ``render`` (also is_sampled=True; it's the model's stop - # signal). + # signal). On assistant tokens ``is_content == sampled_mask`` by + # construction. if is_suffix or (preserve_thinking and reasoning_content): - emit_text(f"{reasoning_content}", msg_idx, is_sampled=True) + emit_text( + f"{reasoning_content}", + msg_idx, + is_sampled=True, + is_content=True, + ) else: - emit_text("", msg_idx, is_sampled=True) - emit_text(text_content, msg_idx, is_sampled=True) + emit_text("", msg_idx, is_sampled=True, is_content=True) + emit_text(text_content, msg_idx, is_sampled=True, is_content=True) tool_calls = msg.get("tool_calls") or [] if tool_calls: - emit_special(self._tool_calls_section_begin, msg_idx, is_sampled=True) + emit_special( + self._tool_calls_section_begin, + msg_idx, + is_sampled=True, + is_content=True, + ) for tc in tool_calls: func = tc.get("function") or tc arguments = func.get("arguments", {}) @@ -1288,12 +1385,32 @@ def _render_assistant_body( # ``functions.{name}:{idx}`` form (Kimi's parser recovers # the function name from that field). tool_id = tc.get("id") or "" - emit_special(self._tool_call_begin, msg_idx, is_sampled=True) - emit_text(tool_id, msg_idx, is_sampled=True) - emit_special(self._tool_call_argument_begin, msg_idx, is_sampled=True) - emit_text(args_str, msg_idx, is_sampled=True) - emit_special(self._tool_call_end, msg_idx, is_sampled=True) - emit_special(self._tool_calls_section_end, msg_idx, is_sampled=True) + emit_special( + self._tool_call_begin, + msg_idx, + is_sampled=True, + is_content=True, + ) + emit_text(tool_id, msg_idx, is_sampled=True, is_content=True) + emit_special( + self._tool_call_argument_begin, + msg_idx, + is_sampled=True, + is_content=True, + ) + emit_text(args_str, msg_idx, is_sampled=True, is_content=True) + emit_special( + self._tool_call_end, + msg_idx, + is_sampled=True, + is_content=True, + ) + emit_special( + self._tool_calls_section_end, + msg_idx, + is_sampled=True, + is_content=True, + ) def _render_tool_body( self, @@ -1319,9 +1436,16 @@ def _render_tool_body( """ # Tool messages are conversation-history injected by the runtime # between assistant turns — the model never samples any of these - # tokens, so every emission is is_sampled=False. + # tokens, so every emission is is_sampled=False. The ``## Return + # of …\n`` header is template-synthesised scaffold; the + # ``content`` body bytes get ``is_content=True``. tool_call_id = msg.get("tool_call_id") or "" - emit_text(f"## Return of {tool_call_id}\n", msg_idx, is_sampled=False) + emit_text( + f"## Return of {tool_call_id}\n", + msg_idx, + is_sampled=False, + is_content=False, + ) content = msg.get("content") if content is not None: self._emit_content( @@ -1332,6 +1456,7 @@ def _render_tool_body( emit_ids, emit_image=emit_image, is_sampled=False, + is_content=True, ) def _normalize_response_tokens(self, response: list[int]) -> list[int]: diff --git a/renderers/laguna_xs2.py b/renderers/laguna_xs2.py index df62e07..ce85037 100644 --- a/renderers/laguna_xs2.py +++ b/renderers/laguna_xs2.py @@ -35,6 +35,7 @@ ParsedResponse, RenderedTokens, ToolSpec, + attribute_text_segments, reject_assistant_in_extension, ) from renderers.parsing import parse_laguna_xs2 @@ -154,24 +155,43 @@ def render( tokens: list[int] = [] indices: list[int] = [] sampled: list[bool] = [] + content_mask: list[bool] = [] - def emit_special(token_id: int, msg_idx: int, *, is_sampled: bool) -> None: + def emit_special( + token_id: int, msg_idx: int, *, is_sampled: bool, is_content: bool + ) -> None: tokens.append(token_id) indices.append(msg_idx) sampled.append(is_sampled) + content_mask.append(is_content) - def emit_text(text: str, msg_idx: int, *, is_sampled: bool) -> None: + def emit_text( + text: str, msg_idx: int, *, is_sampled: bool, is_content: bool + ) -> None: ids = self._encode(text) tokens.extend(ids) indices.extend([msg_idx] * len(ids)) sampled.extend([is_sampled] * len(ids)) + content_mask.extend([is_content] * len(ids)) - emit_special(self._eos, -1, is_sampled=False) + def emit_text_segments( + segments: list[tuple[str, bool]], msg_idx: int, *, is_sampled: bool + ) -> None: + for tok_id, is_content in attribute_text_segments( + self._tokenizer, segments + ): + tokens.append(tok_id) + indices.append(msg_idx) + sampled.append(is_sampled) + content_mask.append(is_content) + + emit_special(self._eos, -1, is_sampled=False, is_content=False) # ── System header (absorbs messages[0] if it's a system message) ── system_content = _DEFAULT_SYSTEM_MESSAGE system_msg_idx = -1 - if messages and messages[0].get("role") == "system": + caller_has_system = bool(messages and messages[0].get("role") == "system") + if caller_has_system: system_content = self._visible_text(messages[0].get("content")) system_msg_idx = 0 @@ -187,9 +207,18 @@ def emit_text(text: str, msg_idx: int, *, is_sampled: bool) -> None: "\n\n" if has_sys_content else "\n", -1, is_sampled=False, + is_content=False, ) if has_sys_content: - emit_text(system_content.rstrip(), system_msg_idx, is_sampled=False) + # If the caller provided system content, it's body bytes; + # otherwise this is the default system prompt (scaffold). + sys_is_content = caller_has_system + emit_text( + system_content.rstrip(), + system_msg_idx, + is_sampled=False, + is_content=sys_is_content, + ) if tools: tool_text = _TOOLS_HEADER for tool in tools: @@ -199,8 +228,8 @@ def emit_text(text: str, msg_idx: int, *, is_sampled: bool) -> None: if self._enable_thinking else _TOOLS_FOOTER_NO_THINKING ) - emit_text(tool_text, -1, is_sampled=False) - emit_text("\n\n", -1, is_sampled=False) + emit_text(tool_text, -1, is_sampled=False, is_content=False) + emit_text("\n\n", -1, is_sampled=False, is_content=False) # ── Per-message loop ────────────────────────────────────────── for i, msg in enumerate(messages): @@ -211,35 +240,49 @@ def emit_text(text: str, msg_idx: int, *, is_sampled: bool) -> None: # Already consumed in the header block. if i == 0: continue - emit_text( - "\n" + content + "\n\n", i, is_sampled=False - ) + # Body = caller's content; the ``...`` + # wrap and surrounding ``\n``s are scaffold. + sys_segs: list[tuple[str, bool]] = [("\n", False)] + if content: + sys_segs.append((content, True)) + sys_segs.append(("\n\n", False)) + emit_text_segments(sys_segs, i, is_sampled=False) case "user": - emit_text("\n" + content + "\n\n", i, is_sampled=False) + user_segs: list[tuple[str, bool]] = [("\n", False)] + if content: + user_segs.append((content, True)) + user_segs.append(("\n\n", False)) + emit_text_segments(user_segs, i, is_sampled=False) case "assistant": self._render_assistant( - msg, i, content, emit_special=emit_special, emit_text=emit_text - ) - case "tool": - emit_text( - "\n" + content + "\n\n", + msg, i, - is_sampled=False, + content, + emit_special=emit_special, + emit_text=emit_text, + emit_text_segments=emit_text_segments, ) + case "tool": + tool_segs: list[tuple[str, bool]] = [("\n", False)] + if content: + tool_segs.append((content, True)) + tool_segs.append(("\n\n", False)) + emit_text_segments(tool_segs, i, is_sampled=False) # ── Generation prompt ───────────────────────────────────────── if add_generation_prompt: - emit_special(self._assistant, -1, is_sampled=False) - emit_text("\n", -1, is_sampled=False) + emit_special(self._assistant, -1, is_sampled=False, is_content=False) + emit_text("\n", -1, is_sampled=False, is_content=False) if self._enable_thinking: - emit_special(self._think, -1, is_sampled=False) + emit_special(self._think, -1, is_sampled=False, is_content=False) else: - emit_special(self._think_end, -1, is_sampled=False) + emit_special(self._think_end, -1, is_sampled=False, is_content=False) return RenderedTokens( token_ids=tokens, message_indices=indices, sampled_mask=sampled, + is_content=content_mask, message_roles=[m.get("role") or "" for m in messages], ) @@ -307,37 +350,74 @@ def bridge_to_next_turn( ext: list[int] = [] ext_indices: list[int] = [] ext_sampled: list[bool] = [] + ext_content: list[bool] = [] # Bridge populates ``message_indices`` (relative to ``new_messages``) # and ``sampled_mask`` (uniformly ``False`` — every token the # bridge emits is template scaffolding for the next prompt, not - # something the model sampled). Downstream consumers can run - # :meth:`RenderedTokens.tokens_per_message` on the bridge output - # to get per-new-message token counts without re-rendering. + # something the model sampled). ``is_content`` follows the same + # rules as in :meth:`render` so consumers can walk the trajectory + # and read each step's own body mask. def emit_special( - token_id: int, msg_idx: int = -1, *, is_sampled: bool = False + token_id: int, + msg_idx: int = -1, + *, + is_sampled: bool = False, + is_content: bool = False, ) -> None: ext.append(token_id) ext_indices.append(msg_idx) ext_sampled.append(is_sampled) + ext_content.append(is_content) def emit_text( - text: str, msg_idx: int = -1, *, is_sampled: bool = False + text: str, + msg_idx: int = -1, + *, + is_sampled: bool = False, + is_content: bool = False, ) -> None: ids = self._encode(text) ext.extend(ids) ext_indices.extend([msg_idx] * len(ids)) ext_sampled.extend([is_sampled] * len(ids)) + ext_content.extend([is_content] * len(ids)) + + def emit_text_segments( + segments: list[tuple[str, bool]], + msg_idx: int = -1, + *, + is_sampled: bool = False, + ) -> None: + for tok_id, is_content in attribute_text_segments( + self._tokenizer, segments + ): + ext.append(tok_id) + ext_indices.append(msg_idx) + ext_sampled.append(is_sampled) + ext_content.append(is_content) for i, msg in enumerate(new_messages): role = msg.get("role") content = self._visible_text(msg.get("content")) if role == "user": - emit_text("\n" + content + "\n\n", i) + segs: list[tuple[str, bool]] = [("\n", False)] + if content: + segs.append((content, True)) + segs.append(("\n\n", False)) + emit_text_segments(segs, i) elif role == "system": - emit_text("\n" + content + "\n\n", i) + segs = [("\n", False)] + if content: + segs.append((content, True)) + segs.append(("\n\n", False)) + emit_text_segments(segs, i) elif role == "tool": - emit_text("\n" + content + "\n\n", i) + segs = [("\n", False)] + if content: + segs.append((content, True)) + segs.append(("\n\n", False)) + emit_text_segments(segs, i) else: return None @@ -353,6 +433,7 @@ def emit_text( token_ids=previous_ids + ext, message_indices=[-1] * len(previous_ids) + ext_indices, sampled_mask=[False] * total_len, + is_content=[False] * len(previous_ids) + ext_content, message_roles=[m.get("role") or "" for m in new_messages], ) @@ -364,6 +445,7 @@ def _render_assistant( *, emit_special, emit_text, + emit_text_segments, ) -> None: reasoning_content = "" if isinstance(msg.get("reasoning_content"), str): @@ -380,23 +462,30 @@ def _render_assistant( # template emits these as the generation prompt at inference and # the model never samples them. Marking the role tag as # ``is_sampled=False`` keeps the SFT loss mask aligned with what - # the model would actually have produced. - emit_special(self._assistant, msg_idx, is_sampled=False) - emit_text("\n", msg_idx, is_sampled=False) + # the model would actually have produced. ``is_content`` is also + # False on the role tag. On assistant the invariant + # ``is_content == sampled_mask`` holds. + emit_special(self._assistant, msg_idx, is_sampled=False, is_content=False) + emit_text("\n", msg_idx, is_sampled=False, is_content=False) if reasoning_content: - emit_special(self._think, msg_idx, is_sampled=True) - emit_text("\n" + reasoning_content.strip() + "\n", msg_idx, is_sampled=True) - emit_special(self._think_end, msg_idx, is_sampled=True) + emit_special(self._think, msg_idx, is_sampled=True, is_content=True) + emit_text( + "\n" + reasoning_content.strip() + "\n", + msg_idx, + is_sampled=True, + is_content=True, + ) + emit_special(self._think_end, msg_idx, is_sampled=True, is_content=True) else: - emit_special(self._think_end, msg_idx, is_sampled=True) + emit_special(self._think_end, msg_idx, is_sampled=True, is_content=True) # Combined newline-after- with optional content. Bundling # preserves BPE merges across the boundary. post_think_text = "\n" if content.strip(): post_think_text += content.strip() + "\n" - emit_text(post_think_text, msg_idx, is_sampled=True) + emit_text(post_think_text, msg_idx, is_sampled=True, is_content=True) tool_calls = msg.get("tool_calls") or [] for tc in tool_calls: @@ -409,7 +498,7 @@ def _render_assistant( except json.JSONDecodeError: arguments = {} - emit_special(self._tool_call, msg_idx, is_sampled=True) + emit_special(self._tool_call, msg_idx, is_sampled=True, is_content=True) inner = name + "\n" if isinstance(arguments, dict): for k, v in arguments.items(): @@ -419,13 +508,13 @@ def _render_assistant( else: val_text = json.dumps(v, ensure_ascii=False) inner += "" + val_text + "\n" - emit_text(inner, msg_idx, is_sampled=True) - emit_special(self._tool_call_end, msg_idx, is_sampled=True) - emit_text("\n", msg_idx, is_sampled=True) + emit_text(inner, msg_idx, is_sampled=True, is_content=True) + emit_special(self._tool_call_end, msg_idx, is_sampled=True, is_content=True) + emit_text("\n", msg_idx, is_sampled=True, is_content=True) # ```` is the model's stop signal (alongside # ``〈|EOS|〉``) — it samples this to end its turn, so it's part of # the sampled stream. The trailing ``\n`` is template-appended # between turns and never sampled. - emit_special(self._assistant_end, msg_idx, is_sampled=True) - emit_text("\n", msg_idx, is_sampled=False) + emit_special(self._assistant_end, msg_idx, is_sampled=True, is_content=True) + emit_text("\n", msg_idx, is_sampled=False, is_content=False) diff --git a/renderers/minimax_m2.py b/renderers/minimax_m2.py index 45357fd..f3c26c8 100644 --- a/renderers/minimax_m2.py +++ b/renderers/minimax_m2.py @@ -21,6 +21,7 @@ ParsedResponse, RenderedTokens, ToolSpec, + attribute_text_segments, reject_assistant_in_extension, should_preserve_past_thinking, trim_to_turn_close, @@ -119,17 +120,69 @@ def render( tokens: list[int] = [] indices: list[int] = [] sampled: list[bool] = [] + content_mask: list[bool] = [] - def emit_special(token_id: int, msg_idx: int, *, is_sampled: bool) -> None: + def emit_special( + token_id: int, msg_idx: int, *, is_sampled: bool, is_content: bool + ) -> None: tokens.append(token_id) indices.append(msg_idx) sampled.append(is_sampled) + content_mask.append(is_content) - def emit_text(text: str, msg_idx: int, *, is_sampled: bool) -> None: + def emit_text( + text: str, msg_idx: int, *, is_sampled: bool, is_content: bool + ) -> None: ids = self._encode(text) tokens.extend(ids) indices.extend([msg_idx] * len(ids)) sampled.extend([is_sampled] * len(ids)) + content_mask.extend([is_content] * len(ids)) + + def emit_text_segments( + segments: list[tuple[str, bool]], msg_idx: int, *, is_sampled: bool + ) -> None: + for tok_id, is_content in attribute_text_segments( + self._tokenizer, segments + ): + tokens.append(tok_id) + indices.append(msg_idx) + sampled.append(is_sampled) + content_mask.append(is_content) + + def emit_token_overlap_body( + full_text: str, + body_start: int, + body_end: int, + msg_idx: int, + *, + is_sampled: bool, + ) -> None: + """Tokenize ``full_text`` and mark tokens that overlap the body + char span as ``is_content=True``. + + Differs from :func:`attribute_text_segments` only in the + boundary-token rule: a token straddling scaffold→body gets + ``True`` if any of its bytes are body bytes (overlap rule), + rather than being attributed to whichever segment its first + char belongs to. The body's first byte is preserved even when + BPE merges it with the wrap's trailing byte (``>The`` → + single token). + """ + from renderers.base import _get_offset_tokenizer + + offset_tok = _get_offset_tokenizer(self._tokenizer) + encoding = offset_tok( + full_text, add_special_tokens=False, return_offsets_mapping=True + ) + for tok_id, (start, end) in zip( + encoding["input_ids"], encoding["offset_mapping"] + ): + overlaps = start < body_end and end > body_start + tokens.append(tok_id) + indices.append(msg_idx) + sampled.append(is_sampled) + content_mask.append(overlaps) # ── Extract system message ────────────────────────────────── first_is_system = messages[0].get("role") == "system" @@ -137,27 +190,38 @@ def emit_text(text: str, msg_idx: int, *, is_sampled: bool) -> None: conversation: list[Message] = messages[1:] if first_is_system else messages # ── System block (always present) ─────────────────────────── - emit_special(self._bos, sys_idx, is_sampled=False) - emit_special(self._role, sys_idx, is_sampled=False) + emit_special(self._bos, sys_idx, is_sampled=False, is_content=False) + emit_special(self._role, sys_idx, is_sampled=False, is_content=False) sys_content = ( self._visible_text(messages[0].get("content")) if first_is_system else "" ) - system_text = "system\n" + (sys_content or self._default_system) + # Body = caller's system content (if any). Default system message + # is template-injected scaffold; tools header / per-tool JSON / + # footer / instructions are scaffold too (the tools dict is + # recoverable from the ``tools`` arg). + sys_segments: list[tuple[str, bool]] = [("system\n", False)] + if sys_content: + sys_segments.append((sys_content, True)) + else: + sys_segments.append((self._default_system, False)) if tools: - system_text += _TOOLS_HEADER + sys_segments.append((_TOOLS_HEADER, False)) for tool in tools: func = tool.get("function", tool) - system_text += ( - "" + json.dumps(func, ensure_ascii=False) + "\n" + sys_segments.append( + ( + "" + json.dumps(func, ensure_ascii=False) + "\n", + False, + ) ) - system_text += _TOOLS_FOOTER_PREFIX - system_text += _TOOLS_INSTRUCTIONS + sys_segments.append((_TOOLS_FOOTER_PREFIX, False)) + sys_segments.append((_TOOLS_INSTRUCTIONS, False)) - emit_text(system_text, sys_idx, is_sampled=False) - emit_special(self._eos, sys_idx, is_sampled=False) - emit_text("\n", sys_idx, is_sampled=False) + emit_text_segments(sys_segments, sys_idx, is_sampled=False) + emit_special(self._eos, sys_idx, is_sampled=False, is_content=False) + emit_text("\n", sys_idx, is_sampled=False, is_content=False) # ── Compute last_user_index (relative to conversation) ────── last_ui = -1 @@ -172,14 +236,14 @@ def emit_text(text: str, msg_idx: int, *, is_sampled: bool) -> None: orig_idx = ci + (1 if first_is_system else 0) if role == "user": - emit_special(self._role, orig_idx, is_sampled=False) - emit_text( - "user\n" + self._visible_text(msg.get("content")), - orig_idx, - is_sampled=False, - ) - emit_special(self._eos, orig_idx, is_sampled=False) - emit_text("\n", orig_idx, is_sampled=False) + emit_special(self._role, orig_idx, is_sampled=False, is_content=False) + user_content = self._visible_text(msg.get("content")) + user_segments: list[tuple[str, bool]] = [("user\n", False)] + if user_content: + user_segments.append((user_content, True)) + emit_text_segments(user_segments, orig_idx, is_sampled=False) + emit_special(self._eos, orig_idx, is_sampled=False, is_content=False) + emit_text("\n", orig_idx, is_sampled=False, is_content=False) elif role == "assistant": preserve_thinking = should_preserve_past_thinking( @@ -196,6 +260,7 @@ def emit_text(text: str, msg_idx: int, *, is_sampled: bool) -> None: preserve_thinking=preserve_thinking, emit_special=emit_special, emit_text=emit_text, + emit_text_segments=emit_text_segments, ) elif role == "tool": @@ -206,19 +271,22 @@ def emit_text(text: str, msg_idx: int, *, is_sampled: bool) -> None: msg, emit_special=emit_special, emit_text=emit_text, + emit_text_segments=emit_text_segments, + emit_token_overlap_body=emit_token_overlap_body, ) # ── Generation prompt ─────────────────────────────────────── if add_generation_prompt: - emit_special(self._role, -1, is_sampled=False) - emit_text("ai\n", -1, is_sampled=False) - emit_special(self._think, -1, is_sampled=False) - emit_text("\n", -1, is_sampled=False) + emit_special(self._role, -1, is_sampled=False, is_content=False) + emit_text("ai\n", -1, is_sampled=False, is_content=False) + emit_special(self._think, -1, is_sampled=False, is_content=False) + emit_text("\n", -1, is_sampled=False, is_content=False) return RenderedTokens( token_ids=tokens, message_indices=indices, sampled_mask=sampled, + is_content=content_mask, message_roles=[m.get("role") or "" for m in messages], ) @@ -282,27 +350,75 @@ def bridge_to_next_turn( ext: list[int] = [] ext_indices: list[int] = [] ext_sampled: list[bool] = [] + ext_content: list[bool] = [] # Bridge populates ``message_indices`` (relative to ``new_messages``) # and ``sampled_mask`` (uniformly ``False`` — every token the # bridge emits is template scaffolding for the next prompt, not - # something the model sampled). Downstream consumers can run - # :meth:`RenderedTokens.tokens_per_message` on the bridge output - # to get per-new-message token counts without re-rendering. + # something the model sampled). ``is_content`` follows the same + # rules as in :meth:`render` so consumers can walk the trajectory + # and read each step's own body mask. def emit_special( - token_id: int, msg_idx: int = -1, *, is_sampled: bool = False + token_id: int, + msg_idx: int = -1, + *, + is_sampled: bool = False, + is_content: bool = False, ) -> None: ext.append(token_id) ext_indices.append(msg_idx) ext_sampled.append(is_sampled) + ext_content.append(is_content) def emit_text( - text: str, msg_idx: int = -1, *, is_sampled: bool = False + text: str, + msg_idx: int = -1, + *, + is_sampled: bool = False, + is_content: bool = False, ) -> None: ids = self._encode(text) ext.extend(ids) ext_indices.extend([msg_idx] * len(ids)) ext_sampled.extend([is_sampled] * len(ids)) + ext_content.extend([is_content] * len(ids)) + + def emit_text_segments( + segments: list[tuple[str, bool]], + msg_idx: int = -1, + *, + is_sampled: bool = False, + ) -> None: + for tok_id, is_content in attribute_text_segments( + self._tokenizer, segments + ): + ext.append(tok_id) + ext_indices.append(msg_idx) + ext_sampled.append(is_sampled) + ext_content.append(is_content) + + def emit_token_overlap_body( + full_text: str, + body_start: int, + body_end: int, + msg_idx: int, + *, + is_sampled: bool, + ) -> None: + from renderers.base import _get_offset_tokenizer + + offset_tok = _get_offset_tokenizer(self._tokenizer) + encoding = offset_tok( + full_text, add_special_tokens=False, return_offsets_mapping=True + ) + for tok_id, (start, end) in zip( + encoding["input_ids"], encoding["offset_mapping"] + ): + overlaps = start < body_end and end > body_start + ext.append(tok_id) + ext_indices.append(msg_idx) + ext_sampled.append(is_sampled) + ext_content.append(overlaps) # Trailing ``\n`` after the ``[e~[`` turn close — see ``render()``. emit_text("\n", -1) @@ -312,12 +428,18 @@ def emit_text( content = self._visible_text(msg.get("content")) if role == "user": emit_special(self._role, i) - emit_text("user\n" + content, i) + user_segments: list[tuple[str, bool]] = [("user\n", False)] + if content: + user_segments.append((content, True)) + emit_text_segments(user_segments, i) emit_special(self._eos, i) emit_text("\n", i) elif role == "system": emit_special(self._role, i) - emit_text("system\n" + content, i) + sys_segments: list[tuple[str, bool]] = [("system\n", False)] + if content: + sys_segments.append((content, True)) + emit_text_segments(sys_segments, i) emit_special(self._eos, i) emit_text("\n", i) elif role == "tool": @@ -328,6 +450,8 @@ def emit_text( msg, emit_special=emit_special, emit_text=emit_text, + emit_text_segments=emit_text_segments, + emit_token_overlap_body=emit_token_overlap_body, ) else: return None @@ -343,6 +467,7 @@ def emit_text( token_ids=previous_ids + ext, message_indices=[-1] * len(previous_ids) + ext_indices, sampled_mask=[False] * total_len, + is_content=[False] * len(previous_ids) + ext_content, message_roles=[m.get("role") or "" for m in new_messages], ) @@ -356,6 +481,7 @@ def _render_assistant( preserve_thinking: bool = False, emit_special, emit_text, + emit_text_segments, ): content = self._visible_text(msg.get("content")) @@ -374,12 +500,14 @@ def _render_assistant( # the chat template emits these as the generation prompt and the # model never samples them. Marking the role marker and tag as # ``is_sampled=False`` keeps the SFT loss mask aligned with what - # the model would actually have produced. - emit_special(self._role, orig_idx, is_sampled=False) + # the model would actually have produced. ``is_content`` is also + # False here — the role tag isn't part of any message's body. + emit_special(self._role, orig_idx, is_sampled=False, is_content=False) # Build the model-sampled portion (think block + content + tool - # calls). Text segments stay contiguous within each is_sampled - # span to preserve BPE merges. + # calls). For assistant messages the invariant + # ``is_content == sampled_mask`` holds — every sampled token is + # body, every scaffold token isn't. tool_calls = msg.get("tool_calls") or [] emit_thinking = reasoning_content and ( conv_idx > last_user_index or preserve_thinking @@ -390,10 +518,15 @@ def _render_assistant( # immediately after ``ai\n``, which forces a tokenizer # boundary — splitting ``ai\n`` (not_sampled) from the # ````-led body (sampled) is BPE-safe. - emit_text("ai\n", orig_idx, is_sampled=False) - emit_special(self._think, orig_idx, is_sampled=True) - emit_text("\n" + reasoning_content + "\n", orig_idx, is_sampled=True) - emit_special(self._think_end, orig_idx, is_sampled=True) + emit_text("ai\n", orig_idx, is_sampled=False, is_content=False) + emit_special(self._think, orig_idx, is_sampled=True, is_content=True) + emit_text( + "\n" + reasoning_content + "\n", + orig_idx, + is_sampled=True, + is_content=True, + ) + emit_special(self._think_end, orig_idx, is_sampled=True, is_content=True) # \n\n + content must be contiguous for BPE body = "\n\n" + content if content else "\n\n" else: @@ -406,17 +539,19 @@ def _render_assistant( # choice for SFT (don't train on a token whose first byte # is template scaffolding). if tool_calls and not body: - emit_text("ai\n\n", orig_idx, is_sampled=False) + emit_text("ai\n\n", orig_idx, is_sampled=False, is_content=False) else: - emit_text("ai\n", orig_idx, is_sampled=False) + emit_text("ai\n", orig_idx, is_sampled=False, is_content=False) if tool_calls: # \n before must be contiguous with preceding text. # The empty-body / non-thinking case folded the leading \n # into the role-tag emission above; skip it here. if emit_thinking or body: - emit_text(body + "\n", orig_idx, is_sampled=True) - emit_special(self._tool_call_tok, orig_idx, is_sampled=True) + emit_text(body + "\n", orig_idx, is_sampled=True, is_content=True) + emit_special( + self._tool_call_tok, orig_idx, is_sampled=True, is_content=True + ) invoke_block = "\n" for tc in tool_calls: @@ -448,16 +583,18 @@ def _render_assistant( ) invoke_block += "\n" - emit_text(invoke_block, orig_idx, is_sampled=True) - emit_special(self._tool_call_end_tok, orig_idx, is_sampled=True) + emit_text(invoke_block, orig_idx, is_sampled=True, is_content=True) + emit_special( + self._tool_call_end_tok, orig_idx, is_sampled=True, is_content=True + ) elif body: - emit_text(body, orig_idx, is_sampled=True) + emit_text(body, orig_idx, is_sampled=True, is_content=True) # ``[e~[`` is the model's stop signal — it samples this to end # its turn, so it is part of the sampled stream. The trailing # ``\n`` is template-appended between turns and never sampled. - emit_special(self._eos, orig_idx, is_sampled=True) - emit_text("\n", orig_idx, is_sampled=False) + emit_special(self._eos, orig_idx, is_sampled=True, is_content=True) + emit_text("\n", orig_idx, is_sampled=False, is_content=False) def _render_tool( self, @@ -468,10 +605,15 @@ def _render_tool( *, emit_special, emit_text, + emit_text_segments, + emit_token_overlap_body=None, ) -> None: # Tool messages are conversation history injected by the runtime # between assistant turns — the model never samples any of these - # tokens, so every emission is is_sampled=False. + # tokens, so every emission is is_sampled=False. The ``content`` + # body bytes get ``is_content=True``; the surrounding ```` + # wrap, role tag and separators are scaffold so an SFT mask over + # tool body never trains the model to emit those. prev_is_tool = conv_idx > 0 and conversation[conv_idx - 1]["role"] == "tool" next_is_tool = ( conv_idx + 1 < len(conversation) @@ -479,8 +621,8 @@ def _render_tool( ) if not prev_is_tool: - emit_special(self._role, orig_idx, is_sampled=False) - emit_text("tool", orig_idx, is_sampled=False) + emit_special(self._role, orig_idx, is_sampled=False, is_content=False) + emit_text("tool", orig_idx, is_sampled=False, is_content=False) content = self._visible_text(msg.get("content")) # Leading ``\n`` before ```` only on the first of a @@ -489,12 +631,39 @@ def _render_tool( # through a single emit_text call instead of splitting the merge. prefix = "" if prev_is_tool else "\n" suffix = "\n" if next_is_tool else "" - emit_text( - prefix + "" + content + "" + suffix, - orig_idx, - is_sampled=False, - ) + + # ```` is plain text with no separator between the + # closing ``>`` and ``content``'s first byte, so BPE can merge + # them into a single token (e.g., ``>The``). The shared + # ``attribute_text_segments`` helper picks the segment of a + # boundary-spanning token by its *first* char (here scaffold), + # which would drop the body's leading letter out of the body + # run. We instead use an "intersects body" rule: any token whose + # ``[start, end)`` char range overlaps the body span gets + # ``is_content=True``. A few scaffold bytes (the leading ``>`` + # or trailing ``<``) bleed into the body run, but body bytes are + # recoverable as a substring of the decoded body span. + body_text = prefix + "" + content + "" + suffix + body_start = len(prefix) + len("") + body_end = body_start + len(content) + if content and emit_token_overlap_body is not None: + emit_token_overlap_body( + body_text, body_start, body_end, orig_idx, is_sampled=False + ) + else: + # Empty body or no overlap-aware emitter available — fall back + # to the standard segments path. + tool_segments: list[tuple[str, bool]] = [] + if prefix: + tool_segments.append((prefix, False)) + tool_segments.append(("", False)) + if content: + tool_segments.append((content, True)) + tool_segments.append(("", False)) + if suffix: + tool_segments.append((suffix, False)) + emit_text_segments(tool_segments, orig_idx, is_sampled=False) if not next_is_tool: - emit_special(self._eos, orig_idx, is_sampled=False) - emit_text("\n", orig_idx, is_sampled=False) + emit_special(self._eos, orig_idx, is_sampled=False, is_content=False) + emit_text("\n", orig_idx, is_sampled=False, is_content=False) diff --git a/renderers/nemotron3.py b/renderers/nemotron3.py index e17e7ee..e97790d 100644 --- a/renderers/nemotron3.py +++ b/renderers/nemotron3.py @@ -24,6 +24,7 @@ ParsedResponse, RenderedTokens, ToolSpec, + attribute_text_segments, reject_assistant_in_extension, should_preserve_past_thinking, trim_to_turn_close, @@ -241,19 +242,44 @@ def orig_idx(i: int) -> int: tokens: list[int] = [] indices: list[int] = [] sampled: list[bool] = [] + content_mask: list[bool] = [] - def emit_ids(ids: list[int], msg_idx: int, *, is_sampled: bool) -> None: + def emit_ids( + ids: list[int], msg_idx: int, *, is_sampled: bool, is_content: bool + ) -> None: tokens.extend(ids) indices.extend([msg_idx] * len(ids)) sampled.extend([is_sampled] * len(ids)) + content_mask.extend([is_content] * len(ids)) - def emit_special(token_id: int, msg_idx: int, *, is_sampled: bool) -> None: + def emit_special( + token_id: int, msg_idx: int, *, is_sampled: bool, is_content: bool + ) -> None: tokens.append(token_id) indices.append(msg_idx) sampled.append(is_sampled) + content_mask.append(is_content) + + def emit_text( + text: str, msg_idx: int, *, is_sampled: bool, is_content: bool + ) -> None: + emit_ids( + self._encode(text), + msg_idx, + is_sampled=is_sampled, + is_content=is_content, + ) - def emit_text(text: str, msg_idx: int, *, is_sampled: bool) -> None: - emit_ids(self._encode(text), msg_idx, is_sampled=is_sampled) + def emit_text_segments( + segments: list[tuple[str, bool]], msg_idx: int, *, is_sampled: bool + ) -> None: + for tok_id, is_content in attribute_text_segments( + self._tokenizer, segments + ): + tokens.append(tok_id) + indices.append(msg_idx) + sampled.append(is_sampled) + content_mask.append(is_content) # ── 1. System message + optional tools ────────────────────── first_is_system = messages[0].get("role") == "system" @@ -262,8 +288,7 @@ def emit_text(text: str, msg_idx: int, *, is_sampled: bool) -> None: # Nemotron 3: system prompt BEFORE tools block sys_idx = orig_idx(0) if first_is_system else -1 - emit_special(self._im_start, sys_idx, is_sampled=False) - emit_text("system\n", sys_idx, is_sampled=False) + emit_special(self._im_start, sys_idx, is_sampled=False, is_content=False) # Build system content: user's system text first, then tools if first_is_system: @@ -282,22 +307,27 @@ def emit_text(text: str, msg_idx: int, *, is_sampled: bool) -> None: + _TOOLS_INSTRUCTIONS ) + # Body = caller's system text only; tools block (header, per- + # tool XML, footer, instructions) is scaffold. + sys_segments: list[tuple[str, bool]] = [("system\n", False)] if sys_content: - full_sys = sys_content + "\n\n" + tools_block - else: - full_sys = tools_block - - emit_text(full_sys, sys_idx, is_sampled=False) - emit_special(self._im_end, sys_idx, is_sampled=False) - emit_text("\n", sys_idx, is_sampled=False) + sys_segments.append((sys_content, True)) + sys_segments.append(("\n\n", False)) + sys_segments.append((tools_block, False)) + emit_text_segments(sys_segments, sys_idx, is_sampled=False) + emit_special(self._im_end, sys_idx, is_sampled=False, is_content=False) + emit_text("\n", sys_idx, is_sampled=False, is_content=False) elif first_is_system: sys_idx = orig_idx(0) sys_content = self._render_content(messages[0].get("content")).strip() - emit_special(self._im_start, sys_idx, is_sampled=False) - emit_text("system\n" + sys_content, sys_idx, is_sampled=False) - emit_special(self._im_end, sys_idx, is_sampled=False) - emit_text("\n", sys_idx, is_sampled=False) + emit_special(self._im_start, sys_idx, is_sampled=False, is_content=False) + sys_segments2: list[tuple[str, bool]] = [("system\n", False)] + if sys_content: + sys_segments2.append((sys_content, True)) + emit_text_segments(sys_segments2, sys_idx, is_sampled=False) + emit_special(self._im_end, sys_idx, is_sampled=False, is_content=False) + emit_text("\n", sys_idx, is_sampled=False, is_content=False) # Track the most-recent plain (non-tool-call) assistant so we can # preserve its reasoning while stripping reasoning from earlier @@ -322,10 +352,17 @@ def emit_text(text: str, msg_idx: int, *, is_sampled: bool) -> None: continue # Already handled above elif role == "user": - emit_special(self._im_start, msg_orig_idx, is_sampled=False) - emit_text("user\n" + content, msg_orig_idx, is_sampled=False) - emit_special(self._im_end, msg_orig_idx, is_sampled=False) - emit_text("\n", msg_orig_idx, is_sampled=False) + emit_special( + self._im_start, msg_orig_idx, is_sampled=False, is_content=False + ) + user_segments: list[tuple[str, bool]] = [("user\n", False)] + if content: + user_segments.append((content, True)) + emit_text_segments(user_segments, msg_orig_idx, is_sampled=False) + emit_special( + self._im_end, msg_orig_idx, is_sampled=False, is_content=False + ) + emit_text("\n", msg_orig_idx, is_sampled=False, is_content=False) elif role == "assistant": is_last_turn = i >= last_plain_assistant_idx @@ -344,6 +381,7 @@ def emit_text(text: str, msg_idx: int, *, is_sampled: bool) -> None: emit_special=emit_special, emit_text=emit_text, emit_ids=emit_ids, + emit_text_segments=emit_text_segments, ) elif role == "tool": @@ -355,6 +393,7 @@ def emit_text(text: str, msg_idx: int, *, is_sampled: bool) -> None: auto_system_injected=auto_system_injected, emit_special=emit_special, emit_text=emit_text, + emit_text_segments=emit_text_segments, ) else: @@ -362,21 +401,22 @@ def emit_text(text: str, msg_idx: int, *, is_sampled: bool) -> None: # ── 3. Generation prompt ──────────────────────────────────── if add_generation_prompt: - emit_special(self._im_start, -1, is_sampled=False) - emit_text("assistant\n", -1, is_sampled=False) + emit_special(self._im_start, -1, is_sampled=False, is_content=False) + emit_text("assistant\n", -1, is_sampled=False, is_content=False) if self._enable_thinking: - emit_special(self._think, -1, is_sampled=False) - emit_text("\n", -1, is_sampled=False) + emit_special(self._think, -1, is_sampled=False, is_content=False) + emit_text("\n", -1, is_sampled=False, is_content=False) else: # Disable-thinking suffix: with no trailing newlines - emit_special(self._think, -1, is_sampled=False) - emit_special(self._think_end, -1, is_sampled=False) + emit_special(self._think, -1, is_sampled=False, is_content=False) + emit_special(self._think_end, -1, is_sampled=False, is_content=False) return RenderedTokens( token_ids=tokens, message_indices=indices, sampled_mask=sampled, - message_roles=[m.get("role") or "" for m in messages], + is_content=content_mask, + message_roles=[m.get("role") or "" for m in original_messages], ) def render_ids( @@ -447,27 +487,52 @@ def bridge_to_next_turn( ext: list[int] = [] ext_indices: list[int] = [] ext_sampled: list[bool] = [] + ext_content: list[bool] = [] # Bridge populates ``message_indices`` (relative to ``new_messages``) # and ``sampled_mask`` (uniformly ``False`` — every token the # bridge emits is template scaffolding for the next prompt, not - # something the model sampled). Downstream consumers can run - # :meth:`RenderedTokens.tokens_per_message` on the bridge output - # to get per-new-message token counts without re-rendering. + # something the model sampled). ``is_content`` follows the same + # rules as in :meth:`render` so consumers can walk the trajectory + # and read each step's own body mask. def emit_special( - token_id: int, msg_idx: int = -1, *, is_sampled: bool = False + token_id: int, + msg_idx: int = -1, + *, + is_sampled: bool = False, + is_content: bool = False, ) -> None: ext.append(token_id) ext_indices.append(msg_idx) ext_sampled.append(is_sampled) + ext_content.append(is_content) def emit_text( - text: str, msg_idx: int = -1, *, is_sampled: bool = False + text: str, + msg_idx: int = -1, + *, + is_sampled: bool = False, + is_content: bool = False, ) -> None: ids = self._encode(text) ext.extend(ids) ext_indices.extend([msg_idx] * len(ids)) ext_sampled.extend([is_sampled] * len(ids)) + ext_content.extend([is_content] * len(ids)) + + def emit_text_segments( + segments: list[tuple[str, bool]], + msg_idx: int = -1, + *, + is_sampled: bool = False, + ) -> None: + for tok_id, is_content in attribute_text_segments( + self._tokenizer, segments + ): + ext.append(tok_id) + ext_indices.append(msg_idx) + ext_sampled.append(is_sampled) + ext_content.append(is_content) emit_text("\n", -1) @@ -476,12 +541,18 @@ def emit_text( content = self._render_content(msg.get("content")).strip() if role == "user": emit_special(self._im_start, i) - emit_text("user\n" + content, i) + user_segments: list[tuple[str, bool]] = [("user\n", False)] + if content: + user_segments.append((content, True)) + emit_text_segments(user_segments, i) emit_special(self._im_end, i) emit_text("\n", i) elif role == "system": emit_special(self._im_start, i) - emit_text("system\n" + content, i) + sys_segments: list[tuple[str, bool]] = [("system\n", False)] + if content: + sys_segments.append((content, True)) + emit_text_segments(sys_segments, i) emit_special(self._im_end, i) emit_text("\n", i) elif role == "tool": @@ -493,6 +564,7 @@ def emit_text( auto_system_injected=False, emit_special=emit_special, emit_text=emit_text, + emit_text_segments=emit_text_segments, ) else: return None @@ -512,6 +584,7 @@ def emit_text( token_ids=previous_ids + ext, message_indices=[-1] * len(previous_ids) + ext_indices, sampled_mask=[False] * total_len, + is_content=[False] * len(previous_ids) + ext_content, message_roles=[m.get("role") or "" for m in new_messages], ) @@ -530,6 +603,7 @@ def _render_assistant( emit_special, emit_text, emit_ids, + emit_text_segments, ) -> None: # Extract reasoning_content reasoning_content = "" @@ -550,9 +624,10 @@ def _render_assistant( # at inference the chat template emits these as the generation # prompt and the model never samples them. Marking the role tag # as ``is_sampled=False`` keeps the SFT loss mask aligned with - # what the model would actually have produced. - emit_special(self._im_start, msg_idx, is_sampled=False) - emit_text("assistant\n", msg_idx, is_sampled=False) + # what the model would actually have produced. On assistant the + # invariant ``is_content == sampled_mask`` holds. + emit_special(self._im_start, msg_idx, is_sampled=False, is_content=False) + emit_text("assistant\n", msg_idx, is_sampled=False, is_content=False) # Nemotron 3 keeps reasoning on the most-recent plain assistant but # strips it from historical turns, which collapse to an empty @@ -567,23 +642,43 @@ def _render_assistant( content_suffix = "\n" if tool_calls else "" if reasoning_content and (is_last_turn or preserve_thinking): - emit_special(self._think, msg_idx, is_sampled=True) - emit_text("\n" + reasoning_content + "\n", msg_idx, is_sampled=True) - emit_special(self._think_end, msg_idx, is_sampled=True) + emit_special(self._think, msg_idx, is_sampled=True, is_content=True) + emit_text( + "\n" + reasoning_content + "\n", + msg_idx, + is_sampled=True, + is_content=True, + ) + emit_special(self._think_end, msg_idx, is_sampled=True, is_content=True) # Single \n separator (not \n\n like Qwen3.5) - emit_text("\n" + content + content_suffix, msg_idx, is_sampled=True) + emit_text( + "\n" + content + content_suffix, + msg_idx, + is_sampled=True, + is_content=True, + ) elif reasoning_content: # Historical assistant whose reasoning got stripped — template # keeps a single \n between the collapsed and # the content as a marker that reasoning existed. - emit_special(self._think, msg_idx, is_sampled=True) - emit_special(self._think_end, msg_idx, is_sampled=True) - emit_text("\n" + content + content_suffix, msg_idx, is_sampled=True) + emit_special(self._think, msg_idx, is_sampled=True, is_content=True) + emit_special(self._think_end, msg_idx, is_sampled=True, is_content=True) + emit_text( + "\n" + content + content_suffix, + msg_idx, + is_sampled=True, + is_content=True, + ) else: # No reasoning ever — glued directly to content. - emit_special(self._think, msg_idx, is_sampled=True) - emit_special(self._think_end, msg_idx, is_sampled=True) - emit_text(content + content_suffix, msg_idx, is_sampled=True) + emit_special(self._think, msg_idx, is_sampled=True, is_content=True) + emit_special(self._think_end, msg_idx, is_sampled=True, is_content=True) + emit_text( + content + content_suffix, + msg_idx, + is_sampled=True, + is_content=True, + ) # Tool calls (leading \n was glued to the content above; each # iteration's trailing \n after handles the @@ -594,8 +689,13 @@ def _render_assistant( name = func.get("name", "") arguments = func.get("arguments", {}) - emit_special(self._tool_call, msg_idx, is_sampled=True) - emit_text("\n\n", msg_idx, is_sampled=True) + emit_special(self._tool_call, msg_idx, is_sampled=True, is_content=True) + emit_text( + "\n\n", + msg_idx, + is_sampled=True, + is_content=True, + ) # Render arguments # OpenAI canonical form: arguments is a JSON string. Parse it so the @@ -619,18 +719,21 @@ def _render_assistant( + "\n\n", msg_idx, is_sampled=True, + is_content=True, ) - emit_text("\n", msg_idx, is_sampled=True) - emit_special(self._tool_call_end, msg_idx, is_sampled=True) + emit_text("\n", msg_idx, is_sampled=True, is_content=True) + emit_special( + self._tool_call_end, msg_idx, is_sampled=True, is_content=True + ) # Trailing \n after (Nemotron 3 specific) - emit_text("\n", msg_idx, is_sampled=True) + emit_text("\n", msg_idx, is_sampled=True, is_content=True) # ``<|im_end|>`` is the model's stop signal — it samples this to # end its turn, so it is part of the sampled stream. The trailing # ``\n`` is template-appended between turns and never sampled. - emit_special(self._im_end, msg_idx, is_sampled=True) - emit_text("\n", msg_idx, is_sampled=False) + emit_special(self._im_end, msg_idx, is_sampled=True, is_content=True) + emit_text("\n", msg_idx, is_sampled=False, is_content=False) # ------------------------------------------------------------------ # Tool message rendering @@ -646,10 +749,13 @@ def _render_tool( auto_system_injected: bool, emit_special, emit_text, + emit_text_segments, ) -> None: # Tool messages are conversation history injected by the runtime # between assistant turns — the model never samples any of these - # tokens, so every emission is is_sampled=False. + # tokens, so every emission is is_sampled=False. The ``content`` + # body bytes get ``is_content=True``; the surrounding wrap is + # scaffold. prev_is_tool = msg_idx > 0 and messages[msg_idx - 1]["role"] == "tool" next_is_tool = ( msg_idx + 1 < len(messages) and messages[msg_idx + 1]["role"] == "tool" @@ -657,17 +763,19 @@ def _render_tool( oi = msg_orig_idx if not prev_is_tool: - emit_special(self._im_start, oi, is_sampled=False) - emit_text("user\n", oi, is_sampled=False) + emit_special(self._im_start, oi, is_sampled=False, is_content=False) + emit_text("user\n", oi, is_sampled=False, is_content=False) # else: the previous tool's trailing \n already provides the # separator into this block. - emit_special(self._tool_response, oi, is_sampled=False) - emit_text("\n" + content + "\n", oi, is_sampled=False) - emit_special(self._tool_response_end, oi, is_sampled=False) + emit_special(self._tool_response, oi, is_sampled=False, is_content=False) + emit_text_segments( + [("\n", False), (content, True), ("\n", False)], oi, is_sampled=False + ) + emit_special(self._tool_response_end, oi, is_sampled=False, is_content=False) # Nemotron 3: trailing \n after - emit_text("\n", oi, is_sampled=False) + emit_text("\n", oi, is_sampled=False, is_content=False) if not next_is_tool: - emit_special(self._im_end, oi, is_sampled=False) - emit_text("\n", oi, is_sampled=False) + emit_special(self._im_end, oi, is_sampled=False, is_content=False) + emit_text("\n", oi, is_sampled=False, is_content=False) diff --git a/renderers/qwen3.py b/renderers/qwen3.py index 8f5f17a..4562546 100644 --- a/renderers/qwen3.py +++ b/renderers/qwen3.py @@ -18,6 +18,7 @@ ParsedResponse, RenderedTokens, ToolSpec, + attribute_text_segments, reject_assistant_in_extension, should_preserve_past_thinking, trim_to_turn_close, @@ -108,43 +109,76 @@ def render( tokens: list[int] = [] indices: list[int] = [] sampled: list[bool] = [] + content_mask: list[bool] = [] - def emit_ids(ids: list[int], msg_idx: int, *, is_sampled: bool) -> None: - tokens.extend(ids) - indices.extend([msg_idx] * len(ids)) - sampled.extend([is_sampled] * len(ids)) - - def emit_special(token_id: int, msg_idx: int, *, is_sampled: bool) -> None: + def emit_special( + token_id: int, msg_idx: int, *, is_sampled: bool, is_content: bool + ) -> None: tokens.append(token_id) indices.append(msg_idx) sampled.append(is_sampled) + content_mask.append(is_content) - def emit_text(text: str, msg_idx: int, *, is_sampled: bool) -> None: - emit_ids(self._encode(text), msg_idx, is_sampled=is_sampled) + def emit_text( + text: str, msg_idx: int, *, is_sampled: bool, is_content: bool + ) -> None: + ids = self._encode(text) + tokens.extend(ids) + indices.extend([msg_idx] * len(ids)) + sampled.extend([is_sampled] * len(ids)) + content_mask.extend([is_content] * len(ids)) + + def emit_text_segments( + segments: list[tuple[str, bool]], msg_idx: int, *, is_sampled: bool + ) -> None: + """Tokenize concatenated segments as one BPE pass; per-token + ``is_content`` follows each token's source segment. + + Lets call sites express "this wrap + this body, joined the + same way as the chat template, but attributed separately" + without splitting the encode call (which could shift BPE + merges at the boundary).""" + for tok_id, is_content in attribute_text_segments( + self._tokenizer, segments + ): + tokens.append(tok_id) + indices.append(msg_idx) + sampled.append(is_sampled) + content_mask.append(is_content) # ── 1. System + tools ─────────────────────────────────────── first_is_system = messages[0].get("role") == "system" if tools: sys_idx = 0 if first_is_system else -1 - emit_special(self._im_start, sys_idx, is_sampled=False) - tool_text = "system\n" + emit_special(self._im_start, sys_idx, is_sampled=False, is_content=False) + # Body = system content (if any). Everything else in this + # block — role tag, tools header / footer, the JSON tool + # specs — is scaffold. The tools dict is recoverable from + # the ``tools`` argument; don't re-attribute its embedded + # JSON as message body. + segments: list[tuple[str, bool]] = [("system\n", False)] if first_is_system: - tool_text += (messages[0].get("content") or "") + "\n\n" - tool_text += _TOOLS_HEADER + sys_content = messages[0].get("content") or "" + if sys_content: + segments.append((sys_content, True)) + segments.append(("\n\n", False)) + segments.append((_TOOLS_HEADER, False)) for tool in tools: - tool_text += "\n" + json.dumps(tool, ensure_ascii=False) - tool_text += _TOOLS_FOOTER - emit_text(tool_text, sys_idx, is_sampled=False) - emit_special(self._im_end, sys_idx, is_sampled=False) - emit_text("\n", sys_idx, is_sampled=False) + segments.append(("\n" + json.dumps(tool, ensure_ascii=False), False)) + segments.append((_TOOLS_FOOTER, False)) + emit_text_segments(segments, sys_idx, is_sampled=False) + emit_special(self._im_end, sys_idx, is_sampled=False, is_content=False) + emit_text("\n", sys_idx, is_sampled=False, is_content=False) elif first_is_system: - emit_special(self._im_start, 0, is_sampled=False) - emit_text( - "system\n" + (messages[0].get("content") or ""), 0, is_sampled=False - ) - emit_special(self._im_end, 0, is_sampled=False) - emit_text("\n", 0, is_sampled=False) + emit_special(self._im_start, 0, is_sampled=False, is_content=False) + sys_content = messages[0].get("content") or "" + sys_segments: list[tuple[str, bool]] = [("system\n", False)] + if sys_content: + sys_segments.append((sys_content, True)) + emit_text_segments(sys_segments, 0, is_sampled=False) + emit_special(self._im_end, 0, is_sampled=False, is_content=False) + emit_text("\n", 0, is_sampled=False, is_content=False) # ── 2. Compute last_query_index ───────────────────────────── last_qi = self._last_query_index(messages) @@ -158,16 +192,22 @@ def emit_text(text: str, msg_idx: int, *, is_sampled: bool) -> None: if role == "system": if i == 0: continue - emit_special(self._im_start, i, is_sampled=False) - emit_text(role + "\n" + content, i, is_sampled=False) - emit_special(self._im_end, i, is_sampled=False) - emit_text("\n", i, is_sampled=False) + emit_special(self._im_start, i, is_sampled=False, is_content=False) + msg_segments: list[tuple[str, bool]] = [(role + "\n", False)] + if content: + msg_segments.append((content, True)) + emit_text_segments(msg_segments, i, is_sampled=False) + emit_special(self._im_end, i, is_sampled=False, is_content=False) + emit_text("\n", i, is_sampled=False, is_content=False) elif role == "user": - emit_special(self._im_start, i, is_sampled=False) - emit_text("user\n" + content, i, is_sampled=False) - emit_special(self._im_end, i, is_sampled=False) - emit_text("\n", i, is_sampled=False) + emit_special(self._im_start, i, is_sampled=False, is_content=False) + user_segments: list[tuple[str, bool]] = [("user\n", False)] + if content: + user_segments.append((content, True)) + emit_text_segments(user_segments, i, is_sampled=False) + emit_special(self._im_end, i, is_sampled=False, is_content=False) + emit_text("\n", i, is_sampled=False, is_content=False) elif role == "assistant": preserve_thinking = should_preserve_past_thinking( @@ -185,24 +225,33 @@ def emit_text(text: str, msg_idx: int, *, is_sampled: bool) -> None: preserve_thinking=preserve_thinking, emit_special=emit_special, emit_text=emit_text, + emit_text_segments=emit_text_segments, ) elif role == "tool": self._render_tool( - messages, i, content, emit_special=emit_special, emit_text=emit_text + messages, + i, + content, + emit_special=emit_special, + emit_text=emit_text, + emit_text_segments=emit_text_segments, ) # ── 4. Generation prompt ──────────────────────────────────── if add_generation_prompt: - emit_special(self._im_start, -1, is_sampled=False) - emit_text("assistant\n", -1, is_sampled=False) + emit_special(self._im_start, -1, is_sampled=False, is_content=False) + emit_text("assistant\n", -1, is_sampled=False, is_content=False) if not self._enable_thinking: - emit_text("\n\n\n\n", -1, is_sampled=False) + emit_text( + "\n\n\n\n", -1, is_sampled=False, is_content=False + ) return RenderedTokens( token_ids=tokens, message_indices=indices, sampled_mask=sampled, + is_content=content_mask, message_roles=[m.get("role") or "" for m in messages], ) @@ -263,27 +312,54 @@ def bridge_to_next_turn( ext: list[int] = [] ext_indices: list[int] = [] ext_sampled: list[bool] = [] + ext_content: list[bool] = [] # Bridge populates ``message_indices`` (relative to ``new_messages``) # and ``sampled_mask`` (uniformly ``False`` — every token the # bridge emits is template scaffolding for the next prompt, not - # something the model sampled). Downstream consumers can run - # :meth:`RenderedTokens.tokens_per_message` on the bridge output - # to get per-new-message token counts without re-rendering. + # something the model sampled). ``is_content`` follows the same + # rules as in :meth:`render` so consumers can walk the trajectory + # and read each step's own body mask. Downstream consumers can + # run :meth:`RenderedTokens.tokens_per_message` on the bridge + # output to get per-new-message token counts without re-rendering. def emit_special( - token_id: int, msg_idx: int = -1, *, is_sampled: bool = False + token_id: int, + msg_idx: int = -1, + *, + is_sampled: bool = False, + is_content: bool = False, ) -> None: ext.append(token_id) ext_indices.append(msg_idx) ext_sampled.append(is_sampled) + ext_content.append(is_content) def emit_text( - text: str, msg_idx: int = -1, *, is_sampled: bool = False + text: str, + msg_idx: int = -1, + *, + is_sampled: bool = False, + is_content: bool = False, ) -> None: ids = self._encode(text) ext.extend(ids) ext_indices.extend([msg_idx] * len(ids)) ext_sampled.extend([is_sampled] * len(ids)) + ext_content.extend([is_content] * len(ids)) + + def emit_text_segments( + segments: list[tuple[str, bool]], + msg_idx: int = -1, + *, + is_sampled: bool = False, + ) -> None: + for tok_id, is_content in attribute_text_segments( + self._tokenizer, segments + ): + ext.append(tok_id) + ext_indices.append(msg_idx) + ext_sampled.append(is_sampled) + ext_content.append(is_content) # Trailing ``\n`` after the turn-close token. ``render()`` emits this # as part of the prior turn, but vLLM stops on ``<|im_end|>`` so the @@ -295,12 +371,18 @@ def emit_text( content = msg.get("content") if isinstance(msg.get("content"), str) else "" if role == "user": emit_special(self._im_start, i) - emit_text("user\n" + content, i) + user_segments: list[tuple[str, bool]] = [("user\n", False)] + if content: + user_segments.append((content, True)) + emit_text_segments(user_segments, i) emit_special(self._im_end, i) emit_text("\n", i) elif role == "system": emit_special(self._im_start, i) - emit_text("system\n" + content, i) + sys_segments: list[tuple[str, bool]] = [("system\n", False)] + if content: + sys_segments.append((content, True)) + emit_text_segments(sys_segments, i) emit_special(self._im_end, i) emit_text("\n", i) elif role == "tool": @@ -310,6 +392,7 @@ def emit_text( content, emit_special=emit_special, emit_text=emit_text, + emit_text_segments=emit_text_segments, ) else: return None @@ -324,6 +407,7 @@ def emit_text( token_ids=previous_ids + ext, message_indices=[-1] * len(previous_ids) + ext_indices, sampled_mask=[False] * total_len, + is_content=[False] * len(previous_ids) + ext_content, message_roles=[m.get("role") or "" for m in new_messages], ) @@ -338,6 +422,7 @@ def _render_assistant( preserve_thinking: bool = False, emit_special, emit_text, + emit_text_segments, ): reasoning_content = "" if isinstance(msg.get("reasoning_content"), str): @@ -355,16 +440,20 @@ def _render_assistant( # at inference the chat template emits these as the generation # prompt and the model never samples them. Marking the role tag # as ``is_sampled=False`` keeps the SFT loss mask aligned with - # what the model would actually have produced. - emit_special(self._im_start, msg_idx, is_sampled=False) - emit_text("assistant\n", msg_idx, is_sampled=False) + # what the model would actually have produced. ``is_content`` is + # also False here — the role tag isn't part of any message's + # body, on any role. + emit_special(self._im_start, msg_idx, is_sampled=False, is_content=False) + emit_text("assistant\n", msg_idx, is_sampled=False, is_content=False) # Build the model-sampled portion (think block + content + tool # calls). Text segments stay contiguous within each is_sampled # span to preserve BPE merges (e.g., ".\n" is a single token in # Qwen3); the only split we introduce here is at ``\n`` after the # role tag, which the existing renderer already treats as a - # token boundary (cf. ``_render_tool``). + # token boundary (cf. ``_render_tool``). For assistant messages + # the invariant ``is_content == sampled_mask`` holds — every + # sampled token is body, every scaffold token isn't. tool_calls = msg.get("tool_calls") or [] emit_in_template_window = msg_idx > last_query_index and ( @@ -382,7 +471,7 @@ def _render_assistant( body = content if not tool_calls: - emit_text(body, msg_idx, is_sampled=True) + emit_text(body, msg_idx, is_sampled=True, is_content=True) else: for tc_idx, tc in enumerate(tool_calls): func = tc.get("function") or tc @@ -397,23 +486,29 @@ def _render_assistant( # Text before this tool_call (includes separator) if tc_idx == 0: separator = "\n" if content else "" - emit_text(body + separator, msg_idx, is_sampled=True) + emit_text( + body + separator, msg_idx, is_sampled=True, is_content=True + ) else: - emit_text("\n", msg_idx, is_sampled=True) + emit_text("\n", msg_idx, is_sampled=True, is_content=True) - emit_special(self._tool_call, msg_idx, is_sampled=True) + emit_special(self._tool_call, msg_idx, is_sampled=True, is_content=True) emit_text( '\n{"name": "' + name + '", "arguments": ' + args_str + "}\n", msg_idx, is_sampled=True, + is_content=True, + ) + emit_special( + self._tool_call_end, msg_idx, is_sampled=True, is_content=True ) - emit_special(self._tool_call_end, msg_idx, is_sampled=True) # ``<|im_end|>`` is the model's stop signal — it samples this to - # end its turn, so it is part of the sampled stream. The trailing - # ``\n`` is template-appended between turns and never sampled. - emit_special(self._im_end, msg_idx, is_sampled=True) - emit_text("\n", msg_idx, is_sampled=False) + # end its turn, so it is part of the sampled stream (and the + # assistant's body). The trailing ``\n`` is template-appended + # between turns and never sampled — scaffold for is_content too. + emit_special(self._im_end, msg_idx, is_sampled=True, is_content=True) + emit_text("\n", msg_idx, is_sampled=False, is_content=False) def _render_tool( self, @@ -423,24 +518,40 @@ def _render_tool( *, emit_special, emit_text, + emit_text_segments, ) -> None: # Tool messages are conversation history injected by the runtime # between assistant turns — the model never samples any of these - # tokens, so every emission is is_sampled=False. + # tokens, so every emission is is_sampled=False. The ``content`` + # field's body bytes get ``is_content=True``; everything else — + # the ``<|im_start|>user`` wrap, the inter-section ``\n``s, the + # ``<|tool_response>`` specials — is scaffold so the SFT mask + # for tool body never trains the model to emit them. prev_is_tool = msg_idx > 0 and messages[msg_idx - 1]["role"] == "tool" next_is_tool = ( msg_idx + 1 < len(messages) and messages[msg_idx + 1]["role"] == "tool" ) if not prev_is_tool: - emit_special(self._im_start, msg_idx, is_sampled=False) - emit_text("user", msg_idx, is_sampled=False) - - emit_text("\n", msg_idx, is_sampled=False) - emit_special(self._tool_response, msg_idx, is_sampled=False) - emit_text("\n" + content + "\n", msg_idx, is_sampled=False) - emit_special(self._tool_response_end, msg_idx, is_sampled=False) + emit_special(self._im_start, msg_idx, is_sampled=False, is_content=False) + emit_text("user", msg_idx, is_sampled=False, is_content=False) + + emit_text("\n", msg_idx, is_sampled=False, is_content=False) + emit_special(self._tool_response, msg_idx, is_sampled=False, is_content=False) + # ``\n`` + content + ``\n`` — body is the middle segment only. + # Single BPE pass over the joined text preserves boundary + # merges (Qwen3 keeps ``\n`` as its own token, so this is + # mostly a no-op, but we route through segments anyway so the + # attribution doesn't depend on tokenizer-specific behaviour). + emit_text_segments( + [("\n", False), (content, True), ("\n", False)], + msg_idx, + is_sampled=False, + ) + emit_special( + self._tool_response_end, msg_idx, is_sampled=False, is_content=False + ) if not next_is_tool: - emit_special(self._im_end, msg_idx, is_sampled=False) - emit_text("\n", msg_idx, is_sampled=False) + emit_special(self._im_end, msg_idx, is_sampled=False, is_content=False) + emit_text("\n", msg_idx, is_sampled=False, is_content=False) diff --git a/renderers/qwen35.py b/renderers/qwen35.py index e2c84c6..2deefcf 100644 --- a/renderers/qwen35.py +++ b/renderers/qwen35.py @@ -26,6 +26,7 @@ PlaceholderRange, RenderedTokens, ToolSpec, + attribute_text_segments, reject_assistant_in_extension, should_preserve_past_thinking, trim_to_turn_close, @@ -291,33 +292,73 @@ def render( tokens: list[int] = [] indices: list[int] = [] sampled: list[bool] = [] + content_mask: list[bool] = [] mm_hashes: dict[str, list[str]] = {} mm_placeholders: dict[str, list[PlaceholderRange]] = {} mm_items: dict[str, list[dict[str, Any]]] = {} - def emit_ids(ids: list[int], msg_idx: int, *, is_sampled: bool) -> None: + def emit_ids( + ids: list[int], msg_idx: int, *, is_sampled: bool, is_content: bool + ) -> None: tokens.extend(ids) indices.extend([msg_idx] * len(ids)) sampled.extend([is_sampled] * len(ids)) + content_mask.extend([is_content] * len(ids)) - def emit_special(token_id: int, msg_idx: int, *, is_sampled: bool) -> None: + def emit_special( + token_id: int, msg_idx: int, *, is_sampled: bool, is_content: bool + ) -> None: tokens.append(token_id) indices.append(msg_idx) sampled.append(is_sampled) + content_mask.append(is_content) + + def emit_text( + text: str, msg_idx: int, *, is_sampled: bool, is_content: bool + ) -> None: + emit_ids( + self._encode(text), + msg_idx, + is_sampled=is_sampled, + is_content=is_content, + ) - def emit_text(text: str, msg_idx: int, *, is_sampled: bool) -> None: - emit_ids(self._encode(text), msg_idx, is_sampled=is_sampled) + def emit_text_segments( + segments: list[tuple[str, bool]], msg_idx: int, *, is_sampled: bool + ) -> None: + """Tokenize concatenated segments as one BPE pass; per-token + ``is_content`` follows each token's source segment. + + Lets call sites express "this wrap + this body, joined the + same way as the chat template, but attributed separately" + without splitting the encode call (which could shift BPE + merges at the boundary).""" + for tok_id, is_content in attribute_text_segments( + self._tokenizer, segments + ): + tokens.append(tok_id) + indices.append(msg_idx) + sampled.append(is_sampled) + content_mask.append(is_content) def emit_image(part: dict[str, Any], msg_idx: int) -> None: # Image placeholders only appear in user / tool messages; the # model never samples them. Pin is_sampled=False here so - # callers don't need to thread the flag through. + # callers don't need to thread the flag through. The + # ``<|image_pad|>`` placeholders represent caller-provided + # image data, so they ARE body content (is_content=True); + # the surrounding ``<|vision_start|>`` / ``<|vision_end|>`` + # specials are template scaffold. _, out, n, h = self._process_image(part) - emit_special(self._vision_start, msg_idx, is_sampled=False) + emit_special( + self._vision_start, msg_idx, is_sampled=False, is_content=False + ) offset = len(tokens) for _ in range(n): - emit_special(self._image_pad, msg_idx, is_sampled=False) - emit_special(self._vision_end, msg_idx, is_sampled=False) + emit_special( + self._image_pad, msg_idx, is_sampled=False, is_content=True + ) + emit_special(self._vision_end, msg_idx, is_sampled=False, is_content=False) mm_hashes.setdefault("image", []).append(h) mm_placeholders.setdefault("image", []).append( PlaceholderRange(offset=offset, length=n) @@ -337,20 +378,28 @@ def emit_user_with_media(content_list: list[Any], msg_idx: int) -> None: ``<|im_end|>``), matching how Jinja's ``render_content`` macro concatenates strings before tokenization. This preserves BPE byte-parity against ``apply_chat_template``. + + Within each flush the ``"user\\n"`` wrap is scaffold and the + text parts are body. ``emit_text_segments`` carries that + attribution through the BPE pass. """ # Every token in a user message is conversation history that # the model never samples at inference. - emit_special(self._im_start, msg_idx, is_sampled=False) - buf: list[str] = ["user\n"] + emit_special(self._im_start, msg_idx, is_sampled=False, is_content=False) + # First flush includes the ``"user\n"`` wrap as a scaffold + # segment; subsequent flushes are pure body (after a media + # break). + buf_segments: list[tuple[str, bool]] = [("user\n", False)] def flush_buf() -> None: - if buf: - emit_text("".join(buf), msg_idx, is_sampled=False) - buf.clear() + if buf_segments: + emit_text_segments(buf_segments, msg_idx, is_sampled=False) + buf_segments.clear() for item in content_list: if isinstance(item, str): - buf.append(item) + if item: + buf_segments.append((item, True)) elif isinstance(item, dict): if _is_image_part(item): flush_buf() @@ -360,14 +409,15 @@ def flush_buf() -> None: "Video parts are not yet supported by Qwen35Renderer." ) elif "text" in item: - buf.append(item["text"]) + if item["text"]: + buf_segments.append((item["text"], True)) else: raise ValueError(f"Unexpected content item: {item}") else: raise ValueError(f"Unexpected content item: {item}") flush_buf() - emit_special(self._im_end, msg_idx, is_sampled=False) - emit_text("\n", msg_idx, is_sampled=False) + emit_special(self._im_end, msg_idx, is_sampled=False, is_content=False) + emit_text("\n", msg_idx, is_sampled=False, is_content=False) # ── 1. System message + optional tools ────────────────────── first_is_system = messages[0].get("role") == "system" @@ -376,31 +426,37 @@ def flush_buf() -> None: # System message index for attribution sys_idx = 0 if first_is_system else -1 - emit_special(self._im_start, sys_idx, is_sampled=False) - emit_text("system\n", sys_idx, is_sampled=False) - - # Tools header + JSON definitions - tool_text = _TOOLS_HEADER + emit_special(self._im_start, sys_idx, is_sampled=False, is_content=False) + # Body = system content (if any). Everything else in this + # block — role tag, tools header / footer / instructions, the + # JSON tool specs — is scaffold. The tools dict is + # recoverable from the ``tools`` argument; don't re-attribute + # its embedded JSON as message body. + segments: list[tuple[str, bool]] = [ + ("system\n", False), + (_TOOLS_HEADER, False), + ] for tool in tools: - tool_text += "\n" + json.dumps(tool, ensure_ascii=False) - tool_text += _TOOLS_FOOTER - tool_text += _TOOLS_INSTRUCTIONS - - # Append user's system content if present + segments.append(("\n" + json.dumps(tool, ensure_ascii=False), False)) + segments.append((_TOOLS_FOOTER, False)) + segments.append((_TOOLS_INSTRUCTIONS, False)) if first_is_system: sys_content = self._render_content(messages[0].get("content")).strip() if sys_content: - tool_text += "\n\n" + sys_content - - emit_text(tool_text, sys_idx, is_sampled=False) - emit_special(self._im_end, sys_idx, is_sampled=False) - emit_text("\n", sys_idx, is_sampled=False) + segments.append(("\n\n", False)) + segments.append((sys_content, True)) + emit_text_segments(segments, sys_idx, is_sampled=False) + emit_special(self._im_end, sys_idx, is_sampled=False, is_content=False) + emit_text("\n", sys_idx, is_sampled=False, is_content=False) elif first_is_system: sys_content = self._render_content(messages[0].get("content")).strip() - emit_special(self._im_start, 0, is_sampled=False) - emit_text("system\n" + sys_content, 0, is_sampled=False) - emit_special(self._im_end, 0, is_sampled=False) - emit_text("\n", 0, is_sampled=False) + emit_special(self._im_start, 0, is_sampled=False, is_content=False) + sys_segments: list[tuple[str, bool]] = [("system\n", False)] + if sys_content: + sys_segments.append((sys_content, True)) + emit_text_segments(sys_segments, 0, is_sampled=False) + emit_special(self._im_end, 0, is_sampled=False, is_content=False) + emit_text("\n", 0, is_sampled=False, is_content=False) # ── 2. Compute last_query_index ───────────────────────────── last_qi = self._last_query_index(messages) @@ -420,10 +476,13 @@ def flush_buf() -> None: if self._content_has_media(raw_content): emit_user_with_media(raw_content, i) else: - emit_special(self._im_start, i, is_sampled=False) - emit_text("user\n" + content, i, is_sampled=False) - emit_special(self._im_end, i, is_sampled=False) - emit_text("\n", i, is_sampled=False) + emit_special(self._im_start, i, is_sampled=False, is_content=False) + user_segments: list[tuple[str, bool]] = [("user\n", False)] + if content: + user_segments.append((content, True)) + emit_text_segments(user_segments, i, is_sampled=False) + emit_special(self._im_end, i, is_sampled=False, is_content=False) + emit_text("\n", i, is_sampled=False, is_content=False) elif role == "assistant": preserve_thinking = should_preserve_past_thinking( @@ -441,6 +500,7 @@ def flush_buf() -> None: emit_special=emit_special, emit_text=emit_text, emit_ids=emit_ids, + emit_text_segments=emit_text_segments, ) elif role == "tool": @@ -450,6 +510,7 @@ def flush_buf() -> None: emit_special=emit_special, emit_text=emit_text, emit_image=emit_image, + emit_text_segments=emit_text_segments, ) else: @@ -457,16 +518,16 @@ def flush_buf() -> None: # ── 4. Generation prompt ──────────────────────────────────── if add_generation_prompt: - emit_special(self._im_start, -1, is_sampled=False) - emit_text("assistant\n", -1, is_sampled=False) + emit_special(self._im_start, -1, is_sampled=False, is_content=False) + emit_text("assistant\n", -1, is_sampled=False, is_content=False) if self._enable_thinking: - emit_special(self._think, -1, is_sampled=False) - emit_text("\n", -1, is_sampled=False) + emit_special(self._think, -1, is_sampled=False, is_content=False) + emit_text("\n", -1, is_sampled=False, is_content=False) else: - emit_special(self._think, -1, is_sampled=False) - emit_text("\n\n", -1, is_sampled=False) - emit_special(self._think_end, -1, is_sampled=False) - emit_text("\n\n", -1, is_sampled=False) + emit_special(self._think, -1, is_sampled=False, is_content=False) + emit_text("\n\n", -1, is_sampled=False, is_content=False) + emit_special(self._think_end, -1, is_sampled=False, is_content=False) + emit_text("\n\n", -1, is_sampled=False, is_content=False) mm_data: MultiModalData | None = None if mm_hashes or mm_placeholders or mm_items: @@ -480,6 +541,7 @@ def flush_buf() -> None: token_ids=tokens, message_indices=indices, sampled_mask=sampled, + is_content=content_mask, message_roles=[m.get("role") or "" for m in messages], multi_modal_data=mm_data, ) @@ -549,34 +611,63 @@ def bridge_to_next_turn( # ``previous_ids``. Bridge-added tokens get proper ``msg_idx`` # (relative to ``new_messages``) and uniformly ``False`` # ``sampled``: nothing the bridge emits was model-sampled. + # ``is_content`` follows the same rules as in :meth:`render` so + # consumers can walk the trajectory and read each step's own + # body mask; the prior portion is uniformly False since we have + # no attribution info for it. tokens: list[int] = list(previous_ids) indices: list[int] = [-1] * len(previous_ids) sampled: list[bool] = [False] * len(previous_ids) + content_mask: list[bool] = [False] * len(previous_ids) new_hashes: dict[str, list[str]] = {} new_placeholders: dict[str, list[PlaceholderRange]] = {} new_items: dict[str, list[dict[str, Any]]] = {} def emit_special( - token_id: int, msg_idx: int = -1, *, is_sampled: bool = False + token_id: int, + msg_idx: int = -1, + *, + is_sampled: bool = False, + is_content: bool = False, ) -> None: tokens.append(token_id) indices.append(msg_idx) sampled.append(is_sampled) + content_mask.append(is_content) def emit_text( - text: str, msg_idx: int = -1, *, is_sampled: bool = False + text: str, + msg_idx: int = -1, + *, + is_sampled: bool = False, + is_content: bool = False, ) -> None: ids = self._encode(text) tokens.extend(ids) indices.extend([msg_idx] * len(ids)) sampled.extend([is_sampled] * len(ids)) + content_mask.extend([is_content] * len(ids)) + + def emit_text_segments( + segments: list[tuple[str, bool]], + msg_idx: int = -1, + *, + is_sampled: bool = False, + ) -> None: + for tok_id, is_content in attribute_text_segments( + self._tokenizer, segments + ): + tokens.append(tok_id) + indices.append(msg_idx) + sampled.append(is_sampled) + content_mask.append(is_content) def emit_image(part: dict[str, Any], msg_idx: int = -1) -> None: _, out, n, h = self._process_image(part) emit_special(self._vision_start, msg_idx) offset = len(tokens) for _ in range(n): - emit_special(self._image_pad, msg_idx) + emit_special(self._image_pad, msg_idx, is_content=True) emit_special(self._vision_end, msg_idx) new_hashes.setdefault("image", []).append(h) new_placeholders.setdefault("image", []).append( @@ -591,16 +682,17 @@ def emit_image(part: dict[str, Any], msg_idx: int = -1) -> None: def emit_user_with_media(content_list: list[Any], msg_idx: int) -> None: emit_special(self._im_start, msg_idx) - buf: list[str] = ["user\n"] + buf_segments: list[tuple[str, bool]] = [("user\n", False)] def flush_buf() -> None: - if buf: - emit_text("".join(buf), msg_idx) - buf.clear() + if buf_segments: + emit_text_segments(buf_segments, msg_idx) + buf_segments.clear() for item in content_list: if isinstance(item, str): - buf.append(item) + if item: + buf_segments.append((item, True)) elif isinstance(item, dict): if _is_image_part(item): flush_buf() @@ -610,7 +702,8 @@ def flush_buf() -> None: "Video parts are not yet supported by Qwen35Renderer." ) elif "text" in item: - buf.append(item["text"]) + if item["text"]: + buf_segments.append((item["text"], True)) else: raise ValueError(f"Unexpected content item: {item}") else: @@ -633,12 +726,18 @@ def flush_buf() -> None: emit_user_with_media(raw_content, i) else: emit_special(self._im_start, i) - emit_text("user\n" + content, i) + user_segments: list[tuple[str, bool]] = [("user\n", False)] + if content: + user_segments.append((content, True)) + emit_text_segments(user_segments, i) emit_special(self._im_end, i) emit_text("\n", i) elif role == "system": emit_special(self._im_start, i) - emit_text("system\n" + content, i) + sys_segments: list[tuple[str, bool]] = [("system\n", False)] + if content: + sys_segments.append((content, True)) + emit_text_segments(sys_segments, i) emit_special(self._im_end, i) emit_text("\n", i) elif role == "tool": @@ -648,6 +747,7 @@ def flush_buf() -> None: emit_special=emit_special, emit_text=emit_text, emit_image=emit_image, + emit_text_segments=emit_text_segments, ) else: return None @@ -693,6 +793,7 @@ def flush_buf() -> None: token_ids=tokens, message_indices=indices, sampled_mask=sampled, + is_content=content_mask, message_roles=bridge_roles, ) @@ -705,6 +806,7 @@ def flush_buf() -> None: token_ids=tokens, message_indices=indices, sampled_mask=sampled, + is_content=content_mask, message_roles=bridge_roles, multi_modal_data=mm_data, ) @@ -747,6 +849,7 @@ def _render_assistant( emit_special, emit_text, emit_ids, + emit_text_segments, ) -> None: # Extract reasoning_content reasoning_content = "" @@ -769,26 +872,36 @@ def _render_assistant( # at inference the chat template emits these as the generation # prompt and the model never samples them. Marking the role tag # as ``is_sampled=False`` keeps the SFT loss mask aligned with - # what the model would actually have produced. The split between - # ``assistant`` and ``\n`` is a safe BPE boundary in the Qwen - # tokenizer (``\n`` after the role is its own token). - emit_special(self._im_start, msg_idx, is_sampled=False) - emit_text("assistant\n", msg_idx, is_sampled=False) + # what the model would actually have produced. ``is_content`` is + # also False here — the role tag isn't part of any message's + # body, on any role. + emit_special(self._im_start, msg_idx, is_sampled=False, is_content=False) + emit_text("assistant\n", msg_idx, is_sampled=False, is_content=False) # Build the model-sampled portion (think block + content + tool # calls). Text segments stay contiguous within each is_sampled - # span to preserve BPE merges. + # span to preserve BPE merges. For assistant messages the + # invariant ``is_content == sampled_mask`` holds — every sampled + # token is body, every scaffold token isn't. The XML-style tool + # call tags (````, ````, etc.) are + # part of the model's emitted output too — keep them + # ``is_content=True`` per the assistant rule. emit_thinking = self._should_render_thinking(msg_idx, last_query_index) or ( preserve_thinking and bool(reasoning_content) ) if emit_thinking: # Include thinking block - emit_special(self._think, msg_idx, is_sampled=True) - emit_text("\n" + reasoning_content + "\n", msg_idx, is_sampled=True) - emit_special(self._think_end, msg_idx, is_sampled=True) - emit_text("\n\n" + content, msg_idx, is_sampled=True) + emit_special(self._think, msg_idx, is_sampled=True, is_content=True) + emit_text( + "\n" + reasoning_content + "\n", + msg_idx, + is_sampled=True, + is_content=True, + ) + emit_special(self._think_end, msg_idx, is_sampled=True, is_content=True) + emit_text("\n\n" + content, msg_idx, is_sampled=True, is_content=True) else: - emit_text(content, msg_idx, is_sampled=True) + emit_text(content, msg_idx, is_sampled=True, is_content=True) # Tool calls tool_calls = msg.get("tool_calls") or [] @@ -801,13 +914,18 @@ def _render_assistant( # Separator before if tc_idx == 0: if content.strip(): - emit_text("\n\n", msg_idx, is_sampled=True) + emit_text("\n\n", msg_idx, is_sampled=True, is_content=True) # else: no separator else: - emit_text("\n", msg_idx, is_sampled=True) - - emit_special(self._tool_call, msg_idx, is_sampled=True) - emit_text("\n\n", msg_idx, is_sampled=True) + emit_text("\n", msg_idx, is_sampled=True, is_content=True) + + emit_special(self._tool_call, msg_idx, is_sampled=True, is_content=True) + emit_text( + "\n\n", + msg_idx, + is_sampled=True, + is_content=True, + ) # Render arguments # OpenAI canonical form: arguments is a JSON string. Parse it so the @@ -828,16 +946,20 @@ def _render_assistant( + "\n\n", msg_idx, is_sampled=True, + is_content=True, ) - emit_text("\n", msg_idx, is_sampled=True) - emit_special(self._tool_call_end, msg_idx, is_sampled=True) + emit_text("\n", msg_idx, is_sampled=True, is_content=True) + emit_special( + self._tool_call_end, msg_idx, is_sampled=True, is_content=True + ) # ``<|im_end|>`` is the model's stop signal — it samples this to - # end its turn, so it is part of the sampled stream. The trailing - # ``\n`` is template-appended between turns and never sampled. - emit_special(self._im_end, msg_idx, is_sampled=True) - emit_text("\n", msg_idx, is_sampled=False) + # end its turn, so it is part of the sampled stream (and the + # assistant's body). The trailing ``\n`` is template-appended + # between turns and never sampled — scaffold for is_content too. + emit_special(self._im_end, msg_idx, is_sampled=True, is_content=True) + emit_text("\n", msg_idx, is_sampled=False, is_content=False) # ------------------------------------------------------------------ # Tool message rendering @@ -851,6 +973,7 @@ def _render_tool( emit_special, emit_text, emit_image, + emit_text_segments, ) -> None: # Consecutive tool messages share a single <|im_start|>user ... <|im_end|> # envelope. Whether to open and close the envelope depends only on the @@ -858,9 +981,11 @@ def _render_tool( # tool message — keep this predicate text/media-agnostic. # Tool messages are conversation history injected by the runtime # between assistant turns — the model never samples any of these - # tokens, so every emission is is_sampled=False. The bridge's - # local emit_special / emit_text accept the is_sampled kwarg for - # signature compatibility but ignore it. + # tokens, so every emission is is_sampled=False. The ``content`` + # field's body bytes get ``is_content=True``; everything else — + # the ``<|im_start|>user`` wrap, the inter-section ``\n``s, the + # ``<|tool_response>`` specials — is scaffold so the SFT mask + # for tool body never trains the model to emit them. prev_is_tool = msg_idx > 0 and messages[msg_idx - 1]["role"] == "tool" next_is_tool = ( msg_idx + 1 < len(messages) and messages[msg_idx + 1]["role"] == "tool" @@ -868,11 +993,11 @@ def _render_tool( raw_content = messages[msg_idx].get("content") if not prev_is_tool: - emit_special(self._im_start, msg_idx, is_sampled=False) - emit_text("user", msg_idx, is_sampled=False) + emit_special(self._im_start, msg_idx, is_sampled=False, is_content=False) + emit_text("user", msg_idx, is_sampled=False, is_content=False) - emit_text("\n", msg_idx, is_sampled=False) - emit_special(self._tool_response, msg_idx, is_sampled=False) + emit_text("\n", msg_idx, is_sampled=False, is_content=False) + emit_special(self._tool_response, msg_idx, is_sampled=False, is_content=False) if self._content_has_media(raw_content): # Mirror the chat template's ``render_content`` macro for list @@ -882,16 +1007,22 @@ def _render_tool( # ``_content_has_media`` returns False unless content is a list, # but the type checker can't follow that through the call. assert isinstance(raw_content, list) - buf: list[str] = ["\n"] + # First flush: leading ``"\n"`` is scaffold (separates the + # ``<|tool_response>`` special from the body); subsequent + # text items in this run are body. After a media break, the + # buffer resets to pure body until the next media break or + # end-of-content. + buf_segments: list[tuple[str, bool]] = [("\n", False)] def flush_buf() -> None: - if buf: - emit_text("".join(buf), msg_idx, is_sampled=False) - buf.clear() + if buf_segments: + emit_text_segments(buf_segments, msg_idx, is_sampled=False) + buf_segments.clear() for item in raw_content: if isinstance(item, str): - buf.append(item) + if item: + buf_segments.append((item, True)) elif isinstance(item, dict): if _is_image_part(item): flush_buf() @@ -901,19 +1032,29 @@ def flush_buf() -> None: "Video parts are not yet supported by Qwen35Renderer." ) elif "text" in item: - buf.append(item["text"]) + if item["text"]: + buf_segments.append((item["text"], True)) else: raise ValueError(f"Unexpected content item: {item}") else: raise ValueError(f"Unexpected content item: {item}") flush_buf() - emit_text("\n", msg_idx, is_sampled=False) + emit_text("\n", msg_idx, is_sampled=False, is_content=False) else: content = self._render_content(raw_content).strip() - emit_text("\n" + content + "\n", msg_idx, is_sampled=False) + # ``\n`` + content + ``\n`` — body is the middle segment only. + # Single BPE pass over the joined text preserves boundary + # merges. + emit_text_segments( + [("\n", False), (content, True), ("\n", False)], + msg_idx, + is_sampled=False, + ) - emit_special(self._tool_response_end, msg_idx, is_sampled=False) + emit_special( + self._tool_response_end, msg_idx, is_sampled=False, is_content=False + ) if not next_is_tool: - emit_special(self._im_end, msg_idx, is_sampled=False) - emit_text("\n", msg_idx, is_sampled=False) + emit_special(self._im_end, msg_idx, is_sampled=False, is_content=False) + emit_text("\n", msg_idx, is_sampled=False, is_content=False) diff --git a/renderers/qwen3_vl.py b/renderers/qwen3_vl.py index 311ce26..94ae13d 100644 --- a/renderers/qwen3_vl.py +++ b/renderers/qwen3_vl.py @@ -42,6 +42,7 @@ PlaceholderRange, RenderedTokens, ToolSpec, + attribute_text_segments, reject_assistant_in_extension, trim_to_turn_close, ) @@ -177,14 +178,31 @@ class _Emitter: value than the current buffer triggers a flush first — split points are always at the ``is_sampled`` boundary, which the caller is expected to place at a ``\\n`` boundary so BPE merges don't drift. + + ``is_content`` is the per-token body/scaffold attribution. Within a + single flush adjacent text fragments may carry different + ``is_content`` values (e.g. ``"user\\n"`` scaffold + caller content + body): the buffer stores fragments as a list of + ``(text, is_content)`` segments and flushes via + :func:`attribute_text_segments`, which performs one BPE pass over + the joined text and assigns per-token is_content from each token's + source segment. When every segment in a flush shares the same + is_content (the common case for sampled assistant body / pure + scaffold) the fast path of a single ``encode()`` call is used and + no offset-tokenizer lookup is required. """ - def __init__(self, encode_fn, msg_idx: int = -1): + def __init__(self, encode_fn, tokenizer=None, msg_idx: int = -1): self._encode = encode_fn + self._tokenizer = tokenizer self.token_ids: list[int] = [] self.message_indices: list[int] = [] self.sampled: list[bool] = [] - self._buf: str = "" + self.is_content: list[bool] = [] + # Buffered text fragments as ``(text, is_content)`` tuples. All + # fragments share a single ``_buf_sampled`` / ``_buf_idx``; + # changing either of those triggers a flush. + self._segments: list[tuple[str, bool]] = [] self._buf_idx: int = msg_idx self._buf_sampled: bool = False self.msg_idx = msg_idx @@ -194,49 +212,78 @@ def set_msg_idx(self, msg_idx: int) -> None: # text doesn't get glued to the previous one's BPE context. # In practice messages are always separated by an <|im_end|> # special token, which already flushes — but be defensive. - if self._buf: + if self._segments: self._flush() self.msg_idx = msg_idx self._buf_idx = msg_idx - def text(self, text: str, *, is_sampled: bool) -> None: + def text(self, text: str, *, is_sampled: bool, is_content: bool) -> None: if not text: return # Adjacent text under different msg_idx or is_sampled is rare in # this template — but flush at those boundaries so attribution - # and the sampled signal stay accurate. - if self._buf and ( + # and the sampled signal stay accurate. is_content boundaries do + # NOT force a flush: they're carried through the joined BPE pass + # via :func:`attribute_text_segments`, preserving merges across + # the wrap/body boundary. + if self._segments and ( self._buf_idx != self.msg_idx or self._buf_sampled != is_sampled ): self._flush() - if not self._buf: + if not self._segments: self._buf_idx = self.msg_idx self._buf_sampled = is_sampled - self._buf += text + self._segments.append((text, is_content)) - def special(self, token_id: int, *, is_sampled: bool) -> None: - if self._buf: + def special(self, token_id: int, *, is_sampled: bool, is_content: bool) -> None: + if self._segments: self._flush() self.token_ids.append(token_id) self.message_indices.append(self.msg_idx) self.sampled.append(is_sampled) + self.is_content.append(is_content) def cursor(self) -> int: """Current token offset after flushing — used to anchor placeholder ranges.""" - if self._buf: + if self._segments: self._flush() return len(self.token_ids) def finalize(self) -> None: - if self._buf: + if self._segments: self._flush() def _flush(self) -> None: - ids = self._encode(self._buf) - self.token_ids.extend(ids) - self.message_indices.extend([self._buf_idx] * len(ids)) - self.sampled.extend([self._buf_sampled] * len(ids)) - self._buf = "" + segments = self._segments + self._segments = [] + if not segments: + return + # Fast path: every segment shares the same is_content — use the + # plain ``encode()`` call so we don't pay for the offset + # tokenizer. This is the common case (pure scaffold flushes, or + # pure body flushes). + first_ic = segments[0][1] + all_same = all(ic == first_ic for _, ic in segments) + if all_same: + joined = "".join(text for text, _ in segments) + ids = self._encode(joined) + self.token_ids.extend(ids) + self.message_indices.extend([self._buf_idx] * len(ids)) + self.sampled.extend([self._buf_sampled] * len(ids)) + self.is_content.extend([first_ic] * len(ids)) + return + # Mixed body/scaffold flush — encode once and attribute back to + # each segment via the fast tokenizer's offset_mapping. Requires + # a tokenizer (not just the encode fn) to look up offsets. + assert self._tokenizer is not None, ( + "_Emitter mixed-is_content flush requires a tokenizer; " + "pass one to the constructor." + ) + for tok_id, is_content in attribute_text_segments(self._tokenizer, segments): + self.token_ids.append(tok_id) + self.message_indices.append(self._buf_idx) + self.sampled.append(self._buf_sampled) + self.is_content.append(is_content) class Qwen3VLRenderer: @@ -401,7 +448,7 @@ def render( if not messages: raise ValueError("No messages provided.") - em = _Emitter(self._encode) + em = _Emitter(self._encode, tokenizer=self._tokenizer) mm_hashes: dict[str, list[str]] = {} mm_placeholders: dict[str, list[PlaceholderRange]] = {} mm_items: dict[str, list[dict[str, Any]]] = {} @@ -409,13 +456,17 @@ def render( def emit_image(part: dict[str, Any]) -> None: # Image placeholders are prompt-side scaffolding the user # message attaches — the model never samples ``<|vision_start|>`` - # / ``<|image_pad|>`` / ``<|vision_end|>``. + # / ``<|image_pad|>`` / ``<|vision_end|>``. The + # ``<|image_pad|>`` placeholders represent caller-provided + # image data, so they ARE body content (is_content=True); + # the surrounding ``<|vision_start|>`` / ``<|vision_end|>`` + # markers are renderer-emitted scaffold. _, out, n, h = self._process_image(part) - em.special(self._vision_start, is_sampled=False) + em.special(self._vision_start, is_sampled=False, is_content=False) offset = em.cursor() for _ in range(n): - em.special(self._image_pad, is_sampled=False) - em.special(self._vision_end, is_sampled=False) + em.special(self._image_pad, is_sampled=False, is_content=True) + em.special(self._vision_end, is_sampled=False, is_content=False) mm_hashes.setdefault("image", []).append(h) mm_placeholders.setdefault("image", []).append( PlaceholderRange(offset=offset, length=n) @@ -432,16 +483,24 @@ def render_media_content(content: Any) -> None: User / tool content is conversation context the model never samples — every text fragment goes in as ``is_sampled=False``. + Text from the caller IS the message body, so every text + fragment is ``is_content=True``; the vision-marker specials + around image_pad placeholders are scaffold (handled in + :func:`emit_image`). """ if isinstance(content, str): - em.text(content, is_sampled=False) + em.text(content, is_sampled=False, is_content=True) return if not isinstance(content, list): - em.text(self._render_text_content(content), is_sampled=False) + em.text( + self._render_text_content(content), + is_sampled=False, + is_content=True, + ) return for item in content: if isinstance(item, str): - em.text(item, is_sampled=False) + em.text(item, is_sampled=False, is_content=True) elif isinstance(item, dict): if _is_image_part(item): emit_image(item) @@ -450,7 +509,7 @@ def render_media_content(content: Any) -> None: "Video parts are not yet supported by Qwen3VLRenderer." ) elif "text" in item: - em.text(item["text"], is_sampled=False) + em.text(item["text"], is_sampled=False, is_content=True) # ── 1. System + tools ─────────────────────────────────────── first_is_system = messages[0].get("role") == "system" @@ -458,26 +517,37 @@ def render_media_content(content: Any) -> None: if tools: sys_idx = 0 if first_is_system else -1 em.set_msg_idx(sys_idx) - em.special(self._im_start, is_sampled=False) - buf = "system\n" + em.special(self._im_start, is_sampled=False, is_content=False) + # Body = system content (if any). Everything else in this + # block — role tag, tools header / footer, JSON tool specs — + # is scaffold. The tools dict is recoverable from the + # ``tools`` argument; we don't re-attribute its embedded + # JSON as message body. + em.text("system\n", is_sampled=False, is_content=False) if first_is_system: - buf += self._render_text_content(messages[0].get("content")) + "\n\n" - buf += _TOOLS_HEADER + sys_content = self._render_text_content(messages[0].get("content")) + if sys_content: + em.text(sys_content, is_sampled=False, is_content=True) + em.text("\n\n", is_sampled=False, is_content=False) + em.text(_TOOLS_HEADER, is_sampled=False, is_content=False) for tool in tools: - buf += "\n" + json.dumps(tool, ensure_ascii=False) - buf += _TOOLS_FOOTER - em.text(buf, is_sampled=False) - em.special(self._im_end, is_sampled=False) - em.text("\n", is_sampled=False) + em.text( + "\n" + json.dumps(tool, ensure_ascii=False), + is_sampled=False, + is_content=False, + ) + em.text(_TOOLS_FOOTER, is_sampled=False, is_content=False) + em.special(self._im_end, is_sampled=False, is_content=False) + em.text("\n", is_sampled=False, is_content=False) elif first_is_system: em.set_msg_idx(0) - em.special(self._im_start, is_sampled=False) - em.text( - "system\n" + self._render_text_content(messages[0].get("content")), - is_sampled=False, - ) - em.special(self._im_end, is_sampled=False) - em.text("\n", is_sampled=False) + em.special(self._im_start, is_sampled=False, is_content=False) + em.text("system\n", is_sampled=False, is_content=False) + sys_content = self._render_text_content(messages[0].get("content")) + if sys_content: + em.text(sys_content, is_sampled=False, is_content=True) + em.special(self._im_end, is_sampled=False, is_content=False) + em.text("\n", is_sampled=False, is_content=False) # ── 2. Iterate messages ───────────────────────────────────── for i, msg in enumerate(messages): @@ -489,11 +559,11 @@ def render_media_content(content: Any) -> None: em.set_msg_idx(i) if role == "user": - em.special(self._im_start, is_sampled=False) - em.text("user\n", is_sampled=False) + em.special(self._im_start, is_sampled=False, is_content=False) + em.text("user\n", is_sampled=False, is_content=False) render_media_content(msg.get("content")) - em.special(self._im_end, is_sampled=False) - em.text("\n", is_sampled=False) + em.special(self._im_end, is_sampled=False, is_content=False) + em.text("\n", is_sampled=False, is_content=False) elif role == "assistant": self._render_assistant(msg, em) @@ -507,8 +577,8 @@ def render_media_content(content: Any) -> None: # ── 3. Generation prompt ──────────────────────────────────── if add_generation_prompt: em.set_msg_idx(-1) - em.special(self._im_start, is_sampled=False) - em.text("assistant\n", is_sampled=False) + em.special(self._im_start, is_sampled=False, is_content=False) + em.text("assistant\n", is_sampled=False, is_content=False) em.finalize() @@ -524,6 +594,7 @@ def render_media_content(content: Any) -> None: token_ids=em.token_ids, message_indices=em.message_indices, sampled_mask=em.sampled, + is_content=em.is_content, message_roles=[m.get("role") or "" for m in messages], multi_modal_data=mm_data, ) @@ -595,18 +666,21 @@ def bridge_to_next_turn( # Bridge populates ``message_indices`` (relative to ``new_messages``) # and ``sampled_mask`` (uniformly ``False`` — every token the # bridge emits is template scaffolding for the next prompt, not - # something the model sampled). Downstream consumers can run - # :meth:`RenderedTokens.tokens_per_message` on the bridge output - # to get per-new-message token counts without re-rendering. - em = _Emitter(self._encode) + # something the model sampled). ``is_content`` follows the same + # rules as in :meth:`render` so consumers can walk the trajectory + # and read each step's own body mask. Downstream consumers can + # run :meth:`RenderedTokens.tokens_per_message` on the bridge + # output to get per-new-message token counts without re-rendering. + em = _Emitter(self._encode, tokenizer=self._tokenizer) # Seed the emitter with the prior turn's tokens so cursor() reports # absolute offsets in the combined sequence. Per-token attribution # for the prior portion is unknown to the bridge (it only has # prev_prompt_ids + prev_completion_ids as raw lists), so seed - # both side channels with the "no info" sentinel. + # all side channels with the "no info" sentinel. em.token_ids = list(previous_ids) em.message_indices = [-1] * len(previous_ids) em.sampled = [False] * len(previous_ids) + em.is_content = [False] * len(previous_ids) new_hashes: dict[str, list[str]] = {} new_placeholders: dict[str, list[PlaceholderRange]] = {} @@ -614,11 +688,11 @@ def bridge_to_next_turn( def emit_image(part: dict[str, Any]) -> None: _, out, n, h = self._process_image(part) - em.special(self._vision_start, is_sampled=False) + em.special(self._vision_start, is_sampled=False, is_content=False) offset = em.cursor() for _ in range(n): - em.special(self._image_pad, is_sampled=False) - em.special(self._vision_end, is_sampled=False) + em.special(self._image_pad, is_sampled=False, is_content=True) + em.special(self._vision_end, is_sampled=False, is_content=False) new_hashes.setdefault("image", []).append(h) new_placeholders.setdefault("image", []).append( PlaceholderRange(offset=offset, length=n) @@ -632,14 +706,18 @@ def emit_image(part: dict[str, Any]) -> None: def render_media_content(content: Any) -> None: if isinstance(content, str): - em.text(content, is_sampled=False) + em.text(content, is_sampled=False, is_content=True) return if not isinstance(content, list): - em.text(self._render_text_content(content), is_sampled=False) + em.text( + self._render_text_content(content), + is_sampled=False, + is_content=True, + ) return for item in content: if isinstance(item, str): - em.text(item, is_sampled=False) + em.text(item, is_sampled=False, is_content=True) elif isinstance(item, dict): if _is_image_part(item): emit_image(item) @@ -648,34 +726,34 @@ def render_media_content(content: Any) -> None: "Video parts are not yet supported by Qwen3VLRenderer." ) elif "text" in item: - em.text(item["text"], is_sampled=False) + em.text(item["text"], is_sampled=False, is_content=True) em.set_msg_idx(-1) - em.text("\n", is_sampled=False) + em.text("\n", is_sampled=False, is_content=False) for i, msg in enumerate(new_messages): role = msg.get("role") em.set_msg_idx(i) if role == "user": - em.special(self._im_start, is_sampled=False) - em.text("user\n", is_sampled=False) + em.special(self._im_start, is_sampled=False, is_content=False) + em.text("user\n", is_sampled=False, is_content=False) render_media_content(msg.get("content")) - em.special(self._im_end, is_sampled=False) - em.text("\n", is_sampled=False) + em.special(self._im_end, is_sampled=False, is_content=False) + em.text("\n", is_sampled=False, is_content=False) elif role == "system": - em.special(self._im_start, is_sampled=False) - em.text("system\n", is_sampled=False) + em.special(self._im_start, is_sampled=False, is_content=False) + em.text("system\n", is_sampled=False, is_content=False) render_media_content(msg.get("content")) - em.special(self._im_end, is_sampled=False) - em.text("\n", is_sampled=False) + em.special(self._im_end, is_sampled=False, is_content=False) + em.text("\n", is_sampled=False, is_content=False) elif role == "tool": self._render_tool(new_messages, i, em, render_media_content) else: return None em.set_msg_idx(-1) - em.special(self._im_start, is_sampled=False) - em.text("assistant\n", is_sampled=False) + em.special(self._im_start, is_sampled=False, is_content=False) + em.text("assistant\n", is_sampled=False, is_content=False) em.finalize() # Merge prev mm_data with the new turn's items. @@ -713,6 +791,7 @@ def render_media_content(content: Any) -> None: token_ids=em.token_ids, message_indices=em.message_indices, sampled_mask=em.sampled, + is_content=em.is_content, message_roles=[m.get("role") or "" for m in new_messages], multi_modal_data=mm_data, ) @@ -726,20 +805,22 @@ def _render_assistant(self, msg: Message, em: _Emitter) -> None: # at inference the chat template emits these as the generation # prompt and the model never samples them. Splitting the text # at the ``\n`` after the role tag is safe: Qwen3 BPE treats - # ``\n`` as a token boundary. - em.special(self._im_start, is_sampled=False) - em.text("assistant\n", is_sampled=False) + # ``\n`` as a token boundary. For the assistant role the + # invariant ``is_content == sampled_mask`` holds — every sampled + # token is body, every scaffold token isn't. + em.special(self._im_start, is_sampled=False, is_content=False) + em.text("assistant\n", is_sampled=False, is_content=False) # Body (content + tool calls) is the model-sampled portion. if not tool_calls: - em.text(content, is_sampled=True) + em.text(content, is_sampled=True, is_content=True) else: for tc_idx, tc in enumerate(tool_calls): if tc_idx == 0: separator = "\n" if original_content else "" - em.text(content + separator, is_sampled=True) + em.text(content + separator, is_sampled=True, is_content=True) else: - em.text("\n", is_sampled=True) + em.text("\n", is_sampled=True, is_content=True) func = tc.get("function") or tc name = func.get("name", "") @@ -750,18 +831,19 @@ def _render_assistant(self, msg: Message, em: _Emitter) -> None: else json.dumps(arguments, ensure_ascii=False) ) - em.special(self._tool_call, is_sampled=True) + em.special(self._tool_call, is_sampled=True, is_content=True) em.text( '\n{"name": "' + name + '", "arguments": ' + args_str + "}\n", is_sampled=True, + is_content=True, ) - em.special(self._tool_call_end, is_sampled=True) + em.special(self._tool_call_end, is_sampled=True, is_content=True) # ``<|im_end|>`` is the model's stop signal — it samples this to - # end its turn. The trailing ``\n`` is template-appended between - # turns and never sampled. - em.special(self._im_end, is_sampled=True) - em.text("\n", is_sampled=False) + # end its turn (and counts as part of its body). The trailing + # ``\n`` is template-appended between turns and never sampled. + em.special(self._im_end, is_sampled=True, is_content=True) + em.text("\n", is_sampled=False, is_content=False) def _render_tool( self, @@ -772,24 +854,27 @@ def _render_tool( ) -> None: # Tool messages are conversation history injected by the runtime # between assistant turns — the model never samples any of these - # tokens, so every emission is is_sampled=False. (render_media_content - # also stamps is_sampled=False on its emissions.) + # tokens, so every emission is is_sampled=False. The + # ``content`` body bytes get ``is_content=True`` (via + # ``render_media_content``); everything else — the + # ``<|im_start|>user`` wrap, inter-section ``\n``s, and the + # ``<|tool_response>`` specials — is scaffold. prev_is_tool = msg_idx > 0 and messages[msg_idx - 1]["role"] == "tool" next_is_tool = ( msg_idx + 1 < len(messages) and messages[msg_idx + 1]["role"] == "tool" ) if not prev_is_tool: - em.special(self._im_start, is_sampled=False) - em.text("user", is_sampled=False) + em.special(self._im_start, is_sampled=False, is_content=False) + em.text("user", is_sampled=False, is_content=False) - em.text("\n", is_sampled=False) - em.special(self._tool_response, is_sampled=False) - em.text("\n", is_sampled=False) + em.text("\n", is_sampled=False, is_content=False) + em.special(self._tool_response, is_sampled=False, is_content=False) + em.text("\n", is_sampled=False, is_content=False) render_media_content(messages[msg_idx].get("content")) - em.text("\n", is_sampled=False) - em.special(self._tool_response_end, is_sampled=False) + em.text("\n", is_sampled=False, is_content=False) + em.special(self._tool_response_end, is_sampled=False, is_content=False) if not next_is_tool: - em.special(self._im_end, is_sampled=False) - em.text("\n", is_sampled=False) + em.special(self._im_end, is_sampled=False, is_content=False) + em.text("\n", is_sampled=False, is_content=False) diff --git a/tests/test_client.py b/tests/test_client.py index 850f220..0695d78 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -19,7 +19,15 @@ def render(self, messages, *, tools=None, add_generation_prompt=False): assert messages == [{"role": "user", "content": "hi"}] assert tools == [{"type": "function", "function": {"name": "echo"}}] assert add_generation_prompt is True - return RenderedTokens(token_ids=[1, 2, 3]) + # Populate the full attribution surface so the test can verify + # ``generate`` threads it through to the result dict unchanged. + return RenderedTokens( + token_ids=[1, 2, 3], + message_indices=[0, 0, -1], + sampled_mask=[False, False, False], + is_content=[False, True, False], + message_roles=["user"], + ) def render_ids(self, messages, *, tools=None, add_generation_prompt=False): return self.render( @@ -137,6 +145,17 @@ def test_generate_builds_request_body_and_parses_response(): assert result["routed_experts"] == [[[1]], [[2]]] assert result["multi_modal_data"] is None assert result["request_id"] == "gen-test" + # Per-token attribution from the renderer surfaces on the result so + # downstream consumers (verifiers RendererClient → prime-rl) can + # build selective loss masks without a second render pass. + attr = result["prompt_attribution"] + assert attr is not None + assert isinstance(attr, RenderedTokens) + assert attr.token_ids == [1, 2, 3] + assert attr.is_content == [False, True, False] + assert attr.sampled_mask == [False, False, False] + assert attr.message_indices == [0, 0, -1] + assert attr.message_roles == ["user"] assert len(result["tool_calls"]) == 1 tc = result["tool_calls"][0] assert tc.name == "echo" @@ -207,6 +226,44 @@ def test_generate_uses_prebuilt_prompt_ids_without_rendering(): assert client.calls[0]["body"]["token_ids"] == [11, 12, 13] assert result["prompt_ids"] == [11, 12, 13] + # Pre-built prompt without explicit attribution → ``None`` carried + # through. Consumers fall back to whatever attribution-free path + # they have (e.g. uniform completion mask). + assert result["prompt_attribution"] is None + + +def test_generate_threads_prompt_attribution_through_prebuilt_prompt_path(): + """When the caller passes both ``prompt_ids`` and ``prompt_attribution`` + (the multi-turn bridge path in verifiers), ``generate`` must thread + the attribution through to the result dict unchanged — no re-rendering, + no per-token reshuffling. Lets downstream consumers carry the + renderer's body/scaffold cut into the trajectory step without an + extra render pass.""" + client = _FakeClient() + # Caller-supplied attribution; mirrors what + # ``RendererClient._get_incremental_prompt_ids`` returns from the + # bridge_to_next_turn output. + supplied = RenderedTokens( + token_ids=[11, 12, 13], + message_indices=[-1, 0, 0], + sampled_mask=[False, False, False], + is_content=[False, True, True], + message_roles=["tool"], + ) + + result = asyncio.run( + generate( + client=client, + renderer=_NoRenderRenderer(), + messages=[{"role": "user", "content": "hi"}], + model="test-model", + prompt_ids=[11, 12, 13], + prompt_attribution=supplied, + ) + ) + + # Exact passthrough — same object, no copy / no transform. + assert result["prompt_attribution"] is supplied # --------------------------------------------------------------------------- diff --git a/tests/test_is_content.py b/tests/test_is_content.py new file mode 100644 index 0000000..ddac1f5 --- /dev/null +++ b/tests/test_is_content.py @@ -0,0 +1,389 @@ +"""Per-token ``RenderedTokens.is_content`` invariants. + +``is_content[k]`` answers "does ``token_ids[k]`` come from message body +bytes (caller-provided content / tool_calls / reasoning_content, or +the model's sampled emission for the assistant role) or is it template +scaffolding the renderer added around the body (role tags, special +tokens, separators, tool-response wraps, generation prompt)?" + +By design ``is_content`` is a superset of ``sampled_mask``: + +- Equal on every token attributed to an assistant message (sampled == + body for that role by construction). +- Carries new information on every other role: the model never samples + user / tool / system tokens, so ``sampled_mask`` is uniformly + ``False`` over those — ``is_content`` differentiates body from wrap. + +These tests parametrise across every renderer in +``conftest.RENDERER_MODELS`` and verify the contract for hand-coded +renderers. ``DefaultRenderer`` leaves ``is_content`` empty by design +(the Jinja template is opaque, so the renderer cannot know the +wrap/body split) and is exempt from the populated-length check. + +The body-bytes decode invariant uses ``in`` (substring) rather than +strict equality because some renderers normalise whitespace, strip +trailing newlines, etc. within the message body emit — but the +caller-provided content must always be recoverable as a substring of +the decoded body run. +""" + +from __future__ import annotations + + +def _is_populated(rendered) -> bool: + return len(rendered.is_content) == len(rendered.token_ids) and bool( + rendered.is_content + ) + + +def test_is_content_length_or_empty(model_name, renderer): + """``is_content`` is either empty (opt-out) or matches token_ids + length exactly. No partial fills.""" + msgs = [ + {"role": "user", "content": "Hi"}, + {"role": "assistant", "content": "Hello!"}, + ] + rendered = renderer.render(msgs) + n_tokens = len(rendered.token_ids) + n_mask = len(rendered.is_content) + assert n_mask == 0 or n_mask == n_tokens, ( + f"{model_name}: is_content length {n_mask} must be 0 or match " + f"token_ids length {n_tokens}" + ) + + +def test_is_content_equals_sampled_on_assistant(model_name, renderer): + """On every token attributed to an assistant message, + ``is_content[k] == sampled_mask[k]``. The two signals collapse on + that role by design — the model's sampled output IS the assistant + message's body, and the surrounding scaffold is neither sampled + nor body.""" + msgs = [ + {"role": "system", "content": "You are helpful."}, + {"role": "user", "content": "Hi"}, + {"role": "assistant", "content": "Hello world!"}, + ] + rendered = renderer.render(msgs) + if not _is_populated(rendered): + return + if len(rendered.sampled_mask) != len(rendered.token_ids): + return # renderer opts out of sampled_mask + + mismatches = [] + for k, msg_idx in enumerate(rendered.message_indices): + if msg_idx < 0: + continue + if msgs[msg_idx].get("role") != "assistant": + continue + if rendered.is_content[k] != rendered.sampled_mask[k]: + mismatches.append((k, rendered.is_content[k], rendered.sampled_mask[k])) + assert not mismatches, ( + f"{model_name}: is_content != sampled_mask on assistant tokens " + f"(k, is_content, sampled): {mismatches[:8]}" + ) + + +def test_is_content_excludes_generation_prompt(model_name, renderer): + """All generation-prompt tokens (msg_idx=-1) are scaffold — the + next-turn opener the chat template injects so the model can + continue. ``is_content`` must be False over that entire span.""" + msgs = [{"role": "user", "content": "Hi"}] + rendered = renderer.render(msgs, add_generation_prompt=True) + if not _is_populated(rendered): + return + + bad = [ + k + for k, (msg_idx, is_content) in enumerate( + zip(rendered.message_indices, rendered.is_content) + ) + if msg_idx == -1 and is_content + ] + assert not bad, ( + f"{model_name}: generation-prompt tokens marked is_content=True " + f"at positions {bad[:8]}" + ) + + +def test_is_content_recovers_user_body(model_name, tokenizer, renderer): + """The decoded run of is_content=True tokens within a user message + contains the user's original content. ``in`` rather than equality + because some templates normalise whitespace inside the body emit; + the input substring must still be recoverable from the decoded + body run.""" + user_text = "Hello, my name is Sebastian." + msgs = [ + {"role": "user", "content": user_text}, + {"role": "assistant", "content": "Hi!"}, + ] + rendered = renderer.render(msgs) + if not _is_populated(rendered): + return + + user_body_ids = [ + tid + for tid, mi, ic in zip( + rendered.token_ids, rendered.message_indices, rendered.is_content + ) + if mi == 0 and ic + ] + assert user_body_ids, ( + f"{model_name}: no is_content=True tokens attributed to user message" + ) + decoded = tokenizer.decode(user_body_ids).strip() + assert user_text in decoded or decoded in user_text, ( + f"{model_name}: user body run decodes to {decoded!r}, " + f"expected to contain {user_text!r}" + ) + + +def test_is_content_recovers_tool_body(model_name, tokenizer, renderer): + """The decoded run of is_content=True tokens within a tool message + contains the tool response body. The whole point of the body/wrap + cut: SFT on this run trains the model to anticipate tool outputs + without learning to emit the surrounding ``<|tool_response>`` / + role-tag scaffold (which would interrupt a real rollout).""" + tool_text = "The capital of France is Paris." + msgs = [ + {"role": "user", "content": "What's the capital of France?"}, + { + "role": "assistant", + "content": "", + "tool_calls": [ + { + "type": "function", + "id": "call_1", + "function": { + "name": "lookup", + "arguments": {"q": "capital of France"}, + }, + } + ], + }, + {"role": "tool", "content": tool_text, "tool_call_id": "call_1"}, + {"role": "assistant", "content": "Paris."}, + ] + rendered = renderer.render(msgs) + if not _is_populated(rendered): + return + + tool_body_ids = [ + tid + for tid, mi, ic in zip( + rendered.token_ids, rendered.message_indices, rendered.is_content + ) + if mi == 2 and ic + ] + assert tool_body_ids, ( + f"{model_name}: no is_content=True tokens attributed to tool message" + ) + decoded = tokenizer.decode(tool_body_ids).strip() + assert tool_text in decoded, ( + f"{model_name}: tool body run decodes to {decoded!r}, " + f"expected to contain {tool_text!r}" + ) + + +def test_is_content_recovers_system_body(model_name, tokenizer, renderer): + """The decoded run of is_content=True tokens within a system + message contains the caller-provided system content. Tools header + / footer (if present) are scaffold and never appear as body.""" + sys_text = "You are an unusually precise assistant." + msgs = [ + {"role": "system", "content": sys_text}, + {"role": "user", "content": "Hi"}, + {"role": "assistant", "content": "Hi!"}, + ] + rendered = renderer.render(msgs) + if not _is_populated(rendered): + return + + sys_body_ids = [ + tid + for tid, mi, ic in zip( + rendered.token_ids, rendered.message_indices, rendered.is_content + ) + if mi == 0 and ic + ] + assert sys_body_ids, ( + f"{model_name}: no is_content=True tokens attributed to system message" + ) + decoded = tokenizer.decode(sys_body_ids).strip() + assert sys_text in decoded, ( + f"{model_name}: system body run decodes to {decoded!r}, " + f"expected to contain {sys_text!r}" + ) + + +def test_is_content_no_body_on_role_tag(model_name, renderer): + """The first token attributed to a user/system/tool message must + have ``is_content=False`` — that's the leading role-tag run + (``<|im_start|>`` / equivalent), which is template scaffold, never + body.""" + msgs = [ + {"role": "user", "content": "Hi"}, + {"role": "assistant", "content": "Hello!"}, + ] + rendered = renderer.render(msgs) + if not _is_populated(rendered): + return + + user_positions = [k for k, idx in enumerate(rendered.message_indices) if idx == 0] + assert user_positions, f"{model_name}: no tokens attributed to user message" + first_k = user_positions[0] + assert not rendered.is_content[first_k], ( + f"{model_name}: first user-attributed token at k={first_k} should " + f"be is_content=False (role-tag scaffolding), but was True" + ) + + +def test_content_token_spans_by_role_isolates_tool_body( + model_name, tokenizer, renderer +): + """``content_token_spans_by_role()["tool"]`` returns spans over + which every token is the tool message body. Joining the decoded + spans recovers the tool response. Adjacent scaffold tokens + (``<|tool_response>``, role-tag openers) are never inside any + returned span.""" + tool_text = "Result: 42" + msgs = [ + {"role": "user", "content": "What's 6*7?"}, + { + "role": "assistant", + "content": "", + "tool_calls": [ + { + "type": "function", + "id": "call_x", + "function": {"name": "calc", "arguments": {"e": "6*7"}}, + } + ], + }, + {"role": "tool", "content": tool_text, "tool_call_id": "call_x"}, + {"role": "assistant", "content": "42."}, + ] + rendered = renderer.render(msgs) + if not _is_populated(rendered): + return + + spans = rendered.content_token_spans_by_role() + tool_spans = spans.get("tool") or [] + assert tool_spans, f"{model_name}: no tool content spans returned" + + pieces: list[str] = [] + for s, e in tool_spans: + run_ids = rendered.token_ids[s:e] + # All tokens in the span must be is_content=True by definition. + for k in range(s, e): + assert rendered.is_content[k], ( + f"{model_name}: span {(s, e)} contains is_content=False at k={k}" + ) + pieces.append(tokenizer.decode(run_ids)) + joined = "".join(pieces).strip() + assert tool_text in joined, ( + f"{model_name}: joined tool spans decode to {joined!r}, " + f"expected to contain {tool_text!r}" + ) + + +def test_content_mask_for_roles_excludes_assistant_when_unset(model_name, renderer): + """``content_mask_for_roles({"tool"})`` returns a mask that's True + only on tool body tokens — never on assistant tokens, even though + those also have ``is_content=True``. The role filter is the whole + point: SFT-on-tool-body must not bleed into the assistant span.""" + msgs = [ + {"role": "user", "content": "Hi"}, + { + "role": "assistant", + "content": "", + "tool_calls": [ + { + "type": "function", + "id": "call_a", + "function": {"name": "ping", "arguments": {}}, + } + ], + }, + {"role": "tool", "content": "pong", "tool_call_id": "call_a"}, + {"role": "assistant", "content": "OK."}, + ] + rendered = renderer.render(msgs) + if not _is_populated(rendered): + return + + tool_mask = rendered.content_mask_for_roles({"tool"}) + assert len(tool_mask) == len(rendered.token_ids) + + for k, mi in enumerate(rendered.message_indices): + if mi < 0: + assert not tool_mask[k], ( + f"{model_name}: scaffold token k={k} (msg_idx=-1) marked in tool mask" + ) + continue + role = msgs[mi].get("role") + if role != "tool" and tool_mask[k]: + raise AssertionError( + f"{model_name}: tool-role mask True on a {role!r} token at k={k}" + ) + + +def test_build_training_sample_content_sft_roles_picks_up_tool_body( + model_name, renderer +): + """``build_training_sample(..., content_sft_roles={"tool"})`` + produces a loss mask that's True on tool body tokens AND assistant + sampled tokens, but False on the tool-message scaffold (role tag, + ``<|tool_response>`` wraps, separators). The canonical + SFT-on-tool-body + RL-on-assistant composition.""" + from renderers import build_training_sample + + msgs = [ + {"role": "user", "content": "Hi"}, + { + "role": "assistant", + "content": "", + "tool_calls": [ + { + "type": "function", + "id": "call_z", + "function": {"name": "noop", "arguments": {}}, + } + ], + }, + {"role": "tool", "content": "done", "tool_call_id": "call_z"}, + {"role": "assistant", "content": "OK."}, + ] + ids, mask = build_training_sample( + renderer, + msgs, + role_to_mask=lambda m: m["role"] == "assistant", + content_sft_roles={"tool"}, + ) + assert len(mask) == len(ids) + + # We need at least one trainable tool-body token if the renderer + # populates is_content. Renderers that opt out (empty is_content) + # fall back to the existing role_to_mask behaviour, which leaves + # tool tokens False — that's the documented fallback and is fine. + rendered = renderer.render(msgs) + if not _is_populated(rendered): + return + + trainable_per_role: dict[str, int] = {} + for k, mi in enumerate(rendered.message_indices): + if mi < 0: + continue + if mask[k]: + r = msgs[mi].get("role") or "" + trainable_per_role[r] = trainable_per_role.get(r, 0) + 1 + + assert trainable_per_role.get("tool", 0) > 0, ( + f"{model_name}: build_training_sample with content_sft_roles={{'tool'}} " + f"trained on zero tool tokens" + ) + assert trainable_per_role.get("assistant", 0) > 0, ( + f"{model_name}: assistant tokens dropped from training mask" + ) + assert trainable_per_role.get("user", 0) == 0, ( + f"{model_name}: user tokens leaked into training mask: {trainable_per_role}" + )