Skip to content

[Bug][TE Schedule] Unsupported nested parallel created by Softmax TE schedule #12001

@lazycal

Description

@lazycal

I got this error Check failed: (!parallel_env_.in_parallel_loop) is false: Nested parallel loop is not supported by threadpool, try fuse them instead when compiling this model:
image

I do see nesting parallel in the produced IR:

@tvmgen_default_fused_floor_layout_transform_subtract_nn_fast_softmax_subtract_layout_transform__fd9b5c2e66420ec5_ = primfn(placeholder_3: handle, placeholder_4: handle, placeholder_5: handle, T_add_1: handle) -> ()
  attr = {"from_legacy_te_schedule": True, "global_symbol": "tvmgen_default_fused_floor_layout_transform_subtract_nn_fast_softmax_subtract_layout_transform__fd9b5c2e66420ec5_", "tir.noalias": True}
  buffers = {placeholder: Buffer(placeholder_6: Pointer(float32), float32, [3, 3, 10, 1], []),
             placeholder_1: Buffer(placeholder_7: Pointer(float32), float32, [3, 1, 10, 1, 3], []),
             placeholder_2: Buffer(placeholder_8: Pointer(float32), float32, [3, 3, 10, 1], []),
             T_add: Buffer(T_add_2: Pointer(float32), float32, [3, 1, 10, 1, 3], [])}
  buffer_map = {placeholder_3: placeholder, placeholder_4: placeholder_1, placeholder_5: placeholder_2, T_add_1: T_add} {
  realize(T_add, [0:3, 0:1, 0:10, 0:1, 0:3], True {
    for (ax0.ax1.fused: int32, 0, 3) "parallel" {
      realize(T_softmax_norm: Buffer(T_softmax_norm_1: Pointer(float32), float32, [3, 3, 10, 1], []), [ax0.ax1.fused:(ax0.ax1.fused + 1), 0:3, 0:10, 0:1], True {
        for (i0.i1.fused: int32, 0, 3) "parallel" {
          realize(T_softmax_maxelem: Buffer(T_softmax_maxelem_1: Pointer(float32), float32, [3, 3, 1], []), [(floordiv(i0.i1.fused, 3) + ax0.ax1.fused):((floordiv(i0.i1.fused, 3) + ax0.ax1.fused) + 1), floormod(i0.i1.fused, 3):(floormod(i0.i1.fused, 3) + 1), 0:1], True {
...

The double "parallel"s seems to be introduced by

s[output].parallel(fused_outer_axes)
and
s[softmax_op].parallel(fused_outer_axes)
.
The problem disappears after commenting out either line.

Also the model passes with opt_level=2.

Expected behavior

The model passes.

Actual behavior

It fails on higher opt_level.

Environment

System: Ubuntu 18.04
TVM version: commit 1787cca

Steps to reproduce

Run this code:

import torch


class Model(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv = torch.nn.Conv2d(3, 3, (1, 1), 1)

    @torch.no_grad()
    def forward(self, x):
        t0a = self.conv(x)
        t0b = torch.floor(x)
        t2b = torch.softmax(t0a, dim=2)
        t3a = t0a - t0b
        t4a = t2b - t0b
        t6a = t3a + t4a
        return t6a,


sh = [3, 3, 10, 1]
inputs = (torch.ones(*sh, dtype=torch.float32),)
model = Model()
model.eval()
outputs_torch = [o.numpy() for o in model(*inputs)]
iname = ["i0"]
oname = ["o0"]
torch.onnx.export(model, inputs, "output.onnx",
                  input_names=iname, output_names=oname, opset_version=14)
from onnx import checker
checker.check_model("output.onnx", full_check=True)

import tvm
from tvm import relay
import onnx
mod, params = relay.frontend.from_onnx(
    onnx.load("output.onnx"), freeze_params=True)
with tvm.transform.PassContext(opt_level=4):
    relay.create_executor("graph", mod, params=params).evaluate()(
        i0=inputs[0].numpy())

Or run the provided onnx file:
model.onnx.zip

The error log (a more complete log with all IRs after each pass is here: output.log):

One or more operators have not been tuned. Please tune your model for better performance. Use DEBUG logging level to see more details.
/workspace/workspace/tvm-intact/python/tvm/driver/build_module.py:268: UserWarning: target_host parameter is going to be deprecated. Please pass in tvm.target.Target(target, host=target_host) instead.
  "target_host parameter is going to be deprecated. "
Traceback (most recent call last):
  File "test.py", line 38, in <module>
    relay.create_executor("graph", mod, params=params).evaluate()(
  File "/workspace/workspace/tvm-intact/python/tvm/relay/backend/interpreter.py", line 171, in evaluate
    return self._make_executor()
  File "/workspace/workspace/tvm-intact/python/tvm/relay/build_module.py", line 593, in _make_executor
    mod = build(self.mod, target=self.target)
  File "/workspace/workspace/tvm-intact/python/tvm/relay/build_module.py", line 446, in build
    mod_name=mod_name,
  File "/workspace/workspace/tvm-intact/python/tvm/relay/build_module.py", line 169, in build
    mod_name,
  File "tvm/_ffi/_cython/./packed_func.pxi", line 331, in tvm._ffi._cy3.core.PackedFuncBase.__call__
  File "tvm/_ffi/_cython/./packed_func.pxi", line 276, in tvm._ffi._cy3.core.FuncCall
  File "tvm/_ffi/_cython/./base.pxi", line 181, in tvm._ffi._cy3.core.CHECK_CALL
tvm._ffi.base.TVMError: Traceback (most recent call last):
  231: TVMFuncCall
  230: tvm::relay::backend::RelayBuildModule::GetFunction(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, tvm::runtime::ObjectPtr<tvm::runtime::Object> const&)::{lambda(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)#3}::operator()(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*) const
  229: tvm::relay::backend::RelayBuildModule::Build(tvm::IRModule, tvm::runtime::Array<tvm::Target, void> const&, tvm::Target const&, tvm::relay::Executor const&, tvm::relay::Runtime const&, tvm::WorkspaceMemoryPools const&, tvm::ConstantMemoryPools const&, tvm::runtime::String)
  228: tvm::relay::backend::RelayBuildModule::BuildRelay(tvm::IRModule, tvm::runtime::String const&)
  227: tvm::TIRToRuntime(tvm::runtime::Map<tvm::Target, tvm::IRModule, void, void> const&, tvm::Target const&)
  226: tvm::codegen::Build(tvm::IRModule, tvm::Target)
  225: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<tvm::runtime::Module (tvm::IRModule, tvm::Target)>::AssignTypedLambda<tvm::codegen::{lambda(tvm::IRModule, tvm::Target)#6}>(tvm::codegen::{lambda(tvm::IRModule, tvm::Target)#6}, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, tvm::runtime::TVMRetValue)
  224: tvm::codegen::LLVMModuleNode::Init(tvm::IRModule const&, tvm::Target const&)
  223: void tvm::codegen::CodeGenLLVM::AddFunctionsOrdered<__gnu_cxx::__normal_iterator<tvm::tir::PrimFunc*, std::vector<tvm::tir::PrimFunc, std::allocator<tvm::tir::PrimFunc> > >, tvm::codegen::CodeGenLLVM::AddFunctionsOrdered<__gnu_cxx::__normal_iterator<tvm::tir::PrimFunc*, std::vector<tvm::tir::PrimFunc, std::allocator<tvm::tir::PrimFunc> > > >(__gnu_cxx::__normal_iterator<tvm::tir::PrimFunc*, std::vector<tvm::tir::PrimFunc, std::allocator<tvm::tir::PrimFunc> > >, __gnu_cxx::__normal_iterator<tvm::tir::PrimFunc*, std::vector<tvm::tir::PrimFunc, std::allocator<tvm::tir::PrimFunc> > >)::{lambda(auto:1)#1}>(__gnu_cxx::__normal_iterator<tvm::tir::PrimFunc*, std::vector<tvm::tir::PrimFunc, std::allocator<tvm::tir::PrimFunc> > >, __gnu_cxx::__normal_iterator<tvm::tir::PrimFunc*, std::vector<tvm::tir::PrimFunc, std::allocator<tvm::tir::PrimFunc> > >, tvm::codegen::CodeGenLLVM::AddFunctionsOrdered<__gnu_cxx::__normal_iterator<tvm::tir::PrimFunc*, std::vector<tvm::tir::PrimFunc, std::allocator<tvm::tir::PrimFunc> > > >(__gnu_cxx::__normal_iterator<tvm::tir::PrimFunc*, std::vector<tvm::tir::PrimFunc, std::allocator<tvm::tir::PrimFunc> > >, __gnu_cxx::__normal_iterator<tvm::tir::PrimFunc*, std::vector<tvm::tir::PrimFunc, std::allocator<tvm::tir::PrimFunc> > >)::{lambda(auto:1)#1})
  222: tvm::codegen::CodeGenCPU::AddFunction(tvm::tir::PrimFunc const&)
  221: tvm::codegen::CodeGenLLVM::AddFunctionInternal(tvm::tir::PrimFunc const&, bool)
  220: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
  219: tvm::codegen::CodeGenCPU::VisitStmt_(tvm::tir::AssertStmtNode const*)
  218: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::AssertStmtNode const*)
  217: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
  216: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::LetStmtNode const*)
  215: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
  214: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::LetStmtNode const*)
  213: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
  212: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::LetStmtNode const*)
  211: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
  210: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::LetStmtNode const*)
  209: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
  208: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::LetStmtNode const*)
  207: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
  206: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::LetStmtNode const*)
  205: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
  204: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::LetStmtNode const*)
  203: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
  202: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::LetStmtNode const*)
  201: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
  200: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::LetStmtNode const*)
  199: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
  198: tvm::codegen::CodeGenCPU::VisitStmt_(tvm::tir::AttrStmtNode const*)
  197: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::AttrStmtNode const*)
  196: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
  195: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::LetStmtNode const*)
  194: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
  193: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::LetStmtNode const*)
  192: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
  191: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::LetStmtNode const*)
  190: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
  189: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::LetStmtNode const*)
  188: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
  187: tvm::codegen::CodeGenCPU::VisitStmt_(tvm::tir::AttrStmtNode const*)
  186: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::AttrStmtNode const*)
  185: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
  184: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::LetStmtNode const*)
  183: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
  182: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::LetStmtNode const*)
  181: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
  180: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::LetStmtNode const*)
  179: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
  178: tvm::codegen::CodeGenCPU::VisitStmt_(tvm::tir::AttrStmtNode const*)
  177: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::AttrStmtNode const*)
  176: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
  175: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::LetStmtNode const*)
  174: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
  173: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::LetStmtNode const*)
  172: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
  171: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::LetStmtNode const*)
  170: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
  169: tvm::codegen::CodeGenCPU::VisitStmt_(tvm::tir::AttrStmtNode const*)
  168: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::AttrStmtNode const*)
  167: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
  166: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::LetStmtNode const*)
  165: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
  164: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::LetStmtNode const*)
  163: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
  162: tvm::codegen::CodeGenCPU::VisitStmt_(tvm::tir::AssertStmtNode const*)
  161: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::AssertStmtNode const*)
  160: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
  159: tvm::codegen::CodeGenCPU::VisitStmt_(tvm::tir::AssertStmtNode const*)
  158: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::AssertStmtNode const*)
  157: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
  156: tvm::codegen::CodeGenCPU::VisitStmt_(tvm::tir::AssertStmtNode const*)
  155: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::AssertStmtNode const*)
  154: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
  153: tvm::codegen::CodeGenCPU::VisitStmt_(tvm::tir::AssertStmtNode const*)
  152: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::AssertStmtNode const*)
  151: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
  150: tvm::codegen::CodeGenCPU::VisitStmt_(tvm::tir::AssertStmtNode const*)
  149: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::AssertStmtNode const*)
  148: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
  147: tvm::codegen::CodeGenCPU::VisitStmt_(tvm::tir::AssertStmtNode const*)
  146: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::AssertStmtNode const*)
  145: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
  144: tvm::codegen::CodeGenCPU::VisitStmt_(tvm::tir::AssertStmtNode const*)
  143: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::AssertStmtNode const*)
  142: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
  141: tvm::codegen::CodeGenCPU::VisitStmt_(tvm::tir::AssertStmtNode const*)
  140: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::AssertStmtNode const*)
  139: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
  138: tvm::codegen::CodeGenCPU::VisitStmt_(tvm::tir::AssertStmtNode const*)
  137: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::AssertStmtNode const*)
  136: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
  135: tvm::codegen::CodeGenCPU::VisitStmt_(tvm::tir::AssertStmtNode const*)
  134: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::AssertStmtNode const*)
  133: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
  132: tvm::codegen::CodeGenCPU::VisitStmt_(tvm::tir::AssertStmtNode const*)
  131: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::AssertStmtNode const*)
  130: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
  129: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::SeqStmtNode const*)
  128: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
  127: tvm::codegen::CodeGenCPU::VisitStmt_(tvm::tir::AssertStmtNode const*)
  126: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::AssertStmtNode const*)
  125: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
  124: tvm::codegen::CodeGenCPU::VisitStmt_(tvm::tir::AssertStmtNode const*)
  123: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::AssertStmtNode const*)
  122: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
  121: tvm::codegen::CodeGenCPU::VisitStmt_(tvm::tir::AssertStmtNode const*)
  120: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::AssertStmtNode const*)
  119: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
  118: tvm::codegen::CodeGenCPU::VisitStmt_(tvm::tir::AssertStmtNode const*)
  117: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::AssertStmtNode const*)
  116: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
  115: tvm::codegen::CodeGenCPU::VisitStmt_(tvm::tir::AssertStmtNode const*)
  114: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::AssertStmtNode const*)
  113: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
  112: tvm::codegen::CodeGenCPU::VisitStmt_(tvm::tir::AssertStmtNode const*)
  111: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::AssertStmtNode const*)
  110: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
  109: tvm::codegen::CodeGenCPU::VisitStmt_(tvm::tir::AssertStmtNode const*)
  108: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::AssertStmtNode const*)
  107: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
  106: tvm::codegen::CodeGenCPU::VisitStmt_(tvm::tir::AssertStmtNode const*)
  105: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::AssertStmtNode const*)
  104: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
  103: tvm::codegen::CodeGenCPU::VisitStmt_(tvm::tir::AssertStmtNode const*)
  102: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::AssertStmtNode const*)
  101: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
  100: tvm::codegen::CodeGenCPU::VisitStmt_(tvm::tir::AssertStmtNode const*)
  99: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::AssertStmtNode const*)
  98: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
  97: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::SeqStmtNode const*)
  96: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
  95: tvm::codegen::CodeGenCPU::VisitStmt_(tvm::tir::AssertStmtNode const*)
  94: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::AssertStmtNode const*)
  93: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
  92: tvm::codegen::CodeGenCPU::VisitStmt_(tvm::tir::AssertStmtNode const*)
  91: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::AssertStmtNode const*)
  90: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
  89: tvm::codegen::CodeGenCPU::VisitStmt_(tvm::tir::AssertStmtNode const*)
  88: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::AssertStmtNode const*)
  87: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
  86: tvm::codegen::CodeGenCPU::VisitStmt_(tvm::tir::AssertStmtNode const*)
  85: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::AssertStmtNode const*)
  84: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
  83: tvm::codegen::CodeGenCPU::VisitStmt_(tvm::tir::AssertStmtNode const*)
  82: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::AssertStmtNode const*)
  81: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
  80: tvm::codegen::CodeGenCPU::VisitStmt_(tvm::tir::AssertStmtNode const*)
  79: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::AssertStmtNode const*)
  78: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
  77: tvm::codegen::CodeGenCPU::VisitStmt_(tvm::tir::AssertStmtNode const*)
  76: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::AssertStmtNode const*)
  75: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
  74: tvm::codegen::CodeGenCPU::VisitStmt_(tvm::tir::AssertStmtNode const*)
  73: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::AssertStmtNode const*)
  72: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
  71: tvm::codegen::CodeGenCPU::VisitStmt_(tvm::tir::AssertStmtNode const*)
  70: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::AssertStmtNode const*)
  69: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
  68: tvm::codegen::CodeGenCPU::VisitStmt_(tvm::tir::AssertStmtNode const*)
  67: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::AssertStmtNode const*)
  66: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
  65: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::SeqStmtNode const*)
  64: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
  63: tvm::codegen::CodeGenCPU::VisitStmt_(tvm::tir::AssertStmtNode const*)
  62: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::AssertStmtNode const*)
  61: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
  60: tvm::codegen::CodeGenCPU::VisitStmt_(tvm::tir::AssertStmtNode const*)
  59: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::AssertStmtNode const*)
  58: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
  57: tvm::codegen::CodeGenCPU::VisitStmt_(tvm::tir::AssertStmtNode const*)
  56: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::AssertStmtNode const*)
  55: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
  54: tvm::codegen::CodeGenCPU::VisitStmt_(tvm::tir::AssertStmtNode const*)
  53: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::AssertStmtNode const*)
  52: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
  51: tvm::codegen::CodeGenCPU::VisitStmt_(tvm::tir::AssertStmtNode const*)
  50: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::AssertStmtNode const*)
  49: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
  48: tvm::codegen::CodeGenCPU::VisitStmt_(tvm::tir::AssertStmtNode const*)
  47: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::AssertStmtNode const*)
  46: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
  45: tvm::codegen::CodeGenCPU::VisitStmt_(tvm::tir::AssertStmtNode const*)
  44: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::AssertStmtNode const*)
  43: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
  42: tvm::codegen::CodeGenCPU::VisitStmt_(tvm::tir::AssertStmtNode const*)
  41: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::AssertStmtNode const*)
  40: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
  39: tvm::codegen::CodeGenCPU::VisitStmt_(tvm::tir::AssertStmtNode const*)
  38: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::AssertStmtNode const*)
  37: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
  36: tvm::codegen::CodeGenCPU::VisitStmt_(tvm::tir::AssertStmtNode const*)
  35: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::AssertStmtNode const*)
  34: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
  33: tvm::codegen::CodeGenCPU::VisitStmt_(tvm::tir::AssertStmtNode const*)
  32: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::AssertStmtNode const*)
  31: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
  30: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::SeqStmtNode const*)
  29: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
  28: tvm::codegen::CodeGenCPU::VisitStmt_(tvm::tir::AssertStmtNode const*)
  27: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::AssertStmtNode const*)
  26: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
  25: tvm::codegen::CodeGenCPU::VisitStmt_(tvm::tir::AssertStmtNode const*)
  24: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::AssertStmtNode const*)
  23: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
  22: tvm::codegen::CodeGenCPU::VisitStmt_(tvm::tir::AssertStmtNode const*)
  21: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::AssertStmtNode const*)
  20: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
  19: tvm::codegen::CodeGenCPU::VisitStmt_(tvm::tir::AttrStmtNode const*)
  18: tvm::codegen::CodeGenCPU::CreateComputeScope(tvm::tir::AttrStmtNode const*)
  17: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
  16: tvm::codegen::CodeGenCPU::VisitStmt_(tvm::tir::ForNode const*)
  15: tvm::codegen::CodeGenCPU::CreateParallelLaunch(tvm::tir::Stmt const&, int, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)
  14: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
  13: tvm::codegen::CodeGenCPU::VisitStmt_(tvm::tir::ForNode const*)
  12: tvm::codegen::CodeGenLLVM::CreateSerialFor(llvm::Value*, llvm::Value*, llvm::Value*, tvm::tir::Var const&, tvm::tir::Stmt const&)
  11: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
  10: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::AllocateNode const*)
  9: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
  8: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::AllocateNode const*)
  7: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
  6: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::AllocateNode const*)
  5: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
  4: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::AllocateNode const*)
  3: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
  2: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::SeqStmtNode const*)
  1: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
  0: tvm::codegen::CodeGenCPU::VisitStmt_(tvm::tir::ForNode const*)
  File "/workspace/workspace/tvm-intact/src/target/llvm/codegen_cpu.cc", line 1514
TVMError: 
---------------------------------------------------------------
An error occurred during the execution of TVM.
For more information, please see: https://tvm.apache.org/docs/errors.html
---------------------------------------------------------------
  Check failed: (!parallel_env_.in_parallel_loop) is false: Nested parallel loop is not supported by threadpool, try fuse them instead

Metadata

Metadata

Assignees

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions