diff --git a/renderers/__init__.py b/renderers/__init__.py
index 60cc271..7c82510 100644
--- a/renderers/__init__.py
+++ b/renderers/__init__.py
@@ -28,6 +28,7 @@
ToolCallParseStatus,
ToolSpec,
VideoPart,
+ attribute_text_segments,
build_training_sample,
build_trajectory_step,
create_renderer,
@@ -90,6 +91,7 @@
"ToolSpec",
"VideoPart",
"__version__",
+ "attribute_text_segments",
"build_training_sample",
"build_trajectory_step",
"create_renderer",
diff --git a/renderers/base.py b/renderers/base.py
index 2da9dbf..defb267 100644
--- a/renderers/base.py
+++ b/renderers/base.py
@@ -169,6 +169,32 @@ class RenderedTokens:
masking. ``DefaultRenderer`` leaves it empty because the Jinja
template is opaque; hand-coded renderers populate it.
+ ``is_content`` is a per-token signal generalizing the "scaffold vs
+ body" distinction across all roles: ``True`` iff the token was
+ produced from message-body bytes (caller-provided ``content`` /
+ ``tool_calls`` / ``reasoning_content``, or the model's sampled
+ emission for the assistant role), ``False`` iff it is template
+ scaffolding the renderer added around message bodies — role-tag
+ openers, closers when not model-sampled, inter-turn separators,
+ tool-response wraps, the tools-header block, the generation prompt.
+ Generalises ``sampled_mask``: where ``sampled_mask`` answers "would
+ the model emit this?" (useful for assistant tokens; uniformly
+ ``False`` elsewhere), ``is_content`` answers "is this from caller
+ or model data?" (meaningful on every role). By construction
+ ``is_content[k] == sampled_mask[k]`` over every token attributed to
+ an assistant message; on other roles ``is_content`` carries new
+ information that ``sampled_mask`` does not.
+
+ The use case: SFT on tool response bodies while applying RL only to
+ assistant tokens. The trainer wants the model to anticipate tool
+ outputs but never to emit ``<|tool_response>`` itself (that would
+ interrupt the rollout), so the SFT loss mask is
+ ``message_role == "tool" AND is_content``.
+
+ Empty ``is_content`` (``[]``) — like ``sampled_mask`` — means the
+ renderer doesn't provide the signal. ``DefaultRenderer`` leaves it
+ empty for the same reason.
+
``multi_modal_data`` is populated by multimodal renderers (e.g.
``Qwen3VLRenderer``) when image / video content parts are present;
text-only renderers leave it as ``None``.
@@ -177,6 +203,7 @@ class RenderedTokens:
token_ids: list[int] = field(default_factory=list)
message_indices: list[int] = field(default_factory=list)
sampled_mask: list[bool] = field(default_factory=list)
+ is_content: list[bool] = field(default_factory=list)
message_roles: list[str] = field(default_factory=list)
multi_modal_data: "MultiModalData | None" = None
@@ -333,6 +360,94 @@ def tokens_by_role(self, *, sampled_only: bool = False) -> dict[str, int]:
out[role] = out.get(role, 0) + n
return out
+ def content_token_spans_by_role(self) -> dict[str, list[tuple[int, int]]]:
+ """Per-role spans of contiguous body-only tokens (``is_content=True``).
+
+ Maps each role appearing in :attr:`message_roles` to a list of
+ half-open ``[start, end)`` slices into :attr:`token_ids` over
+ which every token satisfies ``is_content=True`` AND belongs to
+ a message of that role. Spans never cross message boundaries:
+ a tool message contributes its own runs; an immediately
+ adjacent assistant message contributes separate runs even when
+ the bodies abut on the token axis.
+
+ Returns an empty dict when :attr:`is_content` or
+ :attr:`message_roles` is empty (renderer didn't populate the
+ signal — e.g. ``DefaultRenderer``).
+
+ Intended for selective loss masking: SFT on tool response
+ bodies while RL acts only on assistant turns is the canonical
+ case::
+
+ spans = rendered.content_token_spans_by_role()
+ tool_sft_mask = [False] * len(rendered.token_ids)
+ for s, e in spans.get("tool", []):
+ for k in range(s, e):
+ tool_sft_mask[k] = True
+
+ See also :meth:`content_mask_for_roles` for the same
+ computation returned as a per-token bool list.
+ """
+ out: dict[str, list[tuple[int, int]]] = {}
+ if not self.is_content or not self.message_roles:
+ return out
+ n = len(self.token_ids)
+ if len(self.is_content) != n or len(self.message_indices) != n:
+ return out
+
+ msg_spans = self.message_token_spans()
+ for role, span in zip(self.message_roles, msg_spans):
+ bucket = out.setdefault(role, [])
+ if span is None:
+ continue
+ start, end = span
+ run_start: int | None = None
+ for k in range(start, end):
+ if self.is_content[k]:
+ if run_start is None:
+ run_start = k
+ else:
+ if run_start is not None:
+ bucket.append((run_start, k))
+ run_start = None
+ if run_start is not None:
+ bucket.append((run_start, end))
+ return out
+
+ def content_mask_for_roles(self, roles: "set[str] | frozenset[str]") -> list[bool]:
+ """Per-token bool list: ``True`` iff the token is body of a
+ message whose role is in ``roles``.
+
+ Length matches :attr:`token_ids`. Returns an all-``False``
+ list of that length when :attr:`is_content` or
+ :attr:`message_roles` is empty — consumers can AND this with
+ their own attribution masks without length checks.
+
+ ``role_to_mask`` style helpers in :func:`build_training_sample`
+ cover the trainable-role question; this one covers the
+ complementary "body-only" question. The two compose: SFT mask
+ on tool body is
+ ``rendered.content_mask_for_roles({"tool"})``; RL mask on
+ assistant tokens stays
+ ``[s and (mi >= 0 and rendered.message_roles[mi] == "assistant")
+ for s, mi in zip(rendered.sampled_mask, rendered.message_indices)]``.
+ """
+ n = len(self.token_ids)
+ mask = [False] * n
+ if not self.is_content or not self.message_roles:
+ return mask
+ if len(self.is_content) != n or len(self.message_indices) != n:
+ return mask
+
+ for k, msg_idx in enumerate(self.message_indices):
+ if msg_idx < 0:
+ continue
+ if msg_idx >= len(self.message_roles):
+ continue
+ if self.message_roles[msg_idx] in roles and self.is_content[k]:
+ mask[k] = True
+ return mask
+
class ToolCallParseStatus(str, enum.Enum):
"""Per-attempt outcome of parsing a single ```` block.
@@ -530,6 +645,15 @@ def bridge_to_next_turn(
caller needs that distinction for the prior portion, they
have it directly: every token in ``prev_completion_ids`` was
sampled; every token in ``prev_prompt_ids`` was not.
+ - ``is_content`` mirrors ``sampled_mask``'s scheme for the
+ prior portion (uniformly ``False`` — body-vs-wrap
+ attribution can't be recovered from raw token ids), and on
+ the bridge-added portion the renderer populates it the same
+ way as in :meth:`render`: ``True`` over the body bytes of
+ each new message, ``False`` over the surrounding scaffold.
+ Consumers walk the trajectory and read each step's own
+ ``is_content`` for full-conversation body masks; the bridge
+ output covers only the *new* tokens this turn adds.
Text-only renderers return :class:`RenderedTokens` with
``multi_modal_data=None``. Multimodal renderers (see
@@ -1208,6 +1332,7 @@ def build_training_sample(
*,
role_to_mask: Callable[[Message], bool],
tools: list[ToolSpec] | None = None,
+ content_sft_roles: "set[str] | frozenset[str] | None" = None,
) -> tuple[list[int], list[bool]]:
"""Build (token_ids, loss_mask) for supervised training.
@@ -1223,17 +1348,53 @@ def build_training_sample(
back to attribution-only masking — every token attributed to a
trainable role is trained on, including template-injected
``<|im_start|>role\\n`` openers.
+
+ ``content_sft_roles`` opts in additional roles for "body-only"
+ supervision: for every message whose role is in this set, tokens
+ with ``is_content=True`` are marked trainable even though the
+ ``sampled_mask`` gate excludes them (the model never samples
+ tool / user / system tokens). Template scaffolding around those
+ messages — ``<|im_start|>role\\n`` openers, ``<|im_end|>``
+ closers, ``<|tool_response>`` wraps, inter-turn ``\\n`` — stays
+ masked out, so the model learns to anticipate the body text
+ without producing the surrounding special tokens (which would
+ interrupt a real rollout). The canonical use case is RL on
+ assistant tokens (``role_to_mask=lambda m: m["role"] ==
+ "assistant"``) plus SFT on tool response bodies
+ (``content_sft_roles={"tool"}``).
+
+ Requires the renderer to populate ``is_content`` for the body-only
+ path to fire. Renderers that leave it empty (``DefaultRenderer``,
+ or hand-coded renderers that haven't been wired up yet) ignore
+ ``content_sft_roles`` silently — falling back to the original
+ ``role_to_mask`` + ``sampled_mask`` behaviour.
"""
rendered = renderer.render(messages, tools=tools)
has_sampled_info = len(rendered.sampled_mask) == len(rendered.token_ids)
+ has_content_info = len(rendered.is_content) == len(rendered.token_ids)
+ body_roles: "frozenset[str]"
+ if content_sft_roles and has_content_info:
+ body_roles = frozenset(content_sft_roles)
+ else:
+ body_roles = frozenset()
+
loss_mask: list[bool] = []
for k, msg_idx in enumerate(rendered.message_indices):
if msg_idx < 0:
loss_mask.append(False)
- elif has_sampled_info and not rendered.sampled_mask[k]:
+ continue
+ msg = messages[msg_idx]
+ # Body-only path for opt-in roles. Fires only on tokens whose
+ # is_content bit is set; never adds the scaffolding around the
+ # message, so the model isn't supervised on emitting the role
+ # tags / wraps that would derail a rollout.
+ if body_roles and msg.get("role") in body_roles:
+ loss_mask.append(rendered.is_content[k])
+ continue
+ if has_sampled_info and not rendered.sampled_mask[k]:
loss_mask.append(False)
else:
- loss_mask.append(role_to_mask(messages[msg_idx]))
+ loss_mask.append(role_to_mask(msg))
return rendered.token_ids, loss_mask
@@ -1280,6 +1441,157 @@ def trim_to_turn_close(
return previous_ids
+# Per-model offset-aware tokenizer cache. ``attribute_text_segments``
+# uses the fast HuggingFace tokenizer's ``offset_mapping`` to attribute
+# each token to its source text segment under one BPE pass. Fastokens
+# (the Rust BPE we patch in by default for ~10x faster encode) does not
+# track character offsets — the patched tokenizer's
+# ``return_offsets_mapping=True`` raises ``NotImplementedError``. So we
+# keep a parallel vanilla tokenizer per model purely for offset queries.
+# Memory cost is one extra tokenizer per *unique* model name across all
+# pools / renderers (the cache is process-global), independent of pool
+# size.
+_offset_tokenizers: dict[str, Any] = {}
+_offset_tokenizers_lock = threading.Lock()
+
+
+def _get_offset_tokenizer(tokenizer):
+ """Return a tokenizer that supports ``return_offsets_mapping=True``.
+
+ If ``tokenizer`` itself supports offsets, returns it unchanged.
+ Otherwise loads a vanilla (non-fastokens) tokenizer from
+ ``tokenizer.name_or_path`` and caches it. Raises if the tokenizer
+ has no usable ``name_or_path`` — hand-coded renderers always pass
+ a tokenizer loaded via ``load_tokenizer`` which does set it.
+ """
+ # Cheap probe: does this tokenizer already provide offsets?
+ try:
+ tokenizer("a", add_special_tokens=False, return_offsets_mapping=True)
+ return tokenizer
+ except (NotImplementedError, ValueError, TypeError):
+ pass
+
+ name_or_path = getattr(tokenizer, "name_or_path", "")
+ if not name_or_path:
+ raise RuntimeError(
+ "Cannot construct an offset-aware tokenizer: the supplied "
+ "tokenizer has no ``name_or_path`` to fall back on. Pass a "
+ "tokenizer loaded via ``renderers.base.load_tokenizer``."
+ )
+
+ with _offset_tokenizers_lock:
+ cached = _offset_tokenizers.get(name_or_path)
+ if cached is not None:
+ return cached
+ from transformers import AutoTokenizer
+
+ kwargs: dict[str, Any] = {}
+ revision = TRUSTED_REVISIONS.get(name_or_path)
+ if revision is not None:
+ kwargs = {"trust_remote_code": True, "revision": revision}
+ else:
+ kwargs = {"trust_remote_code": False}
+ # Explicitly vanilla — we want HF's Rust tokenizer with offset
+ # tracking, not the fastokens shim. ``load_tokenizer`` would
+ # patch fastokens in by default; calling
+ # ``AutoTokenizer.from_pretrained`` directly here keeps the
+ # fastokens patch out of this code path entirely.
+ offset_tok = AutoTokenizer.from_pretrained(name_or_path, **kwargs)
+ if not getattr(offset_tok, "is_fast", False):
+ raise RuntimeError(
+ f"Vanilla tokenizer for {name_or_path!r} is not a fast "
+ "tokenizer; offset_mapping is unavailable. Hand-coded "
+ "renderers require a fast tokenizer for body/scaffold "
+ "attribution."
+ )
+ _offset_tokenizers[name_or_path] = offset_tok
+ return offset_tok
+
+
+def attribute_text_segments(
+ tokenizer,
+ segments: "list[tuple[str, bool]]",
+) -> "list[tuple[int, bool]]":
+ """Tokenize concatenated segments as a single BPE pass and return
+ ``(token_id, is_content)`` pairs.
+
+ ``segments`` is a list of ``(text, is_content)`` chunks the renderer
+ wants to emit contiguously — for example ``[("user\\n", False),
+ (content, True)]`` for a user message. Concatenation is done before
+ encoding to preserve BPE merges across the wrap/body boundary; the
+ resulting tokens are then attributed back to their source segment
+ via the fast tokenizer's ``offset_mapping``.
+
+ A token is attributed to the segment containing its first source
+ character (``offset_mapping[k][0]``). Tokens whose first character
+ falls exactly on a segment boundary are attributed to the segment
+ that *starts* at that offset (the "later" segment). Zero-length
+ tokens (rare; usually pre-tokenizer artefacts) are attributed to
+ the most recently entered segment.
+
+ Requires a HuggingFace fast tokenizer with offset tracking. The
+ ``fastokens`` patch ``load_tokenizer`` applies by default does
+ **not** track offsets — when that's the case we transparently load
+ a vanilla offset-capable tokenizer for the same model and cache it
+ (see :func:`_get_offset_tokenizer`). Hand-coded renderers are only
+ registered for model families that ship a fast tokenizer, so a
+ silent slow-tokenizer fallback isn't supported — BPE drift at the
+ wrap/body boundary would defeat the whole point.
+
+ Empty input or empty joined text returns an empty list.
+ """
+ if not segments:
+ return []
+ full_text = "".join(text for text, _ in segments)
+ if not full_text:
+ return []
+
+ offset_tokenizer = _get_offset_tokenizer(tokenizer)
+ encoding = offset_tokenizer(
+ full_text,
+ add_special_tokens=False,
+ return_offsets_mapping=True,
+ )
+ token_ids = list(encoding["input_ids"])
+ offsets = list(encoding["offset_mapping"])
+
+ # Build segment char-span lookup. Track the half-open span
+ # [seg_start, seg_end) of each segment and its is_content bit.
+ spans: list[tuple[int, int, bool]] = []
+ pos = 0
+ for text, is_content in segments:
+ spans.append((pos, pos + len(text), is_content))
+ pos += len(text)
+ total_len = pos
+
+ out: list[tuple[int, bool]] = []
+ last_is_content = spans[-1][2] if spans else False
+ for tok_id, (start, _end) in zip(token_ids, offsets):
+ if start >= total_len:
+ # Token's character offset is past every segment (shouldn't
+ # normally happen for add_special_tokens=False, but defensive
+ # against tokenizer-specific edge cases).
+ out.append((tok_id, last_is_content))
+ continue
+ # Find the segment that contains `start`. Segments are
+ # contiguous and ordered, so a linear scan is fine — the inner
+ # loop runs at most len(segments) times per token and segments
+ # is typically 2-3 in practice.
+ is_content = last_is_content
+ for seg_start, seg_end, seg_is_content in spans:
+ if seg_start <= start < seg_end:
+ is_content = seg_is_content
+ break
+ else:
+ # start == total_len handled above; the remaining case is
+ # an empty segment in the middle. Empty segments emit no
+ # characters, so no token can land in them; fall through to
+ # the last non-empty segment's bit.
+ pass
+ out.append((tok_id, is_content))
+ return out
+
+
def reject_assistant_in_extension(new_messages: list[Message]) -> bool:
"""Return True if any message in ``new_messages`` is an assistant turn.
diff --git a/renderers/client.py b/renderers/client.py
index ae1722f..576e5c2 100644
--- a/renderers/client.py
+++ b/renderers/client.py
@@ -23,6 +23,7 @@
from renderers.base import (
Message,
MultiModalData,
+ RenderedTokens,
Renderer,
RendererPool,
ToolCallParseStatus,
@@ -127,6 +128,7 @@ async def generate(
model: str,
prompt_ids: list[int] | None = None,
multi_modal_data: MultiModalData | None = None,
+ prompt_attribution: RenderedTokens | None = None,
tools: list[ToolSpec] | None = None,
sampling_params: dict[str, Any] | None = None,
cache_salt: str | None = None,
@@ -141,7 +143,11 @@ async def generate(
renderer) and ``logprobs=1`` (we always emit completion_logprobs). Pass
``prompt_ids`` to skip rendering and use a prebuilt token sequence —
pair it with ``multi_modal_data`` when the prebuilt prompt has image /
- video placeholders that need engine-side mm payload.
+ video placeholders that need engine-side mm payload, and with
+ ``prompt_attribution`` (a :class:`RenderedTokens` whose ``token_ids``
+ match the passed-in ``prompt_ids``) to carry the renderer's per-token
+ attribution (``is_content`` / ``sampled_mask`` / ``message_indices`` /
+ ``message_roles``) into the result without re-rendering.
For multimodal renderers (e.g. ``Qwen3VLRenderer``), the call goes
through ``renderer.render(...)`` to recover the ``multi_modal_data``
@@ -161,7 +167,19 @@ async def generate(
Returns a dict with: request_id, prompt_ids, completion_ids,
completion_logprobs, content, reasoning_content, tool_calls,
- finish_reason, routed_experts.
+ finish_reason, routed_experts, multi_modal_data, prompt_attribution.
+
+ ``prompt_attribution`` is the renderer's :class:`RenderedTokens` for
+ the prompt — either the one this call computed via
+ ``renderer.render(...)`` or the one the caller threaded in alongside
+ ``prompt_ids``. Carries ``token_ids``, ``message_indices``,
+ ``sampled_mask``, ``is_content``, ``message_roles``, and
+ ``multi_modal_data``, so downstream consumers (verifiers
+ ``RendererClient`` → prime-rl) can build per-token loss masks
+ (``content_mask_for_roles({"tool"})`` for SFT-on-tool-body,
+ ``sampled_mask`` for RL trainable spans) without a second render
+ pass. ``None`` when the caller passed pre-built ``prompt_ids``
+ without attribution.
"""
if tools and not getattr(renderer, "supports_tools", True):
raise ValueError(
@@ -171,15 +189,26 @@ async def generate(
def _prepare():
if prompt_ids is not None:
- return list(prompt_ids), renderer.get_stop_token_ids(), multi_modal_data
+ # Caller-supplied prompt; if they also gave us pre-computed
+ # attribution (e.g. the bridge path in verifiers), thread it
+ # through unchanged.
+ return (
+ list(prompt_ids),
+ renderer.get_stop_token_ids(),
+ multi_modal_data,
+ prompt_attribution,
+ )
rendered = renderer.render(messages, tools=tools, add_generation_prompt=True)
return (
rendered.token_ids,
renderer.get_stop_token_ids(),
rendered.multi_modal_data,
+ rendered,
)
- prompt_ids, stop_token_ids, mm_data = await _maybe_offload(renderer, _prepare)
+ prompt_ids, stop_token_ids, mm_data, prompt_attr = await _maybe_offload(
+ renderer, _prepare
+ )
if max_prompt_len is None:
max_prompt_len = await _resolve_max_prompt_len(client, model)
@@ -279,6 +308,14 @@ def _prepare():
# callers can persist it on the trajectory step for downstream
# multi-turn bridging and training-sample construction.
"multi_modal_data": mm_data,
+ # The renderer's per-token attribution for the prompt — either
+ # the RenderedTokens computed here via renderer.render(...) or
+ # the one threaded in by the caller alongside prompt_ids (the
+ # bridge path). Lets downstream consumers (verifiers
+ # RendererClient → prime-rl) build SFT-on-tool-body and other
+ # selective loss masks without a second render pass. ``None``
+ # when the caller passed prompt_ids without attribution.
+ "prompt_attribution": prompt_attr,
}
diff --git a/renderers/deepseek_v3.py b/renderers/deepseek_v3.py
index b3c843b..507d81d 100644
--- a/renderers/deepseek_v3.py
+++ b/renderers/deepseek_v3.py
@@ -21,6 +21,7 @@
ParsedResponse,
RenderedTokens,
ToolSpec,
+ attribute_text_segments,
reject_assistant_in_extension,
trim_to_turn_close,
)
@@ -114,22 +115,47 @@ def render(
tokens: list[int] = []
indices: list[int] = []
sampled: list[bool] = []
+ content_mask: list[bool] = []
- def emit_ids(ids: list[int], msg_idx: int, *, is_sampled: bool) -> None:
+ def emit_ids(
+ ids: list[int], msg_idx: int, *, is_sampled: bool, is_content: bool
+ ) -> None:
tokens.extend(ids)
indices.extend([msg_idx] * len(ids))
sampled.extend([is_sampled] * len(ids))
+ content_mask.extend([is_content] * len(ids))
- def emit_special(token_id: int, msg_idx: int, *, is_sampled: bool) -> None:
+ def emit_special(
+ token_id: int, msg_idx: int, *, is_sampled: bool, is_content: bool
+ ) -> None:
tokens.append(token_id)
indices.append(msg_idx)
sampled.append(is_sampled)
+ content_mask.append(is_content)
- def emit_text(text: str, msg_idx: int, *, is_sampled: bool) -> None:
- emit_ids(self._encode(text), msg_idx, is_sampled=is_sampled)
+ def emit_text(
+ text: str, msg_idx: int, *, is_sampled: bool, is_content: bool
+ ) -> None:
+ emit_ids(
+ self._encode(text),
+ msg_idx,
+ is_sampled=is_sampled,
+ is_content=is_content,
+ )
+
+ def emit_text_segments(
+ segments: list[tuple[str, bool]], msg_idx: int, *, is_sampled: bool
+ ) -> None:
+ for tok_id, is_content in attribute_text_segments(
+ self._tokenizer, segments
+ ):
+ tokens.append(tok_id)
+ indices.append(msg_idx)
+ sampled.append(is_sampled)
+ content_mask.append(is_content)
# ── 1. BOS token ─────────────────────────────────────────────
- emit_special(self._bos, -1, is_sampled=False)
+ emit_special(self._bos, -1, is_sampled=False, is_content=False)
# ── 2. Collect system messages at the start ───────────────────
# All leading system messages are concatenated with "\n\n" and emitted
@@ -151,7 +177,8 @@ def emit_text(text: str, msg_idx: int, *, is_sampled: bool) -> None:
if sys_parts:
# Attribute the concatenated system text to the first system message (index 0).
- emit_text("\n\n".join(sys_parts), 0, is_sampled=False)
+ # The system content is the caller's body — mark is_content=True.
+ emit_text("\n\n".join(sys_parts), 0, is_sampled=False, is_content=True)
# ── 3. Render non-system messages ─────────────────────────────
num_messages = len(messages)
@@ -166,8 +193,8 @@ def emit_text(text: str, msg_idx: int, *, is_sampled: bool) -> None:
content = "".join(
p.get("text", "") for p in content if isinstance(p, dict)
)
- emit_special(self._user_token, i, is_sampled=False)
- emit_text(str(content), i, is_sampled=False)
+ emit_special(self._user_token, i, is_sampled=False, is_content=False)
+ emit_text(str(content), i, is_sampled=False, is_content=True)
elif role == "user":
content = msg.get("content") or ""
@@ -180,8 +207,8 @@ def emit_text(text: str, msg_idx: int, *, is_sampled: bool) -> None:
else ""
for p in content
)
- emit_special(self._user_token, i, is_sampled=False)
- emit_text(str(content), i, is_sampled=False)
+ emit_special(self._user_token, i, is_sampled=False, is_content=False)
+ emit_text(str(content), i, is_sampled=False, is_content=True)
elif role == "assistant":
self._render_assistant(
@@ -190,6 +217,7 @@ def emit_text(text: str, msg_idx: int, *, is_sampled: bool) -> None:
messages,
emit_special=emit_special,
emit_text=emit_text,
+ emit_text_segments=emit_text_segments,
)
elif role == "tool":
@@ -198,6 +226,7 @@ def emit_text(text: str, msg_idx: int, *, is_sampled: bool) -> None:
i,
emit_special=emit_special,
emit_text=emit_text,
+ emit_text_segments=emit_text_segments,
)
# ── 4. Generation prompt ──────────────────────────────────────
@@ -205,14 +234,17 @@ def emit_text(text: str, msg_idx: int, *, is_sampled: bool) -> None:
# Don't add <|Assistant|> after tool outputs — content flows directly.
last_role = messages[-1]["role"] if messages else None
if last_role != "tool":
- emit_special(self._assistant_token, -1, is_sampled=False)
+ emit_special(
+ self._assistant_token, -1, is_sampled=False, is_content=False
+ )
if self._enable_thinking:
- emit_text("\n", -1, is_sampled=False)
+ emit_text("\n", -1, is_sampled=False, is_content=False)
return RenderedTokens(
token_ids=tokens,
message_indices=indices,
sampled_mask=sampled,
+ is_content=content_mask,
message_roles=[m.get("role") or "" for m in messages],
)
@@ -276,27 +308,38 @@ def bridge_to_next_turn(
ext: list[int] = []
ext_indices: list[int] = []
ext_sampled: list[bool] = []
+ ext_content: list[bool] = []
# Bridge populates ``message_indices`` (relative to ``new_messages``)
# and ``sampled_mask`` (uniformly ``False`` — every token the
# bridge emits is template scaffolding for the next prompt, not
- # something the model sampled). Downstream consumers can run
- # :meth:`RenderedTokens.tokens_per_message` on the bridge output
- # to get per-new-message token counts without re-rendering.
+ # something the model sampled). ``is_content`` follows the same
+ # rules as in :meth:`render` so consumers can walk the trajectory
+ # and read each step's own body mask.
def emit_special(
- token_id: int, msg_idx: int = -1, *, is_sampled: bool = False
+ token_id: int,
+ msg_idx: int = -1,
+ *,
+ is_sampled: bool = False,
+ is_content: bool = False,
) -> None:
ext.append(token_id)
ext_indices.append(msg_idx)
ext_sampled.append(is_sampled)
+ ext_content.append(is_content)
def emit_text(
- text: str, msg_idx: int = -1, *, is_sampled: bool = False
+ text: str,
+ msg_idx: int = -1,
+ *,
+ is_sampled: bool = False,
+ is_content: bool = False,
) -> None:
ids = self._encode(text)
ext.extend(ids)
ext_indices.extend([msg_idx] * len(ids))
ext_sampled.extend([is_sampled] * len(ids))
+ ext_content.extend([is_content] * len(ids))
for i, msg in enumerate(new_messages):
role = msg.get("role")
@@ -309,11 +352,11 @@ def emit_text(
if role == "user":
emit_special(self._user_token, i)
- emit_text(content, i)
+ emit_text(content, i, is_content=True)
elif role == "system":
# Post-initial system messages render as user turns.
emit_special(self._user_token, i)
- emit_text(content, i)
+ emit_text(content, i, is_content=True)
elif role == "tool":
prev_is_tool = i > 0 and new_messages[i - 1].get("role") == "tool"
next_is_tool = (
@@ -323,7 +366,7 @@ def emit_text(
if not prev_is_tool:
emit_special(self._tool_outputs_begin, i)
emit_special(self._tool_output_begin, i)
- emit_text(content, i)
+ emit_text(content, i, is_content=True)
emit_special(self._tool_output_end, i)
if not next_is_tool:
emit_special(self._tool_outputs_end, i)
@@ -344,6 +387,7 @@ def emit_text(
token_ids=previous_ids + ext,
message_indices=[-1] * len(previous_ids) + ext_indices,
sampled_mask=[False] * total_len,
+ is_content=[False] * len(previous_ids) + ext_content,
message_roles=[m.get("role") or "" for m in new_messages],
)
@@ -359,6 +403,7 @@ def _render_assistant(
*,
emit_special,
emit_text,
+ emit_text_segments,
) -> None:
# Determine whether this message follows a tool output sequence.
# The HF template emits <|tool▁outputs▁end|> before the assistant content
@@ -391,18 +436,23 @@ def _render_assistant(
# keeps the SFT loss mask aligned with what the model would
# actually have produced. When the previous message is a tool
# response, the template skips this token entirely (content
- # flows directly out of ``<|tool▁outputs▁end|>``).
+ # flows directly out of ``<|tool▁outputs▁end|>``). On assistant
+ # the invariant ``is_content == sampled_mask`` holds.
if not prev_is_tool:
- emit_special(self._assistant_token, msg_idx, is_sampled=False)
+ emit_special(
+ self._assistant_token, msg_idx, is_sampled=False, is_content=False
+ )
if not tool_calls:
- emit_text(content, msg_idx, is_sampled=True)
+ emit_text(content, msg_idx, is_sampled=True, is_content=True)
else:
# Emit any pre-tool-call content first.
- emit_text(content, msg_idx, is_sampled=True)
+ emit_text(content, msg_idx, is_sampled=True, is_content=True)
# Tool call section.
- emit_special(self._tool_calls_begin, msg_idx, is_sampled=True)
+ emit_special(
+ self._tool_calls_begin, msg_idx, is_sampled=True, is_content=True
+ )
for tc in tool_calls:
func = tc.get("function") or tc
name = func.get("name", "")
@@ -414,17 +464,28 @@ def _render_assistant(
)
# Format: <|tool▁call▁begin|>function<|tool▁sep|>{name}\n```json\n{args}\n```<|tool▁call▁end|>
# tool_sep is a special token; type ("function") and name+args are plain text.
- emit_special(self._tool_call_begin, msg_idx, is_sampled=True)
- emit_text("function", msg_idx, is_sampled=True)
- emit_special(self._tool_sep, msg_idx, is_sampled=True)
- emit_text(f"{name}\n```json\n{args_str}\n```", msg_idx, is_sampled=True)
- emit_special(self._tool_call_end, msg_idx, is_sampled=True)
- emit_special(self._tool_calls_end, msg_idx, is_sampled=True)
+ emit_special(
+ self._tool_call_begin, msg_idx, is_sampled=True, is_content=True
+ )
+ emit_text("function", msg_idx, is_sampled=True, is_content=True)
+ emit_special(self._tool_sep, msg_idx, is_sampled=True, is_content=True)
+ emit_text(
+ f"{name}\n```json\n{args_str}\n```",
+ msg_idx,
+ is_sampled=True,
+ is_content=True,
+ )
+ emit_special(
+ self._tool_call_end, msg_idx, is_sampled=True, is_content=True
+ )
+ emit_special(
+ self._tool_calls_end, msg_idx, is_sampled=True, is_content=True
+ )
# ``<|end▁of▁sentence|>`` is the model's stop signal — it
# samples this to end its turn, so it is part of the sampled
# stream.
- emit_special(self._eos, msg_idx, is_sampled=True)
+ emit_special(self._eos, msg_idx, is_sampled=True, is_content=True)
# ------------------------------------------------------------------
# Tool (tool-response) rendering
@@ -437,10 +498,13 @@ def _render_tool(
*,
emit_special,
emit_text,
+ emit_text_segments,
) -> None:
# Tool messages are conversation history injected by the runtime
# between assistant turns — the model never samples any of these
- # tokens, so every emission is is_sampled=False.
+ # tokens, so every emission is is_sampled=False. The ``content``
+ # body bytes get ``is_content=True``; the surrounding section
+ # specials are scaffold.
prev_is_tool = msg_idx > 0 and messages[msg_idx - 1]["role"] == "tool"
next_is_tool = (
msg_idx + 1 < len(messages) and messages[msg_idx + 1]["role"] == "tool"
@@ -451,11 +515,17 @@ def _render_tool(
content = "".join(p.get("text", "") for p in content if isinstance(p, dict))
if not prev_is_tool:
- emit_special(self._tool_outputs_begin, msg_idx, is_sampled=False)
+ emit_special(
+ self._tool_outputs_begin, msg_idx, is_sampled=False, is_content=False
+ )
- emit_special(self._tool_output_begin, msg_idx, is_sampled=False)
- emit_text(str(content), msg_idx, is_sampled=False)
- emit_special(self._tool_output_end, msg_idx, is_sampled=False)
+ emit_special(
+ self._tool_output_begin, msg_idx, is_sampled=False, is_content=False
+ )
+ emit_text(str(content), msg_idx, is_sampled=False, is_content=True)
+ emit_special(self._tool_output_end, msg_idx, is_sampled=False, is_content=False)
if not next_is_tool:
- emit_special(self._tool_outputs_end, msg_idx, is_sampled=False)
+ emit_special(
+ self._tool_outputs_end, msg_idx, is_sampled=False, is_content=False
+ )
diff --git a/renderers/glm45.py b/renderers/glm45.py
index 4a80ab3..206f366 100644
--- a/renderers/glm45.py
+++ b/renderers/glm45.py
@@ -20,6 +20,7 @@
ParsedResponse,
RenderedTokens,
ToolSpec,
+ attribute_text_segments,
reject_assistant_in_extension,
should_preserve_past_thinking,
)
@@ -128,30 +129,58 @@ def render(
tokens: list[int] = []
indices: list[int] = []
sampled: list[bool] = []
+ content_mask: list[bool] = []
- def emit_special(token_id: int, msg_idx: int, *, is_sampled: bool) -> None:
+ def emit_special(
+ token_id: int, msg_idx: int, *, is_sampled: bool, is_content: bool
+ ) -> None:
tokens.append(token_id)
indices.append(msg_idx)
sampled.append(is_sampled)
+ content_mask.append(is_content)
- def emit_text(text: str, msg_idx: int, *, is_sampled: bool) -> None:
+ def emit_text(
+ text: str, msg_idx: int, *, is_sampled: bool, is_content: bool
+ ) -> None:
ids = self._encode(text)
tokens.extend(ids)
indices.extend([msg_idx] * len(ids))
sampled.extend([is_sampled] * len(ids))
+ content_mask.extend([is_content] * len(ids))
+
+ def emit_text_segments(
+ segments: list[tuple[str, bool]], msg_idx: int, *, is_sampled: bool
+ ) -> None:
+ """Tokenize concatenated segments as one BPE pass; per-token
+ ``is_content`` follows each token's source segment.
+
+ Lets call sites express "this wrap + this body, joined the
+ same way as the chat template, but attributed separately"
+ without splitting the encode call (which could shift BPE
+ merges at the boundary)."""
+ for tok_id, is_content in attribute_text_segments(
+ self._tokenizer, segments
+ ):
+ tokens.append(tok_id)
+ indices.append(msg_idx)
+ sampled.append(is_sampled)
+ content_mask.append(is_content)
# ── Prefix ──────────────────────────────────────────────────
- emit_special(self._gmask, -1, is_sampled=False)
- emit_special(self._sop, -1, is_sampled=False)
+ emit_special(self._gmask, -1, is_sampled=False, is_content=False)
+ emit_special(self._sop, -1, is_sampled=False, is_content=False)
# ── Tools in system prompt ──────────────────────────────────
+ # The tools-header block is all scaffold by design — the tools
+ # dict is recoverable from the ``tools`` argument; don't
+ # re-attribute the embedded JSON specs as message body.
if tools:
- emit_special(self._system, -1, is_sampled=False)
+ emit_special(self._system, -1, is_sampled=False, is_content=False)
tool_text = _TOOLS_HEADER
for tool in tools:
tool_text += json.dumps(tool, ensure_ascii=False) + "\n"
tool_text += _TOOLS_FOOTER
- emit_text(tool_text, -1, is_sampled=False)
+ emit_text(tool_text, -1, is_sampled=False, is_content=False)
# ── Compute last_user_index ─────────────────────────────────
last_ui = self._last_user_index(messages)
@@ -162,15 +191,22 @@ def emit_text(text: str, msg_idx: int, *, is_sampled: bool) -> None:
content = self._visible_text(msg.get("content"))
if role == "system":
- emit_special(self._system, i, is_sampled=False)
- emit_text("\n" + content, i, is_sampled=False)
+ emit_special(self._system, i, is_sampled=False, is_content=False)
+ # ``\n`` is the scaffold separator after the role tag;
+ # the body proper is the caller-provided content.
+ emit_text_segments(
+ [("\n", False), (content, True)], i, is_sampled=False
+ )
elif role == "user":
- emit_special(self._user, i, is_sampled=False)
- user_text = "\n" + content
+ emit_special(self._user, i, is_sampled=False, is_content=False)
+ # ``\n`` is scaffold; ``content`` is body; the optional
+ # ``/nothink`` suffix is scaffold the renderer injects
+ # when ``enable_thinking=False``.
+ user_segments: list[tuple[str, bool]] = [("\n", False), (content, True)]
if not self._enable_thinking and not content.endswith("/nothink"):
- user_text += "/nothink"
- emit_text(user_text, i, is_sampled=False)
+ user_segments.append(("/nothink", False))
+ emit_text_segments(user_segments, i, is_sampled=False)
elif role == "assistant":
preserve_thinking = should_preserve_past_thinking(
@@ -187,25 +223,32 @@ def emit_text(text: str, msg_idx: int, *, is_sampled: bool) -> None:
preserve_thinking=preserve_thinking,
emit_special=emit_special,
emit_text=emit_text,
+ emit_text_segments=emit_text_segments,
)
elif role == "tool":
self._render_tool(
- messages, i, content, emit_special=emit_special, emit_text=emit_text
+ messages,
+ i,
+ content,
+ emit_special=emit_special,
+ emit_text=emit_text,
+ emit_text_segments=emit_text_segments,
)
# ── Generation prompt ───────────────────────────────────────
if add_generation_prompt:
- emit_special(self._assistant, -1, is_sampled=False)
+ emit_special(self._assistant, -1, is_sampled=False, is_content=False)
if not self._enable_thinking:
- emit_text("\n", -1, is_sampled=False)
- emit_special(self._think, -1, is_sampled=False)
- emit_special(self._think_end, -1, is_sampled=False)
+ emit_text("\n", -1, is_sampled=False, is_content=False)
+ emit_special(self._think, -1, is_sampled=False, is_content=False)
+ emit_special(self._think_end, -1, is_sampled=False, is_content=False)
return RenderedTokens(
token_ids=tokens,
message_indices=indices,
sampled_mask=sampled,
+ is_content=content_mask,
message_roles=[m.get("role") or "" for m in messages],
)
@@ -276,27 +319,54 @@ def bridge_to_next_turn(
ext: list[int] = []
ext_indices: list[int] = []
ext_sampled: list[bool] = []
+ ext_content: list[bool] = []
# Bridge populates ``message_indices`` (relative to ``new_messages``)
# and ``sampled_mask`` (uniformly ``False`` — every token the
# bridge emits is template scaffolding for the next prompt, not
- # something the model sampled). Downstream consumers can run
- # :meth:`RenderedTokens.tokens_per_message` on the bridge output
- # to get per-new-message token counts without re-rendering.
+ # something the model sampled). ``is_content`` follows the same
+ # rules as in :meth:`render` so consumers can walk the trajectory
+ # and read each step's own body mask. Downstream consumers can
+ # run :meth:`RenderedTokens.tokens_per_message` on the bridge
+ # output to get per-new-message token counts without re-rendering.
def emit_special(
- token_id: int, msg_idx: int = -1, *, is_sampled: bool = False
+ token_id: int,
+ msg_idx: int = -1,
+ *,
+ is_sampled: bool = False,
+ is_content: bool = False,
) -> None:
ext.append(token_id)
ext_indices.append(msg_idx)
ext_sampled.append(is_sampled)
+ ext_content.append(is_content)
def emit_text(
- text: str, msg_idx: int = -1, *, is_sampled: bool = False
+ text: str,
+ msg_idx: int = -1,
+ *,
+ is_sampled: bool = False,
+ is_content: bool = False,
) -> None:
ids = self._encode(text)
ext.extend(ids)
ext_indices.extend([msg_idx] * len(ids))
ext_sampled.extend([is_sampled] * len(ids))
+ ext_content.extend([is_content] * len(ids))
+
+ def emit_text_segments(
+ segments: list[tuple[str, bool]],
+ msg_idx: int = -1,
+ *,
+ is_sampled: bool = False,
+ ) -> None:
+ for tok_id, is_content in attribute_text_segments(
+ self._tokenizer, segments
+ ):
+ ext.append(tok_id)
+ ext_indices.append(msg_idx)
+ ext_sampled.append(is_sampled)
+ ext_content.append(is_content)
for i, msg in enumerate(new_messages):
role = msg.get("role")
@@ -304,20 +374,30 @@ def emit_text(
if role == "user":
if not (i == 0 and last_prev == self._user):
emit_special(self._user, i)
- user_text = "\n" + content
+ user_segments: list[tuple[str, bool]] = [
+ ("\n", False),
+ (content, True),
+ ]
if not self._enable_thinking and not content.endswith("/nothink"):
- user_text += "/nothink"
- emit_text(user_text, i)
+ user_segments.append(("/nothink", False))
+ emit_text_segments(user_segments, i)
elif role == "system":
emit_special(self._system, i)
- emit_text("\n" + content, i)
+ emit_text_segments([("\n", False), (content, True)], i)
elif role == "tool":
prev_is_tool = i > 0 and new_messages[i - 1].get("role") == "tool"
if i == 0 and last_prev == self._observation:
pass
elif not prev_is_tool:
emit_special(self._observation, i)
- emit_text("\n\n" + content + "\n", i)
+ emit_text_segments(
+ [
+ ("\n\n", False),
+ (content, True),
+ ("\n", False),
+ ],
+ i,
+ )
else:
return None
@@ -333,6 +413,7 @@ def emit_text(
token_ids=previous_ids + ext,
message_indices=[-1] * len(previous_ids) + ext_indices,
sampled_mask=[False] * total_len,
+ is_content=[False] * len(previous_ids) + ext_content,
message_roles=[m.get("role") or "" for m in new_messages],
)
@@ -346,6 +427,7 @@ def _render_assistant(
preserve_thinking: bool = False,
emit_special,
emit_text,
+ emit_text_segments,
):
reasoning_content = ""
if isinstance(msg.get("reasoning_content"), str):
@@ -373,23 +455,31 @@ def _render_assistant(
# turn). So no sampled stop-signal token lives inside this
# assistant span — content / think / tool_calls carry the
# is_sampled=True signal.
- emit_special(self._assistant, msg_idx, is_sampled=False)
- emit_text("\n", msg_idx, is_sampled=False)
+ #
+ # Invariant on assistant tokens: ``is_content == sampled_mask``.
+ # Every scaffold token here gets ``is_sampled=False/is_content=False``;
+ # every model-sampled emit gets both True.
+ emit_special(self._assistant, msg_idx, is_sampled=False, is_content=False)
+ emit_text("\n", msg_idx, is_sampled=False, is_content=False)
if (msg_idx > last_user_index or preserve_thinking) and reasoning_content:
- emit_special(self._think, msg_idx, is_sampled=True)
- emit_text(reasoning_content.strip(), msg_idx, is_sampled=True)
- emit_special(self._think_end, msg_idx, is_sampled=True)
+ emit_special(self._think, msg_idx, is_sampled=True, is_content=True)
+ emit_text(
+ reasoning_content.strip(), msg_idx, is_sampled=True, is_content=True
+ )
+ emit_special(self._think_end, msg_idx, is_sampled=True, is_content=True)
else:
- emit_special(self._think, msg_idx, is_sampled=True)
- emit_special(self._think_end, msg_idx, is_sampled=True)
+ emit_special(self._think, msg_idx, is_sampled=True, is_content=True)
+ emit_special(self._think_end, msg_idx, is_sampled=True, is_content=True)
# Tool calls — keep content + \n contiguous to preserve BPE merges
tool_calls = msg.get("tool_calls") or []
if content.strip() and tool_calls:
- emit_text("\n" + content.strip() + "\n", msg_idx, is_sampled=True)
+ emit_text(
+ "\n" + content.strip() + "\n", msg_idx, is_sampled=True, is_content=True
+ )
elif content.strip():
- emit_text("\n" + content.strip(), msg_idx, is_sampled=True)
+ emit_text("\n" + content.strip(), msg_idx, is_sampled=True, is_content=True)
for tc in tool_calls:
func = tc.get("function") or tc
@@ -397,9 +487,9 @@ def _render_assistant(
arguments = func.get("arguments", {})
if not content.strip():
- emit_text("\n", msg_idx, is_sampled=True)
- emit_special(self._tool_call_tok, msg_idx, is_sampled=True)
- emit_text(name + "\n", msg_idx, is_sampled=True)
+ emit_text("\n", msg_idx, is_sampled=True, is_content=True)
+ emit_special(self._tool_call_tok, msg_idx, is_sampled=True, is_content=True)
+ emit_text(name + "\n", msg_idx, is_sampled=True, is_content=True)
# OpenAI canonical form: arguments is a JSON string. Parse it so the
# per-argument rendering below still works.
if isinstance(arguments, str):
@@ -409,22 +499,33 @@ def _render_assistant(
arguments = {}
if isinstance(arguments, dict):
for arg_name, arg_value in arguments.items():
- emit_special(self._arg_key, msg_idx, is_sampled=True)
- emit_text(arg_name, msg_idx, is_sampled=True)
- emit_special(self._arg_key_end, msg_idx, is_sampled=True)
- emit_text("\n", msg_idx, is_sampled=True)
- emit_special(self._arg_value, msg_idx, is_sampled=True)
+ emit_special(
+ self._arg_key, msg_idx, is_sampled=True, is_content=True
+ )
+ emit_text(arg_name, msg_idx, is_sampled=True, is_content=True)
+ emit_special(
+ self._arg_key_end, msg_idx, is_sampled=True, is_content=True
+ )
+ emit_text("\n", msg_idx, is_sampled=True, is_content=True)
+ emit_special(
+ self._arg_value, msg_idx, is_sampled=True, is_content=True
+ )
if isinstance(arg_value, str):
- emit_text(arg_value, msg_idx, is_sampled=True)
+ emit_text(arg_value, msg_idx, is_sampled=True, is_content=True)
else:
emit_text(
json.dumps(arg_value, ensure_ascii=False),
msg_idx,
is_sampled=True,
+ is_content=True,
)
- emit_special(self._arg_value_end, msg_idx, is_sampled=True)
- emit_text("\n", msg_idx, is_sampled=True)
- emit_special(self._tool_call_end_tok, msg_idx, is_sampled=True)
+ emit_special(
+ self._arg_value_end, msg_idx, is_sampled=True, is_content=True
+ )
+ emit_text("\n", msg_idx, is_sampled=True, is_content=True)
+ emit_special(
+ self._tool_call_end_tok, msg_idx, is_sampled=True, is_content=True
+ )
def _render_tool(
self,
@@ -434,17 +535,30 @@ def _render_tool(
*,
emit_special,
emit_text,
+ emit_text_segments,
) -> None:
# Tool messages are conversation history injected by the runtime
# between assistant turns — the model never samples any of these
- # tokens, so every emission is is_sampled=False.
+ # tokens, so every emission is is_sampled=False. The body bytes
+ # get ``is_content=True``; the ``\n\n`` /
+ # ``\n`` wraps and the ``<|observation|>`` role
+ # tag are scaffold so the SFT mask for tool body never trains
+ # the model to emit them. Single BPE pass over the joined text
+ # preserves boundary merges (the tool body's leading/trailing
+ # chars can merge with the wrap's ``\n``s if the tokenizer would
+ # do so; we route through ``emit_text_segments`` so the
+ # attribution is offset-driven and tokenizer-agnostic).
prev_is_tool = msg_idx > 0 and messages[msg_idx - 1]["role"] == "tool"
if not prev_is_tool:
- emit_special(self._observation, msg_idx, is_sampled=False)
-
- emit_text(
- "\n\n" + content + "\n",
+ emit_special(self._observation, msg_idx, is_sampled=False, is_content=False)
+
+ emit_text_segments(
+ [
+ ("\n\n", False),
+ (content, True),
+ ("\n", False),
+ ],
msg_idx,
is_sampled=False,
)
diff --git a/renderers/glm5.py b/renderers/glm5.py
index 63599f2..6de6ba3 100644
--- a/renderers/glm5.py
+++ b/renderers/glm5.py
@@ -21,6 +21,7 @@
ParsedResponse,
RenderedTokens,
ToolSpec,
+ attribute_text_segments,
reject_assistant_in_extension,
should_preserve_past_thinking,
)
@@ -143,32 +144,61 @@ def render(
tokens: list[int] = []
indices: list[int] = []
sampled: list[bool] = []
+ content_mask: list[bool] = []
- def emit_special(token_id: int, msg_idx: int, *, is_sampled: bool) -> None:
+ def emit_special(
+ token_id: int, msg_idx: int, *, is_sampled: bool, is_content: bool
+ ) -> None:
tokens.append(token_id)
indices.append(msg_idx)
sampled.append(is_sampled)
+ content_mask.append(is_content)
- def emit_text(text: str, msg_idx: int, *, is_sampled: bool) -> None:
+ def emit_text(
+ text: str, msg_idx: int, *, is_sampled: bool, is_content: bool
+ ) -> None:
ids = self._encode(text)
tokens.extend(ids)
indices.extend([msg_idx] * len(ids))
sampled.extend([is_sampled] * len(ids))
+ content_mask.extend([is_content] * len(ids))
+
+ def emit_text_segments(
+ segments: list[tuple[str, bool]], msg_idx: int, *, is_sampled: bool
+ ) -> None:
+ """Tokenize concatenated segments as one BPE pass; per-token
+ ``is_content`` follows each token's source segment.
+
+ Lets call sites express "this wrap + this body, joined the
+ same way as the chat template, but attributed separately"
+ without splitting the encode call (which could shift BPE
+ merges at the boundary)."""
+ for tok_id, is_content in attribute_text_segments(
+ self._tokenizer, segments
+ ):
+ tokens.append(tok_id)
+ indices.append(msg_idx)
+ sampled.append(is_sampled)
+ content_mask.append(is_content)
# ── Prefix ──────────────────────────────────────────────────
# ``[gMASK]`` is unconditional template scaffolding at the
- # very start of the stream — the model never samples these.
- emit_special(self._gmask, -1, is_sampled=False)
- emit_special(self._sop, -1, is_sampled=False)
+ # very start of the stream — the model never samples these and
+ # they are not part of any message body.
+ emit_special(self._gmask, -1, is_sampled=False, is_content=False)
+ emit_special(self._sop, -1, is_sampled=False, is_content=False)
# ── Tools in system prompt ──────────────────────────────────
+ # The tools-header block is all scaffold by design — the tools
+ # dict is recoverable from the ``tools`` argument; don't
+ # re-attribute the embedded JSON specs as message body.
if tools:
- emit_special(self._system, -1, is_sampled=False)
+ emit_special(self._system, -1, is_sampled=False, is_content=False)
tool_text = _TOOLS_HEADER
for tool in tools:
tool_text += self._format_tool_spec(tool) + "\n"
tool_text += _TOOLS_FOOTER
- emit_text(tool_text, -1, is_sampled=False)
+ emit_text(tool_text, -1, is_sampled=False, is_content=False)
# ── Compute last_user_index ─────────────────────────────────
last_ui = self._last_user_index(messages)
@@ -179,12 +209,12 @@ def emit_text(text: str, msg_idx: int, *, is_sampled: bool) -> None:
content = self._visible_text(msg.get("content"))
if role == "system":
- emit_special(self._system, i, is_sampled=False)
- emit_text(content, i, is_sampled=False)
+ emit_special(self._system, i, is_sampled=False, is_content=False)
+ emit_text(content, i, is_sampled=False, is_content=True)
elif role == "user":
- emit_special(self._user, i, is_sampled=False)
- emit_text(content, i, is_sampled=False)
+ emit_special(self._user, i, is_sampled=False, is_content=False)
+ emit_text(content, i, is_sampled=False, is_content=True)
elif role == "assistant":
preserve_thinking = should_preserve_past_thinking(
@@ -201,28 +231,35 @@ def emit_text(text: str, msg_idx: int, *, is_sampled: bool) -> None:
preserve_thinking=preserve_thinking,
emit_special=emit_special,
emit_text=emit_text,
+ emit_text_segments=emit_text_segments,
)
elif role == "tool":
self._render_tool(
- messages, i, content, emit_special=emit_special, emit_text=emit_text
+ messages,
+ i,
+ content,
+ emit_special=emit_special,
+ emit_text=emit_text,
+ emit_text_segments=emit_text_segments,
)
# ── Generation prompt ───────────────────────────────────────
# Gen prompt tokens are what the chat template prepends before
# sampling starts — the model continues from these, never emits
- # them. Always is_sampled=False.
+ # them. Always is_sampled=False / is_content=False.
if add_generation_prompt:
- emit_special(self._assistant, -1, is_sampled=False)
+ emit_special(self._assistant, -1, is_sampled=False, is_content=False)
if self._enable_thinking:
- emit_special(self._think, -1, is_sampled=False)
+ emit_special(self._think, -1, is_sampled=False, is_content=False)
else:
- emit_special(self._think_end, -1, is_sampled=False)
+ emit_special(self._think_end, -1, is_sampled=False, is_content=False)
return RenderedTokens(
token_ids=tokens,
message_indices=indices,
sampled_mask=sampled,
+ is_content=content_mask,
message_roles=[m.get("role") or "" for m in messages],
)
@@ -297,27 +334,54 @@ def bridge_to_next_turn(
ext: list[int] = []
ext_indices: list[int] = []
ext_sampled: list[bool] = []
+ ext_content: list[bool] = []
# Bridge populates ``message_indices`` (relative to ``new_messages``)
# and ``sampled_mask`` (uniformly ``False`` — every token the
# bridge emits is template scaffolding for the next prompt, not
- # something the model sampled). Downstream consumers can run
- # :meth:`RenderedTokens.tokens_per_message` on the bridge output
- # to get per-new-message token counts without re-rendering.
+ # something the model sampled). ``is_content`` follows the same
+ # rules as in :meth:`render` so consumers can walk the trajectory
+ # and read each step's own body mask. Downstream consumers can
+ # run :meth:`RenderedTokens.tokens_per_message` on the bridge
+ # output to get per-new-message token counts without re-rendering.
def emit_special(
- token_id: int, msg_idx: int = -1, *, is_sampled: bool = False
+ token_id: int,
+ msg_idx: int = -1,
+ *,
+ is_sampled: bool = False,
+ is_content: bool = False,
) -> None:
ext.append(token_id)
ext_indices.append(msg_idx)
ext_sampled.append(is_sampled)
+ ext_content.append(is_content)
def emit_text(
- text: str, msg_idx: int = -1, *, is_sampled: bool = False
+ text: str,
+ msg_idx: int = -1,
+ *,
+ is_sampled: bool = False,
+ is_content: bool = False,
) -> None:
ids = self._encode(text)
ext.extend(ids)
ext_indices.extend([msg_idx] * len(ids))
ext_sampled.extend([is_sampled] * len(ids))
+ ext_content.extend([is_content] * len(ids))
+
+ def emit_text_segments(
+ segments: list[tuple[str, bool]],
+ msg_idx: int = -1,
+ *,
+ is_sampled: bool = False,
+ ) -> None:
+ for tok_id, is_content in attribute_text_segments(
+ self._tokenizer, segments
+ ):
+ ext.append(tok_id)
+ ext_indices.append(msg_idx)
+ ext_sampled.append(is_sampled)
+ ext_content.append(is_content)
for i, msg in enumerate(new_messages):
role = msg.get("role")
@@ -326,10 +390,10 @@ def emit_text(
# Dedup: model already emitted <|user|> as its stop token.
if not (i == 0 and last_prev == self._user):
emit_special(self._user, i)
- emit_text(content, i)
+ emit_text(content, i, is_content=True)
elif role == "system":
emit_special(self._system, i)
- emit_text(content, i)
+ emit_text(content, i, is_content=True)
elif role == "tool":
prev_is_tool = i > 0 and new_messages[i - 1].get("role") == "tool"
if i == 0 and last_prev == self._observation:
@@ -338,7 +402,7 @@ def emit_text(
elif not prev_is_tool:
emit_special(self._observation, i)
emit_special(self._tool_response_tok, i)
- emit_text(content, i)
+ emit_text(content, i, is_content=True)
emit_special(self._tool_response_end_tok, i)
else:
return None
@@ -355,6 +419,7 @@ def emit_text(
token_ids=previous_ids + ext,
message_indices=[-1] * len(previous_ids) + ext_indices,
sampled_mask=[False] * total_len,
+ is_content=[False] * len(previous_ids) + ext_content,
message_roles=[m.get("role") or "" for m in new_messages],
)
@@ -368,6 +433,7 @@ def _render_assistant(
preserve_thinking: bool = False,
emit_special,
emit_text,
+ emit_text_segments,
):
reasoning_content = ""
if isinstance(msg.get("reasoning_content"), str):
@@ -386,7 +452,11 @@ def _render_assistant(
# samples it. Same for the ```` open / standalone
# ```` separator that the template wraps around the
# assistant body — see the per-branch comments below.
- emit_special(self._assistant, msg_idx, is_sampled=False)
+ #
+ # Invariant on assistant tokens: ``is_content == sampled_mask``.
+ # Every scaffold token here gets ``is_sampled=False/is_content=False``;
+ # every model-sampled emit gets both True.
+ emit_special(self._assistant, msg_idx, is_sampled=False, is_content=False)
# Chat-template default: keep ```` only on the in-flight cycle
# (post-last-user). Past-cycle assistants drop their reasoning.
@@ -403,9 +473,11 @@ def _render_assistant(
# inference (gen prompt = ``<|assistant|>``), so it's
# template-injected scaffolding. The reasoning text and the
# closing ```` are what the model actually samples.
- emit_special(self._think, msg_idx, is_sampled=False)
- emit_text(reasoning_content.strip(), msg_idx, is_sampled=True)
- emit_special(self._think_end, msg_idx, is_sampled=True)
+ emit_special(self._think, msg_idx, is_sampled=False, is_content=False)
+ emit_text(
+ reasoning_content.strip(), msg_idx, is_sampled=True, is_content=True
+ )
+ emit_special(self._think_end, msg_idx, is_sampled=True, is_content=True)
elif self.empty_think_on_last_assistant and msg_idx > last_user_index:
# GLM-5.1: wrap the last assistant with an empty
# even without reasoning, matching the Jinja template. With
@@ -413,26 +485,28 @@ def _render_assistant(
# ````; the model then samples ```` to close an
# empty think block. So ```` is scaffolding,
# ```` is sampled.
- emit_special(self._think, msg_idx, is_sampled=False)
- emit_special(self._think_end, msg_idx, is_sampled=True)
+ emit_special(self._think, msg_idx, is_sampled=False, is_content=False)
+ emit_special(self._think_end, msg_idx, is_sampled=True, is_content=True)
else:
# Lone ```` separator the template injects when no
# reasoning is rendered (historical assistants, GLM-5 default
# with no thinking). Not sampled.
- emit_special(self._think_end, msg_idx, is_sampled=False)
+ emit_special(self._think_end, msg_idx, is_sampled=False, is_content=False)
if content.strip():
- emit_text(content.strip(), msg_idx, is_sampled=True)
+ emit_text(content.strip(), msg_idx, is_sampled=True, is_content=True)
- # Tool calls (directly after content, no newlines)
+ # Tool calls (directly after content, no newlines). All of these
+ # are the model's sampled output — both is_sampled and is_content
+ # are True across the entire tool-call span.
tool_calls = msg.get("tool_calls") or []
for tc in tool_calls:
func = tc.get("function") or tc
name = func.get("name", "")
arguments = func.get("arguments", {})
- emit_special(self._tool_call_tok, msg_idx, is_sampled=True)
- emit_text(name, msg_idx, is_sampled=True)
+ emit_special(self._tool_call_tok, msg_idx, is_sampled=True, is_content=True)
+ emit_text(name, msg_idx, is_sampled=True, is_content=True)
# OpenAI canonical form: arguments is a JSON string. Parse it so the
# per-argument rendering below still works.
if isinstance(arguments, str):
@@ -442,20 +516,31 @@ def _render_assistant(
arguments = {}
if isinstance(arguments, dict):
for arg_name, arg_value in arguments.items():
- emit_special(self._arg_key, msg_idx, is_sampled=True)
- emit_text(arg_name, msg_idx, is_sampled=True)
- emit_special(self._arg_key_end, msg_idx, is_sampled=True)
- emit_special(self._arg_value, msg_idx, is_sampled=True)
+ emit_special(
+ self._arg_key, msg_idx, is_sampled=True, is_content=True
+ )
+ emit_text(arg_name, msg_idx, is_sampled=True, is_content=True)
+ emit_special(
+ self._arg_key_end, msg_idx, is_sampled=True, is_content=True
+ )
+ emit_special(
+ self._arg_value, msg_idx, is_sampled=True, is_content=True
+ )
if isinstance(arg_value, str):
- emit_text(arg_value, msg_idx, is_sampled=True)
+ emit_text(arg_value, msg_idx, is_sampled=True, is_content=True)
else:
emit_text(
json.dumps(arg_value, ensure_ascii=False),
msg_idx,
is_sampled=True,
+ is_content=True,
)
- emit_special(self._arg_value_end, msg_idx, is_sampled=True)
- emit_special(self._tool_call_end_tok, msg_idx, is_sampled=True)
+ emit_special(
+ self._arg_value_end, msg_idx, is_sampled=True, is_content=True
+ )
+ emit_special(
+ self._tool_call_end_tok, msg_idx, is_sampled=True, is_content=True
+ )
def _render_tool(
self,
@@ -465,18 +550,26 @@ def _render_tool(
*,
emit_special,
emit_text,
+ emit_text_segments,
) -> None:
# Tool messages are conversation history injected by the runtime
# between assistant turns — the model never samples any of these
- # tokens, so every emission is is_sampled=False.
+ # tokens, so every emission is is_sampled=False. The tool body
+ # bytes get ``is_content=True``; the ``<|observation|>`` /
+ # ```` wraps are scaffold so the SFT mask for
+ # tool body never trains the model to emit them.
prev_is_tool = msg_idx > 0 and messages[msg_idx - 1]["role"] == "tool"
if not prev_is_tool:
- emit_special(self._observation, msg_idx, is_sampled=False)
+ emit_special(self._observation, msg_idx, is_sampled=False, is_content=False)
- emit_special(self._tool_response_tok, msg_idx, is_sampled=False)
- emit_text(content, msg_idx, is_sampled=False)
- emit_special(self._tool_response_end_tok, msg_idx, is_sampled=False)
+ emit_special(
+ self._tool_response_tok, msg_idx, is_sampled=False, is_content=False
+ )
+ emit_text(content, msg_idx, is_sampled=False, is_content=True)
+ emit_special(
+ self._tool_response_end_tok, msg_idx, is_sampled=False, is_content=False
+ )
class GLM51Renderer(GLM5Renderer):
diff --git a/renderers/gpt_oss.py b/renderers/gpt_oss.py
index 7a5b26e..9939de1 100644
--- a/renderers/gpt_oss.py
+++ b/renderers/gpt_oss.py
@@ -184,6 +184,86 @@ def _encode(self, text: str) -> list[int]:
return []
return self._tokenizer.encode(text, add_special_tokens=False)
+ def _prefix_content_mask(
+ self,
+ prefix_tokens: list[int],
+ first_system_idx: int | None,
+ messages: list[Message],
+ tools: list[ToolSpec] | None,
+ ) -> list[bool]:
+ """Per-token is_content mask over the rendered system+developer prefix.
+
+ Harmony's prefix is one opaque block. The caller's system content
+ lands inside the developer message as ``# Instructions\\n\\n{content}``.
+ To attribute body bytes back, we render the same prefix with empty
+ instructions and diff: the unique-to-with-instructions span is the
+ body region. Falls back to all-False (whole prefix scaffold) if
+ the caller didn't supply a system message — there's no body to
+ attribute in that case.
+ """
+ n = len(prefix_tokens)
+ mask = [False] * n
+ if first_system_idx is None:
+ return mask
+ instructions = _content_text(messages[first_system_idx].get("content"))
+ if not instructions:
+ return mask
+
+ # Build the same prefix with empty instructions.
+ empty_prefix_msgs: list[HarmonyMessage] = []
+ if self._use_system_prompt:
+ sys_content = SystemContent.new().with_reasoning_effort(
+ self._reasoning_effort
+ )
+ sys_content = sys_content.with_conversation_start_date(
+ self._conversation_start_date
+ )
+ if self._knowledge_cutoff is not None:
+ sys_content = sys_content.with_knowledge_cutoff(self._knowledge_cutoff)
+ if self._model_identity is not None:
+ sys_content = sys_content.with_model_identity(self._model_identity)
+ empty_prefix_msgs.append(
+ HarmonyMessage.from_role_and_content(Role.SYSTEM, sys_content)
+ )
+ dev = DeveloperContent.new()
+ if tools:
+ dev = dev.with_function_tools([_tool_to_description(t) for t in tools])
+ empty_prefix_msgs.append(
+ HarmonyMessage.from_role_and_content(Role.DEVELOPER, dev)
+ )
+ try:
+ empty_tokens = self._enc.render_conversation(
+ Conversation.from_messages(empty_prefix_msgs)
+ )
+ except Exception:
+ return mask
+
+ # Longest common prefix.
+ i_start = 0
+ n_empty = len(empty_tokens)
+ while (
+ i_start < min(n, n_empty)
+ and prefix_tokens[i_start] == empty_tokens[i_start]
+ ):
+ i_start += 1
+ # Longest common suffix.
+ j_full = n
+ j_empty = n_empty
+ while (
+ j_full > i_start
+ and j_empty > i_start
+ and prefix_tokens[j_full - 1] == empty_tokens[j_empty - 1]
+ ):
+ j_full -= 1
+ j_empty -= 1
+ # Tokens [i_start:j_full] in prefix_tokens are unique to the
+ # with-instructions render — that's the body span (includes the
+ # ``# Instructions\n\n`` scaffolding header, which the substring
+ # match in the body-decode test ignores).
+ for k in range(i_start, j_full):
+ mask[k] = True
+ return mask
+
# ── public interface ─────────────────────────────────────────────────────
def render(
@@ -199,11 +279,15 @@ def render(
tokens: list[int] = []
indices: list[int] = []
sampled: list[bool] = []
+ content_mask: list[bool] = []
- def emit(ids: list[int], msg_idx: int, *, is_sampled: bool) -> None:
+ def emit(
+ ids: list[int], msg_idx: int, *, is_sampled: bool, is_content: bool
+ ) -> None:
tokens.extend(ids)
indices.extend([msg_idx] * len(ids))
sampled.extend([is_sampled] * len(ids))
+ content_mask.extend([is_content] * len(ids))
def emit_harmony_message(
hm_ids: list[int], msg_idx: int, *, is_assistant: bool
@@ -225,18 +309,53 @@ def emit_harmony_message(
developer, tool) the whole message is conversation history
the model never samples — every token is
``is_sampled=False``.
+
+ ``is_content`` further splits the body: the trailing
+ terminator (``<|end|>`` / ``<|return|>`` / ``<|call|>``) is
+ scaffold on non-assistant turns; on assistant turns it's the
+ model's stop signal, so ``is_content=True`` mirrors
+ ``sampled_mask`` as the invariant on assistant requires.
+ The body bytes between ``<|message|>`` and the terminator
+ are body (``is_content=True``) on every role — that's the
+ caller-provided content (or, for assistant, the model's
+ sampled emission). The header (``<|start|>`` ... ``<|message|>``)
+ — including any ``functions.{name}`` recipient text on tool
+ results, which comes from a prior assistant's tool_calls
+ rather than this tool message's own content — is scaffold.
"""
try:
msg_marker = hm_ids.index(self._message)
except ValueError:
# Defensive: a harmony message without <|message|> is
# malformed. Treat the whole thing as scaffolding.
- emit(hm_ids, msg_idx, is_sampled=False)
+ emit(hm_ids, msg_idx, is_sampled=False, is_content=False)
return
header = hm_ids[: msg_marker + 1]
body = hm_ids[msg_marker + 1 :]
- emit(header, msg_idx, is_sampled=False)
- emit(body, msg_idx, is_sampled=is_assistant)
+ emit(header, msg_idx, is_sampled=False, is_content=False)
+ # Split body into content + terminator. The terminator (if
+ # present) is the last token of the body and is one of the
+ # three harmony stop tokens.
+ terminator_ids = {self._end, self._return, self._call}
+ if body and body[-1] in terminator_ids:
+ body_content = body[:-1]
+ terminator = body[-1:]
+ else:
+ body_content = body
+ terminator = []
+ emit(
+ body_content,
+ msg_idx,
+ is_sampled=is_assistant,
+ is_content=True,
+ )
+ if terminator:
+ emit(
+ terminator,
+ msg_idx,
+ is_sampled=is_assistant,
+ is_content=is_assistant,
+ )
# ── Build harmony prefix (system + developer) ───────────────────
# When tools are present, harmony's conversation-level renderer
@@ -285,7 +404,23 @@ def emit_harmony_message(
# caller-relative attribution); otherwise to -1 (pure scaffolding).
# The whole prefix is pure template scaffolding — never sampled.
prefix_origin = first_system_idx if first_system_idx is not None else -1
- emit(prefix_tokens, prefix_origin, is_sampled=False)
+ # Compute the body-token span inside the prefix by diffing
+ # against the same render with empty developer instructions.
+ # Tokens unique to the with-instructions render are the body
+ # span (``# Instructions\n\n{caller_system_content}``). Marking
+ # those is_content=True so the caller's system text is
+ # recoverable from ``content_token_spans_by_role()["system"]``.
+ # The scaffolding ``# Instructions\n\n`` prefix bleeds into
+ # the body run; consumers reading the body do a substring
+ # check rather than expecting an exact match.
+ prefix_content_mask = self._prefix_content_mask(
+ prefix_tokens, first_system_idx, messages, tools
+ )
+ for tid, is_content in zip(prefix_tokens, prefix_content_mask):
+ tokens.append(tid)
+ indices.append(prefix_origin)
+ sampled.append(False)
+ content_mask.append(is_content)
# ── Iterate the rest of the messages ────────────────────────────
last_idx = len(messages) - 1
@@ -326,16 +461,17 @@ def emit_harmony_message(
# ── Generation prompt: <|start|>assistant<|channel|>analysis<|message|>
# Pure template scaffolding the model continues from — never sampled.
if add_generation_prompt:
- emit([self._start], -1, is_sampled=False)
- emit(self._encode("assistant"), -1, is_sampled=False)
- emit([self._channel], -1, is_sampled=False)
- emit(self._encode("analysis"), -1, is_sampled=False)
- emit([self._message], -1, is_sampled=False)
+ emit([self._start], -1, is_sampled=False, is_content=False)
+ emit(self._encode("assistant"), -1, is_sampled=False, is_content=False)
+ emit([self._channel], -1, is_sampled=False, is_content=False)
+ emit(self._encode("analysis"), -1, is_sampled=False, is_content=False)
+ emit([self._message], -1, is_sampled=False, is_content=False)
return RenderedTokens(
token_ids=tokens,
message_indices=indices,
sampled_mask=sampled,
+ is_content=content_mask,
message_roles=[m.get("role") or "" for m in messages],
)
@@ -407,17 +543,47 @@ def bridge_to_next_turn(
# and ``sampled_mask`` (uniformly ``False``). The harmony encoder
# renders each ``new_messages[i]`` as a single block, so every
# token in that block carries index ``i``; the trailing
- # generation prompt uses ``-1``.
+ # generation prompt uses ``-1``. ``is_content`` follows the same
+ # rules as :meth:`render`'s ``emit_harmony_message``: header is
+ # scaffold, body bytes are body, terminator scaffold (the bridge
+ # never carries assistant turns, so terminators are always
+ # scaffold on the non-assistant roles the bridge accepts).
+ terminator_ids = {self._end, self._return, self._call}
ext: list[int] = []
ext_indices: list[int] = []
+ ext_content: list[bool] = []
for i, msg in enumerate(new_messages):
role = msg.get("role")
if role not in ("tool", "user", "system", "developer"):
return None
for hm in self._to_harmony_messages(msg):
ids = self._enc.render(hm)
- ext.extend(ids)
- ext_indices.extend([i] * len(ids))
+ try:
+ msg_marker = ids.index(self._message)
+ except ValueError:
+ # Defensive: treat as scaffolding.
+ ext.extend(ids)
+ ext_indices.extend([i] * len(ids))
+ ext_content.extend([False] * len(ids))
+ continue
+ header = ids[: msg_marker + 1]
+ body = ids[msg_marker + 1 :]
+ ext.extend(header)
+ ext_indices.extend([i] * len(header))
+ ext_content.extend([False] * len(header))
+ if body and body[-1] in terminator_ids:
+ body_content = body[:-1]
+ terminator = body[-1:]
+ else:
+ body_content = body
+ terminator = []
+ ext.extend(body_content)
+ ext_indices.extend([i] * len(body_content))
+ ext_content.extend([True] * len(body_content))
+ if terminator:
+ ext.extend(terminator)
+ ext_indices.extend([i] * len(terminator))
+ ext_content.extend([False] * len(terminator))
# Generation prompt: <|start|>assistant<|channel|>analysis<|message|>
gen_before = len(ext)
@@ -427,12 +593,14 @@ def bridge_to_next_turn(
ext.extend(self._encode("analysis"))
ext.append(self._message)
ext_indices.extend([-1] * (len(ext) - gen_before))
+ ext_content.extend([False] * (len(ext) - gen_before))
total_len = len(previous_ids) + len(ext)
return RenderedTokens(
token_ids=previous_ids + ext,
message_indices=[-1] * len(previous_ids) + ext_indices,
sampled_mask=[False] * total_len,
+ is_content=[False] * len(previous_ids) + ext_content,
message_roles=[m.get("role") or "" for m in new_messages],
)
diff --git a/renderers/kimi_k2.py b/renderers/kimi_k2.py
index b0106e4..9e08141 100644
--- a/renderers/kimi_k2.py
+++ b/renderers/kimi_k2.py
@@ -118,6 +118,11 @@ def render(
if not messages:
raise ValueError("No messages provided.")
+ # Preserve the caller's list — ``message_roles`` and per-token
+ # attribution refer to this frame (not the post-normalisation
+ # list that includes auto-injected system / tool_declare).
+ caller_messages = list(messages)
+
# Inject tools as tool_declare message + ensure system message.
# The Jinja template emits the tools list directly (no
# ``{"type":"function","function":...}`` wrapper) using
@@ -166,19 +171,33 @@ def orig_idx(i: int) -> int:
token_ids: list[int] = []
indices: list[int] = []
sampled: list[bool] = []
+ content_mask: list[bool] = []
- def emit_ids(ids: list[int], msg_idx: int, *, is_sampled: bool) -> None:
+ def emit_ids(
+ ids: list[int], msg_idx: int, *, is_sampled: bool, is_content: bool
+ ) -> None:
token_ids.extend(ids)
indices.extend([msg_idx] * len(ids))
sampled.extend([is_sampled] * len(ids))
+ content_mask.extend([is_content] * len(ids))
- def emit_special(token_id: int, msg_idx: int, *, is_sampled: bool) -> None:
+ def emit_special(
+ token_id: int, msg_idx: int, *, is_sampled: bool, is_content: bool
+ ) -> None:
token_ids.append(token_id)
indices.append(msg_idx)
sampled.append(is_sampled)
+ content_mask.append(is_content)
- def emit_text(text: str, msg_idx: int, *, is_sampled: bool) -> None:
- emit_ids(self._encode(text), msg_idx, is_sampled=is_sampled)
+ def emit_text(
+ text: str, msg_idx: int, *, is_sampled: bool, is_content: bool
+ ) -> None:
+ emit_ids(
+ self._encode(text),
+ msg_idx,
+ is_sampled=is_sampled,
+ is_content=is_content,
+ )
# Compute last non-tool-call assistant index to determine thinking preservation
last_plain_assistant_idx = -1
@@ -206,31 +225,40 @@ def emit_text(text: str, msg_idx: int, *, is_sampled: bool) -> None:
content = "".join(parts)
oi = orig_idx(i)
+ # Auto-injected system / tool_declare messages have ``oi == -1``.
+ # Their text isn't from the caller's input, so we treat the
+ # whole emission as scaffold (``is_content=False`` everywhere).
+ # The test contract is that ``msg_idx == -1`` runs are
+ # template-only and ``is_content=False``.
+ body_is_content = oi >= 0
if role == "system":
- emit_special(self._im_system, oi, is_sampled=False)
- emit_text("system", oi, is_sampled=False)
- emit_special(self._im_middle, oi, is_sampled=False)
- emit_text(content, oi, is_sampled=False)
- emit_special(self._im_end, oi, is_sampled=False)
+ emit_special(self._im_system, oi, is_sampled=False, is_content=False)
+ emit_text("system", oi, is_sampled=False, is_content=False)
+ emit_special(self._im_middle, oi, is_sampled=False, is_content=False)
+ emit_text(content, oi, is_sampled=False, is_content=body_is_content)
+ emit_special(self._im_end, oi, is_sampled=False, is_content=False)
# Jinja emits a literal newline only after the auto-injected
# system's <|im_end|> (see _ensure_system_message's contract).
if i == auto_system_idx:
- emit_text("\n", oi, is_sampled=False)
+ emit_text("\n", oi, is_sampled=False, is_content=False)
elif role == "tool_declare":
- emit_special(self._im_system, oi, is_sampled=False)
- emit_text("tool_declare", oi, is_sampled=False)
- emit_special(self._im_middle, oi, is_sampled=False)
- emit_text(content, oi, is_sampled=False)
- emit_special(self._im_end, oi, is_sampled=False)
+ # The tool_declare body is the tools JSON — recoverable
+ # from the caller's ``tools`` argument, so we treat it as
+ # scaffold (consistent with Qwen3's tools-header block).
+ emit_special(self._im_system, oi, is_sampled=False, is_content=False)
+ emit_text("tool_declare", oi, is_sampled=False, is_content=False)
+ emit_special(self._im_middle, oi, is_sampled=False, is_content=False)
+ emit_text(content, oi, is_sampled=False, is_content=False)
+ emit_special(self._im_end, oi, is_sampled=False, is_content=False)
elif role == "user":
- emit_special(self._im_user, oi, is_sampled=False)
- emit_text("user", oi, is_sampled=False)
- emit_special(self._im_middle, oi, is_sampled=False)
- emit_text(content, oi, is_sampled=False)
- emit_special(self._im_end, oi, is_sampled=False)
+ emit_special(self._im_user, oi, is_sampled=False, is_content=False)
+ emit_text("user", oi, is_sampled=False, is_content=False)
+ emit_special(self._im_middle, oi, is_sampled=False, is_content=False)
+ emit_text(content, oi, is_sampled=False, is_content=body_is_content)
+ emit_special(self._im_end, oi, is_sampled=False, is_content=False)
elif role == "assistant":
# Kimi strips reasoning from historical assistant turns and
@@ -250,30 +278,35 @@ def emit_text(text: str, msg_idx: int, *, is_sampled: bool) -> None:
elif role == "tool":
self._render_tool(
- msg, oi, content, emit_special=emit_special, emit_text=emit_text
+ msg,
+ oi,
+ content,
+ emit_special=emit_special,
+ emit_text=emit_text,
)
else:
# Unknown role: use system-style formatting. Not a sampled
# assistant turn — every token is template-injected from the
# caller's POV, so is_sampled=False across the whole emission.
- emit_special(self._im_system, oi, is_sampled=False)
- emit_text(role, oi, is_sampled=False)
- emit_special(self._im_middle, oi, is_sampled=False)
- emit_text(content, oi, is_sampled=False)
- emit_special(self._im_end, oi, is_sampled=False)
+ emit_special(self._im_system, oi, is_sampled=False, is_content=False)
+ emit_text(role, oi, is_sampled=False, is_content=False)
+ emit_special(self._im_middle, oi, is_sampled=False, is_content=False)
+ emit_text(content, oi, is_sampled=False, is_content=body_is_content)
+ emit_special(self._im_end, oi, is_sampled=False, is_content=False)
# Generation prompt
if add_generation_prompt:
- emit_special(self._im_assistant, -1, is_sampled=False)
- emit_text("assistant", -1, is_sampled=False)
- emit_special(self._im_middle, -1, is_sampled=False)
+ emit_special(self._im_assistant, -1, is_sampled=False, is_content=False)
+ emit_text("assistant", -1, is_sampled=False, is_content=False)
+ emit_special(self._im_middle, -1, is_sampled=False, is_content=False)
return RenderedTokens(
token_ids=token_ids,
message_indices=indices,
sampled_mask=sampled,
- message_roles=[m.get("role") or "" for m in messages],
+ is_content=content_mask,
+ message_roles=[m.get("role") or "" for m in caller_messages],
)
def render_ids(
@@ -336,27 +369,40 @@ def bridge_to_next_turn(
ext: list[int] = []
ext_indices: list[int] = []
ext_sampled: list[bool] = []
+ ext_content: list[bool] = []
# Bridge populates ``message_indices`` (relative to ``new_messages``)
# and ``sampled_mask`` (uniformly ``False`` — every token the
# bridge emits is template scaffolding for the next prompt, not
- # something the model sampled). Downstream consumers can run
- # :meth:`RenderedTokens.tokens_per_message` on the bridge output
- # to get per-new-message token counts without re-rendering.
+ # something the model sampled). ``is_content`` follows the same
+ # rules as in :meth:`render` so consumers can walk the trajectory
+ # and read each step's own body mask. Downstream consumers can
+ # run :meth:`RenderedTokens.tokens_per_message` on the bridge
+ # output to get per-new-message token counts without re-rendering.
def emit_special(
- token_id: int, msg_idx: int = -1, *, is_sampled: bool = False
+ token_id: int,
+ msg_idx: int = -1,
+ *,
+ is_sampled: bool = False,
+ is_content: bool = False,
) -> None:
ext.append(token_id)
ext_indices.append(msg_idx)
ext_sampled.append(is_sampled)
+ ext_content.append(is_content)
def emit_text(
- text: str, msg_idx: int = -1, *, is_sampled: bool = False
+ text: str,
+ msg_idx: int = -1,
+ *,
+ is_sampled: bool = False,
+ is_content: bool = False,
) -> None:
ids = self._encode(text)
ext.extend(ids)
ext_indices.extend([msg_idx] * len(ids))
ext_sampled.extend([is_sampled] * len(ids))
+ ext_content.extend([is_content] * len(ids))
for i, msg in enumerate(new_messages):
role = msg.get("role")
@@ -376,34 +422,39 @@ def emit_text(
content = "".join(parts)
if role == "user":
- emit_special(self._im_user, i, is_sampled=False)
- emit_text("user", i, is_sampled=False)
- emit_special(self._im_middle, i, is_sampled=False)
- emit_text(content, i, is_sampled=False)
- emit_special(self._im_end, i, is_sampled=False)
+ emit_special(self._im_user, i)
+ emit_text("user", i)
+ emit_special(self._im_middle, i)
+ emit_text(content, i, is_content=True)
+ emit_special(self._im_end, i)
elif role == "system":
- emit_special(self._im_system, i, is_sampled=False)
- emit_text("system", i, is_sampled=False)
- emit_special(self._im_middle, i, is_sampled=False)
- emit_text(content, i, is_sampled=False)
- emit_special(self._im_end, i, is_sampled=False)
+ emit_special(self._im_system, i)
+ emit_text("system", i)
+ emit_special(self._im_middle, i)
+ emit_text(content, i, is_content=True)
+ emit_special(self._im_end, i)
elif role == "tool":
self._render_tool(
- msg, i, content, emit_special=emit_special, emit_text=emit_text
+ msg,
+ i,
+ content,
+ emit_special=emit_special,
+ emit_text=emit_text,
)
else:
return None
# Generation prompt.
- emit_special(self._im_assistant, -1, is_sampled=False)
- emit_text("assistant", -1, is_sampled=False)
- emit_special(self._im_middle, -1, is_sampled=False)
+ emit_special(self._im_assistant, -1)
+ emit_text("assistant", -1)
+ emit_special(self._im_middle, -1)
total_len = len(previous_ids) + len(ext)
return RenderedTokens(
token_ids=previous_ids + ext,
message_indices=[-1] * len(previous_ids) + ext_indices,
sampled_mask=[False] * total_len,
+ is_content=[False] * len(previous_ids) + ext_content,
message_roles=[m.get("role") or "" for m in new_messages],
)
@@ -421,23 +472,32 @@ def _render_assistant(
# scaffolding — at inference the chat template emits these as the
# generation prompt and the model never samples them. Marking the
# role tag as ``is_sampled=False`` keeps the SFT loss mask aligned
- # with what the model would actually have produced.
- emit_special(self._im_assistant, msg_idx, is_sampled=False)
- emit_text("assistant", msg_idx, is_sampled=False)
- emit_special(self._im_middle, msg_idx, is_sampled=False)
+ # with what the model would actually have produced. ``is_content``
+ # is also False here — the role tag isn't part of any message's
+ # body, on any role.
+ emit_special(self._im_assistant, msg_idx, is_sampled=False, is_content=False)
+ emit_text("assistant", msg_idx, is_sampled=False, is_content=False)
+ emit_special(self._im_middle, msg_idx, is_sampled=False, is_content=False)
# Kimi K2's Jinja template has no reasoning-content support: the
# assistant turn renders its ``content`` verbatim, including any
# inline ``...`` tags. The separate
# ``reasoning_content`` field is dropped (the template never reads
# it). ``is_last_turn`` is unused here for the same reason.
+ # On assistant tokens, ``is_content == sampled_mask`` by construction
+ # — every sampled token is body, every scaffold token isn't.
_ = is_last_turn
- emit_text(content, msg_idx, is_sampled=True)
+ emit_text(content, msg_idx, is_sampled=True, is_content=True)
- # Tool calls
+ # Tool calls — model-sampled markup carrying caller / model body.
tool_calls = msg.get("tool_calls") or []
if tool_calls:
- emit_special(self._tool_calls_section_begin, msg_idx, is_sampled=True)
+ emit_special(
+ self._tool_calls_section_begin,
+ msg_idx,
+ is_sampled=True,
+ is_content=True,
+ )
for tc in tool_calls:
func = tc.get("function") or tc
arguments = func.get("arguments", {})
@@ -451,16 +511,37 @@ def _render_assistant(
# caller to provide an id in ``functions.{name}:{idx}`` form
# (that's where the Kimi parser recovers the function name).
tc_id = tc.get("id") or ""
- emit_special(self._tool_call_begin, msg_idx, is_sampled=True)
- emit_text(tc_id, msg_idx, is_sampled=True)
- emit_special(self._tool_call_argument_begin, msg_idx, is_sampled=True)
- emit_text(args_str, msg_idx, is_sampled=True)
- emit_special(self._tool_call_end, msg_idx, is_sampled=True)
- emit_special(self._tool_calls_section_end, msg_idx, is_sampled=True)
+ emit_special(
+ self._tool_call_begin,
+ msg_idx,
+ is_sampled=True,
+ is_content=True,
+ )
+ emit_text(tc_id, msg_idx, is_sampled=True, is_content=True)
+ emit_special(
+ self._tool_call_argument_begin,
+ msg_idx,
+ is_sampled=True,
+ is_content=True,
+ )
+ emit_text(args_str, msg_idx, is_sampled=True, is_content=True)
+ emit_special(
+ self._tool_call_end,
+ msg_idx,
+ is_sampled=True,
+ is_content=True,
+ )
+ emit_special(
+ self._tool_calls_section_end,
+ msg_idx,
+ is_sampled=True,
+ is_content=True,
+ )
# ``<|im_end|>`` is the model's stop signal — it samples this to
- # end its turn, so it is part of the sampled stream.
- emit_special(self._im_end, msg_idx, is_sampled=True)
+ # end its turn, so it is part of the sampled stream (and the
+ # assistant's body).
+ emit_special(self._im_end, msg_idx, is_sampled=True, is_content=True)
def _render_tool(
self,
@@ -473,13 +554,28 @@ def _render_tool(
) -> None:
# Tool messages are conversation history injected by the runtime
# between assistant turns — the model never samples any of these
- # tokens, so every emission is is_sampled=False.
+ # tokens, so every emission is is_sampled=False. The ``content``
+ # field's body bytes get ``is_content=True``; everything else —
+ # the ``<|im_system|>name<|im_middle|>`` wrap, the ``## Return of
+ # …\n`` header (template-synthesised, not part of the body) —
+ # is scaffold so the SFT mask for tool body never trains the
+ # model to emit them.
+ #
+ # We keep the original kimi_k2 emit boundaries — the header and
+ # the content are encoded separately, which preserves the
+ # template's byte-identity since the original code also emitted
+ # them as separate ``encode`` calls.
name = msg.get("name") or "tool"
tool_call_id = msg.get("tool_call_id") or ""
- emit_special(self._im_system, msg_idx, is_sampled=False)
- emit_text(name, msg_idx, is_sampled=False)
- emit_special(self._im_middle, msg_idx, is_sampled=False)
- emit_text(f"## Return of {tool_call_id}\n", msg_idx, is_sampled=False)
- emit_text(content, msg_idx, is_sampled=False)
- emit_special(self._im_end, msg_idx, is_sampled=False)
+ emit_special(self._im_system, msg_idx, is_sampled=False, is_content=False)
+ emit_text(name, msg_idx, is_sampled=False, is_content=False)
+ emit_special(self._im_middle, msg_idx, is_sampled=False, is_content=False)
+ emit_text(
+ f"## Return of {tool_call_id}\n",
+ msg_idx,
+ is_sampled=False,
+ is_content=False,
+ )
+ emit_text(content, msg_idx, is_sampled=False, is_content=True)
+ emit_special(self._im_end, msg_idx, is_sampled=False, is_content=False)
diff --git a/renderers/kimi_k25.py b/renderers/kimi_k25.py
index 31ee759..d3729ac 100644
--- a/renderers/kimi_k25.py
+++ b/renderers/kimi_k25.py
@@ -749,24 +749,44 @@ def render(
tokens: list[int] = []
indices: list[int] = []
sampled: list[bool] = []
+ content_mask: list[bool] = []
mm_hashes: dict[str, list[str]] = {}
mm_placeholders: dict[str, list[PlaceholderRange]] = {}
mm_items: dict[str, list[dict[str, Any]]] = {}
- def emit_ids(ids: list[int], msg_idx: int, *, is_sampled: bool) -> None:
+ def emit_ids(
+ ids: list[int], msg_idx: int, *, is_sampled: bool, is_content: bool
+ ) -> None:
tokens.extend(ids)
indices.extend([msg_idx] * len(ids))
sampled.extend([is_sampled] * len(ids))
+ content_mask.extend([is_content] * len(ids))
- def emit_special(token_id: int, msg_idx: int, *, is_sampled: bool) -> None:
+ def emit_special(
+ token_id: int, msg_idx: int, *, is_sampled: bool, is_content: bool
+ ) -> None:
tokens.append(token_id)
indices.append(msg_idx)
sampled.append(is_sampled)
+ content_mask.append(is_content)
- def emit_text(text: str, msg_idx: int, *, is_sampled: bool) -> None:
- emit_ids(self._encode(text), msg_idx, is_sampled=is_sampled)
+ def emit_text(
+ text: str, msg_idx: int, *, is_sampled: bool, is_content: bool
+ ) -> None:
+ emit_ids(
+ self._encode(text),
+ msg_idx,
+ is_sampled=is_sampled,
+ is_content=is_content,
+ )
- def emit_image(part: dict[str, Any], msg_idx: int, *, is_sampled: bool) -> None:
+ def emit_image(
+ part: dict[str, Any],
+ msg_idx: int,
+ *,
+ is_sampled: bool,
+ is_content: bool,
+ ) -> None:
"""Emit Kimi K2.5's image wrap and accumulate ``mm_data``.
Template-equivalent expansion per image:
@@ -779,15 +799,32 @@ def emit_image(part: dict[str, Any], msg_idx: int, *, is_sampled: bool) -> None:
Kimi's chat template after every image — kept here verbatim
for byte-parity, regardless of what follows (more images,
text, or the ``<|im_end|>`` close).
+
+ ``is_content`` attribution: ``<|media_pad|>`` represents
+ caller-provided image data — body. The wrap tokens
+ (``<|media_begin|>``, the literal ``"image"`` prose,
+ ``<|media_content|>``, ``<|media_end|>``, the trailing
+ ``\\n``) are template-injected scaffold.
"""
_, out, _num_patches, h = self._process_image(part)
- emit_special(self._media_begin, msg_idx, is_sampled=is_sampled)
- emit_text("image", msg_idx, is_sampled=is_sampled)
- emit_special(self._media_content, msg_idx, is_sampled=is_sampled)
+ emit_special(
+ self._media_begin, msg_idx, is_sampled=is_sampled, is_content=False
+ )
+ emit_text("image", msg_idx, is_sampled=is_sampled, is_content=False)
+ emit_special(
+ self._media_content, msg_idx, is_sampled=is_sampled, is_content=False
+ )
offset = len(tokens)
- emit_special(self._media_pad, msg_idx, is_sampled=is_sampled)
- emit_special(self._media_end, msg_idx, is_sampled=is_sampled)
- emit_text("\n", msg_idx, is_sampled=is_sampled)
+ emit_special(
+ self._media_pad,
+ msg_idx,
+ is_sampled=is_sampled,
+ is_content=is_content,
+ )
+ emit_special(
+ self._media_end, msg_idx, is_sampled=is_sampled, is_content=False
+ )
+ emit_text("\n", msg_idx, is_sampled=is_sampled, is_content=False)
mm_hashes.setdefault("image", []).append(h)
mm_placeholders.setdefault("image", []).append(
PlaceholderRange(offset=offset, length=1)
@@ -807,13 +844,16 @@ def emit_image(part: dict[str, Any], msg_idx: int, *, is_sampled: bool) -> None:
# K2.5/K2.6's tokenizer auto-computes ``tools_ts_str`` and threads
# it into apply_chat_template, so the template's TS branch always
# fires when tools are present. Match that with our own TS encoder.
+ # The tools-TS body is recoverable from the ``tools`` argument, so
+ # we attribute the entire tool_declare emission as scaffold —
+ # consistent with Qwen3's tools-header treatment.
if tools:
tools_ts = _encode_tools_typescript(tools)
- emit_special(self._im_system, -1, is_sampled=False)
- emit_text("tool_declare", -1, is_sampled=False)
- emit_special(self._im_middle, -1, is_sampled=False)
- emit_text(tools_ts, -1, is_sampled=False)
- emit_special(self._im_end, -1, is_sampled=False)
+ emit_special(self._im_system, -1, is_sampled=False, is_content=False)
+ emit_text("tool_declare", -1, is_sampled=False, is_content=False)
+ emit_special(self._im_middle, -1, is_sampled=False, is_content=False)
+ emit_text(tools_ts, -1, is_sampled=False, is_content=False)
+ emit_special(self._im_end, -1, is_sampled=False, is_content=False)
# ── Iterate messages ─────────────────────────────────────────
for i, msg in enumerate(messages):
@@ -824,14 +864,14 @@ def emit_image(part: dict[str, Any], msg_idx: int, *, is_sampled: bool) -> None:
# generation prompt emits them at inference for assistants;
# user / system / tool roles are conversation history).
if role == "user":
- emit_special(self._im_user, i, is_sampled=False)
+ emit_special(self._im_user, i, is_sampled=False, is_content=False)
elif role == "assistant":
- emit_special(self._im_assistant, i, is_sampled=False)
+ emit_special(self._im_assistant, i, is_sampled=False, is_content=False)
else:
- emit_special(self._im_system, i, is_sampled=False)
+ emit_special(self._im_system, i, is_sampled=False, is_content=False)
role_name = msg.get("name") or role
- emit_text(role_name, i, is_sampled=False)
- emit_special(self._im_middle, i, is_sampled=False)
+ emit_text(role_name, i, is_sampled=False, is_content=False)
+ emit_special(self._im_middle, i, is_sampled=False, is_content=False)
# Body
if role == "assistant":
@@ -852,10 +892,10 @@ def emit_image(part: dict[str, Any], msg_idx: int, *, is_sampled: bool) -> None:
)
# ``<|im_end|>`` is the model's stop signal — it samples
# this to end its turn, so it is part of the sampled
- # stream. Kimi K2.5 has no inter-turn ``\n`` separator
- # (unlike Qwen3), so the turn-close token is the last
- # sampled token.
- emit_special(self._im_end, i, is_sampled=True)
+ # stream (and the assistant's body). Kimi K2.5 has no
+ # inter-turn ``\n`` separator (unlike Qwen3), so the
+ # turn-close token is the last sampled token.
+ emit_special(self._im_end, i, is_sampled=True, is_content=True)
continue
elif role == "tool":
self._render_tool_body(
@@ -869,7 +909,8 @@ def emit_image(part: dict[str, Any], msg_idx: int, *, is_sampled: bool) -> None:
elif msg.get("content") is not None:
# User / other content branches — images allowed. All
# non-assistant content is conversation history, never
- # sampled by the model.
+ # sampled by the model. The body is caller-provided, so
+ # ``is_content=True`` over the content emit.
self._emit_content(
msg.get("content"),
i,
@@ -878,21 +919,22 @@ def emit_image(part: dict[str, Any], msg_idx: int, *, is_sampled: bool) -> None:
emit_ids,
emit_image=emit_image,
is_sampled=False,
+ is_content=True,
)
- emit_special(self._im_end, i, is_sampled=False)
+ emit_special(self._im_end, i, is_sampled=False, is_content=False)
# ── Generation prompt ────────────────────────────────────────
if add_generation_prompt:
- emit_special(self._im_assistant, -1, is_sampled=False)
- emit_text("assistant", -1, is_sampled=False)
- emit_special(self._im_middle, -1, is_sampled=False)
+ emit_special(self._im_assistant, -1, is_sampled=False, is_content=False)
+ emit_text("assistant", -1, is_sampled=False, is_content=False)
+ emit_special(self._im_middle, -1, is_sampled=False, is_content=False)
if self._enable_thinking:
# Prefill open tag to trigger thinking mode
- emit_text("", -1, is_sampled=False)
+ emit_text("", -1, is_sampled=False, is_content=False)
else:
# Empty to disable thinking
- emit_text("", -1, is_sampled=False)
+ emit_text("", -1, is_sampled=False, is_content=False)
mm_data: MultiModalData | None = None
if mm_hashes or mm_placeholders or mm_items:
@@ -906,6 +948,7 @@ def emit_image(part: dict[str, Any], msg_idx: int, *, is_sampled: bool) -> None:
token_ids=tokens,
message_indices=indices,
sampled_mask=sampled,
+ is_content=content_mask,
message_roles=[m.get("role") or "" for m in messages],
multi_modal_data=mm_data,
)
@@ -997,49 +1040,74 @@ def bridge_to_next_turn(
# Seed combined-token list with prior turn so placeholder offsets
# are absolute in the bridged sequence. Parallel
- # ``indices``/``sampled`` are seeded with ``-1``/``False`` for the
- # prior portion — the bridge has no attribution info for
- # ``previous_ids``. Bridge-added tokens get proper ``msg_idx``
- # (relative to ``new_messages``) and uniformly ``False``
- # ``sampled``: nothing the bridge emits was model-sampled.
+ # ``indices``/``sampled``/``content_mask`` are seeded with
+ # ``-1``/``False``/``False`` for the prior portion — the bridge
+ # has no attribution info for ``previous_ids``. Bridge-added
+ # tokens get proper ``msg_idx`` (relative to ``new_messages``)
+ # and uniformly ``False`` ``sampled``: nothing the bridge emits
+ # was model-sampled. ``is_content`` follows the same rules as in
+ # :meth:`render` so consumers can walk the trajectory and read
+ # each step's own body mask.
tokens: list[int] = list(previous_ids)
indices: list[int] = [-1] * len(previous_ids)
sampled: list[bool] = [False] * len(previous_ids)
+ content_mask: list[bool] = [False] * len(previous_ids)
new_hashes: dict[str, list[str]] = {}
new_placeholders: dict[str, list[PlaceholderRange]] = {}
new_items: dict[str, list[dict[str, Any]]] = {}
def emit_special(
- token_id: int, msg_idx: int = -1, *, is_sampled: bool = False
+ token_id: int,
+ msg_idx: int = -1,
+ *,
+ is_sampled: bool = False,
+ is_content: bool = False,
) -> None:
tokens.append(token_id)
indices.append(msg_idx)
sampled.append(is_sampled)
+ content_mask.append(is_content)
def emit_text(
- text: str, msg_idx: int = -1, *, is_sampled: bool = False
+ text: str,
+ msg_idx: int = -1,
+ *,
+ is_sampled: bool = False,
+ is_content: bool = False,
) -> None:
ids = self._encode(text)
tokens.extend(ids)
indices.extend([msg_idx] * len(ids))
sampled.extend([is_sampled] * len(ids))
+ content_mask.extend([is_content] * len(ids))
def emit_ids(
- ids: list[int], msg_idx: int = -1, *, is_sampled: bool = False
+ ids: list[int],
+ msg_idx: int = -1,
+ *,
+ is_sampled: bool = False,
+ is_content: bool = False,
) -> None:
tokens.extend(ids)
indices.extend([msg_idx] * len(ids))
sampled.extend([is_sampled] * len(ids))
+ content_mask.extend([is_content] * len(ids))
def emit_image(
- part: dict[str, Any], msg_idx: int = -1, *, is_sampled: bool = False
+ part: dict[str, Any],
+ msg_idx: int = -1,
+ *,
+ is_sampled: bool = False,
+ is_content: bool = False,
) -> None:
_, out, _num_patches, h = self._process_image(part)
emit_special(self._media_begin, msg_idx)
emit_text("image", msg_idx)
emit_special(self._media_content, msg_idx)
offset = len(tokens)
- emit_special(self._media_pad, msg_idx)
+ # ``<|media_pad|>`` stands in for caller-provided image data —
+ # mark it as body when the surrounding content is body.
+ emit_special(self._media_pad, msg_idx, is_content=is_content)
emit_special(self._media_end, msg_idx)
emit_text("\n", msg_idx)
new_hashes.setdefault("image", []).append(h)
@@ -1086,6 +1154,7 @@ def emit_image(
emit_ids,
emit_image=emit_image,
is_sampled=False,
+ is_content=True,
)
emit_special(self._im_end, i)
@@ -1128,6 +1197,7 @@ def emit_image(
token_ids=tokens,
message_indices=indices,
sampled_mask=sampled,
+ is_content=content_mask,
message_roles=bridge_roles,
)
@@ -1140,6 +1210,7 @@ def emit_image(
token_ids=tokens,
message_indices=indices,
sampled_mask=sampled,
+ is_content=content_mask,
message_roles=bridge_roles,
multi_modal_data=mm_data,
)
@@ -1158,6 +1229,7 @@ def _emit_content(
*,
emit_image=None,
is_sampled: bool,
+ is_content: bool = True,
) -> None:
"""Emit message content, handling strings, multipart lists, and (when
``emit_image`` is supplied) image parts.
@@ -1172,11 +1244,15 @@ def _emit_content(
``bridge_to_next_turn``), so consecutive images naturally
produce the template's ``...<|media_end|>\\n<|media_begin|>...``
pattern without an inter-image separator here.
+
+ ``is_content`` propagates to text emits and to the
+ ``<|media_pad|>`` body token of each image; the surrounding
+ media wrap tokens are always scaffold (handled by ``emit_image``).
"""
if content is None:
return
if isinstance(content, str):
- emit_text(content, msg_idx, is_sampled=is_sampled)
+ emit_text(content, msg_idx, is_sampled=is_sampled, is_content=is_content)
return
if isinstance(content, list):
for part in content:
@@ -1189,20 +1265,30 @@ def _emit_content(
if emit_image is None:
# Silently drop — caller didn't opt into multimodal.
continue
- emit_image(part, msg_idx, is_sampled=is_sampled)
+ emit_image(
+ part, msg_idx, is_sampled=is_sampled, is_content=is_content
+ )
continue
if is_video:
raise NotImplementedError(
"Video parts are not yet supported by KimiK25Renderer."
)
if ptype == "text":
- emit_text(part.get("text", ""), msg_idx, is_sampled=is_sampled)
+ emit_text(
+ part.get("text", ""),
+ msg_idx,
+ is_sampled=is_sampled,
+ is_content=is_content,
+ )
elif ptype == "thinking":
# Thinking parts in non-assistant roles are rendered as text
thinking = part.get("thinking", "")
if thinking:
emit_text(
- f"{thinking}", msg_idx, is_sampled=is_sampled
+ f"{thinking}",
+ msg_idx,
+ is_sampled=is_sampled,
+ is_content=is_content,
)
# Other part types are silently skipped
@@ -1265,16 +1351,27 @@ def _render_assistant_body(
# block, text content, and any tool_calls section all live in
# the sampled stream. The closing ``<|im_end|>`` is emitted by
# ``render`` (also is_sampled=True; it's the model's stop
- # signal).
+ # signal). On assistant tokens ``is_content == sampled_mask`` by
+ # construction.
if is_suffix or (preserve_thinking and reasoning_content):
- emit_text(f"{reasoning_content}", msg_idx, is_sampled=True)
+ emit_text(
+ f"{reasoning_content}",
+ msg_idx,
+ is_sampled=True,
+ is_content=True,
+ )
else:
- emit_text("", msg_idx, is_sampled=True)
- emit_text(text_content, msg_idx, is_sampled=True)
+ emit_text("", msg_idx, is_sampled=True, is_content=True)
+ emit_text(text_content, msg_idx, is_sampled=True, is_content=True)
tool_calls = msg.get("tool_calls") or []
if tool_calls:
- emit_special(self._tool_calls_section_begin, msg_idx, is_sampled=True)
+ emit_special(
+ self._tool_calls_section_begin,
+ msg_idx,
+ is_sampled=True,
+ is_content=True,
+ )
for tc in tool_calls:
func = tc.get("function") or tc
arguments = func.get("arguments", {})
@@ -1288,12 +1385,32 @@ def _render_assistant_body(
# ``functions.{name}:{idx}`` form (Kimi's parser recovers
# the function name from that field).
tool_id = tc.get("id") or ""
- emit_special(self._tool_call_begin, msg_idx, is_sampled=True)
- emit_text(tool_id, msg_idx, is_sampled=True)
- emit_special(self._tool_call_argument_begin, msg_idx, is_sampled=True)
- emit_text(args_str, msg_idx, is_sampled=True)
- emit_special(self._tool_call_end, msg_idx, is_sampled=True)
- emit_special(self._tool_calls_section_end, msg_idx, is_sampled=True)
+ emit_special(
+ self._tool_call_begin,
+ msg_idx,
+ is_sampled=True,
+ is_content=True,
+ )
+ emit_text(tool_id, msg_idx, is_sampled=True, is_content=True)
+ emit_special(
+ self._tool_call_argument_begin,
+ msg_idx,
+ is_sampled=True,
+ is_content=True,
+ )
+ emit_text(args_str, msg_idx, is_sampled=True, is_content=True)
+ emit_special(
+ self._tool_call_end,
+ msg_idx,
+ is_sampled=True,
+ is_content=True,
+ )
+ emit_special(
+ self._tool_calls_section_end,
+ msg_idx,
+ is_sampled=True,
+ is_content=True,
+ )
def _render_tool_body(
self,
@@ -1319,9 +1436,16 @@ def _render_tool_body(
"""
# Tool messages are conversation-history injected by the runtime
# between assistant turns — the model never samples any of these
- # tokens, so every emission is is_sampled=False.
+ # tokens, so every emission is is_sampled=False. The ``## Return
+ # of …\n`` header is template-synthesised scaffold; the
+ # ``content`` body bytes get ``is_content=True``.
tool_call_id = msg.get("tool_call_id") or ""
- emit_text(f"## Return of {tool_call_id}\n", msg_idx, is_sampled=False)
+ emit_text(
+ f"## Return of {tool_call_id}\n",
+ msg_idx,
+ is_sampled=False,
+ is_content=False,
+ )
content = msg.get("content")
if content is not None:
self._emit_content(
@@ -1332,6 +1456,7 @@ def _render_tool_body(
emit_ids,
emit_image=emit_image,
is_sampled=False,
+ is_content=True,
)
def _normalize_response_tokens(self, response: list[int]) -> list[int]:
diff --git a/renderers/laguna_xs2.py b/renderers/laguna_xs2.py
index df62e07..ce85037 100644
--- a/renderers/laguna_xs2.py
+++ b/renderers/laguna_xs2.py
@@ -35,6 +35,7 @@
ParsedResponse,
RenderedTokens,
ToolSpec,
+ attribute_text_segments,
reject_assistant_in_extension,
)
from renderers.parsing import parse_laguna_xs2
@@ -154,24 +155,43 @@ def render(
tokens: list[int] = []
indices: list[int] = []
sampled: list[bool] = []
+ content_mask: list[bool] = []
- def emit_special(token_id: int, msg_idx: int, *, is_sampled: bool) -> None:
+ def emit_special(
+ token_id: int, msg_idx: int, *, is_sampled: bool, is_content: bool
+ ) -> None:
tokens.append(token_id)
indices.append(msg_idx)
sampled.append(is_sampled)
+ content_mask.append(is_content)
- def emit_text(text: str, msg_idx: int, *, is_sampled: bool) -> None:
+ def emit_text(
+ text: str, msg_idx: int, *, is_sampled: bool, is_content: bool
+ ) -> None:
ids = self._encode(text)
tokens.extend(ids)
indices.extend([msg_idx] * len(ids))
sampled.extend([is_sampled] * len(ids))
+ content_mask.extend([is_content] * len(ids))
- emit_special(self._eos, -1, is_sampled=False)
+ def emit_text_segments(
+ segments: list[tuple[str, bool]], msg_idx: int, *, is_sampled: bool
+ ) -> None:
+ for tok_id, is_content in attribute_text_segments(
+ self._tokenizer, segments
+ ):
+ tokens.append(tok_id)
+ indices.append(msg_idx)
+ sampled.append(is_sampled)
+ content_mask.append(is_content)
+
+ emit_special(self._eos, -1, is_sampled=False, is_content=False)
# ── System header (absorbs messages[0] if it's a system message) ──
system_content = _DEFAULT_SYSTEM_MESSAGE
system_msg_idx = -1
- if messages and messages[0].get("role") == "system":
+ caller_has_system = bool(messages and messages[0].get("role") == "system")
+ if caller_has_system:
system_content = self._visible_text(messages[0].get("content"))
system_msg_idx = 0
@@ -187,9 +207,18 @@ def emit_text(text: str, msg_idx: int, *, is_sampled: bool) -> None:
"\n\n" if has_sys_content else "\n",
-1,
is_sampled=False,
+ is_content=False,
)
if has_sys_content:
- emit_text(system_content.rstrip(), system_msg_idx, is_sampled=False)
+ # If the caller provided system content, it's body bytes;
+ # otherwise this is the default system prompt (scaffold).
+ sys_is_content = caller_has_system
+ emit_text(
+ system_content.rstrip(),
+ system_msg_idx,
+ is_sampled=False,
+ is_content=sys_is_content,
+ )
if tools:
tool_text = _TOOLS_HEADER
for tool in tools:
@@ -199,8 +228,8 @@ def emit_text(text: str, msg_idx: int, *, is_sampled: bool) -> None:
if self._enable_thinking
else _TOOLS_FOOTER_NO_THINKING
)
- emit_text(tool_text, -1, is_sampled=False)
- emit_text("\n\n", -1, is_sampled=False)
+ emit_text(tool_text, -1, is_sampled=False, is_content=False)
+ emit_text("\n\n", -1, is_sampled=False, is_content=False)
# ── Per-message loop ──────────────────────────────────────────
for i, msg in enumerate(messages):
@@ -211,35 +240,49 @@ def emit_text(text: str, msg_idx: int, *, is_sampled: bool) -> None:
# Already consumed in the header block.
if i == 0:
continue
- emit_text(
- "\n" + content + "\n\n", i, is_sampled=False
- )
+ # Body = caller's content; the ``...``
+ # wrap and surrounding ``\n``s are scaffold.
+ sys_segs: list[tuple[str, bool]] = [("\n", False)]
+ if content:
+ sys_segs.append((content, True))
+ sys_segs.append(("\n\n", False))
+ emit_text_segments(sys_segs, i, is_sampled=False)
case "user":
- emit_text("\n" + content + "\n\n", i, is_sampled=False)
+ user_segs: list[tuple[str, bool]] = [("\n", False)]
+ if content:
+ user_segs.append((content, True))
+ user_segs.append(("\n\n", False))
+ emit_text_segments(user_segs, i, is_sampled=False)
case "assistant":
self._render_assistant(
- msg, i, content, emit_special=emit_special, emit_text=emit_text
- )
- case "tool":
- emit_text(
- "\n" + content + "\n\n",
+ msg,
i,
- is_sampled=False,
+ content,
+ emit_special=emit_special,
+ emit_text=emit_text,
+ emit_text_segments=emit_text_segments,
)
+ case "tool":
+ tool_segs: list[tuple[str, bool]] = [("\n", False)]
+ if content:
+ tool_segs.append((content, True))
+ tool_segs.append(("\n\n", False))
+ emit_text_segments(tool_segs, i, is_sampled=False)
# ── Generation prompt ─────────────────────────────────────────
if add_generation_prompt:
- emit_special(self._assistant, -1, is_sampled=False)
- emit_text("\n", -1, is_sampled=False)
+ emit_special(self._assistant, -1, is_sampled=False, is_content=False)
+ emit_text("\n", -1, is_sampled=False, is_content=False)
if self._enable_thinking:
- emit_special(self._think, -1, is_sampled=False)
+ emit_special(self._think, -1, is_sampled=False, is_content=False)
else:
- emit_special(self._think_end, -1, is_sampled=False)
+ emit_special(self._think_end, -1, is_sampled=False, is_content=False)
return RenderedTokens(
token_ids=tokens,
message_indices=indices,
sampled_mask=sampled,
+ is_content=content_mask,
message_roles=[m.get("role") or "" for m in messages],
)
@@ -307,37 +350,74 @@ def bridge_to_next_turn(
ext: list[int] = []
ext_indices: list[int] = []
ext_sampled: list[bool] = []
+ ext_content: list[bool] = []
# Bridge populates ``message_indices`` (relative to ``new_messages``)
# and ``sampled_mask`` (uniformly ``False`` — every token the
# bridge emits is template scaffolding for the next prompt, not
- # something the model sampled). Downstream consumers can run
- # :meth:`RenderedTokens.tokens_per_message` on the bridge output
- # to get per-new-message token counts without re-rendering.
+ # something the model sampled). ``is_content`` follows the same
+ # rules as in :meth:`render` so consumers can walk the trajectory
+ # and read each step's own body mask.
def emit_special(
- token_id: int, msg_idx: int = -1, *, is_sampled: bool = False
+ token_id: int,
+ msg_idx: int = -1,
+ *,
+ is_sampled: bool = False,
+ is_content: bool = False,
) -> None:
ext.append(token_id)
ext_indices.append(msg_idx)
ext_sampled.append(is_sampled)
+ ext_content.append(is_content)
def emit_text(
- text: str, msg_idx: int = -1, *, is_sampled: bool = False
+ text: str,
+ msg_idx: int = -1,
+ *,
+ is_sampled: bool = False,
+ is_content: bool = False,
) -> None:
ids = self._encode(text)
ext.extend(ids)
ext_indices.extend([msg_idx] * len(ids))
ext_sampled.extend([is_sampled] * len(ids))
+ ext_content.extend([is_content] * len(ids))
+
+ def emit_text_segments(
+ segments: list[tuple[str, bool]],
+ msg_idx: int = -1,
+ *,
+ is_sampled: bool = False,
+ ) -> None:
+ for tok_id, is_content in attribute_text_segments(
+ self._tokenizer, segments
+ ):
+ ext.append(tok_id)
+ ext_indices.append(msg_idx)
+ ext_sampled.append(is_sampled)
+ ext_content.append(is_content)
for i, msg in enumerate(new_messages):
role = msg.get("role")
content = self._visible_text(msg.get("content"))
if role == "user":
- emit_text("\n" + content + "\n\n", i)
+ segs: list[tuple[str, bool]] = [("\n", False)]
+ if content:
+ segs.append((content, True))
+ segs.append(("\n\n", False))
+ emit_text_segments(segs, i)
elif role == "system":
- emit_text("\n" + content + "\n\n", i)
+ segs = [("\n", False)]
+ if content:
+ segs.append((content, True))
+ segs.append(("\n\n", False))
+ emit_text_segments(segs, i)
elif role == "tool":
- emit_text("\n" + content + "\n\n", i)
+ segs = [("\n", False)]
+ if content:
+ segs.append((content, True))
+ segs.append(("\n\n", False))
+ emit_text_segments(segs, i)
else:
return None
@@ -353,6 +433,7 @@ def emit_text(
token_ids=previous_ids + ext,
message_indices=[-1] * len(previous_ids) + ext_indices,
sampled_mask=[False] * total_len,
+ is_content=[False] * len(previous_ids) + ext_content,
message_roles=[m.get("role") or "" for m in new_messages],
)
@@ -364,6 +445,7 @@ def _render_assistant(
*,
emit_special,
emit_text,
+ emit_text_segments,
) -> None:
reasoning_content = ""
if isinstance(msg.get("reasoning_content"), str):
@@ -380,23 +462,30 @@ def _render_assistant(
# template emits these as the generation prompt at inference and
# the model never samples them. Marking the role tag as
# ``is_sampled=False`` keeps the SFT loss mask aligned with what
- # the model would actually have produced.
- emit_special(self._assistant, msg_idx, is_sampled=False)
- emit_text("\n", msg_idx, is_sampled=False)
+ # the model would actually have produced. ``is_content`` is also
+ # False on the role tag. On assistant the invariant
+ # ``is_content == sampled_mask`` holds.
+ emit_special(self._assistant, msg_idx, is_sampled=False, is_content=False)
+ emit_text("\n", msg_idx, is_sampled=False, is_content=False)
if reasoning_content:
- emit_special(self._think, msg_idx, is_sampled=True)
- emit_text("\n" + reasoning_content.strip() + "\n", msg_idx, is_sampled=True)
- emit_special(self._think_end, msg_idx, is_sampled=True)
+ emit_special(self._think, msg_idx, is_sampled=True, is_content=True)
+ emit_text(
+ "\n" + reasoning_content.strip() + "\n",
+ msg_idx,
+ is_sampled=True,
+ is_content=True,
+ )
+ emit_special(self._think_end, msg_idx, is_sampled=True, is_content=True)
else:
- emit_special(self._think_end, msg_idx, is_sampled=True)
+ emit_special(self._think_end, msg_idx, is_sampled=True, is_content=True)
# Combined newline-after- with optional content. Bundling
# preserves BPE merges across the boundary.
post_think_text = "\n"
if content.strip():
post_think_text += content.strip() + "\n"
- emit_text(post_think_text, msg_idx, is_sampled=True)
+ emit_text(post_think_text, msg_idx, is_sampled=True, is_content=True)
tool_calls = msg.get("tool_calls") or []
for tc in tool_calls:
@@ -409,7 +498,7 @@ def _render_assistant(
except json.JSONDecodeError:
arguments = {}
- emit_special(self._tool_call, msg_idx, is_sampled=True)
+ emit_special(self._tool_call, msg_idx, is_sampled=True, is_content=True)
inner = name + "\n"
if isinstance(arguments, dict):
for k, v in arguments.items():
@@ -419,13 +508,13 @@ def _render_assistant(
else:
val_text = json.dumps(v, ensure_ascii=False)
inner += "" + val_text + "\n"
- emit_text(inner, msg_idx, is_sampled=True)
- emit_special(self._tool_call_end, msg_idx, is_sampled=True)
- emit_text("\n", msg_idx, is_sampled=True)
+ emit_text(inner, msg_idx, is_sampled=True, is_content=True)
+ emit_special(self._tool_call_end, msg_idx, is_sampled=True, is_content=True)
+ emit_text("\n", msg_idx, is_sampled=True, is_content=True)
# ```` is the model's stop signal (alongside
# ``〈|EOS|〉``) — it samples this to end its turn, so it's part of
# the sampled stream. The trailing ``\n`` is template-appended
# between turns and never sampled.
- emit_special(self._assistant_end, msg_idx, is_sampled=True)
- emit_text("\n", msg_idx, is_sampled=False)
+ emit_special(self._assistant_end, msg_idx, is_sampled=True, is_content=True)
+ emit_text("\n", msg_idx, is_sampled=False, is_content=False)
diff --git a/renderers/minimax_m2.py b/renderers/minimax_m2.py
index 45357fd..f3c26c8 100644
--- a/renderers/minimax_m2.py
+++ b/renderers/minimax_m2.py
@@ -21,6 +21,7 @@
ParsedResponse,
RenderedTokens,
ToolSpec,
+ attribute_text_segments,
reject_assistant_in_extension,
should_preserve_past_thinking,
trim_to_turn_close,
@@ -119,17 +120,69 @@ def render(
tokens: list[int] = []
indices: list[int] = []
sampled: list[bool] = []
+ content_mask: list[bool] = []
- def emit_special(token_id: int, msg_idx: int, *, is_sampled: bool) -> None:
+ def emit_special(
+ token_id: int, msg_idx: int, *, is_sampled: bool, is_content: bool
+ ) -> None:
tokens.append(token_id)
indices.append(msg_idx)
sampled.append(is_sampled)
+ content_mask.append(is_content)
- def emit_text(text: str, msg_idx: int, *, is_sampled: bool) -> None:
+ def emit_text(
+ text: str, msg_idx: int, *, is_sampled: bool, is_content: bool
+ ) -> None:
ids = self._encode(text)
tokens.extend(ids)
indices.extend([msg_idx] * len(ids))
sampled.extend([is_sampled] * len(ids))
+ content_mask.extend([is_content] * len(ids))
+
+ def emit_text_segments(
+ segments: list[tuple[str, bool]], msg_idx: int, *, is_sampled: bool
+ ) -> None:
+ for tok_id, is_content in attribute_text_segments(
+ self._tokenizer, segments
+ ):
+ tokens.append(tok_id)
+ indices.append(msg_idx)
+ sampled.append(is_sampled)
+ content_mask.append(is_content)
+
+ def emit_token_overlap_body(
+ full_text: str,
+ body_start: int,
+ body_end: int,
+ msg_idx: int,
+ *,
+ is_sampled: bool,
+ ) -> None:
+ """Tokenize ``full_text`` and mark tokens that overlap the body
+ char span as ``is_content=True``.
+
+ Differs from :func:`attribute_text_segments` only in the
+ boundary-token rule: a token straddling scaffold→body gets
+ ``True`` if any of its bytes are body bytes (overlap rule),
+ rather than being attributed to whichever segment its first
+ char belongs to. The body's first byte is preserved even when
+ BPE merges it with the wrap's trailing byte (``>The`` →
+ single token).
+ """
+ from renderers.base import _get_offset_tokenizer
+
+ offset_tok = _get_offset_tokenizer(self._tokenizer)
+ encoding = offset_tok(
+ full_text, add_special_tokens=False, return_offsets_mapping=True
+ )
+ for tok_id, (start, end) in zip(
+ encoding["input_ids"], encoding["offset_mapping"]
+ ):
+ overlaps = start < body_end and end > body_start
+ tokens.append(tok_id)
+ indices.append(msg_idx)
+ sampled.append(is_sampled)
+ content_mask.append(overlaps)
# ── Extract system message ──────────────────────────────────
first_is_system = messages[0].get("role") == "system"
@@ -137,27 +190,38 @@ def emit_text(text: str, msg_idx: int, *, is_sampled: bool) -> None:
conversation: list[Message] = messages[1:] if first_is_system else messages
# ── System block (always present) ───────────────────────────
- emit_special(self._bos, sys_idx, is_sampled=False)
- emit_special(self._role, sys_idx, is_sampled=False)
+ emit_special(self._bos, sys_idx, is_sampled=False, is_content=False)
+ emit_special(self._role, sys_idx, is_sampled=False, is_content=False)
sys_content = (
self._visible_text(messages[0].get("content")) if first_is_system else ""
)
- system_text = "system\n" + (sys_content or self._default_system)
+ # Body = caller's system content (if any). Default system message
+ # is template-injected scaffold; tools header / per-tool JSON /
+ # footer / instructions are scaffold too (the tools dict is
+ # recoverable from the ``tools`` arg).
+ sys_segments: list[tuple[str, bool]] = [("system\n", False)]
+ if sys_content:
+ sys_segments.append((sys_content, True))
+ else:
+ sys_segments.append((self._default_system, False))
if tools:
- system_text += _TOOLS_HEADER
+ sys_segments.append((_TOOLS_HEADER, False))
for tool in tools:
func = tool.get("function", tool)
- system_text += (
- "" + json.dumps(func, ensure_ascii=False) + "\n"
+ sys_segments.append(
+ (
+ "" + json.dumps(func, ensure_ascii=False) + "\n",
+ False,
+ )
)
- system_text += _TOOLS_FOOTER_PREFIX
- system_text += _TOOLS_INSTRUCTIONS
+ sys_segments.append((_TOOLS_FOOTER_PREFIX, False))
+ sys_segments.append((_TOOLS_INSTRUCTIONS, False))
- emit_text(system_text, sys_idx, is_sampled=False)
- emit_special(self._eos, sys_idx, is_sampled=False)
- emit_text("\n", sys_idx, is_sampled=False)
+ emit_text_segments(sys_segments, sys_idx, is_sampled=False)
+ emit_special(self._eos, sys_idx, is_sampled=False, is_content=False)
+ emit_text("\n", sys_idx, is_sampled=False, is_content=False)
# ── Compute last_user_index (relative to conversation) ──────
last_ui = -1
@@ -172,14 +236,14 @@ def emit_text(text: str, msg_idx: int, *, is_sampled: bool) -> None:
orig_idx = ci + (1 if first_is_system else 0)
if role == "user":
- emit_special(self._role, orig_idx, is_sampled=False)
- emit_text(
- "user\n" + self._visible_text(msg.get("content")),
- orig_idx,
- is_sampled=False,
- )
- emit_special(self._eos, orig_idx, is_sampled=False)
- emit_text("\n", orig_idx, is_sampled=False)
+ emit_special(self._role, orig_idx, is_sampled=False, is_content=False)
+ user_content = self._visible_text(msg.get("content"))
+ user_segments: list[tuple[str, bool]] = [("user\n", False)]
+ if user_content:
+ user_segments.append((user_content, True))
+ emit_text_segments(user_segments, orig_idx, is_sampled=False)
+ emit_special(self._eos, orig_idx, is_sampled=False, is_content=False)
+ emit_text("\n", orig_idx, is_sampled=False, is_content=False)
elif role == "assistant":
preserve_thinking = should_preserve_past_thinking(
@@ -196,6 +260,7 @@ def emit_text(text: str, msg_idx: int, *, is_sampled: bool) -> None:
preserve_thinking=preserve_thinking,
emit_special=emit_special,
emit_text=emit_text,
+ emit_text_segments=emit_text_segments,
)
elif role == "tool":
@@ -206,19 +271,22 @@ def emit_text(text: str, msg_idx: int, *, is_sampled: bool) -> None:
msg,
emit_special=emit_special,
emit_text=emit_text,
+ emit_text_segments=emit_text_segments,
+ emit_token_overlap_body=emit_token_overlap_body,
)
# ── Generation prompt ───────────────────────────────────────
if add_generation_prompt:
- emit_special(self._role, -1, is_sampled=False)
- emit_text("ai\n", -1, is_sampled=False)
- emit_special(self._think, -1, is_sampled=False)
- emit_text("\n", -1, is_sampled=False)
+ emit_special(self._role, -1, is_sampled=False, is_content=False)
+ emit_text("ai\n", -1, is_sampled=False, is_content=False)
+ emit_special(self._think, -1, is_sampled=False, is_content=False)
+ emit_text("\n", -1, is_sampled=False, is_content=False)
return RenderedTokens(
token_ids=tokens,
message_indices=indices,
sampled_mask=sampled,
+ is_content=content_mask,
message_roles=[m.get("role") or "" for m in messages],
)
@@ -282,27 +350,75 @@ def bridge_to_next_turn(
ext: list[int] = []
ext_indices: list[int] = []
ext_sampled: list[bool] = []
+ ext_content: list[bool] = []
# Bridge populates ``message_indices`` (relative to ``new_messages``)
# and ``sampled_mask`` (uniformly ``False`` — every token the
# bridge emits is template scaffolding for the next prompt, not
- # something the model sampled). Downstream consumers can run
- # :meth:`RenderedTokens.tokens_per_message` on the bridge output
- # to get per-new-message token counts without re-rendering.
+ # something the model sampled). ``is_content`` follows the same
+ # rules as in :meth:`render` so consumers can walk the trajectory
+ # and read each step's own body mask.
def emit_special(
- token_id: int, msg_idx: int = -1, *, is_sampled: bool = False
+ token_id: int,
+ msg_idx: int = -1,
+ *,
+ is_sampled: bool = False,
+ is_content: bool = False,
) -> None:
ext.append(token_id)
ext_indices.append(msg_idx)
ext_sampled.append(is_sampled)
+ ext_content.append(is_content)
def emit_text(
- text: str, msg_idx: int = -1, *, is_sampled: bool = False
+ text: str,
+ msg_idx: int = -1,
+ *,
+ is_sampled: bool = False,
+ is_content: bool = False,
) -> None:
ids = self._encode(text)
ext.extend(ids)
ext_indices.extend([msg_idx] * len(ids))
ext_sampled.extend([is_sampled] * len(ids))
+ ext_content.extend([is_content] * len(ids))
+
+ def emit_text_segments(
+ segments: list[tuple[str, bool]],
+ msg_idx: int = -1,
+ *,
+ is_sampled: bool = False,
+ ) -> None:
+ for tok_id, is_content in attribute_text_segments(
+ self._tokenizer, segments
+ ):
+ ext.append(tok_id)
+ ext_indices.append(msg_idx)
+ ext_sampled.append(is_sampled)
+ ext_content.append(is_content)
+
+ def emit_token_overlap_body(
+ full_text: str,
+ body_start: int,
+ body_end: int,
+ msg_idx: int,
+ *,
+ is_sampled: bool,
+ ) -> None:
+ from renderers.base import _get_offset_tokenizer
+
+ offset_tok = _get_offset_tokenizer(self._tokenizer)
+ encoding = offset_tok(
+ full_text, add_special_tokens=False, return_offsets_mapping=True
+ )
+ for tok_id, (start, end) in zip(
+ encoding["input_ids"], encoding["offset_mapping"]
+ ):
+ overlaps = start < body_end and end > body_start
+ ext.append(tok_id)
+ ext_indices.append(msg_idx)
+ ext_sampled.append(is_sampled)
+ ext_content.append(overlaps)
# Trailing ``\n`` after the ``[e~[`` turn close — see ``render()``.
emit_text("\n", -1)
@@ -312,12 +428,18 @@ def emit_text(
content = self._visible_text(msg.get("content"))
if role == "user":
emit_special(self._role, i)
- emit_text("user\n" + content, i)
+ user_segments: list[tuple[str, bool]] = [("user\n", False)]
+ if content:
+ user_segments.append((content, True))
+ emit_text_segments(user_segments, i)
emit_special(self._eos, i)
emit_text("\n", i)
elif role == "system":
emit_special(self._role, i)
- emit_text("system\n" + content, i)
+ sys_segments: list[tuple[str, bool]] = [("system\n", False)]
+ if content:
+ sys_segments.append((content, True))
+ emit_text_segments(sys_segments, i)
emit_special(self._eos, i)
emit_text("\n", i)
elif role == "tool":
@@ -328,6 +450,8 @@ def emit_text(
msg,
emit_special=emit_special,
emit_text=emit_text,
+ emit_text_segments=emit_text_segments,
+ emit_token_overlap_body=emit_token_overlap_body,
)
else:
return None
@@ -343,6 +467,7 @@ def emit_text(
token_ids=previous_ids + ext,
message_indices=[-1] * len(previous_ids) + ext_indices,
sampled_mask=[False] * total_len,
+ is_content=[False] * len(previous_ids) + ext_content,
message_roles=[m.get("role") or "" for m in new_messages],
)
@@ -356,6 +481,7 @@ def _render_assistant(
preserve_thinking: bool = False,
emit_special,
emit_text,
+ emit_text_segments,
):
content = self._visible_text(msg.get("content"))
@@ -374,12 +500,14 @@ def _render_assistant(
# the chat template emits these as the generation prompt and the
# model never samples them. Marking the role marker and tag as
# ``is_sampled=False`` keeps the SFT loss mask aligned with what
- # the model would actually have produced.
- emit_special(self._role, orig_idx, is_sampled=False)
+ # the model would actually have produced. ``is_content`` is also
+ # False here — the role tag isn't part of any message's body.
+ emit_special(self._role, orig_idx, is_sampled=False, is_content=False)
# Build the model-sampled portion (think block + content + tool
- # calls). Text segments stay contiguous within each is_sampled
- # span to preserve BPE merges.
+ # calls). For assistant messages the invariant
+ # ``is_content == sampled_mask`` holds — every sampled token is
+ # body, every scaffold token isn't.
tool_calls = msg.get("tool_calls") or []
emit_thinking = reasoning_content and (
conv_idx > last_user_index or preserve_thinking
@@ -390,10 +518,15 @@ def _render_assistant(
# immediately after ``ai\n``, which forces a tokenizer
# boundary — splitting ``ai\n`` (not_sampled) from the
# ````-led body (sampled) is BPE-safe.
- emit_text("ai\n", orig_idx, is_sampled=False)
- emit_special(self._think, orig_idx, is_sampled=True)
- emit_text("\n" + reasoning_content + "\n", orig_idx, is_sampled=True)
- emit_special(self._think_end, orig_idx, is_sampled=True)
+ emit_text("ai\n", orig_idx, is_sampled=False, is_content=False)
+ emit_special(self._think, orig_idx, is_sampled=True, is_content=True)
+ emit_text(
+ "\n" + reasoning_content + "\n",
+ orig_idx,
+ is_sampled=True,
+ is_content=True,
+ )
+ emit_special(self._think_end, orig_idx, is_sampled=True, is_content=True)
# \n\n + content must be contiguous for BPE
body = "\n\n" + content if content else "\n\n"
else:
@@ -406,17 +539,19 @@ def _render_assistant(
# choice for SFT (don't train on a token whose first byte
# is template scaffolding).
if tool_calls and not body:
- emit_text("ai\n\n", orig_idx, is_sampled=False)
+ emit_text("ai\n\n", orig_idx, is_sampled=False, is_content=False)
else:
- emit_text("ai\n", orig_idx, is_sampled=False)
+ emit_text("ai\n", orig_idx, is_sampled=False, is_content=False)
if tool_calls:
# \n before must be contiguous with preceding text.
# The empty-body / non-thinking case folded the leading \n
# into the role-tag emission above; skip it here.
if emit_thinking or body:
- emit_text(body + "\n", orig_idx, is_sampled=True)
- emit_special(self._tool_call_tok, orig_idx, is_sampled=True)
+ emit_text(body + "\n", orig_idx, is_sampled=True, is_content=True)
+ emit_special(
+ self._tool_call_tok, orig_idx, is_sampled=True, is_content=True
+ )
invoke_block = "\n"
for tc in tool_calls:
@@ -448,16 +583,18 @@ def _render_assistant(
)
invoke_block += "\n"
- emit_text(invoke_block, orig_idx, is_sampled=True)
- emit_special(self._tool_call_end_tok, orig_idx, is_sampled=True)
+ emit_text(invoke_block, orig_idx, is_sampled=True, is_content=True)
+ emit_special(
+ self._tool_call_end_tok, orig_idx, is_sampled=True, is_content=True
+ )
elif body:
- emit_text(body, orig_idx, is_sampled=True)
+ emit_text(body, orig_idx, is_sampled=True, is_content=True)
# ``[e~[`` is the model's stop signal — it samples this to end
# its turn, so it is part of the sampled stream. The trailing
# ``\n`` is template-appended between turns and never sampled.
- emit_special(self._eos, orig_idx, is_sampled=True)
- emit_text("\n", orig_idx, is_sampled=False)
+ emit_special(self._eos, orig_idx, is_sampled=True, is_content=True)
+ emit_text("\n", orig_idx, is_sampled=False, is_content=False)
def _render_tool(
self,
@@ -468,10 +605,15 @@ def _render_tool(
*,
emit_special,
emit_text,
+ emit_text_segments,
+ emit_token_overlap_body=None,
) -> None:
# Tool messages are conversation history injected by the runtime
# between assistant turns — the model never samples any of these
- # tokens, so every emission is is_sampled=False.
+ # tokens, so every emission is is_sampled=False. The ``content``
+ # body bytes get ``is_content=True``; the surrounding ````
+ # wrap, role tag and separators are scaffold so an SFT mask over
+ # tool body never trains the model to emit those.
prev_is_tool = conv_idx > 0 and conversation[conv_idx - 1]["role"] == "tool"
next_is_tool = (
conv_idx + 1 < len(conversation)
@@ -479,8 +621,8 @@ def _render_tool(
)
if not prev_is_tool:
- emit_special(self._role, orig_idx, is_sampled=False)
- emit_text("tool", orig_idx, is_sampled=False)
+ emit_special(self._role, orig_idx, is_sampled=False, is_content=False)
+ emit_text("tool", orig_idx, is_sampled=False, is_content=False)
content = self._visible_text(msg.get("content"))
# Leading ``\n`` before ```` only on the first of a
@@ -489,12 +631,39 @@ def _render_tool(
# through a single emit_text call instead of splitting the merge.
prefix = "" if prev_is_tool else "\n"
suffix = "\n" if next_is_tool else ""
- emit_text(
- prefix + "" + content + "" + suffix,
- orig_idx,
- is_sampled=False,
- )
+
+ # ```` is plain text with no separator between the
+ # closing ``>`` and ``content``'s first byte, so BPE can merge
+ # them into a single token (e.g., ``>The``). The shared
+ # ``attribute_text_segments`` helper picks the segment of a
+ # boundary-spanning token by its *first* char (here scaffold),
+ # which would drop the body's leading letter out of the body
+ # run. We instead use an "intersects body" rule: any token whose
+ # ``[start, end)`` char range overlaps the body span gets
+ # ``is_content=True``. A few scaffold bytes (the leading ``>``
+ # or trailing ``<``) bleed into the body run, but body bytes are
+ # recoverable as a substring of the decoded body span.
+ body_text = prefix + "" + content + "" + suffix
+ body_start = len(prefix) + len("")
+ body_end = body_start + len(content)
+ if content and emit_token_overlap_body is not None:
+ emit_token_overlap_body(
+ body_text, body_start, body_end, orig_idx, is_sampled=False
+ )
+ else:
+ # Empty body or no overlap-aware emitter available — fall back
+ # to the standard segments path.
+ tool_segments: list[tuple[str, bool]] = []
+ if prefix:
+ tool_segments.append((prefix, False))
+ tool_segments.append(("", False))
+ if content:
+ tool_segments.append((content, True))
+ tool_segments.append(("", False))
+ if suffix:
+ tool_segments.append((suffix, False))
+ emit_text_segments(tool_segments, orig_idx, is_sampled=False)
if not next_is_tool:
- emit_special(self._eos, orig_idx, is_sampled=False)
- emit_text("\n", orig_idx, is_sampled=False)
+ emit_special(self._eos, orig_idx, is_sampled=False, is_content=False)
+ emit_text("\n", orig_idx, is_sampled=False, is_content=False)
diff --git a/renderers/nemotron3.py b/renderers/nemotron3.py
index e17e7ee..e97790d 100644
--- a/renderers/nemotron3.py
+++ b/renderers/nemotron3.py
@@ -24,6 +24,7 @@
ParsedResponse,
RenderedTokens,
ToolSpec,
+ attribute_text_segments,
reject_assistant_in_extension,
should_preserve_past_thinking,
trim_to_turn_close,
@@ -241,19 +242,44 @@ def orig_idx(i: int) -> int:
tokens: list[int] = []
indices: list[int] = []
sampled: list[bool] = []
+ content_mask: list[bool] = []
- def emit_ids(ids: list[int], msg_idx: int, *, is_sampled: bool) -> None:
+ def emit_ids(
+ ids: list[int], msg_idx: int, *, is_sampled: bool, is_content: bool
+ ) -> None:
tokens.extend(ids)
indices.extend([msg_idx] * len(ids))
sampled.extend([is_sampled] * len(ids))
+ content_mask.extend([is_content] * len(ids))
- def emit_special(token_id: int, msg_idx: int, *, is_sampled: bool) -> None:
+ def emit_special(
+ token_id: int, msg_idx: int, *, is_sampled: bool, is_content: bool
+ ) -> None:
tokens.append(token_id)
indices.append(msg_idx)
sampled.append(is_sampled)
+ content_mask.append(is_content)
+
+ def emit_text(
+ text: str, msg_idx: int, *, is_sampled: bool, is_content: bool
+ ) -> None:
+ emit_ids(
+ self._encode(text),
+ msg_idx,
+ is_sampled=is_sampled,
+ is_content=is_content,
+ )
- def emit_text(text: str, msg_idx: int, *, is_sampled: bool) -> None:
- emit_ids(self._encode(text), msg_idx, is_sampled=is_sampled)
+ def emit_text_segments(
+ segments: list[tuple[str, bool]], msg_idx: int, *, is_sampled: bool
+ ) -> None:
+ for tok_id, is_content in attribute_text_segments(
+ self._tokenizer, segments
+ ):
+ tokens.append(tok_id)
+ indices.append(msg_idx)
+ sampled.append(is_sampled)
+ content_mask.append(is_content)
# ── 1. System message + optional tools ──────────────────────
first_is_system = messages[0].get("role") == "system"
@@ -262,8 +288,7 @@ def emit_text(text: str, msg_idx: int, *, is_sampled: bool) -> None:
# Nemotron 3: system prompt BEFORE tools block
sys_idx = orig_idx(0) if first_is_system else -1
- emit_special(self._im_start, sys_idx, is_sampled=False)
- emit_text("system\n", sys_idx, is_sampled=False)
+ emit_special(self._im_start, sys_idx, is_sampled=False, is_content=False)
# Build system content: user's system text first, then tools
if first_is_system:
@@ -282,22 +307,27 @@ def emit_text(text: str, msg_idx: int, *, is_sampled: bool) -> None:
+ _TOOLS_INSTRUCTIONS
)
+ # Body = caller's system text only; tools block (header, per-
+ # tool XML, footer, instructions) is scaffold.
+ sys_segments: list[tuple[str, bool]] = [("system\n", False)]
if sys_content:
- full_sys = sys_content + "\n\n" + tools_block
- else:
- full_sys = tools_block
-
- emit_text(full_sys, sys_idx, is_sampled=False)
- emit_special(self._im_end, sys_idx, is_sampled=False)
- emit_text("\n", sys_idx, is_sampled=False)
+ sys_segments.append((sys_content, True))
+ sys_segments.append(("\n\n", False))
+ sys_segments.append((tools_block, False))
+ emit_text_segments(sys_segments, sys_idx, is_sampled=False)
+ emit_special(self._im_end, sys_idx, is_sampled=False, is_content=False)
+ emit_text("\n", sys_idx, is_sampled=False, is_content=False)
elif first_is_system:
sys_idx = orig_idx(0)
sys_content = self._render_content(messages[0].get("content")).strip()
- emit_special(self._im_start, sys_idx, is_sampled=False)
- emit_text("system\n" + sys_content, sys_idx, is_sampled=False)
- emit_special(self._im_end, sys_idx, is_sampled=False)
- emit_text("\n", sys_idx, is_sampled=False)
+ emit_special(self._im_start, sys_idx, is_sampled=False, is_content=False)
+ sys_segments2: list[tuple[str, bool]] = [("system\n", False)]
+ if sys_content:
+ sys_segments2.append((sys_content, True))
+ emit_text_segments(sys_segments2, sys_idx, is_sampled=False)
+ emit_special(self._im_end, sys_idx, is_sampled=False, is_content=False)
+ emit_text("\n", sys_idx, is_sampled=False, is_content=False)
# Track the most-recent plain (non-tool-call) assistant so we can
# preserve its reasoning while stripping reasoning from earlier
@@ -322,10 +352,17 @@ def emit_text(text: str, msg_idx: int, *, is_sampled: bool) -> None:
continue # Already handled above
elif role == "user":
- emit_special(self._im_start, msg_orig_idx, is_sampled=False)
- emit_text("user\n" + content, msg_orig_idx, is_sampled=False)
- emit_special(self._im_end, msg_orig_idx, is_sampled=False)
- emit_text("\n", msg_orig_idx, is_sampled=False)
+ emit_special(
+ self._im_start, msg_orig_idx, is_sampled=False, is_content=False
+ )
+ user_segments: list[tuple[str, bool]] = [("user\n", False)]
+ if content:
+ user_segments.append((content, True))
+ emit_text_segments(user_segments, msg_orig_idx, is_sampled=False)
+ emit_special(
+ self._im_end, msg_orig_idx, is_sampled=False, is_content=False
+ )
+ emit_text("\n", msg_orig_idx, is_sampled=False, is_content=False)
elif role == "assistant":
is_last_turn = i >= last_plain_assistant_idx
@@ -344,6 +381,7 @@ def emit_text(text: str, msg_idx: int, *, is_sampled: bool) -> None:
emit_special=emit_special,
emit_text=emit_text,
emit_ids=emit_ids,
+ emit_text_segments=emit_text_segments,
)
elif role == "tool":
@@ -355,6 +393,7 @@ def emit_text(text: str, msg_idx: int, *, is_sampled: bool) -> None:
auto_system_injected=auto_system_injected,
emit_special=emit_special,
emit_text=emit_text,
+ emit_text_segments=emit_text_segments,
)
else:
@@ -362,21 +401,22 @@ def emit_text(text: str, msg_idx: int, *, is_sampled: bool) -> None:
# ── 3. Generation prompt ────────────────────────────────────
if add_generation_prompt:
- emit_special(self._im_start, -1, is_sampled=False)
- emit_text("assistant\n", -1, is_sampled=False)
+ emit_special(self._im_start, -1, is_sampled=False, is_content=False)
+ emit_text("assistant\n", -1, is_sampled=False, is_content=False)
if self._enable_thinking:
- emit_special(self._think, -1, is_sampled=False)
- emit_text("\n", -1, is_sampled=False)
+ emit_special(self._think, -1, is_sampled=False, is_content=False)
+ emit_text("\n", -1, is_sampled=False, is_content=False)
else:
# Disable-thinking suffix: with no trailing newlines
- emit_special(self._think, -1, is_sampled=False)
- emit_special(self._think_end, -1, is_sampled=False)
+ emit_special(self._think, -1, is_sampled=False, is_content=False)
+ emit_special(self._think_end, -1, is_sampled=False, is_content=False)
return RenderedTokens(
token_ids=tokens,
message_indices=indices,
sampled_mask=sampled,
- message_roles=[m.get("role") or "" for m in messages],
+ is_content=content_mask,
+ message_roles=[m.get("role") or "" for m in original_messages],
)
def render_ids(
@@ -447,27 +487,52 @@ def bridge_to_next_turn(
ext: list[int] = []
ext_indices: list[int] = []
ext_sampled: list[bool] = []
+ ext_content: list[bool] = []
# Bridge populates ``message_indices`` (relative to ``new_messages``)
# and ``sampled_mask`` (uniformly ``False`` — every token the
# bridge emits is template scaffolding for the next prompt, not
- # something the model sampled). Downstream consumers can run
- # :meth:`RenderedTokens.tokens_per_message` on the bridge output
- # to get per-new-message token counts without re-rendering.
+ # something the model sampled). ``is_content`` follows the same
+ # rules as in :meth:`render` so consumers can walk the trajectory
+ # and read each step's own body mask.
def emit_special(
- token_id: int, msg_idx: int = -1, *, is_sampled: bool = False
+ token_id: int,
+ msg_idx: int = -1,
+ *,
+ is_sampled: bool = False,
+ is_content: bool = False,
) -> None:
ext.append(token_id)
ext_indices.append(msg_idx)
ext_sampled.append(is_sampled)
+ ext_content.append(is_content)
def emit_text(
- text: str, msg_idx: int = -1, *, is_sampled: bool = False
+ text: str,
+ msg_idx: int = -1,
+ *,
+ is_sampled: bool = False,
+ is_content: bool = False,
) -> None:
ids = self._encode(text)
ext.extend(ids)
ext_indices.extend([msg_idx] * len(ids))
ext_sampled.extend([is_sampled] * len(ids))
+ ext_content.extend([is_content] * len(ids))
+
+ def emit_text_segments(
+ segments: list[tuple[str, bool]],
+ msg_idx: int = -1,
+ *,
+ is_sampled: bool = False,
+ ) -> None:
+ for tok_id, is_content in attribute_text_segments(
+ self._tokenizer, segments
+ ):
+ ext.append(tok_id)
+ ext_indices.append(msg_idx)
+ ext_sampled.append(is_sampled)
+ ext_content.append(is_content)
emit_text("\n", -1)
@@ -476,12 +541,18 @@ def emit_text(
content = self._render_content(msg.get("content")).strip()
if role == "user":
emit_special(self._im_start, i)
- emit_text("user\n" + content, i)
+ user_segments: list[tuple[str, bool]] = [("user\n", False)]
+ if content:
+ user_segments.append((content, True))
+ emit_text_segments(user_segments, i)
emit_special(self._im_end, i)
emit_text("\n", i)
elif role == "system":
emit_special(self._im_start, i)
- emit_text("system\n" + content, i)
+ sys_segments: list[tuple[str, bool]] = [("system\n", False)]
+ if content:
+ sys_segments.append((content, True))
+ emit_text_segments(sys_segments, i)
emit_special(self._im_end, i)
emit_text("\n", i)
elif role == "tool":
@@ -493,6 +564,7 @@ def emit_text(
auto_system_injected=False,
emit_special=emit_special,
emit_text=emit_text,
+ emit_text_segments=emit_text_segments,
)
else:
return None
@@ -512,6 +584,7 @@ def emit_text(
token_ids=previous_ids + ext,
message_indices=[-1] * len(previous_ids) + ext_indices,
sampled_mask=[False] * total_len,
+ is_content=[False] * len(previous_ids) + ext_content,
message_roles=[m.get("role") or "" for m in new_messages],
)
@@ -530,6 +603,7 @@ def _render_assistant(
emit_special,
emit_text,
emit_ids,
+ emit_text_segments,
) -> None:
# Extract reasoning_content
reasoning_content = ""
@@ -550,9 +624,10 @@ def _render_assistant(
# at inference the chat template emits these as the generation
# prompt and the model never samples them. Marking the role tag
# as ``is_sampled=False`` keeps the SFT loss mask aligned with
- # what the model would actually have produced.
- emit_special(self._im_start, msg_idx, is_sampled=False)
- emit_text("assistant\n", msg_idx, is_sampled=False)
+ # what the model would actually have produced. On assistant the
+ # invariant ``is_content == sampled_mask`` holds.
+ emit_special(self._im_start, msg_idx, is_sampled=False, is_content=False)
+ emit_text("assistant\n", msg_idx, is_sampled=False, is_content=False)
# Nemotron 3 keeps reasoning on the most-recent plain assistant but
# strips it from historical turns, which collapse to an empty
@@ -567,23 +642,43 @@ def _render_assistant(
content_suffix = "\n" if tool_calls else ""
if reasoning_content and (is_last_turn or preserve_thinking):
- emit_special(self._think, msg_idx, is_sampled=True)
- emit_text("\n" + reasoning_content + "\n", msg_idx, is_sampled=True)
- emit_special(self._think_end, msg_idx, is_sampled=True)
+ emit_special(self._think, msg_idx, is_sampled=True, is_content=True)
+ emit_text(
+ "\n" + reasoning_content + "\n",
+ msg_idx,
+ is_sampled=True,
+ is_content=True,
+ )
+ emit_special(self._think_end, msg_idx, is_sampled=True, is_content=True)
# Single \n separator (not \n\n like Qwen3.5)
- emit_text("\n" + content + content_suffix, msg_idx, is_sampled=True)
+ emit_text(
+ "\n" + content + content_suffix,
+ msg_idx,
+ is_sampled=True,
+ is_content=True,
+ )
elif reasoning_content:
# Historical assistant whose reasoning got stripped — template
# keeps a single \n between the collapsed and
# the content as a marker that reasoning existed.
- emit_special(self._think, msg_idx, is_sampled=True)
- emit_special(self._think_end, msg_idx, is_sampled=True)
- emit_text("\n" + content + content_suffix, msg_idx, is_sampled=True)
+ emit_special(self._think, msg_idx, is_sampled=True, is_content=True)
+ emit_special(self._think_end, msg_idx, is_sampled=True, is_content=True)
+ emit_text(
+ "\n" + content + content_suffix,
+ msg_idx,
+ is_sampled=True,
+ is_content=True,
+ )
else:
# No reasoning ever — glued directly to content.
- emit_special(self._think, msg_idx, is_sampled=True)
- emit_special(self._think_end, msg_idx, is_sampled=True)
- emit_text(content + content_suffix, msg_idx, is_sampled=True)
+ emit_special(self._think, msg_idx, is_sampled=True, is_content=True)
+ emit_special(self._think_end, msg_idx, is_sampled=True, is_content=True)
+ emit_text(
+ content + content_suffix,
+ msg_idx,
+ is_sampled=True,
+ is_content=True,
+ )
# Tool calls (leading \n was glued to the content above; each
# iteration's trailing \n after handles the
@@ -594,8 +689,13 @@ def _render_assistant(
name = func.get("name", "")
arguments = func.get("arguments", {})
- emit_special(self._tool_call, msg_idx, is_sampled=True)
- emit_text("\n\n", msg_idx, is_sampled=True)
+ emit_special(self._tool_call, msg_idx, is_sampled=True, is_content=True)
+ emit_text(
+ "\n\n",
+ msg_idx,
+ is_sampled=True,
+ is_content=True,
+ )
# Render arguments
# OpenAI canonical form: arguments is a JSON string. Parse it so the
@@ -619,18 +719,21 @@ def _render_assistant(
+ "\n\n",
msg_idx,
is_sampled=True,
+ is_content=True,
)
- emit_text("\n", msg_idx, is_sampled=True)
- emit_special(self._tool_call_end, msg_idx, is_sampled=True)
+ emit_text("\n", msg_idx, is_sampled=True, is_content=True)
+ emit_special(
+ self._tool_call_end, msg_idx, is_sampled=True, is_content=True
+ )
# Trailing \n after (Nemotron 3 specific)
- emit_text("\n", msg_idx, is_sampled=True)
+ emit_text("\n", msg_idx, is_sampled=True, is_content=True)
# ``<|im_end|>`` is the model's stop signal — it samples this to
# end its turn, so it is part of the sampled stream. The trailing
# ``\n`` is template-appended between turns and never sampled.
- emit_special(self._im_end, msg_idx, is_sampled=True)
- emit_text("\n", msg_idx, is_sampled=False)
+ emit_special(self._im_end, msg_idx, is_sampled=True, is_content=True)
+ emit_text("\n", msg_idx, is_sampled=False, is_content=False)
# ------------------------------------------------------------------
# Tool message rendering
@@ -646,10 +749,13 @@ def _render_tool(
auto_system_injected: bool,
emit_special,
emit_text,
+ emit_text_segments,
) -> None:
# Tool messages are conversation history injected by the runtime
# between assistant turns — the model never samples any of these
- # tokens, so every emission is is_sampled=False.
+ # tokens, so every emission is is_sampled=False. The ``content``
+ # body bytes get ``is_content=True``; the surrounding wrap is
+ # scaffold.
prev_is_tool = msg_idx > 0 and messages[msg_idx - 1]["role"] == "tool"
next_is_tool = (
msg_idx + 1 < len(messages) and messages[msg_idx + 1]["role"] == "tool"
@@ -657,17 +763,19 @@ def _render_tool(
oi = msg_orig_idx
if not prev_is_tool:
- emit_special(self._im_start, oi, is_sampled=False)
- emit_text("user\n", oi, is_sampled=False)
+ emit_special(self._im_start, oi, is_sampled=False, is_content=False)
+ emit_text("user\n", oi, is_sampled=False, is_content=False)
# else: the previous tool's trailing \n already provides the
# separator into this block.
- emit_special(self._tool_response, oi, is_sampled=False)
- emit_text("\n" + content + "\n", oi, is_sampled=False)
- emit_special(self._tool_response_end, oi, is_sampled=False)
+ emit_special(self._tool_response, oi, is_sampled=False, is_content=False)
+ emit_text_segments(
+ [("\n", False), (content, True), ("\n", False)], oi, is_sampled=False
+ )
+ emit_special(self._tool_response_end, oi, is_sampled=False, is_content=False)
# Nemotron 3: trailing \n after
- emit_text("\n", oi, is_sampled=False)
+ emit_text("\n", oi, is_sampled=False, is_content=False)
if not next_is_tool:
- emit_special(self._im_end, oi, is_sampled=False)
- emit_text("\n", oi, is_sampled=False)
+ emit_special(self._im_end, oi, is_sampled=False, is_content=False)
+ emit_text("\n", oi, is_sampled=False, is_content=False)
diff --git a/renderers/qwen3.py b/renderers/qwen3.py
index 8f5f17a..4562546 100644
--- a/renderers/qwen3.py
+++ b/renderers/qwen3.py
@@ -18,6 +18,7 @@
ParsedResponse,
RenderedTokens,
ToolSpec,
+ attribute_text_segments,
reject_assistant_in_extension,
should_preserve_past_thinking,
trim_to_turn_close,
@@ -108,43 +109,76 @@ def render(
tokens: list[int] = []
indices: list[int] = []
sampled: list[bool] = []
+ content_mask: list[bool] = []
- def emit_ids(ids: list[int], msg_idx: int, *, is_sampled: bool) -> None:
- tokens.extend(ids)
- indices.extend([msg_idx] * len(ids))
- sampled.extend([is_sampled] * len(ids))
-
- def emit_special(token_id: int, msg_idx: int, *, is_sampled: bool) -> None:
+ def emit_special(
+ token_id: int, msg_idx: int, *, is_sampled: bool, is_content: bool
+ ) -> None:
tokens.append(token_id)
indices.append(msg_idx)
sampled.append(is_sampled)
+ content_mask.append(is_content)
- def emit_text(text: str, msg_idx: int, *, is_sampled: bool) -> None:
- emit_ids(self._encode(text), msg_idx, is_sampled=is_sampled)
+ def emit_text(
+ text: str, msg_idx: int, *, is_sampled: bool, is_content: bool
+ ) -> None:
+ ids = self._encode(text)
+ tokens.extend(ids)
+ indices.extend([msg_idx] * len(ids))
+ sampled.extend([is_sampled] * len(ids))
+ content_mask.extend([is_content] * len(ids))
+
+ def emit_text_segments(
+ segments: list[tuple[str, bool]], msg_idx: int, *, is_sampled: bool
+ ) -> None:
+ """Tokenize concatenated segments as one BPE pass; per-token
+ ``is_content`` follows each token's source segment.
+
+ Lets call sites express "this wrap + this body, joined the
+ same way as the chat template, but attributed separately"
+ without splitting the encode call (which could shift BPE
+ merges at the boundary)."""
+ for tok_id, is_content in attribute_text_segments(
+ self._tokenizer, segments
+ ):
+ tokens.append(tok_id)
+ indices.append(msg_idx)
+ sampled.append(is_sampled)
+ content_mask.append(is_content)
# ── 1. System + tools ───────────────────────────────────────
first_is_system = messages[0].get("role") == "system"
if tools:
sys_idx = 0 if first_is_system else -1
- emit_special(self._im_start, sys_idx, is_sampled=False)
- tool_text = "system\n"
+ emit_special(self._im_start, sys_idx, is_sampled=False, is_content=False)
+ # Body = system content (if any). Everything else in this
+ # block — role tag, tools header / footer, the JSON tool
+ # specs — is scaffold. The tools dict is recoverable from
+ # the ``tools`` argument; don't re-attribute its embedded
+ # JSON as message body.
+ segments: list[tuple[str, bool]] = [("system\n", False)]
if first_is_system:
- tool_text += (messages[0].get("content") or "") + "\n\n"
- tool_text += _TOOLS_HEADER
+ sys_content = messages[0].get("content") or ""
+ if sys_content:
+ segments.append((sys_content, True))
+ segments.append(("\n\n", False))
+ segments.append((_TOOLS_HEADER, False))
for tool in tools:
- tool_text += "\n" + json.dumps(tool, ensure_ascii=False)
- tool_text += _TOOLS_FOOTER
- emit_text(tool_text, sys_idx, is_sampled=False)
- emit_special(self._im_end, sys_idx, is_sampled=False)
- emit_text("\n", sys_idx, is_sampled=False)
+ segments.append(("\n" + json.dumps(tool, ensure_ascii=False), False))
+ segments.append((_TOOLS_FOOTER, False))
+ emit_text_segments(segments, sys_idx, is_sampled=False)
+ emit_special(self._im_end, sys_idx, is_sampled=False, is_content=False)
+ emit_text("\n", sys_idx, is_sampled=False, is_content=False)
elif first_is_system:
- emit_special(self._im_start, 0, is_sampled=False)
- emit_text(
- "system\n" + (messages[0].get("content") or ""), 0, is_sampled=False
- )
- emit_special(self._im_end, 0, is_sampled=False)
- emit_text("\n", 0, is_sampled=False)
+ emit_special(self._im_start, 0, is_sampled=False, is_content=False)
+ sys_content = messages[0].get("content") or ""
+ sys_segments: list[tuple[str, bool]] = [("system\n", False)]
+ if sys_content:
+ sys_segments.append((sys_content, True))
+ emit_text_segments(sys_segments, 0, is_sampled=False)
+ emit_special(self._im_end, 0, is_sampled=False, is_content=False)
+ emit_text("\n", 0, is_sampled=False, is_content=False)
# ── 2. Compute last_query_index ─────────────────────────────
last_qi = self._last_query_index(messages)
@@ -158,16 +192,22 @@ def emit_text(text: str, msg_idx: int, *, is_sampled: bool) -> None:
if role == "system":
if i == 0:
continue
- emit_special(self._im_start, i, is_sampled=False)
- emit_text(role + "\n" + content, i, is_sampled=False)
- emit_special(self._im_end, i, is_sampled=False)
- emit_text("\n", i, is_sampled=False)
+ emit_special(self._im_start, i, is_sampled=False, is_content=False)
+ msg_segments: list[tuple[str, bool]] = [(role + "\n", False)]
+ if content:
+ msg_segments.append((content, True))
+ emit_text_segments(msg_segments, i, is_sampled=False)
+ emit_special(self._im_end, i, is_sampled=False, is_content=False)
+ emit_text("\n", i, is_sampled=False, is_content=False)
elif role == "user":
- emit_special(self._im_start, i, is_sampled=False)
- emit_text("user\n" + content, i, is_sampled=False)
- emit_special(self._im_end, i, is_sampled=False)
- emit_text("\n", i, is_sampled=False)
+ emit_special(self._im_start, i, is_sampled=False, is_content=False)
+ user_segments: list[tuple[str, bool]] = [("user\n", False)]
+ if content:
+ user_segments.append((content, True))
+ emit_text_segments(user_segments, i, is_sampled=False)
+ emit_special(self._im_end, i, is_sampled=False, is_content=False)
+ emit_text("\n", i, is_sampled=False, is_content=False)
elif role == "assistant":
preserve_thinking = should_preserve_past_thinking(
@@ -185,24 +225,33 @@ def emit_text(text: str, msg_idx: int, *, is_sampled: bool) -> None:
preserve_thinking=preserve_thinking,
emit_special=emit_special,
emit_text=emit_text,
+ emit_text_segments=emit_text_segments,
)
elif role == "tool":
self._render_tool(
- messages, i, content, emit_special=emit_special, emit_text=emit_text
+ messages,
+ i,
+ content,
+ emit_special=emit_special,
+ emit_text=emit_text,
+ emit_text_segments=emit_text_segments,
)
# ── 4. Generation prompt ────────────────────────────────────
if add_generation_prompt:
- emit_special(self._im_start, -1, is_sampled=False)
- emit_text("assistant\n", -1, is_sampled=False)
+ emit_special(self._im_start, -1, is_sampled=False, is_content=False)
+ emit_text("assistant\n", -1, is_sampled=False, is_content=False)
if not self._enable_thinking:
- emit_text("\n\n\n\n", -1, is_sampled=False)
+ emit_text(
+ "\n\n\n\n", -1, is_sampled=False, is_content=False
+ )
return RenderedTokens(
token_ids=tokens,
message_indices=indices,
sampled_mask=sampled,
+ is_content=content_mask,
message_roles=[m.get("role") or "" for m in messages],
)
@@ -263,27 +312,54 @@ def bridge_to_next_turn(
ext: list[int] = []
ext_indices: list[int] = []
ext_sampled: list[bool] = []
+ ext_content: list[bool] = []
# Bridge populates ``message_indices`` (relative to ``new_messages``)
# and ``sampled_mask`` (uniformly ``False`` — every token the
# bridge emits is template scaffolding for the next prompt, not
- # something the model sampled). Downstream consumers can run
- # :meth:`RenderedTokens.tokens_per_message` on the bridge output
- # to get per-new-message token counts without re-rendering.
+ # something the model sampled). ``is_content`` follows the same
+ # rules as in :meth:`render` so consumers can walk the trajectory
+ # and read each step's own body mask. Downstream consumers can
+ # run :meth:`RenderedTokens.tokens_per_message` on the bridge
+ # output to get per-new-message token counts without re-rendering.
def emit_special(
- token_id: int, msg_idx: int = -1, *, is_sampled: bool = False
+ token_id: int,
+ msg_idx: int = -1,
+ *,
+ is_sampled: bool = False,
+ is_content: bool = False,
) -> None:
ext.append(token_id)
ext_indices.append(msg_idx)
ext_sampled.append(is_sampled)
+ ext_content.append(is_content)
def emit_text(
- text: str, msg_idx: int = -1, *, is_sampled: bool = False
+ text: str,
+ msg_idx: int = -1,
+ *,
+ is_sampled: bool = False,
+ is_content: bool = False,
) -> None:
ids = self._encode(text)
ext.extend(ids)
ext_indices.extend([msg_idx] * len(ids))
ext_sampled.extend([is_sampled] * len(ids))
+ ext_content.extend([is_content] * len(ids))
+
+ def emit_text_segments(
+ segments: list[tuple[str, bool]],
+ msg_idx: int = -1,
+ *,
+ is_sampled: bool = False,
+ ) -> None:
+ for tok_id, is_content in attribute_text_segments(
+ self._tokenizer, segments
+ ):
+ ext.append(tok_id)
+ ext_indices.append(msg_idx)
+ ext_sampled.append(is_sampled)
+ ext_content.append(is_content)
# Trailing ``\n`` after the turn-close token. ``render()`` emits this
# as part of the prior turn, but vLLM stops on ``<|im_end|>`` so the
@@ -295,12 +371,18 @@ def emit_text(
content = msg.get("content") if isinstance(msg.get("content"), str) else ""
if role == "user":
emit_special(self._im_start, i)
- emit_text("user\n" + content, i)
+ user_segments: list[tuple[str, bool]] = [("user\n", False)]
+ if content:
+ user_segments.append((content, True))
+ emit_text_segments(user_segments, i)
emit_special(self._im_end, i)
emit_text("\n", i)
elif role == "system":
emit_special(self._im_start, i)
- emit_text("system\n" + content, i)
+ sys_segments: list[tuple[str, bool]] = [("system\n", False)]
+ if content:
+ sys_segments.append((content, True))
+ emit_text_segments(sys_segments, i)
emit_special(self._im_end, i)
emit_text("\n", i)
elif role == "tool":
@@ -310,6 +392,7 @@ def emit_text(
content,
emit_special=emit_special,
emit_text=emit_text,
+ emit_text_segments=emit_text_segments,
)
else:
return None
@@ -324,6 +407,7 @@ def emit_text(
token_ids=previous_ids + ext,
message_indices=[-1] * len(previous_ids) + ext_indices,
sampled_mask=[False] * total_len,
+ is_content=[False] * len(previous_ids) + ext_content,
message_roles=[m.get("role") or "" for m in new_messages],
)
@@ -338,6 +422,7 @@ def _render_assistant(
preserve_thinking: bool = False,
emit_special,
emit_text,
+ emit_text_segments,
):
reasoning_content = ""
if isinstance(msg.get("reasoning_content"), str):
@@ -355,16 +440,20 @@ def _render_assistant(
# at inference the chat template emits these as the generation
# prompt and the model never samples them. Marking the role tag
# as ``is_sampled=False`` keeps the SFT loss mask aligned with
- # what the model would actually have produced.
- emit_special(self._im_start, msg_idx, is_sampled=False)
- emit_text("assistant\n", msg_idx, is_sampled=False)
+ # what the model would actually have produced. ``is_content`` is
+ # also False here — the role tag isn't part of any message's
+ # body, on any role.
+ emit_special(self._im_start, msg_idx, is_sampled=False, is_content=False)
+ emit_text("assistant\n", msg_idx, is_sampled=False, is_content=False)
# Build the model-sampled portion (think block + content + tool
# calls). Text segments stay contiguous within each is_sampled
# span to preserve BPE merges (e.g., ".\n" is a single token in
# Qwen3); the only split we introduce here is at ``\n`` after the
# role tag, which the existing renderer already treats as a
- # token boundary (cf. ``_render_tool``).
+ # token boundary (cf. ``_render_tool``). For assistant messages
+ # the invariant ``is_content == sampled_mask`` holds — every
+ # sampled token is body, every scaffold token isn't.
tool_calls = msg.get("tool_calls") or []
emit_in_template_window = msg_idx > last_query_index and (
@@ -382,7 +471,7 @@ def _render_assistant(
body = content
if not tool_calls:
- emit_text(body, msg_idx, is_sampled=True)
+ emit_text(body, msg_idx, is_sampled=True, is_content=True)
else:
for tc_idx, tc in enumerate(tool_calls):
func = tc.get("function") or tc
@@ -397,23 +486,29 @@ def _render_assistant(
# Text before this tool_call (includes separator)
if tc_idx == 0:
separator = "\n" if content else ""
- emit_text(body + separator, msg_idx, is_sampled=True)
+ emit_text(
+ body + separator, msg_idx, is_sampled=True, is_content=True
+ )
else:
- emit_text("\n", msg_idx, is_sampled=True)
+ emit_text("\n", msg_idx, is_sampled=True, is_content=True)
- emit_special(self._tool_call, msg_idx, is_sampled=True)
+ emit_special(self._tool_call, msg_idx, is_sampled=True, is_content=True)
emit_text(
'\n{"name": "' + name + '", "arguments": ' + args_str + "}\n",
msg_idx,
is_sampled=True,
+ is_content=True,
+ )
+ emit_special(
+ self._tool_call_end, msg_idx, is_sampled=True, is_content=True
)
- emit_special(self._tool_call_end, msg_idx, is_sampled=True)
# ``<|im_end|>`` is the model's stop signal — it samples this to
- # end its turn, so it is part of the sampled stream. The trailing
- # ``\n`` is template-appended between turns and never sampled.
- emit_special(self._im_end, msg_idx, is_sampled=True)
- emit_text("\n", msg_idx, is_sampled=False)
+ # end its turn, so it is part of the sampled stream (and the
+ # assistant's body). The trailing ``\n`` is template-appended
+ # between turns and never sampled — scaffold for is_content too.
+ emit_special(self._im_end, msg_idx, is_sampled=True, is_content=True)
+ emit_text("\n", msg_idx, is_sampled=False, is_content=False)
def _render_tool(
self,
@@ -423,24 +518,40 @@ def _render_tool(
*,
emit_special,
emit_text,
+ emit_text_segments,
) -> None:
# Tool messages are conversation history injected by the runtime
# between assistant turns — the model never samples any of these
- # tokens, so every emission is is_sampled=False.
+ # tokens, so every emission is is_sampled=False. The ``content``
+ # field's body bytes get ``is_content=True``; everything else —
+ # the ``<|im_start|>user`` wrap, the inter-section ``\n``s, the
+ # ``<|tool_response>`` specials — is scaffold so the SFT mask
+ # for tool body never trains the model to emit them.
prev_is_tool = msg_idx > 0 and messages[msg_idx - 1]["role"] == "tool"
next_is_tool = (
msg_idx + 1 < len(messages) and messages[msg_idx + 1]["role"] == "tool"
)
if not prev_is_tool:
- emit_special(self._im_start, msg_idx, is_sampled=False)
- emit_text("user", msg_idx, is_sampled=False)
-
- emit_text("\n", msg_idx, is_sampled=False)
- emit_special(self._tool_response, msg_idx, is_sampled=False)
- emit_text("\n" + content + "\n", msg_idx, is_sampled=False)
- emit_special(self._tool_response_end, msg_idx, is_sampled=False)
+ emit_special(self._im_start, msg_idx, is_sampled=False, is_content=False)
+ emit_text("user", msg_idx, is_sampled=False, is_content=False)
+
+ emit_text("\n", msg_idx, is_sampled=False, is_content=False)
+ emit_special(self._tool_response, msg_idx, is_sampled=False, is_content=False)
+ # ``\n`` + content + ``\n`` — body is the middle segment only.
+ # Single BPE pass over the joined text preserves boundary
+ # merges (Qwen3 keeps ``\n`` as its own token, so this is
+ # mostly a no-op, but we route through segments anyway so the
+ # attribution doesn't depend on tokenizer-specific behaviour).
+ emit_text_segments(
+ [("\n", False), (content, True), ("\n", False)],
+ msg_idx,
+ is_sampled=False,
+ )
+ emit_special(
+ self._tool_response_end, msg_idx, is_sampled=False, is_content=False
+ )
if not next_is_tool:
- emit_special(self._im_end, msg_idx, is_sampled=False)
- emit_text("\n", msg_idx, is_sampled=False)
+ emit_special(self._im_end, msg_idx, is_sampled=False, is_content=False)
+ emit_text("\n", msg_idx, is_sampled=False, is_content=False)
diff --git a/renderers/qwen35.py b/renderers/qwen35.py
index e2c84c6..2deefcf 100644
--- a/renderers/qwen35.py
+++ b/renderers/qwen35.py
@@ -26,6 +26,7 @@
PlaceholderRange,
RenderedTokens,
ToolSpec,
+ attribute_text_segments,
reject_assistant_in_extension,
should_preserve_past_thinking,
trim_to_turn_close,
@@ -291,33 +292,73 @@ def render(
tokens: list[int] = []
indices: list[int] = []
sampled: list[bool] = []
+ content_mask: list[bool] = []
mm_hashes: dict[str, list[str]] = {}
mm_placeholders: dict[str, list[PlaceholderRange]] = {}
mm_items: dict[str, list[dict[str, Any]]] = {}
- def emit_ids(ids: list[int], msg_idx: int, *, is_sampled: bool) -> None:
+ def emit_ids(
+ ids: list[int], msg_idx: int, *, is_sampled: bool, is_content: bool
+ ) -> None:
tokens.extend(ids)
indices.extend([msg_idx] * len(ids))
sampled.extend([is_sampled] * len(ids))
+ content_mask.extend([is_content] * len(ids))
- def emit_special(token_id: int, msg_idx: int, *, is_sampled: bool) -> None:
+ def emit_special(
+ token_id: int, msg_idx: int, *, is_sampled: bool, is_content: bool
+ ) -> None:
tokens.append(token_id)
indices.append(msg_idx)
sampled.append(is_sampled)
+ content_mask.append(is_content)
+
+ def emit_text(
+ text: str, msg_idx: int, *, is_sampled: bool, is_content: bool
+ ) -> None:
+ emit_ids(
+ self._encode(text),
+ msg_idx,
+ is_sampled=is_sampled,
+ is_content=is_content,
+ )
- def emit_text(text: str, msg_idx: int, *, is_sampled: bool) -> None:
- emit_ids(self._encode(text), msg_idx, is_sampled=is_sampled)
+ def emit_text_segments(
+ segments: list[tuple[str, bool]], msg_idx: int, *, is_sampled: bool
+ ) -> None:
+ """Tokenize concatenated segments as one BPE pass; per-token
+ ``is_content`` follows each token's source segment.
+
+ Lets call sites express "this wrap + this body, joined the
+ same way as the chat template, but attributed separately"
+ without splitting the encode call (which could shift BPE
+ merges at the boundary)."""
+ for tok_id, is_content in attribute_text_segments(
+ self._tokenizer, segments
+ ):
+ tokens.append(tok_id)
+ indices.append(msg_idx)
+ sampled.append(is_sampled)
+ content_mask.append(is_content)
def emit_image(part: dict[str, Any], msg_idx: int) -> None:
# Image placeholders only appear in user / tool messages; the
# model never samples them. Pin is_sampled=False here so
- # callers don't need to thread the flag through.
+ # callers don't need to thread the flag through. The
+ # ``<|image_pad|>`` placeholders represent caller-provided
+ # image data, so they ARE body content (is_content=True);
+ # the surrounding ``<|vision_start|>`` / ``<|vision_end|>``
+ # specials are template scaffold.
_, out, n, h = self._process_image(part)
- emit_special(self._vision_start, msg_idx, is_sampled=False)
+ emit_special(
+ self._vision_start, msg_idx, is_sampled=False, is_content=False
+ )
offset = len(tokens)
for _ in range(n):
- emit_special(self._image_pad, msg_idx, is_sampled=False)
- emit_special(self._vision_end, msg_idx, is_sampled=False)
+ emit_special(
+ self._image_pad, msg_idx, is_sampled=False, is_content=True
+ )
+ emit_special(self._vision_end, msg_idx, is_sampled=False, is_content=False)
mm_hashes.setdefault("image", []).append(h)
mm_placeholders.setdefault("image", []).append(
PlaceholderRange(offset=offset, length=n)
@@ -337,20 +378,28 @@ def emit_user_with_media(content_list: list[Any], msg_idx: int) -> None:
``<|im_end|>``), matching how Jinja's ``render_content`` macro
concatenates strings before tokenization. This preserves BPE
byte-parity against ``apply_chat_template``.
+
+ Within each flush the ``"user\\n"`` wrap is scaffold and the
+ text parts are body. ``emit_text_segments`` carries that
+ attribution through the BPE pass.
"""
# Every token in a user message is conversation history that
# the model never samples at inference.
- emit_special(self._im_start, msg_idx, is_sampled=False)
- buf: list[str] = ["user\n"]
+ emit_special(self._im_start, msg_idx, is_sampled=False, is_content=False)
+ # First flush includes the ``"user\n"`` wrap as a scaffold
+ # segment; subsequent flushes are pure body (after a media
+ # break).
+ buf_segments: list[tuple[str, bool]] = [("user\n", False)]
def flush_buf() -> None:
- if buf:
- emit_text("".join(buf), msg_idx, is_sampled=False)
- buf.clear()
+ if buf_segments:
+ emit_text_segments(buf_segments, msg_idx, is_sampled=False)
+ buf_segments.clear()
for item in content_list:
if isinstance(item, str):
- buf.append(item)
+ if item:
+ buf_segments.append((item, True))
elif isinstance(item, dict):
if _is_image_part(item):
flush_buf()
@@ -360,14 +409,15 @@ def flush_buf() -> None:
"Video parts are not yet supported by Qwen35Renderer."
)
elif "text" in item:
- buf.append(item["text"])
+ if item["text"]:
+ buf_segments.append((item["text"], True))
else:
raise ValueError(f"Unexpected content item: {item}")
else:
raise ValueError(f"Unexpected content item: {item}")
flush_buf()
- emit_special(self._im_end, msg_idx, is_sampled=False)
- emit_text("\n", msg_idx, is_sampled=False)
+ emit_special(self._im_end, msg_idx, is_sampled=False, is_content=False)
+ emit_text("\n", msg_idx, is_sampled=False, is_content=False)
# ── 1. System message + optional tools ──────────────────────
first_is_system = messages[0].get("role") == "system"
@@ -376,31 +426,37 @@ def flush_buf() -> None:
# System message index for attribution
sys_idx = 0 if first_is_system else -1
- emit_special(self._im_start, sys_idx, is_sampled=False)
- emit_text("system\n", sys_idx, is_sampled=False)
-
- # Tools header + JSON definitions
- tool_text = _TOOLS_HEADER
+ emit_special(self._im_start, sys_idx, is_sampled=False, is_content=False)
+ # Body = system content (if any). Everything else in this
+ # block — role tag, tools header / footer / instructions, the
+ # JSON tool specs — is scaffold. The tools dict is
+ # recoverable from the ``tools`` argument; don't re-attribute
+ # its embedded JSON as message body.
+ segments: list[tuple[str, bool]] = [
+ ("system\n", False),
+ (_TOOLS_HEADER, False),
+ ]
for tool in tools:
- tool_text += "\n" + json.dumps(tool, ensure_ascii=False)
- tool_text += _TOOLS_FOOTER
- tool_text += _TOOLS_INSTRUCTIONS
-
- # Append user's system content if present
+ segments.append(("\n" + json.dumps(tool, ensure_ascii=False), False))
+ segments.append((_TOOLS_FOOTER, False))
+ segments.append((_TOOLS_INSTRUCTIONS, False))
if first_is_system:
sys_content = self._render_content(messages[0].get("content")).strip()
if sys_content:
- tool_text += "\n\n" + sys_content
-
- emit_text(tool_text, sys_idx, is_sampled=False)
- emit_special(self._im_end, sys_idx, is_sampled=False)
- emit_text("\n", sys_idx, is_sampled=False)
+ segments.append(("\n\n", False))
+ segments.append((sys_content, True))
+ emit_text_segments(segments, sys_idx, is_sampled=False)
+ emit_special(self._im_end, sys_idx, is_sampled=False, is_content=False)
+ emit_text("\n", sys_idx, is_sampled=False, is_content=False)
elif first_is_system:
sys_content = self._render_content(messages[0].get("content")).strip()
- emit_special(self._im_start, 0, is_sampled=False)
- emit_text("system\n" + sys_content, 0, is_sampled=False)
- emit_special(self._im_end, 0, is_sampled=False)
- emit_text("\n", 0, is_sampled=False)
+ emit_special(self._im_start, 0, is_sampled=False, is_content=False)
+ sys_segments: list[tuple[str, bool]] = [("system\n", False)]
+ if sys_content:
+ sys_segments.append((sys_content, True))
+ emit_text_segments(sys_segments, 0, is_sampled=False)
+ emit_special(self._im_end, 0, is_sampled=False, is_content=False)
+ emit_text("\n", 0, is_sampled=False, is_content=False)
# ── 2. Compute last_query_index ─────────────────────────────
last_qi = self._last_query_index(messages)
@@ -420,10 +476,13 @@ def flush_buf() -> None:
if self._content_has_media(raw_content):
emit_user_with_media(raw_content, i)
else:
- emit_special(self._im_start, i, is_sampled=False)
- emit_text("user\n" + content, i, is_sampled=False)
- emit_special(self._im_end, i, is_sampled=False)
- emit_text("\n", i, is_sampled=False)
+ emit_special(self._im_start, i, is_sampled=False, is_content=False)
+ user_segments: list[tuple[str, bool]] = [("user\n", False)]
+ if content:
+ user_segments.append((content, True))
+ emit_text_segments(user_segments, i, is_sampled=False)
+ emit_special(self._im_end, i, is_sampled=False, is_content=False)
+ emit_text("\n", i, is_sampled=False, is_content=False)
elif role == "assistant":
preserve_thinking = should_preserve_past_thinking(
@@ -441,6 +500,7 @@ def flush_buf() -> None:
emit_special=emit_special,
emit_text=emit_text,
emit_ids=emit_ids,
+ emit_text_segments=emit_text_segments,
)
elif role == "tool":
@@ -450,6 +510,7 @@ def flush_buf() -> None:
emit_special=emit_special,
emit_text=emit_text,
emit_image=emit_image,
+ emit_text_segments=emit_text_segments,
)
else:
@@ -457,16 +518,16 @@ def flush_buf() -> None:
# ── 4. Generation prompt ────────────────────────────────────
if add_generation_prompt:
- emit_special(self._im_start, -1, is_sampled=False)
- emit_text("assistant\n", -1, is_sampled=False)
+ emit_special(self._im_start, -1, is_sampled=False, is_content=False)
+ emit_text("assistant\n", -1, is_sampled=False, is_content=False)
if self._enable_thinking:
- emit_special(self._think, -1, is_sampled=False)
- emit_text("\n", -1, is_sampled=False)
+ emit_special(self._think, -1, is_sampled=False, is_content=False)
+ emit_text("\n", -1, is_sampled=False, is_content=False)
else:
- emit_special(self._think, -1, is_sampled=False)
- emit_text("\n\n", -1, is_sampled=False)
- emit_special(self._think_end, -1, is_sampled=False)
- emit_text("\n\n", -1, is_sampled=False)
+ emit_special(self._think, -1, is_sampled=False, is_content=False)
+ emit_text("\n\n", -1, is_sampled=False, is_content=False)
+ emit_special(self._think_end, -1, is_sampled=False, is_content=False)
+ emit_text("\n\n", -1, is_sampled=False, is_content=False)
mm_data: MultiModalData | None = None
if mm_hashes or mm_placeholders or mm_items:
@@ -480,6 +541,7 @@ def flush_buf() -> None:
token_ids=tokens,
message_indices=indices,
sampled_mask=sampled,
+ is_content=content_mask,
message_roles=[m.get("role") or "" for m in messages],
multi_modal_data=mm_data,
)
@@ -549,34 +611,63 @@ def bridge_to_next_turn(
# ``previous_ids``. Bridge-added tokens get proper ``msg_idx``
# (relative to ``new_messages``) and uniformly ``False``
# ``sampled``: nothing the bridge emits was model-sampled.
+ # ``is_content`` follows the same rules as in :meth:`render` so
+ # consumers can walk the trajectory and read each step's own
+ # body mask; the prior portion is uniformly False since we have
+ # no attribution info for it.
tokens: list[int] = list(previous_ids)
indices: list[int] = [-1] * len(previous_ids)
sampled: list[bool] = [False] * len(previous_ids)
+ content_mask: list[bool] = [False] * len(previous_ids)
new_hashes: dict[str, list[str]] = {}
new_placeholders: dict[str, list[PlaceholderRange]] = {}
new_items: dict[str, list[dict[str, Any]]] = {}
def emit_special(
- token_id: int, msg_idx: int = -1, *, is_sampled: bool = False
+ token_id: int,
+ msg_idx: int = -1,
+ *,
+ is_sampled: bool = False,
+ is_content: bool = False,
) -> None:
tokens.append(token_id)
indices.append(msg_idx)
sampled.append(is_sampled)
+ content_mask.append(is_content)
def emit_text(
- text: str, msg_idx: int = -1, *, is_sampled: bool = False
+ text: str,
+ msg_idx: int = -1,
+ *,
+ is_sampled: bool = False,
+ is_content: bool = False,
) -> None:
ids = self._encode(text)
tokens.extend(ids)
indices.extend([msg_idx] * len(ids))
sampled.extend([is_sampled] * len(ids))
+ content_mask.extend([is_content] * len(ids))
+
+ def emit_text_segments(
+ segments: list[tuple[str, bool]],
+ msg_idx: int = -1,
+ *,
+ is_sampled: bool = False,
+ ) -> None:
+ for tok_id, is_content in attribute_text_segments(
+ self._tokenizer, segments
+ ):
+ tokens.append(tok_id)
+ indices.append(msg_idx)
+ sampled.append(is_sampled)
+ content_mask.append(is_content)
def emit_image(part: dict[str, Any], msg_idx: int = -1) -> None:
_, out, n, h = self._process_image(part)
emit_special(self._vision_start, msg_idx)
offset = len(tokens)
for _ in range(n):
- emit_special(self._image_pad, msg_idx)
+ emit_special(self._image_pad, msg_idx, is_content=True)
emit_special(self._vision_end, msg_idx)
new_hashes.setdefault("image", []).append(h)
new_placeholders.setdefault("image", []).append(
@@ -591,16 +682,17 @@ def emit_image(part: dict[str, Any], msg_idx: int = -1) -> None:
def emit_user_with_media(content_list: list[Any], msg_idx: int) -> None:
emit_special(self._im_start, msg_idx)
- buf: list[str] = ["user\n"]
+ buf_segments: list[tuple[str, bool]] = [("user\n", False)]
def flush_buf() -> None:
- if buf:
- emit_text("".join(buf), msg_idx)
- buf.clear()
+ if buf_segments:
+ emit_text_segments(buf_segments, msg_idx)
+ buf_segments.clear()
for item in content_list:
if isinstance(item, str):
- buf.append(item)
+ if item:
+ buf_segments.append((item, True))
elif isinstance(item, dict):
if _is_image_part(item):
flush_buf()
@@ -610,7 +702,8 @@ def flush_buf() -> None:
"Video parts are not yet supported by Qwen35Renderer."
)
elif "text" in item:
- buf.append(item["text"])
+ if item["text"]:
+ buf_segments.append((item["text"], True))
else:
raise ValueError(f"Unexpected content item: {item}")
else:
@@ -633,12 +726,18 @@ def flush_buf() -> None:
emit_user_with_media(raw_content, i)
else:
emit_special(self._im_start, i)
- emit_text("user\n" + content, i)
+ user_segments: list[tuple[str, bool]] = [("user\n", False)]
+ if content:
+ user_segments.append((content, True))
+ emit_text_segments(user_segments, i)
emit_special(self._im_end, i)
emit_text("\n", i)
elif role == "system":
emit_special(self._im_start, i)
- emit_text("system\n" + content, i)
+ sys_segments: list[tuple[str, bool]] = [("system\n", False)]
+ if content:
+ sys_segments.append((content, True))
+ emit_text_segments(sys_segments, i)
emit_special(self._im_end, i)
emit_text("\n", i)
elif role == "tool":
@@ -648,6 +747,7 @@ def flush_buf() -> None:
emit_special=emit_special,
emit_text=emit_text,
emit_image=emit_image,
+ emit_text_segments=emit_text_segments,
)
else:
return None
@@ -693,6 +793,7 @@ def flush_buf() -> None:
token_ids=tokens,
message_indices=indices,
sampled_mask=sampled,
+ is_content=content_mask,
message_roles=bridge_roles,
)
@@ -705,6 +806,7 @@ def flush_buf() -> None:
token_ids=tokens,
message_indices=indices,
sampled_mask=sampled,
+ is_content=content_mask,
message_roles=bridge_roles,
multi_modal_data=mm_data,
)
@@ -747,6 +849,7 @@ def _render_assistant(
emit_special,
emit_text,
emit_ids,
+ emit_text_segments,
) -> None:
# Extract reasoning_content
reasoning_content = ""
@@ -769,26 +872,36 @@ def _render_assistant(
# at inference the chat template emits these as the generation
# prompt and the model never samples them. Marking the role tag
# as ``is_sampled=False`` keeps the SFT loss mask aligned with
- # what the model would actually have produced. The split between
- # ``assistant`` and ``\n`` is a safe BPE boundary in the Qwen
- # tokenizer (``\n`` after the role is its own token).
- emit_special(self._im_start, msg_idx, is_sampled=False)
- emit_text("assistant\n", msg_idx, is_sampled=False)
+ # what the model would actually have produced. ``is_content`` is
+ # also False here — the role tag isn't part of any message's
+ # body, on any role.
+ emit_special(self._im_start, msg_idx, is_sampled=False, is_content=False)
+ emit_text("assistant\n", msg_idx, is_sampled=False, is_content=False)
# Build the model-sampled portion (think block + content + tool
# calls). Text segments stay contiguous within each is_sampled
- # span to preserve BPE merges.
+ # span to preserve BPE merges. For assistant messages the
+ # invariant ``is_content == sampled_mask`` holds — every sampled
+ # token is body, every scaffold token isn't. The XML-style tool
+ # call tags (````, ````, etc.) are
+ # part of the model's emitted output too — keep them
+ # ``is_content=True`` per the assistant rule.
emit_thinking = self._should_render_thinking(msg_idx, last_query_index) or (
preserve_thinking and bool(reasoning_content)
)
if emit_thinking:
# Include thinking block
- emit_special(self._think, msg_idx, is_sampled=True)
- emit_text("\n" + reasoning_content + "\n", msg_idx, is_sampled=True)
- emit_special(self._think_end, msg_idx, is_sampled=True)
- emit_text("\n\n" + content, msg_idx, is_sampled=True)
+ emit_special(self._think, msg_idx, is_sampled=True, is_content=True)
+ emit_text(
+ "\n" + reasoning_content + "\n",
+ msg_idx,
+ is_sampled=True,
+ is_content=True,
+ )
+ emit_special(self._think_end, msg_idx, is_sampled=True, is_content=True)
+ emit_text("\n\n" + content, msg_idx, is_sampled=True, is_content=True)
else:
- emit_text(content, msg_idx, is_sampled=True)
+ emit_text(content, msg_idx, is_sampled=True, is_content=True)
# Tool calls
tool_calls = msg.get("tool_calls") or []
@@ -801,13 +914,18 @@ def _render_assistant(
# Separator before
if tc_idx == 0:
if content.strip():
- emit_text("\n\n", msg_idx, is_sampled=True)
+ emit_text("\n\n", msg_idx, is_sampled=True, is_content=True)
# else: no separator
else:
- emit_text("\n", msg_idx, is_sampled=True)
-
- emit_special(self._tool_call, msg_idx, is_sampled=True)
- emit_text("\n\n", msg_idx, is_sampled=True)
+ emit_text("\n", msg_idx, is_sampled=True, is_content=True)
+
+ emit_special(self._tool_call, msg_idx, is_sampled=True, is_content=True)
+ emit_text(
+ "\n\n",
+ msg_idx,
+ is_sampled=True,
+ is_content=True,
+ )
# Render arguments
# OpenAI canonical form: arguments is a JSON string. Parse it so the
@@ -828,16 +946,20 @@ def _render_assistant(
+ "\n\n",
msg_idx,
is_sampled=True,
+ is_content=True,
)
- emit_text("\n", msg_idx, is_sampled=True)
- emit_special(self._tool_call_end, msg_idx, is_sampled=True)
+ emit_text("\n", msg_idx, is_sampled=True, is_content=True)
+ emit_special(
+ self._tool_call_end, msg_idx, is_sampled=True, is_content=True
+ )
# ``<|im_end|>`` is the model's stop signal — it samples this to
- # end its turn, so it is part of the sampled stream. The trailing
- # ``\n`` is template-appended between turns and never sampled.
- emit_special(self._im_end, msg_idx, is_sampled=True)
- emit_text("\n", msg_idx, is_sampled=False)
+ # end its turn, so it is part of the sampled stream (and the
+ # assistant's body). The trailing ``\n`` is template-appended
+ # between turns and never sampled — scaffold for is_content too.
+ emit_special(self._im_end, msg_idx, is_sampled=True, is_content=True)
+ emit_text("\n", msg_idx, is_sampled=False, is_content=False)
# ------------------------------------------------------------------
# Tool message rendering
@@ -851,6 +973,7 @@ def _render_tool(
emit_special,
emit_text,
emit_image,
+ emit_text_segments,
) -> None:
# Consecutive tool messages share a single <|im_start|>user ... <|im_end|>
# envelope. Whether to open and close the envelope depends only on the
@@ -858,9 +981,11 @@ def _render_tool(
# tool message — keep this predicate text/media-agnostic.
# Tool messages are conversation history injected by the runtime
# between assistant turns — the model never samples any of these
- # tokens, so every emission is is_sampled=False. The bridge's
- # local emit_special / emit_text accept the is_sampled kwarg for
- # signature compatibility but ignore it.
+ # tokens, so every emission is is_sampled=False. The ``content``
+ # field's body bytes get ``is_content=True``; everything else —
+ # the ``<|im_start|>user`` wrap, the inter-section ``\n``s, the
+ # ``<|tool_response>`` specials — is scaffold so the SFT mask
+ # for tool body never trains the model to emit them.
prev_is_tool = msg_idx > 0 and messages[msg_idx - 1]["role"] == "tool"
next_is_tool = (
msg_idx + 1 < len(messages) and messages[msg_idx + 1]["role"] == "tool"
@@ -868,11 +993,11 @@ def _render_tool(
raw_content = messages[msg_idx].get("content")
if not prev_is_tool:
- emit_special(self._im_start, msg_idx, is_sampled=False)
- emit_text("user", msg_idx, is_sampled=False)
+ emit_special(self._im_start, msg_idx, is_sampled=False, is_content=False)
+ emit_text("user", msg_idx, is_sampled=False, is_content=False)
- emit_text("\n", msg_idx, is_sampled=False)
- emit_special(self._tool_response, msg_idx, is_sampled=False)
+ emit_text("\n", msg_idx, is_sampled=False, is_content=False)
+ emit_special(self._tool_response, msg_idx, is_sampled=False, is_content=False)
if self._content_has_media(raw_content):
# Mirror the chat template's ``render_content`` macro for list
@@ -882,16 +1007,22 @@ def _render_tool(
# ``_content_has_media`` returns False unless content is a list,
# but the type checker can't follow that through the call.
assert isinstance(raw_content, list)
- buf: list[str] = ["\n"]
+ # First flush: leading ``"\n"`` is scaffold (separates the
+ # ``<|tool_response>`` special from the body); subsequent
+ # text items in this run are body. After a media break, the
+ # buffer resets to pure body until the next media break or
+ # end-of-content.
+ buf_segments: list[tuple[str, bool]] = [("\n", False)]
def flush_buf() -> None:
- if buf:
- emit_text("".join(buf), msg_idx, is_sampled=False)
- buf.clear()
+ if buf_segments:
+ emit_text_segments(buf_segments, msg_idx, is_sampled=False)
+ buf_segments.clear()
for item in raw_content:
if isinstance(item, str):
- buf.append(item)
+ if item:
+ buf_segments.append((item, True))
elif isinstance(item, dict):
if _is_image_part(item):
flush_buf()
@@ -901,19 +1032,29 @@ def flush_buf() -> None:
"Video parts are not yet supported by Qwen35Renderer."
)
elif "text" in item:
- buf.append(item["text"])
+ if item["text"]:
+ buf_segments.append((item["text"], True))
else:
raise ValueError(f"Unexpected content item: {item}")
else:
raise ValueError(f"Unexpected content item: {item}")
flush_buf()
- emit_text("\n", msg_idx, is_sampled=False)
+ emit_text("\n", msg_idx, is_sampled=False, is_content=False)
else:
content = self._render_content(raw_content).strip()
- emit_text("\n" + content + "\n", msg_idx, is_sampled=False)
+ # ``\n`` + content + ``\n`` — body is the middle segment only.
+ # Single BPE pass over the joined text preserves boundary
+ # merges.
+ emit_text_segments(
+ [("\n", False), (content, True), ("\n", False)],
+ msg_idx,
+ is_sampled=False,
+ )
- emit_special(self._tool_response_end, msg_idx, is_sampled=False)
+ emit_special(
+ self._tool_response_end, msg_idx, is_sampled=False, is_content=False
+ )
if not next_is_tool:
- emit_special(self._im_end, msg_idx, is_sampled=False)
- emit_text("\n", msg_idx, is_sampled=False)
+ emit_special(self._im_end, msg_idx, is_sampled=False, is_content=False)
+ emit_text("\n", msg_idx, is_sampled=False, is_content=False)
diff --git a/renderers/qwen3_vl.py b/renderers/qwen3_vl.py
index 311ce26..94ae13d 100644
--- a/renderers/qwen3_vl.py
+++ b/renderers/qwen3_vl.py
@@ -42,6 +42,7 @@
PlaceholderRange,
RenderedTokens,
ToolSpec,
+ attribute_text_segments,
reject_assistant_in_extension,
trim_to_turn_close,
)
@@ -177,14 +178,31 @@ class _Emitter:
value than the current buffer triggers a flush first — split points
are always at the ``is_sampled`` boundary, which the caller is
expected to place at a ``\\n`` boundary so BPE merges don't drift.
+
+ ``is_content`` is the per-token body/scaffold attribution. Within a
+ single flush adjacent text fragments may carry different
+ ``is_content`` values (e.g. ``"user\\n"`` scaffold + caller content
+ body): the buffer stores fragments as a list of
+ ``(text, is_content)`` segments and flushes via
+ :func:`attribute_text_segments`, which performs one BPE pass over
+ the joined text and assigns per-token is_content from each token's
+ source segment. When every segment in a flush shares the same
+ is_content (the common case for sampled assistant body / pure
+ scaffold) the fast path of a single ``encode()`` call is used and
+ no offset-tokenizer lookup is required.
"""
- def __init__(self, encode_fn, msg_idx: int = -1):
+ def __init__(self, encode_fn, tokenizer=None, msg_idx: int = -1):
self._encode = encode_fn
+ self._tokenizer = tokenizer
self.token_ids: list[int] = []
self.message_indices: list[int] = []
self.sampled: list[bool] = []
- self._buf: str = ""
+ self.is_content: list[bool] = []
+ # Buffered text fragments as ``(text, is_content)`` tuples. All
+ # fragments share a single ``_buf_sampled`` / ``_buf_idx``;
+ # changing either of those triggers a flush.
+ self._segments: list[tuple[str, bool]] = []
self._buf_idx: int = msg_idx
self._buf_sampled: bool = False
self.msg_idx = msg_idx
@@ -194,49 +212,78 @@ def set_msg_idx(self, msg_idx: int) -> None:
# text doesn't get glued to the previous one's BPE context.
# In practice messages are always separated by an <|im_end|>
# special token, which already flushes — but be defensive.
- if self._buf:
+ if self._segments:
self._flush()
self.msg_idx = msg_idx
self._buf_idx = msg_idx
- def text(self, text: str, *, is_sampled: bool) -> None:
+ def text(self, text: str, *, is_sampled: bool, is_content: bool) -> None:
if not text:
return
# Adjacent text under different msg_idx or is_sampled is rare in
# this template — but flush at those boundaries so attribution
- # and the sampled signal stay accurate.
- if self._buf and (
+ # and the sampled signal stay accurate. is_content boundaries do
+ # NOT force a flush: they're carried through the joined BPE pass
+ # via :func:`attribute_text_segments`, preserving merges across
+ # the wrap/body boundary.
+ if self._segments and (
self._buf_idx != self.msg_idx or self._buf_sampled != is_sampled
):
self._flush()
- if not self._buf:
+ if not self._segments:
self._buf_idx = self.msg_idx
self._buf_sampled = is_sampled
- self._buf += text
+ self._segments.append((text, is_content))
- def special(self, token_id: int, *, is_sampled: bool) -> None:
- if self._buf:
+ def special(self, token_id: int, *, is_sampled: bool, is_content: bool) -> None:
+ if self._segments:
self._flush()
self.token_ids.append(token_id)
self.message_indices.append(self.msg_idx)
self.sampled.append(is_sampled)
+ self.is_content.append(is_content)
def cursor(self) -> int:
"""Current token offset after flushing — used to anchor placeholder ranges."""
- if self._buf:
+ if self._segments:
self._flush()
return len(self.token_ids)
def finalize(self) -> None:
- if self._buf:
+ if self._segments:
self._flush()
def _flush(self) -> None:
- ids = self._encode(self._buf)
- self.token_ids.extend(ids)
- self.message_indices.extend([self._buf_idx] * len(ids))
- self.sampled.extend([self._buf_sampled] * len(ids))
- self._buf = ""
+ segments = self._segments
+ self._segments = []
+ if not segments:
+ return
+ # Fast path: every segment shares the same is_content — use the
+ # plain ``encode()`` call so we don't pay for the offset
+ # tokenizer. This is the common case (pure scaffold flushes, or
+ # pure body flushes).
+ first_ic = segments[0][1]
+ all_same = all(ic == first_ic for _, ic in segments)
+ if all_same:
+ joined = "".join(text for text, _ in segments)
+ ids = self._encode(joined)
+ self.token_ids.extend(ids)
+ self.message_indices.extend([self._buf_idx] * len(ids))
+ self.sampled.extend([self._buf_sampled] * len(ids))
+ self.is_content.extend([first_ic] * len(ids))
+ return
+ # Mixed body/scaffold flush — encode once and attribute back to
+ # each segment via the fast tokenizer's offset_mapping. Requires
+ # a tokenizer (not just the encode fn) to look up offsets.
+ assert self._tokenizer is not None, (
+ "_Emitter mixed-is_content flush requires a tokenizer; "
+ "pass one to the constructor."
+ )
+ for tok_id, is_content in attribute_text_segments(self._tokenizer, segments):
+ self.token_ids.append(tok_id)
+ self.message_indices.append(self._buf_idx)
+ self.sampled.append(self._buf_sampled)
+ self.is_content.append(is_content)
class Qwen3VLRenderer:
@@ -401,7 +448,7 @@ def render(
if not messages:
raise ValueError("No messages provided.")
- em = _Emitter(self._encode)
+ em = _Emitter(self._encode, tokenizer=self._tokenizer)
mm_hashes: dict[str, list[str]] = {}
mm_placeholders: dict[str, list[PlaceholderRange]] = {}
mm_items: dict[str, list[dict[str, Any]]] = {}
@@ -409,13 +456,17 @@ def render(
def emit_image(part: dict[str, Any]) -> None:
# Image placeholders are prompt-side scaffolding the user
# message attaches — the model never samples ``<|vision_start|>``
- # / ``<|image_pad|>`` / ``<|vision_end|>``.
+ # / ``<|image_pad|>`` / ``<|vision_end|>``. The
+ # ``<|image_pad|>`` placeholders represent caller-provided
+ # image data, so they ARE body content (is_content=True);
+ # the surrounding ``<|vision_start|>`` / ``<|vision_end|>``
+ # markers are renderer-emitted scaffold.
_, out, n, h = self._process_image(part)
- em.special(self._vision_start, is_sampled=False)
+ em.special(self._vision_start, is_sampled=False, is_content=False)
offset = em.cursor()
for _ in range(n):
- em.special(self._image_pad, is_sampled=False)
- em.special(self._vision_end, is_sampled=False)
+ em.special(self._image_pad, is_sampled=False, is_content=True)
+ em.special(self._vision_end, is_sampled=False, is_content=False)
mm_hashes.setdefault("image", []).append(h)
mm_placeholders.setdefault("image", []).append(
PlaceholderRange(offset=offset, length=n)
@@ -432,16 +483,24 @@ def render_media_content(content: Any) -> None:
User / tool content is conversation context the model never
samples — every text fragment goes in as ``is_sampled=False``.
+ Text from the caller IS the message body, so every text
+ fragment is ``is_content=True``; the vision-marker specials
+ around image_pad placeholders are scaffold (handled in
+ :func:`emit_image`).
"""
if isinstance(content, str):
- em.text(content, is_sampled=False)
+ em.text(content, is_sampled=False, is_content=True)
return
if not isinstance(content, list):
- em.text(self._render_text_content(content), is_sampled=False)
+ em.text(
+ self._render_text_content(content),
+ is_sampled=False,
+ is_content=True,
+ )
return
for item in content:
if isinstance(item, str):
- em.text(item, is_sampled=False)
+ em.text(item, is_sampled=False, is_content=True)
elif isinstance(item, dict):
if _is_image_part(item):
emit_image(item)
@@ -450,7 +509,7 @@ def render_media_content(content: Any) -> None:
"Video parts are not yet supported by Qwen3VLRenderer."
)
elif "text" in item:
- em.text(item["text"], is_sampled=False)
+ em.text(item["text"], is_sampled=False, is_content=True)
# ── 1. System + tools ───────────────────────────────────────
first_is_system = messages[0].get("role") == "system"
@@ -458,26 +517,37 @@ def render_media_content(content: Any) -> None:
if tools:
sys_idx = 0 if first_is_system else -1
em.set_msg_idx(sys_idx)
- em.special(self._im_start, is_sampled=False)
- buf = "system\n"
+ em.special(self._im_start, is_sampled=False, is_content=False)
+ # Body = system content (if any). Everything else in this
+ # block — role tag, tools header / footer, JSON tool specs —
+ # is scaffold. The tools dict is recoverable from the
+ # ``tools`` argument; we don't re-attribute its embedded
+ # JSON as message body.
+ em.text("system\n", is_sampled=False, is_content=False)
if first_is_system:
- buf += self._render_text_content(messages[0].get("content")) + "\n\n"
- buf += _TOOLS_HEADER
+ sys_content = self._render_text_content(messages[0].get("content"))
+ if sys_content:
+ em.text(sys_content, is_sampled=False, is_content=True)
+ em.text("\n\n", is_sampled=False, is_content=False)
+ em.text(_TOOLS_HEADER, is_sampled=False, is_content=False)
for tool in tools:
- buf += "\n" + json.dumps(tool, ensure_ascii=False)
- buf += _TOOLS_FOOTER
- em.text(buf, is_sampled=False)
- em.special(self._im_end, is_sampled=False)
- em.text("\n", is_sampled=False)
+ em.text(
+ "\n" + json.dumps(tool, ensure_ascii=False),
+ is_sampled=False,
+ is_content=False,
+ )
+ em.text(_TOOLS_FOOTER, is_sampled=False, is_content=False)
+ em.special(self._im_end, is_sampled=False, is_content=False)
+ em.text("\n", is_sampled=False, is_content=False)
elif first_is_system:
em.set_msg_idx(0)
- em.special(self._im_start, is_sampled=False)
- em.text(
- "system\n" + self._render_text_content(messages[0].get("content")),
- is_sampled=False,
- )
- em.special(self._im_end, is_sampled=False)
- em.text("\n", is_sampled=False)
+ em.special(self._im_start, is_sampled=False, is_content=False)
+ em.text("system\n", is_sampled=False, is_content=False)
+ sys_content = self._render_text_content(messages[0].get("content"))
+ if sys_content:
+ em.text(sys_content, is_sampled=False, is_content=True)
+ em.special(self._im_end, is_sampled=False, is_content=False)
+ em.text("\n", is_sampled=False, is_content=False)
# ── 2. Iterate messages ─────────────────────────────────────
for i, msg in enumerate(messages):
@@ -489,11 +559,11 @@ def render_media_content(content: Any) -> None:
em.set_msg_idx(i)
if role == "user":
- em.special(self._im_start, is_sampled=False)
- em.text("user\n", is_sampled=False)
+ em.special(self._im_start, is_sampled=False, is_content=False)
+ em.text("user\n", is_sampled=False, is_content=False)
render_media_content(msg.get("content"))
- em.special(self._im_end, is_sampled=False)
- em.text("\n", is_sampled=False)
+ em.special(self._im_end, is_sampled=False, is_content=False)
+ em.text("\n", is_sampled=False, is_content=False)
elif role == "assistant":
self._render_assistant(msg, em)
@@ -507,8 +577,8 @@ def render_media_content(content: Any) -> None:
# ── 3. Generation prompt ────────────────────────────────────
if add_generation_prompt:
em.set_msg_idx(-1)
- em.special(self._im_start, is_sampled=False)
- em.text("assistant\n", is_sampled=False)
+ em.special(self._im_start, is_sampled=False, is_content=False)
+ em.text("assistant\n", is_sampled=False, is_content=False)
em.finalize()
@@ -524,6 +594,7 @@ def render_media_content(content: Any) -> None:
token_ids=em.token_ids,
message_indices=em.message_indices,
sampled_mask=em.sampled,
+ is_content=em.is_content,
message_roles=[m.get("role") or "" for m in messages],
multi_modal_data=mm_data,
)
@@ -595,18 +666,21 @@ def bridge_to_next_turn(
# Bridge populates ``message_indices`` (relative to ``new_messages``)
# and ``sampled_mask`` (uniformly ``False`` — every token the
# bridge emits is template scaffolding for the next prompt, not
- # something the model sampled). Downstream consumers can run
- # :meth:`RenderedTokens.tokens_per_message` on the bridge output
- # to get per-new-message token counts without re-rendering.
- em = _Emitter(self._encode)
+ # something the model sampled). ``is_content`` follows the same
+ # rules as in :meth:`render` so consumers can walk the trajectory
+ # and read each step's own body mask. Downstream consumers can
+ # run :meth:`RenderedTokens.tokens_per_message` on the bridge
+ # output to get per-new-message token counts without re-rendering.
+ em = _Emitter(self._encode, tokenizer=self._tokenizer)
# Seed the emitter with the prior turn's tokens so cursor() reports
# absolute offsets in the combined sequence. Per-token attribution
# for the prior portion is unknown to the bridge (it only has
# prev_prompt_ids + prev_completion_ids as raw lists), so seed
- # both side channels with the "no info" sentinel.
+ # all side channels with the "no info" sentinel.
em.token_ids = list(previous_ids)
em.message_indices = [-1] * len(previous_ids)
em.sampled = [False] * len(previous_ids)
+ em.is_content = [False] * len(previous_ids)
new_hashes: dict[str, list[str]] = {}
new_placeholders: dict[str, list[PlaceholderRange]] = {}
@@ -614,11 +688,11 @@ def bridge_to_next_turn(
def emit_image(part: dict[str, Any]) -> None:
_, out, n, h = self._process_image(part)
- em.special(self._vision_start, is_sampled=False)
+ em.special(self._vision_start, is_sampled=False, is_content=False)
offset = em.cursor()
for _ in range(n):
- em.special(self._image_pad, is_sampled=False)
- em.special(self._vision_end, is_sampled=False)
+ em.special(self._image_pad, is_sampled=False, is_content=True)
+ em.special(self._vision_end, is_sampled=False, is_content=False)
new_hashes.setdefault("image", []).append(h)
new_placeholders.setdefault("image", []).append(
PlaceholderRange(offset=offset, length=n)
@@ -632,14 +706,18 @@ def emit_image(part: dict[str, Any]) -> None:
def render_media_content(content: Any) -> None:
if isinstance(content, str):
- em.text(content, is_sampled=False)
+ em.text(content, is_sampled=False, is_content=True)
return
if not isinstance(content, list):
- em.text(self._render_text_content(content), is_sampled=False)
+ em.text(
+ self._render_text_content(content),
+ is_sampled=False,
+ is_content=True,
+ )
return
for item in content:
if isinstance(item, str):
- em.text(item, is_sampled=False)
+ em.text(item, is_sampled=False, is_content=True)
elif isinstance(item, dict):
if _is_image_part(item):
emit_image(item)
@@ -648,34 +726,34 @@ def render_media_content(content: Any) -> None:
"Video parts are not yet supported by Qwen3VLRenderer."
)
elif "text" in item:
- em.text(item["text"], is_sampled=False)
+ em.text(item["text"], is_sampled=False, is_content=True)
em.set_msg_idx(-1)
- em.text("\n", is_sampled=False)
+ em.text("\n", is_sampled=False, is_content=False)
for i, msg in enumerate(new_messages):
role = msg.get("role")
em.set_msg_idx(i)
if role == "user":
- em.special(self._im_start, is_sampled=False)
- em.text("user\n", is_sampled=False)
+ em.special(self._im_start, is_sampled=False, is_content=False)
+ em.text("user\n", is_sampled=False, is_content=False)
render_media_content(msg.get("content"))
- em.special(self._im_end, is_sampled=False)
- em.text("\n", is_sampled=False)
+ em.special(self._im_end, is_sampled=False, is_content=False)
+ em.text("\n", is_sampled=False, is_content=False)
elif role == "system":
- em.special(self._im_start, is_sampled=False)
- em.text("system\n", is_sampled=False)
+ em.special(self._im_start, is_sampled=False, is_content=False)
+ em.text("system\n", is_sampled=False, is_content=False)
render_media_content(msg.get("content"))
- em.special(self._im_end, is_sampled=False)
- em.text("\n", is_sampled=False)
+ em.special(self._im_end, is_sampled=False, is_content=False)
+ em.text("\n", is_sampled=False, is_content=False)
elif role == "tool":
self._render_tool(new_messages, i, em, render_media_content)
else:
return None
em.set_msg_idx(-1)
- em.special(self._im_start, is_sampled=False)
- em.text("assistant\n", is_sampled=False)
+ em.special(self._im_start, is_sampled=False, is_content=False)
+ em.text("assistant\n", is_sampled=False, is_content=False)
em.finalize()
# Merge prev mm_data with the new turn's items.
@@ -713,6 +791,7 @@ def render_media_content(content: Any) -> None:
token_ids=em.token_ids,
message_indices=em.message_indices,
sampled_mask=em.sampled,
+ is_content=em.is_content,
message_roles=[m.get("role") or "" for m in new_messages],
multi_modal_data=mm_data,
)
@@ -726,20 +805,22 @@ def _render_assistant(self, msg: Message, em: _Emitter) -> None:
# at inference the chat template emits these as the generation
# prompt and the model never samples them. Splitting the text
# at the ``\n`` after the role tag is safe: Qwen3 BPE treats
- # ``\n`` as a token boundary.
- em.special(self._im_start, is_sampled=False)
- em.text("assistant\n", is_sampled=False)
+ # ``\n`` as a token boundary. For the assistant role the
+ # invariant ``is_content == sampled_mask`` holds — every sampled
+ # token is body, every scaffold token isn't.
+ em.special(self._im_start, is_sampled=False, is_content=False)
+ em.text("assistant\n", is_sampled=False, is_content=False)
# Body (content + tool calls) is the model-sampled portion.
if not tool_calls:
- em.text(content, is_sampled=True)
+ em.text(content, is_sampled=True, is_content=True)
else:
for tc_idx, tc in enumerate(tool_calls):
if tc_idx == 0:
separator = "\n" if original_content else ""
- em.text(content + separator, is_sampled=True)
+ em.text(content + separator, is_sampled=True, is_content=True)
else:
- em.text("\n", is_sampled=True)
+ em.text("\n", is_sampled=True, is_content=True)
func = tc.get("function") or tc
name = func.get("name", "")
@@ -750,18 +831,19 @@ def _render_assistant(self, msg: Message, em: _Emitter) -> None:
else json.dumps(arguments, ensure_ascii=False)
)
- em.special(self._tool_call, is_sampled=True)
+ em.special(self._tool_call, is_sampled=True, is_content=True)
em.text(
'\n{"name": "' + name + '", "arguments": ' + args_str + "}\n",
is_sampled=True,
+ is_content=True,
)
- em.special(self._tool_call_end, is_sampled=True)
+ em.special(self._tool_call_end, is_sampled=True, is_content=True)
# ``<|im_end|>`` is the model's stop signal — it samples this to
- # end its turn. The trailing ``\n`` is template-appended between
- # turns and never sampled.
- em.special(self._im_end, is_sampled=True)
- em.text("\n", is_sampled=False)
+ # end its turn (and counts as part of its body). The trailing
+ # ``\n`` is template-appended between turns and never sampled.
+ em.special(self._im_end, is_sampled=True, is_content=True)
+ em.text("\n", is_sampled=False, is_content=False)
def _render_tool(
self,
@@ -772,24 +854,27 @@ def _render_tool(
) -> None:
# Tool messages are conversation history injected by the runtime
# between assistant turns — the model never samples any of these
- # tokens, so every emission is is_sampled=False. (render_media_content
- # also stamps is_sampled=False on its emissions.)
+ # tokens, so every emission is is_sampled=False. The
+ # ``content`` body bytes get ``is_content=True`` (via
+ # ``render_media_content``); everything else — the
+ # ``<|im_start|>user`` wrap, inter-section ``\n``s, and the
+ # ``<|tool_response>`` specials — is scaffold.
prev_is_tool = msg_idx > 0 and messages[msg_idx - 1]["role"] == "tool"
next_is_tool = (
msg_idx + 1 < len(messages) and messages[msg_idx + 1]["role"] == "tool"
)
if not prev_is_tool:
- em.special(self._im_start, is_sampled=False)
- em.text("user", is_sampled=False)
+ em.special(self._im_start, is_sampled=False, is_content=False)
+ em.text("user", is_sampled=False, is_content=False)
- em.text("\n", is_sampled=False)
- em.special(self._tool_response, is_sampled=False)
- em.text("\n", is_sampled=False)
+ em.text("\n", is_sampled=False, is_content=False)
+ em.special(self._tool_response, is_sampled=False, is_content=False)
+ em.text("\n", is_sampled=False, is_content=False)
render_media_content(messages[msg_idx].get("content"))
- em.text("\n", is_sampled=False)
- em.special(self._tool_response_end, is_sampled=False)
+ em.text("\n", is_sampled=False, is_content=False)
+ em.special(self._tool_response_end, is_sampled=False, is_content=False)
if not next_is_tool:
- em.special(self._im_end, is_sampled=False)
- em.text("\n", is_sampled=False)
+ em.special(self._im_end, is_sampled=False, is_content=False)
+ em.text("\n", is_sampled=False, is_content=False)
diff --git a/tests/test_client.py b/tests/test_client.py
index 850f220..0695d78 100644
--- a/tests/test_client.py
+++ b/tests/test_client.py
@@ -19,7 +19,15 @@ def render(self, messages, *, tools=None, add_generation_prompt=False):
assert messages == [{"role": "user", "content": "hi"}]
assert tools == [{"type": "function", "function": {"name": "echo"}}]
assert add_generation_prompt is True
- return RenderedTokens(token_ids=[1, 2, 3])
+ # Populate the full attribution surface so the test can verify
+ # ``generate`` threads it through to the result dict unchanged.
+ return RenderedTokens(
+ token_ids=[1, 2, 3],
+ message_indices=[0, 0, -1],
+ sampled_mask=[False, False, False],
+ is_content=[False, True, False],
+ message_roles=["user"],
+ )
def render_ids(self, messages, *, tools=None, add_generation_prompt=False):
return self.render(
@@ -137,6 +145,17 @@ def test_generate_builds_request_body_and_parses_response():
assert result["routed_experts"] == [[[1]], [[2]]]
assert result["multi_modal_data"] is None
assert result["request_id"] == "gen-test"
+ # Per-token attribution from the renderer surfaces on the result so
+ # downstream consumers (verifiers RendererClient → prime-rl) can
+ # build selective loss masks without a second render pass.
+ attr = result["prompt_attribution"]
+ assert attr is not None
+ assert isinstance(attr, RenderedTokens)
+ assert attr.token_ids == [1, 2, 3]
+ assert attr.is_content == [False, True, False]
+ assert attr.sampled_mask == [False, False, False]
+ assert attr.message_indices == [0, 0, -1]
+ assert attr.message_roles == ["user"]
assert len(result["tool_calls"]) == 1
tc = result["tool_calls"][0]
assert tc.name == "echo"
@@ -207,6 +226,44 @@ def test_generate_uses_prebuilt_prompt_ids_without_rendering():
assert client.calls[0]["body"]["token_ids"] == [11, 12, 13]
assert result["prompt_ids"] == [11, 12, 13]
+ # Pre-built prompt without explicit attribution → ``None`` carried
+ # through. Consumers fall back to whatever attribution-free path
+ # they have (e.g. uniform completion mask).
+ assert result["prompt_attribution"] is None
+
+
+def test_generate_threads_prompt_attribution_through_prebuilt_prompt_path():
+ """When the caller passes both ``prompt_ids`` and ``prompt_attribution``
+ (the multi-turn bridge path in verifiers), ``generate`` must thread
+ the attribution through to the result dict unchanged — no re-rendering,
+ no per-token reshuffling. Lets downstream consumers carry the
+ renderer's body/scaffold cut into the trajectory step without an
+ extra render pass."""
+ client = _FakeClient()
+ # Caller-supplied attribution; mirrors what
+ # ``RendererClient._get_incremental_prompt_ids`` returns from the
+ # bridge_to_next_turn output.
+ supplied = RenderedTokens(
+ token_ids=[11, 12, 13],
+ message_indices=[-1, 0, 0],
+ sampled_mask=[False, False, False],
+ is_content=[False, True, True],
+ message_roles=["tool"],
+ )
+
+ result = asyncio.run(
+ generate(
+ client=client,
+ renderer=_NoRenderRenderer(),
+ messages=[{"role": "user", "content": "hi"}],
+ model="test-model",
+ prompt_ids=[11, 12, 13],
+ prompt_attribution=supplied,
+ )
+ )
+
+ # Exact passthrough — same object, no copy / no transform.
+ assert result["prompt_attribution"] is supplied
# ---------------------------------------------------------------------------
diff --git a/tests/test_is_content.py b/tests/test_is_content.py
new file mode 100644
index 0000000..ddac1f5
--- /dev/null
+++ b/tests/test_is_content.py
@@ -0,0 +1,389 @@
+"""Per-token ``RenderedTokens.is_content`` invariants.
+
+``is_content[k]`` answers "does ``token_ids[k]`` come from message body
+bytes (caller-provided content / tool_calls / reasoning_content, or
+the model's sampled emission for the assistant role) or is it template
+scaffolding the renderer added around the body (role tags, special
+tokens, separators, tool-response wraps, generation prompt)?"
+
+By design ``is_content`` is a superset of ``sampled_mask``:
+
+- Equal on every token attributed to an assistant message (sampled ==
+ body for that role by construction).
+- Carries new information on every other role: the model never samples
+ user / tool / system tokens, so ``sampled_mask`` is uniformly
+ ``False`` over those — ``is_content`` differentiates body from wrap.
+
+These tests parametrise across every renderer in
+``conftest.RENDERER_MODELS`` and verify the contract for hand-coded
+renderers. ``DefaultRenderer`` leaves ``is_content`` empty by design
+(the Jinja template is opaque, so the renderer cannot know the
+wrap/body split) and is exempt from the populated-length check.
+
+The body-bytes decode invariant uses ``in`` (substring) rather than
+strict equality because some renderers normalise whitespace, strip
+trailing newlines, etc. within the message body emit — but the
+caller-provided content must always be recoverable as a substring of
+the decoded body run.
+"""
+
+from __future__ import annotations
+
+
+def _is_populated(rendered) -> bool:
+ return len(rendered.is_content) == len(rendered.token_ids) and bool(
+ rendered.is_content
+ )
+
+
+def test_is_content_length_or_empty(model_name, renderer):
+ """``is_content`` is either empty (opt-out) or matches token_ids
+ length exactly. No partial fills."""
+ msgs = [
+ {"role": "user", "content": "Hi"},
+ {"role": "assistant", "content": "Hello!"},
+ ]
+ rendered = renderer.render(msgs)
+ n_tokens = len(rendered.token_ids)
+ n_mask = len(rendered.is_content)
+ assert n_mask == 0 or n_mask == n_tokens, (
+ f"{model_name}: is_content length {n_mask} must be 0 or match "
+ f"token_ids length {n_tokens}"
+ )
+
+
+def test_is_content_equals_sampled_on_assistant(model_name, renderer):
+ """On every token attributed to an assistant message,
+ ``is_content[k] == sampled_mask[k]``. The two signals collapse on
+ that role by design — the model's sampled output IS the assistant
+ message's body, and the surrounding scaffold is neither sampled
+ nor body."""
+ msgs = [
+ {"role": "system", "content": "You are helpful."},
+ {"role": "user", "content": "Hi"},
+ {"role": "assistant", "content": "Hello world!"},
+ ]
+ rendered = renderer.render(msgs)
+ if not _is_populated(rendered):
+ return
+ if len(rendered.sampled_mask) != len(rendered.token_ids):
+ return # renderer opts out of sampled_mask
+
+ mismatches = []
+ for k, msg_idx in enumerate(rendered.message_indices):
+ if msg_idx < 0:
+ continue
+ if msgs[msg_idx].get("role") != "assistant":
+ continue
+ if rendered.is_content[k] != rendered.sampled_mask[k]:
+ mismatches.append((k, rendered.is_content[k], rendered.sampled_mask[k]))
+ assert not mismatches, (
+ f"{model_name}: is_content != sampled_mask on assistant tokens "
+ f"(k, is_content, sampled): {mismatches[:8]}"
+ )
+
+
+def test_is_content_excludes_generation_prompt(model_name, renderer):
+ """All generation-prompt tokens (msg_idx=-1) are scaffold — the
+ next-turn opener the chat template injects so the model can
+ continue. ``is_content`` must be False over that entire span."""
+ msgs = [{"role": "user", "content": "Hi"}]
+ rendered = renderer.render(msgs, add_generation_prompt=True)
+ if not _is_populated(rendered):
+ return
+
+ bad = [
+ k
+ for k, (msg_idx, is_content) in enumerate(
+ zip(rendered.message_indices, rendered.is_content)
+ )
+ if msg_idx == -1 and is_content
+ ]
+ assert not bad, (
+ f"{model_name}: generation-prompt tokens marked is_content=True "
+ f"at positions {bad[:8]}"
+ )
+
+
+def test_is_content_recovers_user_body(model_name, tokenizer, renderer):
+ """The decoded run of is_content=True tokens within a user message
+ contains the user's original content. ``in`` rather than equality
+ because some templates normalise whitespace inside the body emit;
+ the input substring must still be recoverable from the decoded
+ body run."""
+ user_text = "Hello, my name is Sebastian."
+ msgs = [
+ {"role": "user", "content": user_text},
+ {"role": "assistant", "content": "Hi!"},
+ ]
+ rendered = renderer.render(msgs)
+ if not _is_populated(rendered):
+ return
+
+ user_body_ids = [
+ tid
+ for tid, mi, ic in zip(
+ rendered.token_ids, rendered.message_indices, rendered.is_content
+ )
+ if mi == 0 and ic
+ ]
+ assert user_body_ids, (
+ f"{model_name}: no is_content=True tokens attributed to user message"
+ )
+ decoded = tokenizer.decode(user_body_ids).strip()
+ assert user_text in decoded or decoded in user_text, (
+ f"{model_name}: user body run decodes to {decoded!r}, "
+ f"expected to contain {user_text!r}"
+ )
+
+
+def test_is_content_recovers_tool_body(model_name, tokenizer, renderer):
+ """The decoded run of is_content=True tokens within a tool message
+ contains the tool response body. The whole point of the body/wrap
+ cut: SFT on this run trains the model to anticipate tool outputs
+ without learning to emit the surrounding ``<|tool_response>`` /
+ role-tag scaffold (which would interrupt a real rollout)."""
+ tool_text = "The capital of France is Paris."
+ msgs = [
+ {"role": "user", "content": "What's the capital of France?"},
+ {
+ "role": "assistant",
+ "content": "",
+ "tool_calls": [
+ {
+ "type": "function",
+ "id": "call_1",
+ "function": {
+ "name": "lookup",
+ "arguments": {"q": "capital of France"},
+ },
+ }
+ ],
+ },
+ {"role": "tool", "content": tool_text, "tool_call_id": "call_1"},
+ {"role": "assistant", "content": "Paris."},
+ ]
+ rendered = renderer.render(msgs)
+ if not _is_populated(rendered):
+ return
+
+ tool_body_ids = [
+ tid
+ for tid, mi, ic in zip(
+ rendered.token_ids, rendered.message_indices, rendered.is_content
+ )
+ if mi == 2 and ic
+ ]
+ assert tool_body_ids, (
+ f"{model_name}: no is_content=True tokens attributed to tool message"
+ )
+ decoded = tokenizer.decode(tool_body_ids).strip()
+ assert tool_text in decoded, (
+ f"{model_name}: tool body run decodes to {decoded!r}, "
+ f"expected to contain {tool_text!r}"
+ )
+
+
+def test_is_content_recovers_system_body(model_name, tokenizer, renderer):
+ """The decoded run of is_content=True tokens within a system
+ message contains the caller-provided system content. Tools header
+ / footer (if present) are scaffold and never appear as body."""
+ sys_text = "You are an unusually precise assistant."
+ msgs = [
+ {"role": "system", "content": sys_text},
+ {"role": "user", "content": "Hi"},
+ {"role": "assistant", "content": "Hi!"},
+ ]
+ rendered = renderer.render(msgs)
+ if not _is_populated(rendered):
+ return
+
+ sys_body_ids = [
+ tid
+ for tid, mi, ic in zip(
+ rendered.token_ids, rendered.message_indices, rendered.is_content
+ )
+ if mi == 0 and ic
+ ]
+ assert sys_body_ids, (
+ f"{model_name}: no is_content=True tokens attributed to system message"
+ )
+ decoded = tokenizer.decode(sys_body_ids).strip()
+ assert sys_text in decoded, (
+ f"{model_name}: system body run decodes to {decoded!r}, "
+ f"expected to contain {sys_text!r}"
+ )
+
+
+def test_is_content_no_body_on_role_tag(model_name, renderer):
+ """The first token attributed to a user/system/tool message must
+ have ``is_content=False`` — that's the leading role-tag run
+ (``<|im_start|>`` / equivalent), which is template scaffold, never
+ body."""
+ msgs = [
+ {"role": "user", "content": "Hi"},
+ {"role": "assistant", "content": "Hello!"},
+ ]
+ rendered = renderer.render(msgs)
+ if not _is_populated(rendered):
+ return
+
+ user_positions = [k for k, idx in enumerate(rendered.message_indices) if idx == 0]
+ assert user_positions, f"{model_name}: no tokens attributed to user message"
+ first_k = user_positions[0]
+ assert not rendered.is_content[first_k], (
+ f"{model_name}: first user-attributed token at k={first_k} should "
+ f"be is_content=False (role-tag scaffolding), but was True"
+ )
+
+
+def test_content_token_spans_by_role_isolates_tool_body(
+ model_name, tokenizer, renderer
+):
+ """``content_token_spans_by_role()["tool"]`` returns spans over
+ which every token is the tool message body. Joining the decoded
+ spans recovers the tool response. Adjacent scaffold tokens
+ (``<|tool_response>``, role-tag openers) are never inside any
+ returned span."""
+ tool_text = "Result: 42"
+ msgs = [
+ {"role": "user", "content": "What's 6*7?"},
+ {
+ "role": "assistant",
+ "content": "",
+ "tool_calls": [
+ {
+ "type": "function",
+ "id": "call_x",
+ "function": {"name": "calc", "arguments": {"e": "6*7"}},
+ }
+ ],
+ },
+ {"role": "tool", "content": tool_text, "tool_call_id": "call_x"},
+ {"role": "assistant", "content": "42."},
+ ]
+ rendered = renderer.render(msgs)
+ if not _is_populated(rendered):
+ return
+
+ spans = rendered.content_token_spans_by_role()
+ tool_spans = spans.get("tool") or []
+ assert tool_spans, f"{model_name}: no tool content spans returned"
+
+ pieces: list[str] = []
+ for s, e in tool_spans:
+ run_ids = rendered.token_ids[s:e]
+ # All tokens in the span must be is_content=True by definition.
+ for k in range(s, e):
+ assert rendered.is_content[k], (
+ f"{model_name}: span {(s, e)} contains is_content=False at k={k}"
+ )
+ pieces.append(tokenizer.decode(run_ids))
+ joined = "".join(pieces).strip()
+ assert tool_text in joined, (
+ f"{model_name}: joined tool spans decode to {joined!r}, "
+ f"expected to contain {tool_text!r}"
+ )
+
+
+def test_content_mask_for_roles_excludes_assistant_when_unset(model_name, renderer):
+ """``content_mask_for_roles({"tool"})`` returns a mask that's True
+ only on tool body tokens — never on assistant tokens, even though
+ those also have ``is_content=True``. The role filter is the whole
+ point: SFT-on-tool-body must not bleed into the assistant span."""
+ msgs = [
+ {"role": "user", "content": "Hi"},
+ {
+ "role": "assistant",
+ "content": "",
+ "tool_calls": [
+ {
+ "type": "function",
+ "id": "call_a",
+ "function": {"name": "ping", "arguments": {}},
+ }
+ ],
+ },
+ {"role": "tool", "content": "pong", "tool_call_id": "call_a"},
+ {"role": "assistant", "content": "OK."},
+ ]
+ rendered = renderer.render(msgs)
+ if not _is_populated(rendered):
+ return
+
+ tool_mask = rendered.content_mask_for_roles({"tool"})
+ assert len(tool_mask) == len(rendered.token_ids)
+
+ for k, mi in enumerate(rendered.message_indices):
+ if mi < 0:
+ assert not tool_mask[k], (
+ f"{model_name}: scaffold token k={k} (msg_idx=-1) marked in tool mask"
+ )
+ continue
+ role = msgs[mi].get("role")
+ if role != "tool" and tool_mask[k]:
+ raise AssertionError(
+ f"{model_name}: tool-role mask True on a {role!r} token at k={k}"
+ )
+
+
+def test_build_training_sample_content_sft_roles_picks_up_tool_body(
+ model_name, renderer
+):
+ """``build_training_sample(..., content_sft_roles={"tool"})``
+ produces a loss mask that's True on tool body tokens AND assistant
+ sampled tokens, but False on the tool-message scaffold (role tag,
+ ``<|tool_response>`` wraps, separators). The canonical
+ SFT-on-tool-body + RL-on-assistant composition."""
+ from renderers import build_training_sample
+
+ msgs = [
+ {"role": "user", "content": "Hi"},
+ {
+ "role": "assistant",
+ "content": "",
+ "tool_calls": [
+ {
+ "type": "function",
+ "id": "call_z",
+ "function": {"name": "noop", "arguments": {}},
+ }
+ ],
+ },
+ {"role": "tool", "content": "done", "tool_call_id": "call_z"},
+ {"role": "assistant", "content": "OK."},
+ ]
+ ids, mask = build_training_sample(
+ renderer,
+ msgs,
+ role_to_mask=lambda m: m["role"] == "assistant",
+ content_sft_roles={"tool"},
+ )
+ assert len(mask) == len(ids)
+
+ # We need at least one trainable tool-body token if the renderer
+ # populates is_content. Renderers that opt out (empty is_content)
+ # fall back to the existing role_to_mask behaviour, which leaves
+ # tool tokens False — that's the documented fallback and is fine.
+ rendered = renderer.render(msgs)
+ if not _is_populated(rendered):
+ return
+
+ trainable_per_role: dict[str, int] = {}
+ for k, mi in enumerate(rendered.message_indices):
+ if mi < 0:
+ continue
+ if mask[k]:
+ r = msgs[mi].get("role") or ""
+ trainable_per_role[r] = trainable_per_role.get(r, 0) + 1
+
+ assert trainable_per_role.get("tool", 0) > 0, (
+ f"{model_name}: build_training_sample with content_sft_roles={{'tool'}} "
+ f"trained on zero tool tokens"
+ )
+ assert trainable_per_role.get("assistant", 0) > 0, (
+ f"{model_name}: assistant tokens dropped from training mask"
+ )
+ assert trainable_per_role.get("user", 0) == 0, (
+ f"{model_name}: user tokens leaked into training mask: {trainable_per_role}"
+ )