Skip to content

fix: patch model.forward() directly — survives TRL unwrapping#253

Merged
abrichr merged 1 commit intomainfrom
fix/patch-model-forward-directly
Mar 29, 2026
Merged

fix: patch model.forward() directly — survives TRL unwrapping#253
abrichr merged 1 commit intomainfrom
fix/patch-model-forward-directly

Conversation

@abrichr
Copy link
Copy Markdown
Member

@abrichr abrichr commented Mar 29, 2026

Summary

VLMModelWrapper was stripped by TRL/Accelerate's model unwrapping. Fix: patch model.forward() directly on the instance. Instance method patches survive unwrapping because they're on the model object, not a wrapper.

Before: VLMModelWrapper(model) → TRL unwraps → wrapper gone → model blind
After: patch_model_for_trl(model) → patches model.forward → TRL unwraps → patch stays → model sees images

🤖 Generated with Claude Code

TRL unwraps models via Accelerate, stripping wrapper classes. The fix:
patch forward() on the model instance itself. This survives unwrapping.

- patch_model_for_trl(model) → returns cache_fn
- cache_fn(inputs) caches pixel_values from processor output
- Patched forward() injects cached pixel_values when TRL omits them
- Patched __call__ also injects (covers all call paths)
- trl_wrapper passes original model to TRL (not a wrapper)
- cache_vision_fn passed through to rollout_func

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@abrichr abrichr merged commit 0f381b1 into main Mar 29, 2026
1 check passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant