Skip to content

Commit

Permalink
[CI] Fix CI Script and Broken Tests (#16521)
Browse files Browse the repository at this point in the history
* [CI] Fix CI Script and Broken Tests

Co-authored-by: Shengjie Liu <Shengjie.Liu@armchina.com>

* Enhance IterMapSimplify to support uncommon predicate

* Fix runtime traced_callpacked

* Fix derived object attribute get

* update debug line info testcase

* fix relay/relax import and debug_info

* fix lint

---------

Co-authored-by: Shengjie Liu <Shengjie.Liu@armchina.com>
Co-authored-by: tqchen <tianqi.tchen@gmail.com>
  • Loading branch information
3 people committed Feb 7, 2024
1 parent 2dcf9ec commit 268d15c
Show file tree
Hide file tree
Showing 17 changed files with 248 additions and 316 deletions.
6 changes: 6 additions & 0 deletions python/tvm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,12 @@
# Contrib initializers
from .contrib import rocm as _rocm, nvcc as _nvcc, sdaccel as _sdaccel

# Relay and Relax contain modules that are only available in compiler package
# Do not import them if TVM is built with runtime only
if not _RUNTIME_ONLY:
from . import relay
from . import relax

if not _RUNTIME_ONLY and support.libinfo().get("USE_MICRO", "OFF") == "ON":
from . import micro

Expand Down
4 changes: 2 additions & 2 deletions python/tvm/contrib/debugger/debug_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,12 +76,12 @@ def create(graph_json_str, libmod, device, dump_root=None):
# Automatically set params if they can be extracted from the libmod
try:
params = libmod["get_graph_params"]()
if isinstance(params, tvm.ir.container.Map):
gmod.set_input(**params)
except (AttributeError, tvm.error.RPCError):
# Params can not be extracted from the libmod and must be set somewhere else manually
# Do not set params during RPC communication
pass
else:
gmod.set_input(**params)

return gmod

Expand Down
10 changes: 7 additions & 3 deletions python/tvm/meta_schedule/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,11 +128,15 @@ def __init__(self, *args, **kwargs):
self._inst._outer = weakref.ref(self)

def __getattr__(self, name):
# fall back to instance attribute if there is not any
# return self._inst.__getattribute__(name)
import inspect # pylint: disable=import-outside-toplevel

result = self._inst.__getattribute__(name)
try:
# fall back to instance attribute if there is not any
# return self._inst.__getattribute__(name)
result = self._inst.__getattribute__(name)
except AttributeError:
result = super(TVMDerivedObject, self).__getattr__(name)

if inspect.ismethod(result):

def method(*args, **kwargs):
Expand Down
9 changes: 9 additions & 0 deletions src/arith/iter_affine_map.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2143,6 +2143,15 @@ Array<PrimExpr> IterMapSimplify(const Array<PrimExpr>& indices, const Map<Var, R
/*simplify_trivial_iterators=*/simplify_trivial_iterators);
Array<IterSumExpr> rewrite = res->indices;

if (rewrite.empty() && !is_one(input_pred) && check_level != IterMapLevel::Bijective) {
// The input predicate may cause detect iter map to fail
// but we can still detect the iter map without the input predicate
// in which case the resulting iter map is valid and can be used for simplification.
rewrite = DetectIterMap(indices, input_iters, const_true(), check_level, ana,
/*simplify_trivial_iterators=*/simplify_trivial_iterators)
->indices;
}

if (rewrite.empty()) {
return indices;
}
Expand Down
10 changes: 10 additions & 0 deletions src/tir/ir/tir_visitor_with_path.cc
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,17 @@ void TIRVisitorWithPath::VisitStmt_(const BlockNode* op, ObjectPath path) {
for (size_t i = 0; i < op->match_buffers.size(); i++) {
auto buf = op->match_buffers[i]->buffer;
auto buffer_path = match_path->ArrayIndex(i)->Attr("buffer");
auto buffer_strides_path = buffer_path->Attr("strides");
context.push_back(WithDef(buf->data, buffer_path->Attr("data")));
// Define buffer strides and elem_offset if they are vars
if (const auto* v = buf->elem_offset.as<VarNode>()) {
context.push_back(WithDef(GetRef<Var>(v), buffer_path->Attr("elem_offset")));
}
for (size_t i = 0; i < buf->strides.size(); ++i) {
if (const auto* v = buf->strides[i].as<VarNode>()) {
context.push_back(WithDef(GetRef<Var>(v), buffer_strides_path->ArrayIndex(i)));
}
}
context.push_back(WithDef(buf, buffer_path));
}
}
Expand Down
19 changes: 14 additions & 5 deletions src/tir/transforms/lower_tvm_builtin.cc
Original file line number Diff line number Diff line change
Expand Up @@ -320,13 +320,16 @@ class BuiltinLower : public StmtExprMutator {
PrimExpr VisitExpr_(const CallNode* op) final {
if (op->op.same_as(builtin::tvm_call_packed())) {
return MakeCallPackedGeneric(op, 0, builtin::tvm_call_packed_lowered(),
/* use_string_lookup */ true);
/* use_string_lookup */ true,
/* use_last_value_as_traced_value*/ false);
} else if (op->op.same_as(builtin::tvm_call_cpacked())) {
return MakeCallPackedGeneric(op, 0, builtin::tvm_call_cpacked_lowered(),
/* use_string_lookup */ false);
/* use_string_lookup */ false,
/* use_last_value_as_traced_value*/ false);
} else if (op->op.same_as(builtin::tvm_call_trace_packed())) {
return MakeCallPackedGeneric(op, 0, builtin::tvm_call_trace_packed_lowered(),
/* use_string_lookup */ true);
/* use_string_lookup */ true,
/* use_last_value_as_traced_value*/ true);
} else if (op->op.same_as(builtin::anylist_setitem_call_packed())) {
return MakeAnyListSetItemCallPacked(op, builtin::tvm_call_packed_lowered(), true);
} else if (op->op.same_as(builtin::anylist_setitem_call_cpacked())) {
Expand Down Expand Up @@ -510,7 +513,7 @@ class BuiltinLower : public StmtExprMutator {
PrimExpr list_handle = op->args[0];
PrimExpr list_index = op->args[1];

Call call = MakeCallPackedGeneric(op, 2, lowered_op, use_string_lookup);
Call call = MakeCallPackedGeneric(op, 2, lowered_op, use_string_lookup, false);
PrimExpr value_stack = call->args[1];
PrimExpr tcode_stack = call->args[2];
// The stack offset of return value stack_end
Expand All @@ -528,9 +531,10 @@ class BuiltinLower : public StmtExprMutator {
* \param name_offset The beginning of function name and call packed section.
* \param lowered_packed_op The target lowered op.
* \param use_string_lookup Whether to lookup function by string.
* \param pass_last_arg_as_traced_value Whether to pass last argument as traced value
*/
Call MakeCallPackedGeneric(const CallNode* op, size_t name_offset, const Op& lowered_packed_op,
bool use_string_lookup) {
bool use_string_lookup, bool pass_last_arg_as_traced_value) {
auto& scope = alloca_scope_.back();
auto& prep_seq = prep_seq_stack_.back();

Expand Down Expand Up @@ -571,6 +575,7 @@ class BuiltinLower : public StmtExprMutator {
ConstInt32(arg_stack_begin + num_args)};
// cpacked call resource_handle
if (!use_string_lookup) {
ICHECK(!pass_last_arg_as_traced_value);
PrimExpr last_arg = op->args[args_end];
const VarNode* var_node = last_arg.as<VarNode>();
if (var_node != nullptr) {
Expand All @@ -579,6 +584,10 @@ class BuiltinLower : public StmtExprMutator {
} else {
packed_args.push_back(last_arg);
}
} else if (pass_last_arg_as_traced_value) {
// pass in last element as traced value
// used by call_packed_traced
packed_args.push_back(op->args[op->args.size() - 1]);
}
return Call(op->dtype, lowered_packed_op, packed_args);
}
Expand Down
2 changes: 2 additions & 0 deletions tests/python/codegen/test_target_codegen_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -1069,6 +1069,8 @@ def check_cuda(n, lanes):
check_cuda(64, 2)


@tvm.testing.requires_gpu
@tvm.testing.requires_cuda
def test_cuda_thread_sync_inside_condition():
@T.prim_func
def func1(A: T.Buffer((4, 4), "float32")) -> None:
Expand Down

0 comments on commit 268d15c

Please sign in to comment.