Skip to content

Fix int32 torch.mm runtime by lowering to matmul#2673

Merged
TobyRoseman merged 2 commits intoapple:mainfrom
holly-agyei:fix-int32-torch-mm-runtime
Apr 28, 2026
Merged

Fix int32 torch.mm runtime by lowering to matmul#2673
TobyRoseman merged 2 commits intoapple:mainfrom
holly-agyei:fix-int32-torch-mm-runtime

Conversation

@holly-agyei
Copy link
Copy Markdown

Fixes #2575

Summary

  • lower torch.mm/torch.bmm through mb.matmul instead of the constant-weight mb.linear path
  • keep the int32 constant-weight runtime path off the buggy linear lowering
  • add a regression test that checks the converted graph uses matmul rather than linear for int32 torch.mm

Testing

  • /opt/miniconda3/envs/adamed/bin/python -m py_compile coremltools/converters/mil/frontend/torch/ops.py coremltools/converters/mil/frontend/torch/test/test_torch_ops.py
  • /opt/miniconda3/envs/adamed/bin/python -m pytest coremltools/converters/mil/frontend/torch/test/test_torch_ops.py -k "mm_with_int32_constant_weight" -q
    • local result: 1 passed, 1 skipped
    • the mlprogram case is skipped in this source checkout when BlobWriter is not available locally

Copilot AI review requested due to automatic review settings April 20, 2026 19:32
Copy link
Copy Markdown

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR fixes a Core ML runtime failure for int32 torch.mm by ensuring torch.mm/torch.bmm lower to MIL matmul instead of the constant-weight linear lowering path.

Changes:

  • Update Torch frontend matmul lowering to always emit mb.matmul (after dtype promotion) rather than conditionally using mb.linear for constant RHS.
  • Add a regression test verifying int32 constant-weight torch.mm converts to a graph containing matmul and not linear, and (when runnable) matches runtime output.

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 1 comment.

File Description
coremltools/converters/mil/frontend/torch/ops.py Changes lowering for mm/bmm/matmul to always use mb.matmul to avoid buggy linear lowering with int32 constant weights.
coremltools/converters/mil/frontend/torch/test/test_torch_ops.py Adds regression coverage to ensure converted graphs use matmul (not linear) for int32 torch.mm with a constant weight.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +1038 to +1044
@register_torch_op(torch_alias=["bmm", "mm"])
def matmul(context, node):
x, y = _get_inputs(context, node, expected=2)
res = _construct_matmul(x, y, node.name)
x, y = promote_input_dtypes([x, y])
# Keep mm/bmm on the matmul path even when the RHS is constant. Lowering
# constant int32 weights to linear produces incorrect/runtime behavior.
res = mb.matmul(x=x, y=y, name=node.name)
@TobyRoseman
Copy link
Copy Markdown
Collaborator

@TobyRoseman
Copy link
Copy Markdown
Collaborator

@holly-agyei - there are CI failures, please take a look:
https://gitlab.com/coremltools1/coremltools/-/pipelines/2472453197

@holly-agyei
Copy link
Copy Markdown
Author

holly-agyei commented Apr 23, 2026

@TobyRoseman
Thank you for the notice, it seems I left out the _construct_matmul compatibility wrapper, which broke the import tests. I just pushed a fix for the CI failures on this PR to restore it. Can you check it for me when you have a moment?

@TobyRoseman
Copy link
Copy Markdown
Collaborator

@holly-agyei
Copy link
Copy Markdown
Author

holly-agyei commented Apr 27, 2026

Hi @TobyRoseman
It seems the only failing job (test_py310_pytorch_executorch) died in scripts/env_create.sh before any tests ran, with [Errno 2] No such file or directory: '/Users/gitlab/miniforge3/pkgs/packaging-26.2-pyhc364b38_0' during conda create. Looks like a stale conda package cache on the runner. Could you retry it when you have a moment? Thanks!

Edit: I can see that it has been retried. Thanks.

@TobyRoseman
Copy link
Copy Markdown
Collaborator

I restarted those jobs and the CI passed. Thanks for the pull request @holly-agyei.

@TobyRoseman TobyRoseman merged commit e95804f into apple:main Apr 28, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

CoreML fails at runtime with int32 torch.mm

3 participants