Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PyTorch to CoreML model conversion with coremltools v4.0b1 for a TransformerEncoder model/module does not complete. Missing op: bmm. #768

Closed
leovinus2001 opened this issue Jul 8, 2020 · 10 comments
Assignees
Labels
bug Unexpected behaviour that should be corrected (type) PyTorch (traced)

Comments

@leovinus2001
Copy link
Contributor

leovinus2001 commented Jul 8, 2020

Relevance:

Transformers train faster and deliver often better models.
The coremltools v4.0 converter from PyTorch to coreml format .mlmodel should handle it well for embedded work.

Background:

  • the TransformerEncoder was introduced in PyTorch 1.4 I believe
  • the log below shows a CoreML tools v4.0b1 conversion crash on
    RuntimeError: PyTorch convert function for op bmm not implemented
  • same coremltools crash with "pip installed version 4.0b1" as well as TOT git pull from July 7th
  • I am not entirely sure what bmm is but I expect it is "batch matrix multiply"
  • MAYBE related to multi head attention, see [DO NOT MERGE]Decompose MultiheadAttention module pytorch/pytorch#34793
  • In other words, IFF "bmm == batch matrix multiply" THEN maybe you can unfold the batch and do it once by one via normal matmul?
  • The LOG was produced with a patch for an extra print() in coremltools to identify the correct graph node

diff --git a/coremltools/converters/mil/frontend/torch/ops.py b/coremltools/converters/mil/frontend/torch/ops.py
index 1e19c7c..be2b13d 100644
--- a/coremltools/converters/mil/frontend/torch/ops.py
+++ b/coremltools/converters/mil/frontend/torch/ops.py
@@ -47,6 +47,7 @@ def convert_nodes(context, graph):
for node in _tqdm(graph.nodes, desc="Converting Frontend ==> MIL Ops", unit=" ops"):
_add_op = _TORCH_OPS_REGISTRY.get(node.kind, None)
_logging.info("Converting op {} : {}".format(node.name, node.kind))

  •    print ("Converting op {} : {}".format(node.name, node.kind))
       if _add_op is None:
           raise RuntimeError(
               "PyTorch convert function for op {} not implemented".format(node.kind)
    

@@ -955,7 +956,8 @@ def lstm(context, node):
)

Reproducible:

Yes

Testcase:

Attached
testTransformerEncoder.txt

Setup:

macOS Catalina latest
Python 3.7.6

Log:

Torch version : 1.5.1
CoreML tools version : 4.0b1
TestModel(
(transformer_encoder): TransformerEncoder(
(layers): ModuleList(
(0): TransformerEncoderLayer(
(self_attn): MultiheadAttention(
(out_proj): Linear(in_features=28, out_features=28, bias=True)
)
(linear1): Linear(in_features=28, out_features=16, bias=True)
(dropout): Dropout(p=0.1, inplace=False)
(linear2): Linear(in_features=16, out_features=28, bias=True)
(norm1): LayerNorm((28,), eps=1e-05, elementwise_affine=True)
(norm2): LayerNorm((28,), eps=1e-05, elementwise_affine=True)
(dropout1): Dropout(p=0.1, inplace=False)
(dropout2): Dropout(p=0.1, inplace=False)
)
(1): TransformerEncoderLayer(
(self_attn): MultiheadAttention(
(out_proj): Linear(in_features=28, out_features=28, bias=True)
)
(linear1): Linear(in_features=28, out_features=16, bias=True)
(dropout): Dropout(p=0.1, inplace=False)
(linear2): Linear(in_features=16, out_features=28, bias=True)
(norm1): LayerNorm((28,), eps=1e-05, elementwise_affine=True)
(norm2): LayerNorm((28,), eps=1e-05, elementwise_affine=True)
(dropout1): Dropout(p=0.1, inplace=False)
(dropout2): Dropout(p=0.1, inplace=False)
)
)
)
)
Library/Python/3.7/lib/python/site-packages/torch/jit/init.py:1037: TracerWarning: Output nr 1. of the traced function does not match the corresponding output of the Python function. Detailed error:
Not within tolerance rtol=1e-05 atol=1e-05 at input[29, 0, 14] (1.3225914239883423 vs. -0.8373285531997681) and 838 other locations (99.88%)
check_tolerance, _force_outplace, True, _module_class)
Converting Frontend ==> MIL Ops: 0%| | 0/188 [00:00<?, ? ops/s]Converting op 25 : constant
Converting op 26 : size
Converting op 27 : constant
Converting op 28 : size
Converting op 29 : constant
Converting op 30 : size
Converting op embed_dim.1 : numtotensor
Converting op 32 : constant
Converting op head_dim.1 : floor_divide
Converting op 34 : int
Converting op 35 : int
Converting op 36 : int
Converting op 37 : t
Converting op output.1 : matmul
Converting op 39 : constant
Converting op 40 : add
Converting op 41 : constantchunk
Converting op 44 : constant
Converting op q.2 : mul
Converting op 46 : constant
Converting op q.3 : contiguous
Converting op 48 : listconstruct
Converting op 49 : view
Converting op 50 : constant
Converting op 51 : constant
Converting op q.4 : transpose
Converting op 53 : constant
Converting op 54 : contiguous
Converting op 55 : constant
Converting op 56 : listconstruct
Converting op 57 : view
Converting op 58 : constant
Converting op 59 : constant
Converting op k.2 : transpose
Converting op 61 : constant
Converting op 62 : contiguous
Converting op 63 : constant
Converting op 64 : listconstruct
Converting op 65 : view
Converting op 66 : constant
Converting op 67 : constant
Converting op v.2 : transpose
Converting op 69 : constant
Converting op 70 : constant
Converting op 71 : transpose
Converting op attn_output_weights.1 : bmm
Converting Frontend ==> MIL Ops: 24%|██████████████████████████████████████████████████████████▉ | 45/188 [00:00<00:00, 679.96 ops/s]
Traceback (most recent call last):
File "testTransformerEncoder.py", line 40, in
inputs= [ ct.TensorType(name="input1", shape=dummy_input.shape) ]
File "Library/Python/3.7/lib/python/site-packages/coremltools/converters/_converters_entry.py", line 299, in convert
**kwargs
File "Library/Python/3.7/lib/python/site-packages/coremltools/converters/mil/converter.py", line 120, in _convert
prog = frontend_converter(model, **kwargs)
File "Library/Python/3.7/lib/python/site-packages/coremltools/converters/mil/converter.py", line 62, in call
return load(*args, **kwargs)
File "Library/Python/3.7/lib/python/site-packages/coremltools/converters/mil/frontend/torch/load.py", line 84, in load
raise e
File "Library/Python/3.7/lib/python/site-packages/coremltools/converters/mil/frontend/torch/load.py", line 76, in load
prog = converter.convert()
File "Library/Python/3.7/lib/python/site-packages/coremltools/converters/mil/frontend/torch/converter.py", line 302, in convert
convert_nodes(self.context, self.graph)
File "Library/Python/3.7/lib/python/site-packages/coremltools/converters/mil/frontend/torch/ops.py", line 53, in convert_nodes
"PyTorch convert function for op {} not implemented".format(node.kind)
RuntimeError: PyTorch convert function for op bmm not implemented

