Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Inductor] [CPU] Crash failure in torchbench model mobilenet_v2_quantized_qat & resnet50_quantized_qat #93430

Closed
chuanqi129 opened this issue Nov 18, 2022 · 20 comments
Assignees
Labels
bug oncall: cpu inductor CPU Inductor issues for Intel team to triage oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@chuanqi129
Copy link
Collaborator

chuanqi129 commented Nov 18, 2022

馃悰 Describe the bug

This failure found in the latest TorchInductor CPU Performance Dashboard refresh test with below error log

SW info:

SW Nightly commit Master/Main commit
Pytorch 0662e90 e2f0648
Torchbench / 022dfe3
torchaudio 4b10b6a 74f9a89
torchtext 71e4561 c047efe
torchvision 797e1ac ffd5a56

detail info reference the Dashboard

Error logs

ERROR:common:Failed for dynamo 

from user code:
   File "benchmarks/dynamo/torchbench.py", line 361, in forward_pass
    return mod(*inputs)
  File "<eval_with_key>.8", line 7, in forward
    quantize_per_tensor = torch.quantize_per_tensor(x, features_0_0_input_scale_0, features_0_0_input_zero_point_0, torch.quint8);  x = features_0_0_input_scale_0 = features_0_0_input_zero_point_0 = None

Set torch._dynamo.config.verbose=True for more information


You can suppress this exception and fall back to eager by setting:
    torch._dynamo.config.suppress_errors = True
Traceback (most recent call last):
  File "/workspace/pytorch/torch/_subclasses/fake_tensor.py", line 887, in __torch_dispatch__
    r = func(*args, **kwargs)
  File "/workspace/pytorch/torch/_ops.py", line 285, in __call__
    return self._op(*args, **kwargs or {})
  File "/workspace/pytorch/torch/_ops.py", line 367, in _get_dispatch
    final_key = resolve_key(self, key)
  File "/workspace/pytorch/torch/_ops.py", line 107, in resolve_key
    raise NotImplementedError(f"could not find kernel for {op} at dispatch key {k}")
NotImplementedError: could not find kernel for aten.quantize_per_tensor.tensor_qparams at dispatch key DispatchKey.Meta
see more
During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/workspace/pytorch/torch/_dynamo/utils.py", line 1076, in run_node
    return node.target(*args, **kwargs)
  File "/workspace/pytorch/torch/_subclasses/fake_tensor.py", line 892, in __torch_dispatch__
    return run_fallback_kernel(self, func, args, kwargs, not_implemented_error)
  File "/workspace/pytorch/torch/_subclasses/fake_tensor.py", line 1068, in run_fallback_kernel
    return tree_map(map_out, r)
  File "/workspace/pytorch/torch/utils/_pytree.py", line 195, in tree_map
    return tree_unflatten([fn(i) for i in flat_args], spec)
  File "/workspace/pytorch/torch/utils/_pytree.py", line 195, in <listcomp>
    return tree_unflatten([fn(i) for i in flat_args], spec)
  File "/workspace/pytorch/torch/_subclasses/fake_tensor.py", line 1064, in map_out
    return fake_mode.fake_tensor_converter(fake_mode, e)
  File "/workspace/pytorch/torch/_subclasses/fake_tensor.py", line 262, in __call__
    return self.from_real_tensor(fake_mode, t, make_constant, shape_env=shape_env)
  File "/workspace/pytorch/torch/_subclasses/fake_tensor.py", line 214, in from_real_tensor
    raise UnsupportedFakeTensorException("quantized nyi in meta tensors")
