Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Inductor] [CPU] accuracy failure in torchbench model detectron2_fcos_r_50_fpn #93426

Closed
chuanqi129 opened this issue Nov 17, 2022 · 3 comments
Closed
Labels
bug oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@chuanqi129
Copy link
Collaborator

chuanqi129 commented Nov 17, 2022

馃悰 Describe the bug

This failure was found in latest TorchInductor CPU Performance Dashboard refresh test in #93531

SW information

SW Nightly commit Master/Main commit
Pytorch 637228b 46796fe
Torchbench / 022dfe3
torchaudio 4b10b6a 74f9a89
torchtext 71e4561 c047efe
torchvision 797e1ac ffd5a56

detail info reference the Dashboard

Error logs

RuntimeError: Storage size calculation overflowed with sizes=[2, 3, 140639896720224, 140639896720384]

Minified repro

python benchmarks/dynamo/torchbench.py --accuracy --float32 -dcpu -n50 --inductor --no-skip --dashboard --only=detectron2_fcos_r_50_fpn --batch_size 2

cc @ezyang @soumith @msaroufim @wconstab @ngimel @bdhirsh

@chuanqi129 chuanqi129 added the bug label Nov 17, 2022
@eellison eellison changed the title [Inductor] accuracy failure in torchbench model detectron2_fcos_r_50_fpn [Inductor] [CPU] accuracy failure in torchbench model detectron2_fcos_r_50_fpn Nov 17, 2022
@chuanqi129
Copy link
Collaborator Author

Update for latest TorchInductor CPU Performance Dashboard refresh test with below error log

SW info:

SW Nightly commit Master/Main commit
Pytorch 0662e90 e2f0648
Torchbench / 022dfe3
torchaudio 4b10b6a 74f9a89
torchtext 71e4561 c047efe
torchvision 797e1ac ffd5a56
ERROR:common:Failed for dynamo 

from user code:
   File "/opt/conda/lib/python3.8/site-packages/detectron2/modeling/meta_arch/dense_detector.py", line 126, in preprocess_image
    images = ImageList.from_tensors(images, self.backbone.size_divisibility)
  File "/opt/conda/lib/python3.8/site-packages/detectron2/structures/image_list.py", line 102, in from_tensors
    batched_imgs = F.pad(tensors[0], padding_size, value=pad_value).unsqueeze_(0)

Set torch._dynamo.config.verbose=True for more information


You can suppress this exception and fall back to eager by setting:
    torch._dynamo.config.suppress_errors = True
Traceback (most recent call last):
  File "/workspace/pytorch/torch/_dynamo/utils.py", line 1076, in run_node
    return node.target(*args, **kwargs)
  File "/workspace/pytorch/torch/_subclasses/fake_tensor.py", line 879, in __torch_dispatch__
    op_impl_out = op_impl(self, func, *args, **kwargs)
  File "/workspace/pytorch/torch/_subclasses/fake_tensor.py", line 366, in data_dep
    raise DataDependentOutputException(func)
