Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

compile inductor falling operator ‘at::vec::CPU_CAPABILITY::VectorizedN<long int, 2>’ and ‘int’ #126619

Closed
johnnv1 opened this issue May 18, 2024 · 3 comments
Labels
oncall: cpu inductor CPU Inductor issues for Intel team to triage oncall: pt2

Comments

@johnnv1
Copy link

johnnv1 commented May 18, 2024

🐛 Describe the bug

When running torch 2.3.0 against kornia test suite we got some new errors around dynamo (CI in kornia/kornia#2912). For the kornia.filters.InRange operator it's failing to run on CPU.

It's working fine in the previous torch version, and now it's working on CUDA, but not on CPU. Eager mode is ok, and it's failing under the inductor backend.

Code snippet:

import logging

import torch

from kornia.filters import InRange

torch._logging.set_logs(dynamo=logging.DEBUG)
torch._dynamo.config.verbose = True


device = torch.device("cpu")
dtype = torch.float32
batch_size = 1

inpt = torch.rand(batch_size, 3, 5, 5, device=device, dtype=dtype)
op = InRange(lower=(0.2, 0.2, 0.2), upper=(0.6, 0.6, 0.6), return_mask=True)

op_optimized = torch.compile(op, backend="inductor")

op_optimized(inpt)

Error logs

V0518 11:38:06.062000 137524814369856 torch/_dynamo/output_graph.py:1154] [0/0] [__graph_sizes] TRACED GRAPH TENSOR SIZES
V0518 11:38:06.062000 137524814369856 torch/_dynamo/output_graph.py:1154] [0/0] [__graph_sizes] ===== __compiled_fn_0 =====
V0518 11:38:06.062000 137524814369856 torch/_dynamo/output_graph.py:1154] [0/0] [__graph_sizes] input_1: (1, 3, 5, 5)
V0518 11:38:06.062000 137524814369856 torch/_dynamo/output_graph.py:1154] [0/0] [__graph_sizes] tensor: (3,)
V0518 11:38:06.062000 137524814369856 torch/_dynamo/output_graph.py:1154] [0/0] [__graph_sizes] reshape: (1, 3, 1, 1)
V0518 11:38:06.062000 137524814369856 torch/_dynamo/output_graph.py:1154] [0/0] [__graph_sizes] lower: (1, 3, 1, 1)
V0518 11:38:06.062000 137524814369856 torch/_dynamo/output_graph.py:1154] [0/0] [__graph_sizes] tensor_1: (3,)
V0518 11:38:06.062000 137524814369856 torch/_dynamo/output_graph.py:1154] [0/0] [__graph_sizes] reshape_1: (1, 3, 1, 1)
V0518 11:38:06.062000 137524814369856 torch/_dynamo/output_graph.py:1154] [0/0] [__graph_sizes] upper: (1, 3, 1, 1)
V0518 11:38:06.062000 137524814369856 torch/_dynamo/output_graph.py:1154] [0/0] [__graph_sizes] ge: (1, 3, 5, 5)
V0518 11:38:06.062000 137524814369856 torch/_dynamo/output_graph.py:1154] [0/0] [__graph_sizes] le: (1, 3, 5, 5)
V0518 11:38:06.062000 137524814369856 torch/_dynamo/output_graph.py:1154] [0/0] [__graph_sizes] mask: (1, 3, 5, 5)
V0518 11:38:06.062000 137524814369856 torch/_dynamo/output_graph.py:1154] [0/0] [__graph_sizes] all_1: (1, 1, 5, 5)
V0518 11:38:06.062000 137524814369856 torch/_dynamo/output_graph.py:1154] [0/0] [__graph_sizes] output: (1, 1, 5, 5)
V0518 11:38:06.062000 137524814369856 torch/_dynamo/output_graph.py:1154] [0/0] [__graph_sizes] 
I0518 11:38:06.063000 137524814369856 torch/_dynamo/logging.py:55] [0/0] Step 2: calling compiler function inductor
V0518 11:38:06.404000 137524814369856 torch/fx/experimental/symbolic_shapes.py:4119] [0/0] eval True == True [statically known]
Traceback (most recent call last):
  File "/tmp/kornia/t.py", line 20, in <module>
    op_optimized(inpt)
  File "/tmp/kornia/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/kornia/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/kornia/venv/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 451, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/tmp/kornia/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/kornia/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/kornia/venv/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 921, in catch_errors
    return callback(frame, cache_entry, hooks, frame_state, skip=1)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/kornia/venv/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 786, in _convert_frame
    result = inner_convert(
             ^^^^^^^^^^^^^^
  File "/tmp/kornia/venv/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 400, in _convert_frame_assert
    return _compile(
           ^^^^^^^^^
  File "/tmp/kornia/venv/lib/python3.11/contextlib.py", line 81, in inner
    return func(*args, **kwds)
           ^^^^^^^^^^^^^^^^^^^
  File "/tmp/kornia/venv/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 676, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/kornia/venv/lib/python3.11/site-packages/torch/_dynamo/utils.py", line 262, in time_wrapper
    r = func(*args, **kwargs)
        ^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/kornia/venv/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 535, in compile_inner
    out_code = transform_code_object(code, transform)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/kornia/venv/lib/python3.11/site-packages/torch/_dynamo/bytecode_transformation.py", line 1036, in transform_code_object
    transformations(instructions, code_options)
  File "/tmp/kornia/venv/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 165, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/tmp/kornia/venv/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 500, in transform
    tracer.run()
  File "/tmp/kornia/venv/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2149, in run
    super().run()
  File "/tmp/kornia/venv/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 810, in run
    and self.step()
        ^^^^^^^^^^^
  File "/tmp/kornia/venv/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 773, in step
    getattr(self, inst.opname)(inst)
  File "/tmp/kornia/venv/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2268, in RETURN_VALUE
    self.output.compile_subgraph(
  File "/tmp/kornia/venv/lib/python3.11/site-packages/torch/_dynamo/output_graph.py", line 971, in compile_subgraph
    self.compile_and_call_fx_graph(tx, list(reversed(stack_values)), root)
  File "/tmp/kornia/venv/lib/python3.11/contextlib.py", line 81, in inner
    return func(*args, **kwds)
           ^^^^^^^^^^^^^^^^^^^
  File "/tmp/kornia/venv/lib/python3.11/site-packages/torch/_dynamo/output_graph.py", line 1168, in compile_and_call_fx_graph
    compiled_fn = self.call_user_compiler(gm)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/kornia/venv/lib/python3.11/site-packages/torch/_dynamo/utils.py", line 262, in time_wrapper
    r = func(*args, **kwargs)
        ^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/kornia/venv/lib/python3.11/site-packages/torch/_dynamo/output_graph.py", line 1241, in call_user_compiler
    raise BackendCompilerFailed(self.compiler_fn, e).with_traceback(
  File "/tmp/kornia/venv/lib/python3.11/site-packages/torch/_dynamo/output_graph.py", line 1222, in call_user_compiler
    compiled_fn = compiler_fn(gm, self.example_inputs())
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/kornia/venv/lib/python3.11/site-packages/torch/_dynamo/repro/after_dynamo.py", line 117, in debug_wrapper
    compiled_gm = compiler_fn(gm, example_inputs)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/kornia/venv/lib/python3.11/site-packages/torch/__init__.py", line 1729, in __call__
    return compile_fx(model_, inputs_, config_patches=self.config)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/kornia/venv/lib/python3.11/contextlib.py", line 81, in inner
    return func(*args, **kwds)
           ^^^^^^^^^^^^^^^^^^^
  File "/tmp/kornia/venv/lib/python3.11/site-packages/torch/_inductor/compile_fx.py", line 1330, in compile_fx
    return aot_autograd(
           ^^^^^^^^^^^^^
  File "/tmp/kornia/venv/lib/python3.11/site-packages/torch/_dynamo/backends/common.py", line 58, in compiler_fn
    cg = aot_module_simplified(gm, example_inputs, **kwargs)
         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/kornia/venv/lib/python3.11/site-packages/torch/_functorch/aot_autograd.py", line 903, in aot_module_simplified
    compiled_fn = create_aot_dispatcher_function(
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/kornia/venv/lib/python3.11/site-packages/torch/_dynamo/utils.py", line 262, in time_wrapper
    r = func(*args, **kwargs)
        ^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/kornia/venv/lib/python3.11/site-packages/torch/_functorch/aot_autograd.py", line 628, in create_aot_dispatcher_function
    compiled_fn = compiler_fn(flat_fn, fake_flat_args, aot_config, fw_metadata=fw_metadata)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/kornia/venv/lib/python3.11/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 443, in aot_wrapper_dedupe
    return compiler_fn(flat_fn, leaf_flat_args, aot_config, fw_metadata=fw_metadata)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/kornia/venv/lib/python3.11/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 648, in aot_wrapper_synthetic_base
    return compiler_fn(flat_fn, flat_args, aot_config, fw_metadata=fw_metadata)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/kornia/venv/lib/python3.11/site-packages/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py", line 119, in aot_dispatch_base
    compiled_fw = compiler(fw_module, updated_flat_args)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/kornia/venv/lib/python3.11/site-packages/torch/_dynamo/utils.py", line 262, in time_wrapper
    r = func(*args, **kwargs)
        ^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/kornia/venv/lib/python3.11/site-packages/torch/_inductor/compile_fx.py", line 1257, in fw_compiler_base
    return inner_compile(
           ^^^^^^^^^^^^^^
  File "/tmp/kornia/venv/lib/python3.11/site-packages/torch/_dynamo/repro/after_aot.py", line 83, in debug_wrapper
    inner_compiled_fn = compiler_fn(gm, example_inputs)
                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/kornia/venv/lib/python3.11/site-packages/torch/_inductor/debug.py", line 304, in inner
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/tmp/kornia/venv/lib/python3.11/contextlib.py", line 81, in inner
    return func(*args, **kwds)
           ^^^^^^^^^^^^^^^^^^^
  File "/tmp/kornia/venv/lib/python3.11/contextlib.py", line 81, in inner
    return func(*args, **kwds)
           ^^^^^^^^^^^^^^^^^^^
  File "/tmp/kornia/venv/lib/python3.11/site-packages/torch/_dynamo/utils.py", line 262, in time_wrapper
    r = func(*args, **kwargs)
        ^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/kornia/venv/lib/python3.11/site-packages/torch/_inductor/compile_fx.py", line 438, in compile_fx_inner
    compiled_graph = fx_codegen_and_compile(
                     ^^^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/kornia/venv/lib/python3.11/site-packages/torch/_inductor/compile_fx.py", line 714, in fx_codegen_and_compile
    compiled_fn = graph.compile_to_fn()
                  ^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/kornia/venv/lib/python3.11/site-packages/torch/_inductor/graph.py", line 1307, in compile_to_fn
    return self.compile_to_module().call
           ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/kornia/venv/lib/python3.11/site-packages/torch/_dynamo/utils.py", line 262, in time_wrapper
    r = func(*args, **kwargs)
        ^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/kornia/venv/lib/python3.11/site-packages/torch/_inductor/graph.py", line 1254, in compile_to_module
    mod = PyCodeCache.load_by_key_path(
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/kornia/venv/lib/python3.11/site-packages/torch/_inductor/codecache.py", line 2160, in load_by_key_path
    exec(code, mod.__dict__, mod.__dict__)
  File "/tmp/torchinductor_ghost/4l/c4lgprtfzlvjqpoisayik6diql5lvpqhl6axihqxinppsw3ovona.py", line 104, in <module>
    async_compile.wait(globals())
  File "/tmp/kornia/venv/lib/python3.11/site-packages/torch/_inductor/codecache.py", line 2715, in wait
    scope[key] = result.result()
                 ^^^^^^^^^^^^^^^
  File "/tmp/kornia/venv/lib/python3.11/concurrent/futures/_base.py", line 456, in result
    return self.__get_result()
           ^^^^^^^^^^^^^^^^^^^
  File "/tmp/kornia/venv/lib/python3.11/concurrent/futures/_base.py", line 401, in __get_result
    raise self._exception
  File "/tmp/kornia/venv/lib/python3.11/concurrent/futures/thread.py", line 58, in run
    result = self.fn(*self.args, **self.kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/kornia/venv/lib/python3.11/site-packages/torch/_inductor/codecache.py", line 2074, in load_pybinding
    result = cls.load(source_code + suffix, cuda)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/kornia/venv/lib/python3.11/site-packages/torch/_inductor/codecache.py", line 1948, in load
    compile_file(input_path, output_path, cmd)
  File "/tmp/kornia/venv/lib/python3.11/site-packages/torch/_dynamo/utils.py", line 262, in time_wrapper
    r = func(*args, **kwargs)
        ^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/kornia/venv/lib/python3.11/site-packages/torch/_inductor/codecache.py", line 1888, in compile_file
    raise exc.CppCompileError(cmd, output) from e
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
CppCompileError: C++ compile error

Command:
g++ /tmp/torchinductor_ghost/vd/cvdqhyulsvxn7ax2t2plarzirnzrqsstxph3wmdzwrs4kaeeffju.cpp -shared -fPIC -Wall -std=c++17 -Wno-unused-variable -Wno-unknown-pragmas -D_GLIBCXX_USE_CXX11_ABI=0 -I/tmp/kornia/venv/lib/python3.11/site-packages/torch/include -I/tmp/kornia/venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include -I/tmp/kornia/venv/lib/python3.11/site-packages/torch/include/TH -I/tmp/kornia/venv/lib/python3.11/site-packages/torch/include/THC -I/tmp/kornia/venv/include/python3.11 -L/tmp/kornia/venv/lib/python3.11/site-packages/torch/lib -L/tmp/kornia/venv/lib -L/tmp/kornia/venv/lib/python3.11/site-packages/torch/lib -ltorch -ltorch_cpu -lgomp -ltorch_python -lc10 -mavx2 -mfma -DCPU_CAPABILITY_AVX2 -O3 -DNDEBUG -ffast-math -fno-finite-math-only -fno-unsafe-math-optimizations -ffp-contract=off -march=native -fopenmp -D C10_USING_CUSTOM_GENERATED_MACROS -o /tmp/torchinductor_ghost/vd/cvdqhyulsvxn7ax2t2plarzirnzrqsstxph3wmdzwrs4kaeeffju.so

Output:
/tmp/torchinductor_ghost/vd/cvdqhyulsvxn7ax2t2plarzirnzrqsstxph3wmdzwrs4kaeeffju.cpp: In function ‘void kernel(const float*, float*)’:
/tmp/torchinductor_ghost/vd/cvdqhyulsvxn7ax2t2plarzirnzrqsstxph3wmdzwrs4kaeeffju.cpp:28:33: error: no match for ‘operator!=’ (operand types are ‘at::vec::CPU_CAPABILITY::VectorizedN<long int, 2>’ and ‘int’)
   28 |             auto tmp18 = (tmp10 != 0) | (tmp17 != 0);
      |                           ~~~~~ ^~ ~
      |                           |        |
      |                           |        int
      |                           at::vec::CPU_CAPABILITY::VectorizedN<long int, 2>
In file included from /tmp/torchinductor_ghost/wy/cwyvgno7oj63mpe36f4v6pizgeyvccmavffogp6xnqv56a32gbwo.h:31,
                 from /tmp/torchinductor_ghost/vd/cvdqhyulsvxn7ax2t2plarzirnzrqsstxph3wmdzwrs4kaeeffju.cpp:2:
/tmp/kornia/venv/lib/python3.11/site-packages/torch/include/ATen/cpu/vec/vec_n.h:261:32: note: candidate: ‘at::vec::CPU_CAPABILITY::VectorizedN<T, N> at::vec::CPU_CAPABILITY::VectorizedN<T, N>::operator!=(const at::vec::CPU_CAPABILITY::VectorizedN<T, N>&) const [with T = long int; int N = 2]’
  261 |   VECTORIZEDN_DEFINE_BINARY_OP(operator!=)
      |                                ^~~~~~~~
/tmp/kornia/venv/lib/python3.11/site-packages/torch/include/ATen/cpu/vec/vec_n.h:46:21: note: in definition of macro ‘VECTORIZEDN_DEFINE_BINARY_OP’
   46 |   VectorizedN<T, N> op(const VectorizedN<T, N>& other) const {      \
      |                     ^~
/tmp/kornia/venv/lib/python3.11/site-packages/torch/include/ATen/cpu/vec/vec_n.h:46:49: note:   no known conversion for argument 1 from ‘int’ to ‘const at::vec::CPU_CAPABILITY::VectorizedN<long int, 2>&’
   46 |   VectorizedN<T, N> op(const VectorizedN<T, N>& other) const {      \
      |                        ~~~~~~~~~~~~~~~~~~~~~~~~~^~~~~
/tmp/kornia/venv/lib/python3.11/site-packages/torch/include/ATen/cpu/vec/vec_n.h:261:3: note: in expansion of macro ‘VECTORIZEDN_DEFINE_BINARY_OP’
  261 |   VECTORIZEDN_DEFINE_BINARY_OP(operator!=)
      |   ^~~~~~~~~~~~~~~~~~~~~~~~~~~~
/tmp/torchinductor_ghost/vd/cvdqhyulsvxn7ax2t2plarzirnzrqsstxph3wmdzwrs4kaeeffju.cpp:28:48: error: no match for ‘operator!=’ (operand types are ‘at::vec::CPU_CAPABILITY::VectorizedN<long int, 2>’ and ‘int’)
   28 |             auto tmp18 = (tmp10 != 0) | (tmp17 != 0);
      |                                          ~~~~~ ^~ ~
      |                                          |        |
      |                                          |        int
      |                                          at::vec::CPU_CAPABILITY::VectorizedN<long int, 2>
In file included from /tmp/torchinductor_ghost/wy/cwyvgno7oj63mpe36f4v6pizgeyvccmavffogp6xnqv56a32gbwo.h:31,
                 from /tmp/torchinductor_ghost/vd/cvdqhyulsvxn7ax2t2plarzirnzrqsstxph3wmdzwrs4kaeeffju.cpp:2:
/tmp/kornia/venv/lib/python3.11/site-packages/torch/include/ATen/cpu/vec/vec_n.h:261:32: note: candidate: ‘at::vec::CPU_CAPABILITY::VectorizedN<T, N> at::vec::CPU_CAPABILITY::VectorizedN<T, N>::operator!=(const at::vec::CPU_CAPABILITY::VectorizedN<T, N>&) const [with T = long int; int N = 2]’
  261 |   VECTORIZEDN_DEFINE_BINARY_OP(operator!=)
      |                                ^~~~~~~~
/tmp/kornia/venv/lib/python3.11/site-packages/torch/include/ATen/cpu/vec/vec_n.h:46:21: note: in definition of macro ‘VECTORIZEDN_DEFINE_BINARY_OP’
   46 |   VectorizedN<T, N> op(const VectorizedN<T, N>& other) const {      \
      |                     ^~
/tmp/kornia/venv/lib/python3.11/site-packages/torch/include/ATen/cpu/vec/vec_n.h:46:49: note:   no known conversion for argument 1 from ‘int’ to ‘const at::vec::CPU_CAPABILITY::VectorizedN<long int, 2>&’
   46 |   VectorizedN<T, N> op(const VectorizedN<T, N>& other) const {      \
      |                        ~~~~~~~~~~~~~~~~~~~~~~~~~^~~~~
/tmp/kornia/venv/lib/python3.11/site-packages/torch/include/ATen/cpu/vec/vec_n.h:261:3: note: in expansion of macro ‘VECTORIZEDN_DEFINE_BINARY_OP’
  261 |   VECTORIZEDN_DEFINE_BINARY_OP(operator!=)
      |   ^~~~~~~~~~~~~~~~~~~~~~~~~~~~
/tmp/torchinductor_ghost/vd/cvdqhyulsvxn7ax2t2plarzirnzrqsstxph3wmdzwrs4kaeeffju.cpp:35:48: error: no match for ‘operator!=’ (operand types are ‘at::vec::CPU_CAPABILITY::VectorizedN<long int, 2>’ and ‘int’)
   35 |             auto tmp26 = (tmp18 != 0) | (tmp25 != 0);
      |                                          ~~~~~ ^~ ~
      |                                          |        |
      |                                          |        int
      |                                          at::vec::CPU_CAPABILITY::VectorizedN<long int, 2>
In file included from /tmp/torchinductor_ghost/wy/cwyvgno7oj63mpe36f4v6pizgeyvccmavffogp6xnqv56a32gbwo.h:31,
                 from /tmp/torchinductor_ghost/vd/cvdqhyulsvxn7ax2t2plarzirnzrqsstxph3wmdzwrs4kaeeffju.cpp:2:
/tmp/kornia/venv/lib/python3.11/site-packages/torch/include/ATen/cpu/vec/vec_n.h:261:32: note: candidate: ‘at::vec::CPU_CAPABILITY::VectorizedN<T, N> at::vec::CPU_CAPABILITY::VectorizedN<T, N>::operator!=(const at::vec::CPU_CAPABILITY::VectorizedN<T, N>&) const [with T = long int; int N = 2]’
  261 |   VECTORIZEDN_DEFINE_BINARY_OP(operator!=)
      |                                ^~~~~~~~
/tmp/kornia/venv/lib/python3.11/site-packages/torch/include/ATen/cpu/vec/vec_n.h:46:21: note: in definition of macro ‘VECTORIZEDN_DEFINE_BINARY_OP’
   46 |   VectorizedN<T, N> op(const VectorizedN<T, N>& other) const {      \
      |                     ^~
/tmp/kornia/venv/lib/python3.11/site-packages/torch/include/ATen/cpu/vec/vec_n.h:46:49: note:   no known conversion for argument 1 from ‘int’ to ‘const at::vec::CPU_CAPABILITY::VectorizedN<long int, 2>&’
   46 |   VectorizedN<T, N> op(const VectorizedN<T, N>& other) const {      \
      |                        ~~~~~~~~~~~~~~~~~~~~~~~~~^~~~~
/tmp/kornia/venv/lib/python3.11/site-packages/torch/include/ATen/cpu/vec/vec_n.h:261:3: note: in expansion of macro ‘VECTORIZEDN_DEFINE_BINARY_OP’
  261 |   VECTORIZEDN_DEFINE_BINARY_OP(operator!=)
      |   ^~~~~~~~~~~~~~~~~~~~~~~~~~~~

Minified repro

# TORCHDYNAMO_REPRO_AFTER="dynamo" python t.py

from math import inf
import torch
from torch import tensor, device
import torch.fx as fx
import torch._dynamo
from torch._dynamo.testing import rand_strided
from torch._dynamo.debug_utils import run_fwd_maybe_bwd

import torch._dynamo.config
import torch._inductor.config
import torch._functorch.config
import torch.fx.experimental._config
torch._dynamo.config.verbose = True








from torch.nn import *
class Repro(torch.nn.Module):
    def __init__(self):
        super().__init__()



    def forward(self, L_input_ : torch.Tensor):
        input_1 = L_input_
        tensor = torch.tensor((0.2, 0.2, 0.2), device = device(type='cpu'), dtype = torch.float32)
        reshape = tensor.reshape(1, -1, 1, 1);  tensor = None
        lower = reshape.repeat(1, 1, 1, 1);  reshape = None
        tensor_1 = torch.tensor((0.6, 0.6, 0.6), device = device(type='cpu'), dtype = torch.float32)
        reshape_1 = tensor_1.reshape(1, -1, 1, 1);  tensor_1 = None
        upper = reshape_1.repeat(1, 1, 1, 1);  reshape_1 = None
        ge = input_1 >= lower;  lower = None
        le = input_1 <= upper;  input_1 = upper = None
        mask = torch.logical_and(ge, le);  ge = le = None
        all_1 = mask.all(dim = 1, keepdim = True);  mask = None
        output = all_1.to(torch.float32);  all_1 = None
        return (output,)


mod = Repro()

def load_args(reader):
    buf0 = reader.storage('e2ab56dff6760f39c0f07bc28f2c2b46e05c4083', 300)
    reader.tensor(buf0, (1, 3, 5, 5), is_leaf=True)  # L_input_
load_args._version = 0

if __name__ == '__main__':
    from torch._dynamo.repro.after_dynamo import run_repro
    run_repro(mod, load_args, accuracy=False, command='minify',
        save_dir='/tmp/kornia/torch_compile_debug/run_2024_05_18_11_39_38_826873-pid_26238/minifier/checkpoints', autocast=False, backend='inductor')
# TORCHDYNAMO_REPRO_AFTER="aot" python t.py


import torch
from torch import tensor, device
import torch.fx as fx
from torch._dynamo.testing import rand_strided
from math import inf
import torch._inductor.inductor_prims

import torch._dynamo.config
import torch._inductor.config
import torch._functorch.config
import torch.fx.experimental._config
torch._dynamo.config.verbose = True





isolate_fails_code_str = None



# torch version: 2.3.0
# torch cuda version: 12.1
# torch git version: 97ff6cfd9c86c5c09d7ce775ab64ec5c99230f5d


# CUDA Info: 
# nvcc not found
# GPU Hardware Info: 
# NVIDIA GeForce RTX 3060 Ti : 1 


from torch.nn import *
class Repro(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.register_buffer('_tensor_constant0', tensor([0.2000, 0.2000, 0.2000]))
        self.register_buffer('_tensor_constant1', tensor([0.6000, 0.6000, 0.6000]))

    
    
    def forward(self, arg0_1):
        _tensor_constant0 = self._tensor_constant0
        lift_fresh_copy = torch.ops.aten.lift_fresh_copy.default(_tensor_constant0);  _tensor_constant0 = None
        view = torch.ops.aten.view.default(lift_fresh_copy, [1, -1, 1, 1]);  lift_fresh_copy = None
        full_default = torch.ops.aten.full.default([1, 3, 1, 1], 0.20000000298023224, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'), pin_memory = False)
        _tensor_constant1 = self._tensor_constant1
        lift_fresh_copy_1 = torch.ops.aten.lift_fresh_copy.default(_tensor_constant1);  _tensor_constant1 = None
        view_1 = torch.ops.aten.view.default(lift_fresh_copy_1, [1, -1, 1, 1]);  lift_fresh_copy_1 = None
        full_default_1 = torch.ops.aten.full.default([1, 3, 1, 1], 0.6000000238418579, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'), pin_memory = False)
        ge = torch.ops.aten.ge.Tensor(arg0_1, full_default);  full_default = None
        le = torch.ops.aten.le.Tensor(arg0_1, full_default_1);  arg0_1 = full_default_1 = None
        logical_and = torch.ops.aten.logical_and.default(ge, le);  ge = le = None
        logical_not = torch.ops.aten.logical_not.default(logical_and);  logical_and = None
        any_1 = torch.ops.aten.any.dim(logical_not, 1, True);  logical_not = None
        logical_not_1 = torch.ops.aten.logical_not.default(any_1);  any_1 = None
        convert_element_type = torch.ops.prims.convert_element_type.default(logical_not_1, torch.float32);  logical_not_1 = None
        return (convert_element_type,)
        
def load_args(reader):
    buf0 = reader.storage(None, 300)
    reader.tensor(buf0, (1, 3, 5, 5), is_leaf=True)  # arg0_1
load_args._version = 0
mod = Repro()
if __name__ == '__main__':
    from torch._dynamo.repro.after_aot import run_repro
    with torch.no_grad():
        run_repro(mod, load_args, accuracy=False, command='minify', save_dir='/tmp/kornia/torch_compile_debug/run_2024_05_18_11_40_34_518846-pid_26379/minifier/checkpoints', tracing_mode='real', check_str=None)

Versions

PyTorch version: 2.3.0
Is debug build: False
CUDA used to build PyTorch: 12.1
ROCM used to build PyTorch: N/A

OS: Ubuntu 22.04.4 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Clang version: 14.0.0-1ubuntu1.1
CMake version: version 3.22.1
Libc version: glibc-2.35

Python version: 3.11.9 (main, Apr 19 2024, 16:48:06) [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-6.5.0-35-generic-x86_64-with-glibc2.35
Is CUDA available: True
CUDA runtime version: Could not collect
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: NVIDIA GeForce RTX 3060 Ti
Nvidia driver version: 535.171.04
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture:                       x86_64
CPU op-mode(s):                     32-bit, 64-bit
Address sizes:                      48 bits physical, 48 bits virtual
Byte Order:                         Little Endian
CPU(s):                             12
On-line CPU(s) list:                0-11
Vendor ID:                          AuthenticAMD
Model name:                         AMD Ryzen 5 5600X 6-Core Processor
CPU family:                         25
Model:                              33
Thread(s) per core:                 2
Core(s) per socket:                 6
Socket(s):                          1
Stepping:                           0
Frequency boost:                    enabled
CPU max MHz:                        4650,2920
CPU min MHz:                        2200,0000
BogoMIPS:                           7399.51
Flags:                              fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl nonstop_tsc cpuid extd_apicid aperfmperf rapl pni pclmulqdq monitor ssse3 fma cx16 sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 hw_pstate ssbd mba ibrs ibpb stibp vmmcall fsgsbase bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local clzero irperf xsaveerptr rdpru wbnoinvd arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold avic v_vmsave_vmload vgif v_spec_ctrl umip pku ospke vaes vpclmulqdq rdpid overflow_recov succor smca fsrm
Virtualization:                     AMD-V
L1d cache:                          192 KiB (6 instances)
L1i cache:                          192 KiB (6 instances)
L2 cache:                           3 MiB (6 instances)
L3 cache:                           32 MiB (1 instance)
NUMA node(s):                       1
NUMA node0 CPU(s):                  0-11
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit:        Not affected
Vulnerability L1tf:                 Not affected
Vulnerability Mds:                  Not affected
Vulnerability Meltdown:             Not affected
Vulnerability Mmio stale data:      Not affected
Vulnerability Retbleed:             Not affected
Vulnerability Spec rstack overflow: Vulnerable: Safe RET, no microcode
Vulnerability Spec store bypass:    Mitigation; Speculative Store Bypass disabled via prctl
Vulnerability Spectre v1:           Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:           Mitigation; Retpolines; IBPB conditional; IBRS_FW; STIBP always-on; RSB filling; PBRSB-eIBRS Not affected; BHI Not affected
Vulnerability Srbds:                Not affected
Vulnerability Tsx async abort:      Not affected

Versions of relevant libraries:
[pip3] mypy==1.10.0
[pip3] mypy-extensions==1.0.0
[pip3] numpy==1.26.4
[pip3] onnx==1.16.0
[pip3] torch==2.3.0
[pip3] torchaudio==2.3.0
[pip3] torchvision==0.18.0
[pip3] triton==2.3.0
[conda] blas                      1.0                         mkl  
[conda] ffmpeg                    4.3                  hf484d3e_0    pytorch
[conda] libjpeg-turbo             2.0.0                h9bf148f_0    pytorch
[conda] mkl                       2023.1.0         h213fc3f_46344  
[conda] mkl-service               2.4.0           py311h5eee18b_1  
[conda] mkl_fft                   1.3.8           py311h5eee18b_0  
[conda] mkl_random                1.2.4           py311hdb19cb5_0  
[conda] numpy                     1.26.4          py311h08b1b3b_0  
[conda] numpy-base                1.26.4          py311hf175353_0  
[conda] pytorch                   2.3.0           py3.11_cuda12.1_cudnn8.9.2_0    pytorch
[conda] pytorch-cuda              12.1                 ha16c6d3_5    pytorch
[conda] pytorch-mutex             1.0                        cuda    pytorch
[conda] torchaudio                2.3.0               py311_cu121    pytorch
[conda] torchtriton               2.3.0                     py311    pytorch
[conda] torchvision               0.18.0              py311_cu121    pytorch

cc @ezyang @msaroufim @bdhirsh @anijain2305 @chauhang

johnnv1 added a commit to johnnv1/kornia that referenced this issue May 18, 2024
@ezyang ezyang added the oncall: cpu inductor CPU Inductor issues for Intel team to triage label May 19, 2024
@jgong5
Copy link
Collaborator

jgong5 commented May 20, 2024

I cannot repro the problem with the latest pytorch mainline: d9c3485. @johnnv1 Can you double check if the mainline has the problem fixed? Thanks.

@johnnv1
Copy link
Author

johnnv1 commented May 20, 2024

Yeah, it's fine in the nightly (torch-2.4.0.dev20240515+cu121) already, sorry for raising the issue :)

any idea what change caused this? just to know if it's going into a patch version

@johnnv1 johnnv1 closed this as completed May 20, 2024
@jgong5
Copy link
Collaborator

jgong5 commented May 21, 2024

any idea what change caused this? just to know if it's going into a patch version

Originally, the long type was not vectorized and the vectorization of it caused some regression and later fixed on the trunk.

edgarriba pushed a commit to kornia/kornia that referenced this issue May 21, 2024
* chore (CI): ensure support to pytorch 2.3.0

* chore: skip specific dynamo tests for torch 2.3.0

- Report in pytorch/pytorch#126617

* chore: skip specific dynamo tests for torch 2.3.0

- Report in pytorch/pytorch#126619
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
oncall: cpu inductor CPU Inductor issues for Intel team to triage oncall: pt2
Projects
None yet
Development

No branches or pull requests

3 participants