torch._subclasses.fake_tensor.UnsupportedFakeTensorException: quantized nyi in meta tensors

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/workspace/pytorch/torch/_dynamo/utils.py", line 1042, in get_fake_value
    return wrap_fake_exception(
  File "/workspace/pytorch/torch/_dynamo/utils.py", line 721, in wrap_fake_exception
    return fn()
  File "/workspace/pytorch/torch/_dynamo/utils.py", line 1043, in <lambda>
    lambda: run_node(tx.output, node, args, kwargs, nnmodule)
  File "/workspace/pytorch/torch/_dynamo/utils.py", line 1085, in run_node
    raise RuntimeError(
RuntimeError: Failed running call_function <built-in method quantize_per_tensor of type object at 0x7fab96d3fd20>(*(FakeTensor(FakeTensor(..., device='meta', size=(96, 3, 224, 224)), cpu), FakeTensor(FakeTensor(..., device='meta', size=()), cpu), FakeTensor(FakeTensor(..., device='meta', size=(), dtype=torch.int64), cpu), torch.quint8), **{}):
quantized nyi in meta tensors
(scroll up for backtrace)

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/workspace/pytorch/benchmarks/dynamo/common.py", line 1204, in warmup
    fn(model, example_inputs)
  File "/workspace/pytorch/torch/_dynamo/eval_frame.py", line 169, in _fn
    return fn(*args, **kwargs)
  File "/workspace/pytorch/torch/_dynamo/eval_frame.py", line 247, in catch_errors
    return callback(frame, cache_size)
  File "/workspace/pytorch/torch/_dynamo/convert_frame.py", line 476, in _convert_frame
    result = inner_convert(frame, cache_size)
  File "/workspace/pytorch/torch/_dynamo/convert_frame.py", line 118, in _fn
    return fn(*args, **kwargs)
  File "/workspace/pytorch/torch/_dynamo/utils.py", line 89, in time_wrapper
    r = func(*args, **kwargs)
  File "/workspace/pytorch/torch/_dynamo/convert_frame.py", line 349, in _convert_frame_assert
    return _compile(
  File "/workspace/pytorch/torch/_dynamo/convert_frame.py", line 404, in _compile
    out_code = transform_code_object(code, transform)
  File "/workspace/pytorch/torch/_dynamo/bytecode_transformation.py", line 341, in transform_code_object
    transformations(instructions, code_options)
  File "/workspace/pytorch/torch/_dynamo/convert_frame.py", line 392, in transform
    tracer.run()
  File "/workspace/pytorch/torch/_dynamo/symbolic_convert.py", line 1612, in run
    super().run()
  File "/workspace/pytorch/torch/_dynamo/symbolic_convert.py", line 478, in run
    and self.step()
  File "/workspace/pytorch/torch/_dynamo/symbolic_convert.py", line 448, in step
    getattr(self, inst.opname)(inst)
  File "/workspace/pytorch/torch/_dynamo/symbolic_convert.py", line 282, in wrapper
    return inner_fn(self, inst)
  File "/workspace/pytorch/torch/_dynamo/symbolic_convert.py", line 942, in CALL_FUNCTION_EX
    self.call_function(fn, argsvars.items, kwargsvars.items)
  File "/workspace/pytorch/torch/_dynamo/symbolic_convert.py", line 390, in call_function
    self.push(fn.call_function(self, args, kwargs))
  File "/workspace/pytorch/torch/_dynamo/variables/nn_module.py", line 222, in call_function
    return tx.inline_user_function_return(
  File "/workspace/pytorch/torch/_dynamo/symbolic_convert.py", line 419, in inline_user_function_return
    result = InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
  File "/workspace/pytorch/torch/_dynamo/symbolic_convert.py", line 1684, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
  File "/workspace/pytorch/torch/_dynamo/symbolic_convert.py", line 1738, in inline_call_
    tracer.run()
  File "/workspace/pytorch/torch/_dynamo/symbolic_convert.py", line 478, in run
    and self.step()
  File "/workspace/pytorch/torch/_dynamo/symbolic_convert.py", line 448, in step
    getattr(self, inst.opname)(inst)
  File "/workspace/pytorch/torch/_dynamo/symbolic_convert.py", line 282, in wrapper
    return inner_fn(self, inst)
  File "/workspace/pytorch/torch/_dynamo/symbolic_convert.py", line 905, in CALL_FUNCTION
    self.call_function(fn, args, {})
  File "/workspace/pytorch/torch/_dynamo/symbolic_convert.py", line 390, in call_function
    self.push(fn.call_function(self, args, kwargs))
  File "/workspace/pytorch/torch/_dynamo/variables/torch.py", line 406, in call_function
    tensor_variable = wrap_fx_proxy(
  File "/workspace/pytorch/torch/_dynamo/variables/builder.py", line 636, in wrap_fx_proxy
    return wrap_fx_proxy_cls(
  File "/workspace/pytorch/torch/_dynamo/variables/builder.py", line 676, in wrap_fx_proxy_cls
    example_value = get_fake_value(proxy.node, tx)
  File "/workspace/pytorch/torch/_dynamo/utils.py", line 1055, in get_fake_value
    raise TorchRuntimeError() from e
torch._dynamo.exc.TorchRuntimeError: 

from user code:
   File "benchmarks/dynamo/torchbench.py", line 361, in forward_pass
    return mod(*inputs)
  File "<eval_with_key>.8", line 7, in forward
    quantize_per_tensor = torch.quantize_per_tensor(x, features_0_0_input_scale_0, features_0_0_input_zero_point_0, torch.quint8);  x = features_0_0_input_scale_0 = features_0_0_input_zero_point_0 = None

Set torch._dynamo.config.verbose=True for more information


You can suppress this exception and fall back to eager by setting:
    torch._dynamo.config.suppress_errors = True

Minified repro

python benchmarks/dynamo/torchbench.py --performance --float32 -dcpu -n50 --inductor --no-skip --dashboard --only=mobilenet_v2_quantized_qat

cc @ezyang @msaroufim @bdhirsh @anijain2305 @zou3519 @chauhang @soumith @wconstab @ngimel

@chuanqi129 chuanqi129 added the bug label Nov 18, 2022
@eellison
Copy link
Contributor

I'm not sure we support quantized models currently

@malfet malfet transferred this issue from pytorch/torchdynamo Feb 1, 2023
@ezyang
Copy link
Contributor

ezyang commented Feb 1, 2023

this is still failing, though differently

convolution forward propagation primitive
Traceback (most recent call last):
  File "/data/users/ezyang/a/pytorch/benchmarks/dynamo/common.py", line 1350, in warmup
    fn(model, example_inputs)
  File "/data/users/ezyang/a/pytorch/benchmarks/dynamo/torchbench.py", line 361, in forward_pass
    return mod(*inputs)
  File "/data/users/ezyang/a/pytorch/torch/fx/graph_module.py", line 660, in call_wrapped
    return self._wrapped_call(self, *args, **kwargs)
  File "/data/users/ezyang/a/pytorch/torch/fx/graph_module.py", line 279, in __call__
    raise e
  File "/data/users/ezyang/a/pytorch/torch/fx/graph_module.py", line 269, in __call__
    return super(self.cls, obj).__call__(*args, **kwargs)  # type: ignore[misc]
  File "/data/users/ezyang/a/pytorch/torch/nn/modules/module.py", line 1488, in _call_impl
    return forward_call(*args, **kwargs)
  File "<eval_with_key>.8", line 15, in forward
    features_2_conv_1_0 = getattr(getattr(getattr(self.features, "2").conv, "1"), "0")(features_2_conv_0_2);  features_2_conv_0_2 = None
  File "/data/users/ezyang/a/pytorch/torch/nn/modules/module.py", line 1488, in _call_impl
    return forward_call(*args, **kwargs)
  File "/data/users/ezyang/a/pytorch/torch/ao/nn/quantized/modules/conv.py", line 465, in forward
    return ops.quantized.conv2d(
  File "/data/users/ezyang/a/pytorch/torch/_ops.py", line 499, in __call__
    return self._op(*args, **kwargs or {})
RuntimeError: could not create a descriptor for a dilated convolution forward propagation primitive

@jgong5
Copy link
Collaborator

jgong5 commented Feb 2, 2023

this is still failing, though differently

This change of the error message is related to the change of default qengine to "x86". Originally, the code goes to FBGEMM but now goes to onednn which throws the error message. We need to check why onednn throws the error here.

@Xia-Weiwen
Copy link
Collaborator

Hi @ezyang @jgong5 The error RuntimeError: could not create a descriptor for a dilated convolution forward propagation primitive should have been fixed by eea752f. Could you please check?

@Xia-Weiwen
Copy link
Collaborator

@ezyang For the original error NotImplementedError: could not find kernel for aten.quantize_per_tensor.tensor_qparams at dispatch key DispatchKey.Meta, could you please continue investigating? Thanks!

@Xia-Weiwen
Copy link
Collaborator

I think the problem is that inductor does not support quantization right now. Add aten.quantize_per_tensor.tensor_qparams for Meta, and the original error is gone but new errors about quantization will occur.
Actually, we (Intel PyTorch team) have been working on quantization on CPU with Inductor recently. It is still in early PoC phase. We will update later when quantization is available in Inductor.

@jgong5
Copy link
Collaborator

jgong5 commented Feb 3, 2023

Hi @ezyang @jgong5 The error RuntimeError: could not create a descriptor for a dilated convolution forward propagation primitive should have been fixed by eea752f. Could you please check?

Thanks. I confirmed that the new issue with RuntimeError: could not create a descriptor for a dilated convolution forward propagation primitive has been fixed on the latest master.

@jgong5
Copy link
Collaborator

jgong5 commented Feb 3, 2023

I think the problem is that inductor does not support quantization right now. Add aten.quantize_per_tensor.tensor_qparams for Meta, and the original error is gone but new errors about quantization will occur. Actually, we (Intel PyTorch team) have been working on quantization on CPU with Inductor recently. It is still in early PoC phase. We will update later when quantization is available in Inductor.

May I know what error you encountered? I think inductor shouldn't crash with the quantization model even though it lacks the optimization right now.

@Xia-Weiwen
Copy link
Collaborator

May I know what error you encountered? I think inductor shouldn't crash with the quantization model even though it lacks the optimization right now.

If I register aten.quantize_per_tensor.tensor_qparams on Meta that gives output in quint8, the error is NotImplementedError: Could not run 'aten::empty_strided' with arguments from the 'QuantizedMeta' backend.
If aten.quantize_per_tensor.tensor_qparams gives fp32 output, the error is NotImplementedError: could not find kernel for quantized.conv2d.new at dispatch key DispatchKey.Meta

@albanD albanD added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Feb 7, 2023
@Xia-Weiwen
Copy link
Collaborator

The problem is that quantization is not yet implemented in this case. There are probably two ways to fix this problem.

  1. To enable quantization.
    This is actually a new feature. We (Intel PyTorch team) are working on it.
  2. To use the fallback path.
    However, the fallback path is not yet implemented for quantized tensors. There are guards in FakeTensorConverter and MetaConverter that raise unimplemented exception if input tensor is quantized. This is due to this issue: No factory functions for strided quantized tensors聽#74540

So, I think we probably have to go the first way. It will still take a while before those quantization PRs being landed.

@ydwu4
Copy link
Contributor

ydwu4 commented Nov 29, 2023

Hi @Xia-Weiwen , is this issue still valid? Do we want to keep it open?

@Xia-Weiwen
Copy link
Collaborator

Hi @chuanqi129. Are you still seeing the same issue with latest PyTorch? If not, we may close this issue. Thanks.

@chuanqi129
Copy link
Collaborator Author

Hi @Xia-Weiwen , according to latest test results, the accuracy test and perf test still crashed with below msg.
Perf test:

ERROR:common:Backend dynamo failed in warmup()
Traceback (most recent call last):
  File "/workspace/pytorch/benchmarks/dynamo/common.py", line 2383, in warmup
    fn(model, example_inputs)
  File "/workspace/pytorch/torch/_dynamo/eval_frame.py", line 489, in _fn
    return fn(*args, **kwargs)
  File "/workspace/pytorch/torch/_dynamo/eval_frame.py", line 655, in catch_errors
    return callback(frame, cache_entry, hooks, frame_state)
  File "/workspace/pytorch/torch/_dynamo/convert_frame.py", line 721, in _convert_frame
    result = inner_convert(frame, cache_entry, hooks, frame_state)
  File "/workspace/pytorch/torch/_dynamo/convert_frame.py", line 383, in _convert_frame_assert
    compiled_product = _compile(
  File "/workspace/pytorch/torch/_dynamo/convert_frame.py", line 645, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
  File "/workspace/pytorch/torch/_dynamo/utils.py", line 244, in time_wrapper
    r = func(*args, **kwargs)
  File "/workspace/pytorch/torch/_dynamo/convert_frame.py", line 562, in compile_inner
    out_code = transform_code_object(code, transform)
  File "/workspace/pytorch/torch/_dynamo/bytecode_transformation.py", line 1033, in transform_code_object
    transformations(instructions, code_options)
  File "/workspace/pytorch/torch/_dynamo/convert_frame.py", line 151, in _fn
    return fn(*args, **kwargs)
  File "/workspace/pytorch/torch/_dynamo/convert_frame.py", line 527, in transform
    tracer.run()
  File "/workspace/pytorch/torch/_dynamo/symbolic_convert.py", line 2123, in run
    super().run()
  File "/workspace/pytorch/torch/_dynamo/symbolic_convert.py", line 818, in run
    and self.step()
  File "/workspace/pytorch/torch/_dynamo/symbolic_convert.py", line 781, in step
    getattr(self, inst.opname)(inst)
  File "/workspace/pytorch/torch/_dynamo/symbolic_convert.py", line 470, in wrapper
    return inner_fn(self, inst)
  File "/workspace/pytorch/torch/_dynamo/symbolic_convert.py", line 1252, in CALL_FUNCTION_EX
    self.call_function(fn, argsvars.items, kwargsvars.items)
  File "/workspace/pytorch/torch/_dynamo/symbolic_convert.py", line 652, in call_function
    self.push(fn.call_function(self, args, kwargs))
  File "/workspace/pytorch/torch/_dynamo/variables/lazy.py", line 90, in realize_and_forward
    return getattr(self.realize(), name)(*args, **kwargs)
  File "/workspace/pytorch/torch/_dynamo/variables/nn_module.py", line 328, in call_function
    return tx.inline_user_function_return(
  File "/workspace/pytorch/torch/_dynamo/symbolic_convert.py", line 688, in inline_user_function_return
    return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
  File "/workspace/pytorch/torch/_dynamo/symbolic_convert.py", line 2256, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
  File "/workspace/pytorch/torch/_dynamo/symbolic_convert.py", line 2371, in inline_call_
    tracer.run()
  File "/workspace/pytorch/torch/_dynamo/symbolic_convert.py", line 818, in run
    and self.step()
  File "/workspace/pytorch/torch/_dynamo/symbolic_convert.py", line 781, in step
    getattr(self, inst.opname)(inst)
  File "/workspace/pytorch/torch/_dynamo/symbolic_convert.py", line 470, in wrapper
    return inner_fn(self, inst)
  File "/workspace/pytorch/torch/_dynamo/symbolic_convert.py", line 1213, in CALL_FUNCTION
    self.call_function(fn, args, {})
  File "/workspace/pytorch/torch/_dynamo/symbolic_convert.py", line 652, in call_function
    self.push(fn.call_function(self, args, kwargs))
  File "/workspace/pytorch/torch/_dynamo/variables/torch.py", line 599, in call_function
    tensor_variable = wrap_fx_proxy(
  File "/workspace/pytorch/torch/_dynamo/variables/builder.py", line 1283, in wrap_fx_proxy
    return wrap_fx_proxy_cls(target_cls=TensorVariable, **kwargs)
  File "/workspace/pytorch/torch/_dynamo/variables/builder.py", line 1368, in wrap_fx_proxy_cls
    example_value = get_fake_value(proxy.node, tx, allow_non_graph_fake=True)
  File "/workspace/pytorch/torch/_dynamo/utils.py", line 1524, in get_fake_value
    raise TorchRuntimeError(str(e)).with_traceback(e.__traceback__) from None
  File "/workspace/pytorch/torch/_dynamo/utils.py", line 1485, in get_fake_value
    ret_val = wrap_fake_exception(
  File "/workspace/pytorch/torch/_dynamo/utils.py", line 1026, in wrap_fake_exception
    return fn()
  File "/workspace/pytorch/torch/_dynamo/utils.py", line 1486, in <lambda>
    lambda: run_node(tx.output, node, args, kwargs, nnmodule)
  File "/workspace/pytorch/torch/_dynamo/utils.py", line 1591, in run_node
    raise RuntimeError(fn_str + str(e)).with_traceback(e.__traceback__) from e
  File "/workspace/pytorch/torch/_dynamo/utils.py", line 1570, in run_node
    return node.target(*args, **kwargs)
  File "/workspace/pytorch/torch/utils/_stats.py", line 20, in wrapper
    return fn(*args, **kwargs)
  File "/workspace/pytorch/torch/_subclasses/fake_tensor.py", line 1392, in __torch_dispatch__
    return self.dispatch(func, types, args, kwargs)
  File "/workspace/pytorch/torch/_subclasses/fake_tensor.py", line 1705, in dispatch
    return maybe_run_unsafe_fallback(not_implemented_error)
  File "/workspace/pytorch/torch/_subclasses/fake_tensor.py", line 1689, in maybe_run_unsafe_fallback
    return run_fallback_kernel(self, func, flat_args, args_spec, error)
  File "/workspace/pytorch/torch/_subclasses/fake_tensor.py", line 1953, in run_fallback_kernel
    return pytree.tree_map(map_out, r)
  File "/workspace/pytorch/torch/utils/_pytree.py", line 437, in tree_map
    return tree_unflatten([func(i) for i in flat_args], spec)
  File "/workspace/pytorch/torch/utils/_pytree.py", line 437, in <listcomp>
    return tree_unflatten([func(i) for i in flat_args], spec)
  File "/workspace/pytorch/torch/_subclasses/fake_tensor.py", line 1949, in map_out
    return fake_mode.fake_tensor_converter(fake_mode, e)
  File "/workspace/pytorch/torch/_subclasses/fake_tensor.py", line 396, in __call__
    return self.from_real_tensor(
  File "/workspace/pytorch/torch/_subclasses/fake_tensor.py", line 333, in from_real_tensor
    raise UnsupportedFakeTensorException("quantized nyi in meta tensors")
torch._dynamo.exc.TorchRuntimeError: Failed running call_function <built-in method quantize_per_tensor of type object at 0x7f7c2b09cf00>(*(FakeTensor(..., size=(96, 3, 224, 224)), FakeTensor(..., size=()), FakeTensor(..., size=(), dtype=torch.int64), torch.quint8), **{}):
quantized nyi in meta tensors

Acc test:

WARNING:root:mobilenet_v2_quantized_qat failed to load
Original Error: quantized_resize_cpu_ does not have a deterministic implementation, but you set 'torch.use_deterministic_algorithms(True)'. You can turn off determinism just for this operation, or you can use the 'warn_only=True' option, if that's acceptable for your application. You can also file an issue at https://github.com/pytorch/pytorch/issues to help us prioritize adding deterministic support for this operation.
Eager model failed to run
Traceback (most recent call last):
  File "/workspace/pytorch/benchmarks/dynamo/common.py", line 1926, in validate_model
    self.model_iter_fn(model, example_inputs)
  File "benchmarks/dynamo/torchbench.py", line 529, in forward_pass
    return mod(*inputs)
  File "/workspace/pytorch/torch/fx/graph_module.py", line 736, in call_wrapped
    return self._wrapped_call(self, *args, **kwargs)
  File "/workspace/pytorch/torch/fx/graph_module.py", line 315, in __call__
    raise e
  File "/workspace/pytorch/torch/fx/graph_module.py", line 302, in __call__
    return super(self.cls, obj).__call__(*args, **kwargs)  # type: ignore[misc]
  File "/workspace/pytorch/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/workspace/pytorch/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "<eval_with_key>.6", line 128, in forward
    classifier_1 = getattr(self.classifier, "1")(classifier_0);  classifier_0 = None
  File "/workspace/pytorch/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/workspace/pytorch/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/workspace/pytorch/torch/ao/nn/quantized/modules/linear.py", line 168, in forward
    return torch.ops.quantized.linear(
  File "/workspace/pytorch/torch/_ops.py", line 753, in __call__
    return self._op(*args, **kwargs or {})
RuntimeError: quantized_resize_cpu_ does not have a deterministic implementation, but you set 'torch.use_deterministic_algorithms(True)'. You can turn off determinism just for this operation, or you can use the 'warn_only=True' option, if that's acceptable for your application. You can also file an issue at https://github.com/pytorch/pytorch/issues to help us prioritize adding deterministic support for this operation.

@Xia-Weiwen
Copy link
Collaborator

Hi @leslie-fang-intel Since you've been working on it, could you please help check this issue? Is it a bug or a missing feature? Thanks!

@leslie-fang-intel
Copy link
Collaborator

leslie-fang-intel commented Dec 4, 2023

I think we just need to enable mobilenet_v2 & resnet50 with PT2 QAT flow in torchbench since it's already enabled.

@chuanqi129
Copy link
Collaborator Author

Hi @leslie-fang-intel , the model mobilenet_v2_quantized_qat & resnet50_quantized_qat are different models with mobilenet_v2 & resnet50, do you mean those are legacy models for PT2? https://github.com/pytorch/benchmark/blob/main/torchbenchmark/models/mobilenet_v2_quantized_qat/__init__.py

@leslie-fang-intel
Copy link
Collaborator

oh, thanks for the correction, @chuanqi129, then it's a different flow actually.

@WeizhuoZhang-intel
Copy link
Contributor

In release/2.2 branch commit (44d1157), those 2 models' accuracy still report error on AMP Static shape + Default wrapper.

Traceback (most recent call last):
  File "/workspace/pytorch/benchmarks/dynamo/common.py", line 2395, in warmup
    fn(model, example_inputs)
  File "/workspace/pytorch/torch/_dynamo/eval_frame.py", line 489, in _fn
    return fn(*args, **kwargs)
  File "/workspace/pytorch/torch/_dynamo/eval_frame.py", line 655, in catch_errors
    return callback(frame, cache_entry, hooks, frame_state)
  File "/workspace/pytorch/torch/_dynamo/convert_frame.py", line 727, in _convert_frame
    result = inner_convert(frame, cache_entry, hooks, frame_state)
  File "/workspace/pytorch/torch/_dynamo/convert_frame.py", line 383, in _convert_frame_assert
    compiled_product = _compile(
  File "/workspace/pytorch/torch/_dynamo/convert_frame.py", line 646, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
  File "/workspace/pytorch/torch/_dynamo/utils.py", line 244, in time_wrapper
    r = func(*args, **kwargs)
  File "/workspace/pytorch/torch/_dynamo/convert_frame.py", line 562, in compile_inner
    out_code = transform_code_object(code, transform)
  File "/workspace/pytorch/torch/_dynamo/bytecode_transformation.py", line 1033, in transform_code_object
    transformations(instructions, code_options)
  File "/workspace/pytorch/torch/_dynamo/convert_frame.py", line 151, in _fn
    return fn(*args, **kwargs)
  File "/workspace/pytorch/torch/_dynamo/convert_frame.py", line 527, in transform
    tracer.run()
  File "/workspace/pytorch/torch/_dynamo/symbolic_convert.py", line 2128, in run
    super().run()
  File "/workspace/pytorch/torch/_dynamo/symbolic_convert.py", line 818, in run
    and self.step()
  File "/workspace/pytorch/torch/_dynamo/symbolic_convert.py", line 781, in step
    getattr(self, inst.opname)(inst)
  File "/workspace/pytorch/torch/_dynamo/symbolic_convert.py", line 470, in wrapper
    return inner_fn(self, inst)
  File "/workspace/pytorch/torch/_dynamo/symbolic_convert.py", line 1252, in CALL_FUNCTION_EX
    self.call_function(fn, argsvars.items, kwargsvars.items)
  File "/workspace/pytorch/torch/_dynamo/symbolic_convert.py", line 652, in call_function
    self.push(fn.call_function(self, args, kwargs))
  File "/workspace/pytorch/torch/_dynamo/variables/lazy.py", line 90, in realize_and_forward
    return getattr(self.realize(), name)(*args, **kwargs)
  File "/workspace/pytorch/torch/_dynamo/variables/nn_module.py", line 328, in call_function
    return tx.inline_user_function_return(
  File "/workspace/pytorch/torch/_dynamo/symbolic_convert.py", line 688, in inline_user_function_return
    return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
  File "/workspace/pytorch/torch/_dynamo/symbolic_convert.py", line 2261, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
  File "/workspace/pytorch/torch/_dynamo/symbolic_convert.py", line 2376, in inline_call_
    tracer.run()
  File "/workspace/pytorch/torch/_dynamo/symbolic_convert.py", line 818, in run
    and self.step()
  File "/workspace/pytorch/torch/_dynamo/symbolic_convert.py", line 781, in step
    getattr(self, inst.opname)(inst)
  File "/workspace/pytorch/torch/_dynamo/symbolic_convert.py", line 470, in wrapper
    return inner_fn(self, inst)
  File "/workspace/pytorch/torch/_dynamo/symbolic_convert.py", line 1213, in CALL_FUNCTION
    self.call_function(fn, args, {})
  File "/workspace/pytorch/torch/_dynamo/symbolic_convert.py", line 652, in call_function
    self.push(fn.call_function(self, args, kwargs))
  File "/workspace/pytorch/torch/_dynamo/variables/torch.py", line 542, in call_function
    tensor_variable = wrap_fx_proxy(
  File "/workspace/pytorch/torch/_dynamo/variables/builder.py", line 1314, in wrap_fx_proxy
    return wrap_fx_proxy_cls(target_cls=TensorVariable, **kwargs)
  File "/workspace/pytorch/torch/_dynamo/variables/builder.py", line 1399, in wrap_fx_proxy_cls
    example_value = get_fake_value(proxy.node, tx, allow_non_graph_fake=True)
  File "/workspace/pytorch/torch/_dynamo/utils.py", line 1525, in get_fake_value
    raise TorchRuntimeError(str(e)).with_traceback(e.__traceback__) from None
  File "/workspace/pytorch/torch/_dynamo/utils.py", line 1486, in get_fake_value
    ret_val = wrap_fake_exception(
  File "/workspace/pytorch/torch/_dynamo/utils.py", line 1027, in wrap_fake_exception
    return fn()
  File "/workspace/pytorch/torch/_dynamo/utils.py", line 1487, in <lambda>
    lambda: run_node(tx.output, node, args, kwargs, nnmodule)
  File "/workspace/pytorch/torch/_dynamo/utils.py", line 1592, in run_node
    raise RuntimeError(fn_str + str(e)).with_traceback(e.__traceback__) from e
  File "/workspace/pytorch/torch/_dynamo/utils.py", line 1571, in run_node
    return node.target(*args, **kwargs)
  File "/workspace/pytorch/torch/utils/_stats.py", line 20, in wrapper
    return fn(*args, **kwargs)
  File "/workspace/pytorch/torch/_subclasses/fake_tensor.py", line 1392, in __torch_dispatch__
    return self.dispatch(func, types, args, kwargs)
  File "/workspace/pytorch/torch/_subclasses/fake_tensor.py", line 1714, in dispatch
    return maybe_run_unsafe_fallback(not_implemented_error)
  File "/workspace/pytorch/torch/_subclasses/fake_tensor.py", line 1698, in maybe_run_unsafe_fallback
    return run_fallback_kernel(self, func, flat_args, args_spec, error)
  File "/workspace/pytorch/torch/_subclasses/fake_tensor.py", line 1962, in run_fallback_kernel
    return pytree.tree_map(map_out, r)
  File "/workspace/pytorch/torch/utils/_pytree.py", line 602, in tree_map
    return tree_unflatten([func(i) for i in flat_args], spec)
  File "/workspace/pytorch/torch/utils/_pytree.py", line 602, in <listcomp>
    return tree_unflatten([func(i) for i in flat_args], spec)
  File "/workspace/pytorch/torch/_subclasses/fake_tensor.py", line 1958, in map_out
    return fake_mode.fake_tensor_converter(fake_mode, e)
  File "/workspace/pytorch/torch/_subclasses/fake_tensor.py", line 396, in __call__
    return self.from_real_tensor(
  File "/workspace/pytorch/torch/_subclasses/fake_tensor.py", line 333, in from_real_tensor
    raise UnsupportedFakeTensorException("quantized nyi in meta tensors")
torch._dynamo.exc.TorchRuntimeError: Failed running call_function <built-in method quantize_per_tensor of type object at 0x7f9d4d9f7340>(*(FakeTensor(..., size=(96, 3, 224, 224)), FakeTensor(..., size=()), FakeTensor(..., size=(), dtype=torch.int64), torch.quint8), **{}):
quantized nyi in meta tensors

from user code:
   File "benchmarks/dynamo/torchbench.py", line 532, in forward_pass
    return mod(*inputs)
  File "<eval_with_key>.8", line 7, in forward
    quantize_per_tensor = torch.quantize_per_tensor(x, features_0_0_input_scale_0, features_0_0_input_zero_point_0, torch.quint8);  x = features_0_0_input_scale_0 = features_0_0_input_zero_point_0 = None

@penguinwu
Copy link
Member

Is this still valid? If not, please close.

@penguinwu penguinwu added the oncall: cpu inductor CPU Inductor issues for Intel team to triage label Mar 22, 2024
@chuanqi129
Copy link
Collaborator Author

Double checked in the latest TorchInductor CPU Performance Dashboard, this issue have been fixed

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug oncall: cpu inductor CPU Inductor issues for Intel team to triage oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
Status: Done
Development

No branches or pull requests