torch._subclasses.fake_tensor.DataDependentOutputException: aten._local_scalar_dense.default

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/workspace/pytorch/torch/_dynamo/utils.py", line 1042, in get_fake_value
    return wrap_fake_exception(
  File "/workspace/pytorch/torch/_dynamo/utils.py", line 721, in wrap_fake_exception
    return fn()
  File "/workspace/pytorch/torch/_dynamo/utils.py", line 1043, in <lambda>
    lambda: run_node(tx.output, node, args, kwargs, nnmodule)
  File "/workspace/pytorch/torch/_dynamo/utils.py", line 1085, in run_node
    raise RuntimeError(
RuntimeError: Failed running call_function <built-in function pad>(*(FakeTensor(FakeTensor(..., device='meta', size=(3, 800, 1199)), cpu), [0, FakeTensor(FakeTensor(..., device='meta', size=(), dtype=torch.int64), cpu), 0, FakeTensor(FakeTensor(..., device='meta', size=(), dtype=torch.int64), cpu)]), **{'value': 0.0}):
aten._local_scalar_dense.default
(scroll up for backtrace)

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/workspace/pytorch/benchmarks/dynamo/common.py", line 1204, in warmup
    fn(model, example_inputs)
  File "/workspace/pytorch/torch/_dynamo/eval_frame.py", line 169, in _fn
    return fn(*args, **kwargs)
  File "/workspace/pytorch/torch/_dynamo/eval_frame.py", line 247, in catch_errors
    return callback(frame, cache_size)
  File "/workspace/pytorch/torch/_dynamo/convert_frame.py", line 476, in _convert_frame
    result = inner_convert(frame, cache_size)
  File "/workspace/pytorch/torch/_dynamo/convert_frame.py", line 118, in _fn
    return fn(*args, **kwargs)
  File "/workspace/pytorch/torch/_dynamo/utils.py", line 89, in time_wrapper
    r = func(*args, **kwargs)
  File "/workspace/pytorch/torch/_dynamo/convert_frame.py", line 349, in _convert_frame_assert
    return _compile(
  File "/workspace/pytorch/torch/_dynamo/convert_frame.py", line 404, in _compile
    out_code = transform_code_object(code, transform)
  File "/workspace/pytorch/torch/_dynamo/bytecode_transformation.py", line 341, in transform_code_object
    transformations(instructions, code_options)
  File "/workspace/pytorch/torch/_dynamo/convert_frame.py", line 392, in transform
    tracer.run()
  File "/workspace/pytorch/torch/_dynamo/symbolic_convert.py", line 1612, in run
    super().run()
  File "/workspace/pytorch/torch/_dynamo/symbolic_convert.py", line 478, in run
    and self.step()
  File "/workspace/pytorch/torch/_dynamo/symbolic_convert.py", line 448, in step
    getattr(self, inst.opname)(inst)
  File "/workspace/pytorch/torch/_dynamo/symbolic_convert.py", line 282, in wrapper
    return inner_fn(self, inst)
  File "/workspace/pytorch/torch/_dynamo/symbolic_convert.py", line 942, in CALL_FUNCTION_EX
    self.call_function(fn, argsvars.items, kwargsvars.items)
  File "/workspace/pytorch/torch/_dynamo/symbolic_convert.py", line 390, in call_function
    self.push(fn.call_function(self, args, kwargs))
  File "/workspace/pytorch/torch/_dynamo/variables/nn_module.py", line 222, in call_function
    return tx.inline_user_function_return(
  File "/workspace/pytorch/torch/_dynamo/symbolic_convert.py", line 419, in inline_user_function_return
    result = InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
  File "/workspace/pytorch/torch/_dynamo/symbolic_convert.py", line 1684, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
  File "/workspace/pytorch/torch/_dynamo/symbolic_convert.py", line 1738, in inline_call_
    tracer.run()
  File "/workspace/pytorch/torch/_dynamo/symbolic_convert.py", line 478, in run
    and self.step()
  File "/workspace/pytorch/torch/_dynamo/symbolic_convert.py", line 448, in step
    getattr(self, inst.opname)(inst)
  File "/workspace/pytorch/torch/_dynamo/symbolic_convert.py", line 282, in wrapper
    return inner_fn(self, inst)
  File "/workspace/pytorch/torch/_dynamo/symbolic_convert.py", line 905, in CALL_FUNCTION
    self.call_function(fn, args, {})
  File "/workspace/pytorch/torch/_dynamo/symbolic_convert.py", line 390, in call_function
    self.push(fn.call_function(self, args, kwargs))
  File "/workspace/pytorch/torch/_dynamo/variables/functions.py", line 224, in call_function
    return super().call_function(tx, args, kwargs)
  File "/workspace/pytorch/torch/_dynamo/variables/functions.py", line 194, in call_function
    return super(UserFunctionVariable, self).call_function(tx, args, kwargs)
  File "/workspace/pytorch/torch/_dynamo/variables/functions.py", line 65, in call_function
    return tx.inline_user_function_return(
  File "/workspace/pytorch/torch/_dynamo/symbolic_convert.py", line 419, in inline_user_function_return
    result = InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
  File "/workspace/pytorch/torch/_dynamo/symbolic_convert.py", line 1684, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
  File "/workspace/pytorch/torch/_dynamo/symbolic_convert.py", line 1738, in inline_call_
    tracer.run()
  File "/workspace/pytorch/torch/_dynamo/symbolic_convert.py", line 478, in run
    and self.step()
  File "/workspace/pytorch/torch/_dynamo/symbolic_convert.py", line 448, in step
    getattr(self, inst.opname)(inst)
  File "/workspace/pytorch/torch/_dynamo/symbolic_convert.py", line 282, in wrapper
    return inner_fn(self, inst)
  File "/workspace/pytorch/torch/_dynamo/symbolic_convert.py", line 905, in CALL_FUNCTION
    self.call_function(fn, args, {})
  File "/workspace/pytorch/torch/_dynamo/symbolic_convert.py", line 390, in call_function
    self.push(fn.call_function(self, args, kwargs))
  File "/workspace/pytorch/torch/_dynamo/variables/functions.py", line 194, in call_function
    return super(UserFunctionVariable, self).call_function(tx, args, kwargs)
  File "/workspace/pytorch/torch/_dynamo/variables/functions.py", line 65, in call_function
    return tx.inline_user_function_return(
  File "/workspace/pytorch/torch/_dynamo/symbolic_convert.py", line 419, in inline_user_function_return
    result = InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
  File "/workspace/pytorch/torch/_dynamo/symbolic_convert.py", line 1684, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
  File "/workspace/pytorch/torch/_dynamo/symbolic_convert.py", line 1738, in inline_call_
    tracer.run()
  File "/workspace/pytorch/torch/_dynamo/symbolic_convert.py", line 478, in run
    and self.step()
  File "/workspace/pytorch/torch/_dynamo/symbolic_convert.py", line 448, in step
    getattr(self, inst.opname)(inst)
  File "/workspace/pytorch/torch/_dynamo/symbolic_convert.py", line 282, in wrapper
    return inner_fn(self, inst)
  File "/workspace/pytorch/torch/_dynamo/symbolic_convert.py", line 954, in CALL_FUNCTION_KW
    self.call_function(fn, args, kwargs)
  File "/workspace/pytorch/torch/_dynamo/symbolic_convert.py", line 390, in call_function
    self.push(fn.call_function(self, args, kwargs))
  File "/workspace/pytorch/torch/_dynamo/variables/torch.py", line 406, in call_function
    tensor_variable = wrap_fx_proxy(
  File "/workspace/pytorch/torch/_dynamo/variables/builder.py", line 636, in wrap_fx_proxy
    return wrap_fx_proxy_cls(
  File "/workspace/pytorch/torch/_dynamo/variables/builder.py", line 676, in wrap_fx_proxy_cls
    example_value = get_fake_value(proxy.node, tx)
  File "/workspace/pytorch/torch/_dynamo/utils.py", line 1055, in get_fake_value
    raise TorchRuntimeError() from e
torch._dynamo.exc.TorchRuntimeError: 

from user code:
   File "/opt/conda/lib/python3.8/site-packages/detectron2/modeling/meta_arch/dense_detector.py", line 126, in preprocess_image
    images = ImageList.from_tensors(images, self.backbone.size_divisibility)
  File "/opt/conda/lib/python3.8/site-packages/detectron2/structures/image_list.py", line 102, in from_tensors
    batched_imgs = F.pad(tensors[0], padding_size, value=pad_value).unsqueeze_(0)

Set torch._dynamo.config.verbose=True for more information


You can suppress this exception and fall back to eager by setting:
    torch._dynamo.config.suppress_errors = True
    ```

@malfet malfet transferred this issue from pytorch/torchdynamo Feb 1, 2023
@albanD albanD added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Feb 7, 2023
blzheng added a commit that referenced this issue Feb 9, 2023
Fix #93426 (comment)




cc mlazos soumith voznesenskym yanboliang penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire

[ghstack-poisoned]
blzheng added a commit that referenced this issue Feb 10, 2023
Fix #93426 (comment)




cc mlazos soumith voznesenskym yanboliang penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire

[ghstack-poisoned]
blzheng added a commit that referenced this issue Feb 13, 2023
Fix #93426 (comment)




cc mlazos soumith voznesenskym yanboliang penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire

[ghstack-poisoned]
blzheng added a commit that referenced this issue Feb 13, 2023
Fix #93426 (comment)




cc mlazos soumith voznesenskym yanboliang penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire

[ghstack-poisoned]
blzheng added a commit that referenced this issue Feb 13, 2023
Fix #93426 (comment)




cc mlazos soumith voznesenskym yanboliang penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire

[ghstack-poisoned]
@ydwu4
Copy link
Contributor

ydwu4 commented Nov 29, 2023

Hi @chuanqi129 , Is this issue still valid? Do we want to keep it open?

@chuanqi129
Copy link
Collaborator Author

Thanks @ydwu4 for the reminder, this issue has been fixed. Close it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
Status: Done
Development

No branches or pull requests

5 participants