@leovinus2001 leovinus2001 added the bug Unexpected behaviour that should be corrected (type) label Jul 8, 2020
@DawerG DawerG self-assigned this Jul 8, 2020
@DawerG
Copy link
Collaborator

DawerG commented Jul 8, 2020

@leovinus2001 Yes, you are right, bmm stands for batched matrix multiplication. Core ML supports this layer. We will hook it up in the torch frontend as soon as possible.

Thanks for reporting the issue and such a detailed bug reporting. Appreciate your effort and help.

@leovinus2001
Copy link
Contributor Author

You are welcome. Appreciate the fixes!

@leovinus2001
Copy link
Contributor Author

Just politely wondering whether there is any ETA on this fix? Happy to verify it.

@DawerG
Copy link
Collaborator

DawerG commented Aug 12, 2020

@leovinus2001 Can you please add the following snippet in the conversion code to check if it fixes the issue? Thanks.

from coremltools.converters.mil import register_torch_op
from coremltools.converters.mil import Builder as mb
from coremltools.converters.mil.frontend.torch.ops import _get_inputs
import coremltools as ct

@register_torch_op
def bmm(context, node):
    inputs = _get_inputs(context, node)
    res = mb.matmul(x=inputs[0], y=inputs[1], name=node.name)
    context.add(res)

@leovinus2001
Copy link
Contributor Author

@DawerG Thanks, that works like a charm for me :) Conversion now runs successfully and the inference on float32 and int8 quantized models is good as well.

One question, for
inputs = _get_inputs(context, node)
wouldn't it be nicer to add a "expected=" for "expect number of args" ?
Unless it varies of course.

@DawerG
Copy link
Collaborator

DawerG commented Aug 12, 2020

@leovinus2001 Thanks for verifying :) Appreciate your help.

Yes, right. That would be better. We should change the snippet to
inputs = _get_inputs(context, node, expected=2)

We will ensure that this change is covered as part of the next beta release.

@leovinus2001
Copy link
Contributor Author

Nice, coremltools 4.0b3 with PyTorch 1.6.0 fixes the missing "bmm" op.
I see you have added an alias
@register_torch_op(torch_alias=["bmm"])
for matmul in ops.py. A nice one-liner :) Thanks and closing.

@leovinus2001
Copy link
Contributor Author

Closing

@NewCoderQ
Copy link

@leovinus2001 Hi, did you got the same speed of float32 and int8 model?

@den-run-ai
Copy link

same ops worked for mm if anyone is interested

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Unexpected behaviour that should be corrected (type) PyTorch (traced)
Projects
None yet
Development

No branches or pull requests

5 participants