Fix compute_hidden_states_hf.py: handle BatchEncoding from apply_chat_template#1225
Conversation
…emplate
In transformers 4.46+, apply_chat_template() with return_tensors="pt"
returns a BatchEncoding object that no longer subclasses dict. The
previous isinstance(tokenized, dict) guard evaluated to False and fell
through to tokenized (the BatchEncoding), causing input_ids.shape[1] to
call BatchEncoding.__getattr__("shape") and raise AttributeError.
Fix by checking isinstance(tokenized, torch.Tensor) instead, which
correctly handles both old transformers (plain tensor return) and new
transformers (BatchEncoding return).
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Signed-off-by: Ye Yu <yeyu@nvidia.com>
|
No actionable comments were generated in the recent review. 🎉 ℹ️ Recent review info⚙️ Run configurationConfiguration used: Path: .coderabbit.yaml Review profile: CHILL Plan: Pro Run ID: 📒 Files selected for processing (1)
✅ Files skipped from review due to trivial changes (1)
📝 WalkthroughWalkthroughTokenization in a speculative decoding example was simplified: Changes
Estimated code review effort🎯 2 (Simple) | ⏱️ ~10 minutes 🚥 Pre-merge checks | ✅ 3 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches📝 Generate docstrings
🧪 Generate unit tests (beta)
Comment |
| # apply_chat_template return type varies by transformers version: | ||
| # - older versions return a plain torch.Tensor (input_ids directly) | ||
| # - newer versions (4.46+) return a BatchEncoding which no longer | ||
| # subclasses dict, so isinstance(tokenized, dict) is False |
There was a problem hiding this comment.
Is this comment correct? We use 4.56+ so the older version should not need to be supported
There was a problem hiding this comment.
Good point, simplified to just ["input_ids"] directly since transformers 4.46+ is always assumed.
Since transformers 4.46+ is required, apply_chat_template always returns a BatchEncoding. Drop the torch.Tensor fallback and just index ["input_ids"] directly. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> Signed-off-by: Ye Yu <yeyu@nvidia.com>
|
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## main #1225 +/- ##
==========================================
+ Coverage 75.68% 77.39% +1.70%
==========================================
Files 353 353
Lines 40491 40491
==========================================
+ Hits 30644 31336 +692
+ Misses 9847 9155 -692
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
Without return_dict=True, apply_chat_template returns a raw torch.Tensor on transformers <5.0 (default return_dict=False) and a BatchEncoding on transformers >=5.0 (default changed to True). Subscripting a Tensor with ["input_ids"] raises TypeError on <5.0. Passing return_dict=True explicitly forces BatchEncoding on all versions (verified locally on 4.57.1 and 5.0.0). Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> Signed-off-by: Ye Yu <yeyu@nvidia.com>
|
Tested locally across versions:
The |
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> Signed-off-by: Ye Yu <yeyu@nvidia.com>
…_template (#1225) ## Summary - `apply_chat_template(..., return_tensors="pt")` returns a `BatchEncoding` in transformers 4.46+, which no longer subclasses `dict` - The old guard `isinstance(tokenized, dict)` evaluates to `False` for `BatchEncoding`, so `input_ids` was set to the whole `BatchEncoding` object - Calling `.shape[1]` on a `BatchEncoding` triggers `__getattr__("shape")` → `AttributeError` - Fix: check `isinstance(tokenized, torch.Tensor)` instead, which correctly handles both old transformers (plain tensor) and new transformers (BatchEncoding) This is causing `test_collect_hidden_states` to fail in the speculative decoding CI for all open PRs (#1207, #1210, #1221). ## Test plan - [ ] `torch-pr (speculative_decoding, 26.01)` CI passes - [ ] Verify fix handles both `torch.Tensor` return (old transformers) and `BatchEncoding` return (new transformers 4.46+) 🤖 Generated with [Claude Code](https://claude.com/claude-code) Signed-off-by: Ye Yu <yeyu@nvidia.com> Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
Summary
apply_chat_template(..., return_tensors="pt")returns aBatchEncodingin transformers 4.46+, which no longer subclassesdictisinstance(tokenized, dict)evaluates toFalseforBatchEncoding, soinput_idswas set to the wholeBatchEncodingobject.shape[1]on aBatchEncodingtriggers__getattr__("shape")→AttributeErrorisinstance(tokenized, torch.Tensor)instead, which correctly handles both old transformers (plain tensor) and new transformers (BatchEncoding)This is causing
test_collect_hidden_statesto fail in the speculative decoding CI for all open PRs (#1207, #1210, #1221).Test plan
torch-pr (speculative_decoding, 26.01)CI passestorch.Tensorreturn (old transformers) andBatchEncodingreturn (new transformers 4.46+)🤖 Generated with Claude Code
Summary by CodeRabbit