Fix int32 torch.mm runtime by lowering to matmul#2673
Fix int32 torch.mm runtime by lowering to matmul#2673TobyRoseman merged 2 commits intoapple:mainfrom
Conversation
There was a problem hiding this comment.
Pull request overview
This PR fixes a Core ML runtime failure for int32 torch.mm by ensuring torch.mm/torch.bmm lower to MIL matmul instead of the constant-weight linear lowering path.
Changes:
- Update Torch frontend
matmullowering to always emitmb.matmul(after dtype promotion) rather than conditionally usingmb.linearfor constant RHS. - Add a regression test verifying
int32constant-weighttorch.mmconverts to a graph containingmatmuland notlinear, and (when runnable) matches runtime output.
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 1 comment.
| File | Description |
|---|---|
| coremltools/converters/mil/frontend/torch/ops.py | Changes lowering for mm/bmm/matmul to always use mb.matmul to avoid buggy linear lowering with int32 constant weights. |
| coremltools/converters/mil/frontend/torch/test/test_torch_ops.py | Adds regression coverage to ensure converted graphs use matmul (not linear) for int32 torch.mm with a constant weight. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| @register_torch_op(torch_alias=["bmm", "mm"]) | ||
| def matmul(context, node): | ||
| x, y = _get_inputs(context, node, expected=2) | ||
| res = _construct_matmul(x, y, node.name) | ||
| x, y = promote_input_dtypes([x, y]) | ||
| # Keep mm/bmm on the matmul path even when the RHS is constant. Lowering | ||
| # constant int32 weights to linear produces incorrect/runtime behavior. | ||
| res = mb.matmul(x=x, y=y, name=node.name) |
|
@holly-agyei - there are CI failures, please take a look: |
|
@TobyRoseman |
|
Hi @TobyRoseman Edit: I can see that it has been retried. Thanks. |
|
I restarted those jobs and the CI passed. Thanks for the pull request @holly-agyei. |
Fixes #2575
Summary
Testing