Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 12 additions & 5 deletions openadapt_evals/training/standalone/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,11 @@ def _collect_rollout(self, task_id: str, instruction: str) -> Rollout:
logger.info("Stuck at step %d", step_idx)
break

image = Image.open(io.BytesIO(screenshot)).convert("RGB")
image = Image.open(io.BytesIO(screenshot))
if image.mode != "RGB":
image = image.convert("RGB")
# .convert() drops .format; restore it for outlines.Image
image.format = "PNG"
messages = build_agent_messages(instruction, include_image=True)
if hasattr(self._processor, "apply_chat_template"):
text_input = self._processor.apply_chat_template(
Expand All @@ -215,10 +219,13 @@ def _collect_rollout(self, task_id: str, instruction: str) -> Rollout:
else None
)
if outlines_gen is not None:
# Outlines v1.2 Generator API: handles tokenization,
# generation, and decoding internally. For multimodal
# models, pass a dict with "text" + image keys.
model_input = {"text": text_input, "images": [image]}
# Outlines v1.2 Generator API for multimodal models.
# TransformersMultiModal.format_input dispatches on type:
# list → [prompt_text, Image(pil), ...]
# Chat → Chat([Message(...)])
# A dict is NOT accepted (raises TypeError).
import outlines
model_input = [text_input, outlines.Image(image)]
decoded = outlines_gen(
model_input,
max_new_tokens=self._config.max_new_tokens,
Expand Down
43 changes: 43 additions & 0 deletions tests/test_standalone_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,49 @@ def test_outlines_generator_api_contract(self) -> None:
for p in params_call
), f"SteerableGenerator.__call__ doesn't accept **kwargs: {sig_call}"

def test_outlines_multimodal_input_format(self) -> None:
"""Verify outlines TransformersMultiModal accepts list input, not dict.

This is THE test that catches the input format bug. The trainer
must pass [prompt, outlines.Image(pil)] not {"text": ..., "images": ...}.

TransformersMultiModalTypeAdapter.format_input is a singledispatch
that only accepts `list` and `Chat` types. A `dict` raises TypeError.
"""
try:
import outlines
from outlines.models.transformers import TransformersMultiModalTypeAdapter
except ImportError:
pytest.skip("outlines not installed")

# Verify list is a registered dispatch type by checking the
# class-level dispatcher registry (singledispatchmethod stores
# it on the descriptor, not the bound method).
fmt = TransformersMultiModalTypeAdapter.__dict__["format_input"]
registry = fmt.dispatcher.registry
registered_types = set(registry.keys())
assert list in registered_types, (
f"list not registered in format_input dispatch: {registered_types}. "
f"The trainer passes [prompt, Image(pil)] — this type must be accepted."
)
assert dict not in registered_types, (
"dict is registered in format_input — if this changes, the trainer's "
"input format can be simplified back to a dict."
)

# Verify outlines.Image exists and wraps PIL images
assert hasattr(outlines, "Image"), "outlines.Image not found"
from PIL import Image as PILImage
import io
test_img = PILImage.new("RGB", (10, 10))
# outlines.Image requires .format to be set (loaded from file)
buf = io.BytesIO()
test_img.save(buf, format="PNG")
buf.seek(0)
test_img_with_format = PILImage.open(buf)
wrapped = outlines.Image(test_img_with_format)
assert wrapped is not None

def test_false_sentinel_not_confused_with_none(self) -> None:
"""Regression: False sentinel must return None, not be treated as uninitialized."""
config = TrainingConfig(constrained_decoding=True)
Expand Down
Loading