From ae8d398b8852ac84823facd6d58b229e92a8bca3 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Thu, 11 Jan 2024 19:25:28 -0600 Subject: [PATCH 01/16] [CI] In jenkins.cmd_utils.Sh.tee, check for failing subprocess (#16382) Prior to this commit, the `Sh.tee` method was implemented by calling `f"{cmd} | tee"` in `subprocess.run`. While the `check=True` flag was used, the return code was from `tee`, not from the command itself. This causes failures in the command itself to be silently ignored, such as in [this CI pipeline](https://ci.tlcpack.ai/blue/organizations/jenkins/tvm-i386/detail/PR-16183/37/pipeline) in the `ci/scripts/jenkins/s3.py` step. This commit updates `Sh.tee` to call `subprocess.Popen` for `cmd`, tee the stdout, and check the return code. (Roughly adapted from [this stackoverflow post](https://stackoverflow.com/a/56484734).) --- ci/scripts/jenkins/cmd_utils.py | 44 ++++++++++++++++++++++++++------- 1 file changed, 35 insertions(+), 9 deletions(-) diff --git a/ci/scripts/jenkins/cmd_utils.py b/ci/scripts/jenkins/cmd_utils.py index 1b282c50ba0f..57ec39973114 100644 --- a/ci/scripts/jenkins/cmd_utils.py +++ b/ci/scripts/jenkins/cmd_utils.py @@ -58,24 +58,50 @@ def tee(self, cmd: str, **kwargs): """ Run 'cmd' in a shell then return the (process, stdout) as a tuple """ - with tempfile.NamedTemporaryFile(delete=False) as f: - proc = self.run(f"{cmd} | tee {f.name}", **kwargs) - with open(f.name, "r") as f: - output = f.read() - return proc, output + + logging.info(f"+ {cmd}") + + kwargs = { + **self._default_popen_flags(), + **kwargs, + "stdout": subprocess.PIPE, + } + proc = subprocess.Popen(cmd, **kwargs) + + stdout = [] + + def _tee_output(s): + stdout.append(s) + print(s, end="") + + while proc.poll() is None: + _tee_output(proc.stdout.readline()) + _tee_output(proc.stdout.read()) + + stdout = "".join(stdout) + if proc.returncode: + raise subprocess.CalledProcessError(proc.returncode, proc.args, stdout) + + return proc, stdout def run(self, cmd: str, **kwargs): logging.info(f"+ {cmd}") - defaults = { + + kwargs = { + **self._default_popen_flags(), "check": True, + **kwargs, + } + + return subprocess.run(cmd, **kwargs) + + def _default_popen_flags(self): + return { "shell": True, "env": self.env, "encoding": "utf-8", "cwd": self.cwd, } - defaults.update(kwargs) - - return subprocess.run(cmd, **defaults) def tags_from_title(title: str) -> List[str]: From f1bf20a9504e5d6a7e07468ac9bbe568b34b06b4 Mon Sep 17 00:00:00 2001 From: gmeeker Date: Thu, 11 Jan 2024 21:45:14 -0800 Subject: [PATCH 02/16] [RPC] Fix tuning on macOS and Windows (#15771) (#16357) * [RPC] Fix tuning on macOS and Windows (#15771) Fix regression in (#15187) when multiprocessing start method is not 'fork', which prevented tuning from working. This affects macOS and Windows. Also in python 3.14 the default start method will be 'spawn'. * [RPC] clean up _serve_loop function --- python/tvm/rpc/server.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/python/tvm/rpc/server.py b/python/tvm/rpc/server.py index 6ee683c73ba5..ea9576708667 100644 --- a/python/tvm/rpc/server.py +++ b/python/tvm/rpc/server.py @@ -119,6 +119,11 @@ def download_linked_module(file_name): return temp +def _serve_loop(sock, load_library, work_path): + _server_env(load_library, work_path) + _ffi_api.ServerLoop(sock.fileno()) + + def _parse_server_opt(opts): # parse client options ret = {} @@ -135,11 +140,7 @@ def _serving(sock, addr, opts, load_library): os.chdir(work_path.path) # Avoiding file name conflict between sessions. logger.info(f"start serving at {work_path.path}") - def _serve_loop(): - _server_env(load_library, work_path) - _ffi_api.ServerLoop(sock.fileno()) - - server_proc = multiprocessing.Process(target=_serve_loop) + server_proc = multiprocessing.Process(target=_serve_loop, args=(sock, load_library, work_path)) server_proc.start() server_proc.join(opts.get("timeout", None)) # Wait until finish or timeout. From 4258c864b91f1b0b5cffc5ba792a331998f793bd Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Fri, 12 Jan 2024 11:53:19 -0500 Subject: [PATCH 03/16] [RUNTIME][RPC] Enable RPCObjectRef return in RPC (#16387) [Runtime] Enable RPCObjectRef return in RPC This PR enables RPCObjectRef return object similar to the disco transporation. This allows us to do advanced remote debugging when remote vm requires advanced object input like kv cache and shape. To keep the implementation with minRPC(used in some of the limited protocols) forn now, we only support RPCObjectRef for now and do not enable unpacking Shape and String. --- include/tvm/runtime/object.h | 4 +- src/runtime/minrpc/minrpc_server.h | 15 ++++++- src/runtime/minrpc/rpc_reference.h | 8 ++++ src/runtime/rpc/rpc_endpoint.cc | 51 ++++++++++++++++++++---- src/runtime/rpc/rpc_local_session.cc | 20 +++++++++- src/runtime/rpc/rpc_module.cc | 7 ++++ src/runtime/rpc/rpc_session.h | 51 +++++++++++++++++++++++- tests/python/runtime/test_runtime_rpc.py | 31 ++++++++++++++ 8 files changed, 174 insertions(+), 13 deletions(-) diff --git a/include/tvm/runtime/object.h b/include/tvm/runtime/object.h index 94644d797c1a..92f477b058fd 100644 --- a/include/tvm/runtime/object.h +++ b/include/tvm/runtime/object.h @@ -72,8 +72,10 @@ struct TypeIndex { kRuntimeShapeTuple = 6, /*! \brief runtime::PackedFunc. */ kRuntimePackedFunc = 7, - /*! \brief runtime::DRef */ + /*! \brief runtime::DRef for disco distributed runtime */ kRuntimeDiscoDRef = 8, + /*! \brief runtime::RPCObjectRef */ + kRuntimeRPCObjectRef = 9, // static assignments that may subject to change. kRuntimeClosure, kRuntimeADT, diff --git a/src/runtime/minrpc/minrpc_server.h b/src/runtime/minrpc/minrpc_server.h index cca47f80b9df..96a4dbce79cd 100644 --- a/src/runtime/minrpc/minrpc_server.h +++ b/src/runtime/minrpc/minrpc_server.h @@ -206,7 +206,8 @@ class MinRPCExecute : public MinRPCExecInterface { ret_tcode[1] = kTVMBytes; ret_handler_->ReturnPackedSeq(ret_value, ret_tcode, 2); TVMByteArrayFree(reinterpret_cast(ret_value[1].v_handle)); // NOLINT(*) - } else if (rv_tcode == kTVMPackedFuncHandle || rv_tcode == kTVMModuleHandle) { + } else if (rv_tcode == kTVMPackedFuncHandle || rv_tcode == kTVMModuleHandle || + rv_tcode == kTVMObjectHandle) { ret_tcode[1] = kTVMOpaqueHandle; ret_handler_->ReturnPackedSeq(ret_value, ret_tcode, 2); } else { @@ -755,7 +756,17 @@ class MinRPCServer { } void ReadObject(int* tcode, TVMValue* value) { - this->ThrowError(RPCServerStatus::kUnknownTypeCode); + // handles RPCObject in minRPC + // NOTE: object needs to be supported by C runtime + // because minrpc's restriction of C only + // we only handle RPCObjectRef + uint32_t type_index; + Read(&type_index); + MINRPC_CHECK(type_index == kRuntimeRPCObjectRefTypeIndex); + uint64_t object_handle; + Read(&object_handle); + tcode[0] = kTVMObjectHandle; + value[0].v_handle = reinterpret_cast(object_handle); } private: diff --git a/src/runtime/minrpc/rpc_reference.h b/src/runtime/minrpc/rpc_reference.h index e16f09cb9dee..732b017e44fe 100644 --- a/src/runtime/minrpc/rpc_reference.h +++ b/src/runtime/minrpc/rpc_reference.h @@ -33,6 +33,14 @@ class Object; /*! \brief The current RPC procotol version. */ constexpr const char* kRPCProtocolVer = "0.8.0"; +/*! + * \brief type index of kRuntimeRPCObjectRefTypeIndex + * \note this needs to be kept consistent with runtime/object.h + * but we explicitly declare it here because minrpc needs to be minimum dep + * only c C API + */ +constexpr const int kRuntimeRPCObjectRefTypeIndex = 9; + // When tvm.rpc.server.GetCRTMaxPacketSize global function is not registered. const uint64_t kRPCMaxTransferSizeBytesDefault = UINT64_MAX; diff --git a/src/runtime/rpc/rpc_endpoint.cc b/src/runtime/rpc/rpc_endpoint.cc index f2c09132fc70..2c431cdb643c 100644 --- a/src/runtime/rpc/rpc_endpoint.cc +++ b/src/runtime/rpc/rpc_endpoint.cc @@ -175,8 +175,11 @@ class RPCEndpoint::EventHandler : public dmlc::Stream { for (int i = 0; i < num_args; ++i) { int tcode = type_codes[i]; if (tcode == kTVMObjectHandle || tcode == kTVMObjectRValueRefArg) { - LOG(FATAL) << "ValueError: Cannot pass argument " << i << ", type " - << args[i].AsObjectRef()->GetTypeKey() << " is not supported by RPC"; + if (!args[i].IsObjectRef()) { + LOG(FATAL) << "ValueError: Cannot pass argument " << i << ", type " + << args[i].AsObjectRef()->GetTypeKey() + << " is not supported by RPC"; + } } else if (tcode == kDLDevice) { DLDevice dev = args[i]; ICHECK(!IsRPCSessionDevice(dev)) << "InternalError: cannot pass RPC device in the channel"; @@ -219,14 +222,48 @@ class RPCEndpoint::EventHandler : public dmlc::Stream { this->Write(cdata); } - void WriteObject(void* obj) { this->ThrowError(RPCServerStatus::kUnknownTypeCode); } - uint64_t GetObjectBytes(void* obj) { - this->ThrowError(RPCServerStatus::kUnknownTypeCode); - return 0; + void WriteObject(Object* obj) { + // NOTE: for now all remote object are encoded as RPCObjectRef + // follow the same disco protocol in case we would like to upgrade later + // + // Rationale note: Only handle remote object allows the same mechanism to work for minRPC + // which is needed for wasm and other env that goes through C API + if (obj->IsInstance()) { + auto* ref = static_cast(obj); + this->template Write(kRuntimeRPCObjectRefTypeIndex); + uint64_t handle = reinterpret_cast(ref->object_handle()); + this->template Write(handle); + } else { + LOG(FATAL) << "ValueError: Object type is not supported in RPC calling convention: " + << obj->GetTypeKey() << " (type_index = " << obj->type_index() << ")"; + } + } + uint64_t GetObjectBytes(Object* obj) { + if (obj->IsInstance()) { + return sizeof(uint32_t) + sizeof(int64_t); + } else { + LOG(FATAL) << "ValueError: Object type is not supported in RPC calling convention: " + << obj->GetTypeKey() << " (type_index = " << obj->type_index() << ")"; + } } void ReadObject(int* tcode, TVMValue* value) { - this->ThrowError(RPCServerStatus::kUnknownTypeCode); + // NOTE: for now all remote object are encoded as RPCObjectRef + // follow the same disco protocol in case we would like to upgrade later + // + // Rationale note: Only handle remote object allows the same mechanism to work for minRPC + // which is needed for wasm and other env that goes through C API + uint32_t type_index; + this->template Read(&type_index); + if (type_index == kRuntimeRPCObjectRefTypeIndex) { + uint64_t handle; + this->template Read(&handle); + tcode[0] = kTVMObjectHandle; + value[0].v_handle = reinterpret_cast(handle); + } else { + LOG(FATAL) << "ValueError: Object type is not supported in Disco calling convention: " + << Object::TypeIndex2Key(type_index) << " (type_index = " << type_index << ")"; + } } void MessageDone() { diff --git a/src/runtime/rpc/rpc_local_session.cc b/src/runtime/rpc/rpc_local_session.cc index d4aec5596f37..92691ee6fd28 100644 --- a/src/runtime/rpc/rpc_local_session.cc +++ b/src/runtime/rpc/rpc_local_session.cc @@ -27,6 +27,7 @@ #include #include +#include namespace tvm { namespace runtime { @@ -64,7 +65,8 @@ void LocalSession::EncodeReturn(TVMRetValue rv, const FEncodeReturn& encode_retu ret_value_pack[2].v_handle = ret_value_pack[1].v_handle; ret_tcode_pack[2] = kTVMOpaqueHandle; encode_return(TVMArgs(ret_value_pack, ret_tcode_pack, 3)); - } else if (rv_tcode == kTVMPackedFuncHandle || rv_tcode == kTVMModuleHandle) { + } else if (rv_tcode == kTVMPackedFuncHandle || rv_tcode == kTVMModuleHandle || + rv_tcode == kTVMObjectHandle) { // MoveToCHost means rv no longer manages the object. // return handle instead. rv.MoveToCHost(&ret_value_pack[1], &ret_tcode_pack[1]); @@ -88,7 +90,21 @@ void LocalSession::CallFunc(RPCSession::PackedFuncHandle func, const TVMValue* a const FEncodeReturn& encode_return) { PackedFuncObj* pf = static_cast(func); TVMRetValue rv; - pf->CallPacked(TVMArgs(arg_values, arg_type_codes, num_args), &rv); + + // unwrap RPCObjectRef in case we are directly using it to call LocalSession + std::vector values(arg_values, arg_values + num_args); + std::vector type_codes(arg_type_codes, arg_type_codes + num_args); + TVMArgs args(arg_values, arg_type_codes, num_args); + + for (int i = 0; i < num_args; ++i) { + if (args[i].IsObjectRef()) { + RPCObjectRef obj_ref = args[i]; + values[i].v_handle = obj_ref->object_handle(); + continue; + } + } + + pf->CallPacked(TVMArgs(values.data(), type_codes.data(), args.size()), &rv); this->EncodeReturn(std::move(rv), encode_return); } diff --git a/src/runtime/rpc/rpc_module.cc b/src/runtime/rpc/rpc_module.cc index 94f6720ca8da..a696005ab836 100644 --- a/src/runtime/rpc/rpc_module.cc +++ b/src/runtime/rpc/rpc_module.cc @@ -157,6 +157,8 @@ class RPCWrappedFunc : public Object { } }; +TVM_REGISTER_OBJECT_TYPE(RPCObjectRefObj); + // RPC that represents a remote module session. class RPCModuleNode final : public ModuleNode { public: @@ -294,6 +296,11 @@ void RPCWrappedFunc::WrapRemoteReturnToValue(TVMArgs args, TVMRetValue* rv) cons void* handle = args[1]; auto n = make_object(handle, sess_); *rv = Module(n); + } else if (tcode == kTVMObjectHandle) { + ICHECK_EQ(args.size(), 2); + void* handle = args[1]; + auto n = make_object(handle, sess_); + *rv = ObjectRef(n); } else if (tcode == kTVMDLTensorHandle || tcode == kTVMNDArrayHandle) { ICHECK_EQ(args.size(), 3); DLTensor* tensor = args[1]; diff --git a/src/runtime/rpc/rpc_session.h b/src/runtime/rpc/rpc_session.h index 60d067e49d3f..b09900d0abaa 100644 --- a/src/runtime/rpc/rpc_session.h +++ b/src/runtime/rpc/rpc_session.h @@ -142,7 +142,7 @@ class RPCSession { /*! * \brief Free a remote function. - * \param handle The remote handle, can be NDArray/PackedFunc/Module + * \param handle The remote handle, can be NDArray/PackedFunc/Module/Object * \param type_code The type code of the underlying type. */ virtual void FreeHandle(void* handle, int type_code) = 0; @@ -287,6 +287,55 @@ struct RemoteSpace { std::shared_ptr sess; }; +/*! + * \brief Object wrapper that represents a reference to a remote object + */ +class RPCObjectRefObj : public Object { + public: + /*! + * \brief constructor + * \param object_handle handle that points to the remote object + * \param sess The remote session + */ + RPCObjectRefObj(void* object_handle, std::shared_ptr sess) + : object_handle_(object_handle), sess_(sess) {} + + ~RPCObjectRefObj() { + if (object_handle_ != nullptr) { + try { + sess_->FreeHandle(object_handle_, kTVMObjectHandle); + } catch (const Error& e) { + // fault tolerance to remote close + } + object_handle_ = nullptr; + } + } + + const std::shared_ptr& sess() const { return sess_; } + + void* object_handle() const { return object_handle_; } + + static constexpr const uint32_t _type_index = TypeIndex::kRuntimeRPCObjectRef; + static constexpr const char* _type_key = "runtime.RPCObjectRef"; + TVM_DECLARE_FINAL_OBJECT_INFO(RPCObjectRefObj, Object); + + private: + // The object handle + void* object_handle_{nullptr}; + // The local channel + std::shared_ptr sess_; +}; + +/*! + * \brief Managed reference to RPCObjectRefObj. + * \sa RPCObjectRefObj + * \note No public constructor is provided as it is not supposed to be directly created by users. + */ +class RPCObjectRef : public ObjectRef { + public: + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(RPCObjectRef, ObjectRef, RPCObjectRefObj); +}; + /*! * \brief Create a Global RPC module that refers to the session. * \param sess The RPC session of the global module. diff --git a/tests/python/runtime/test_runtime_rpc.py b/tests/python/runtime/test_runtime_rpc.py index 9591e3ea4d60..fff203df0051 100644 --- a/tests/python/runtime/test_runtime_rpc.py +++ b/tests/python/runtime/test_runtime_rpc.py @@ -426,6 +426,7 @@ def test_rpc_return_ndarray(): ref_count = m("ref_count") get_elem = m("get_elem") get_arr_elem = m("get_arr_elem") + # array test def run_arr_test(): arr = get_arr() @@ -435,6 +436,36 @@ def run_arr_test(): run_arr_test() +@tvm.testing.requires_rpc +def test_rpc_return_remote_object(): + def check(client, is_local): + make_shape = client.get_function("runtime.ShapeTuple") + get_elem = client.get_function("runtime.GetShapeTupleElem") + get_size = client.get_function("runtime.GetShapeTupleSize") + shape = make_shape(2, 3) + assert shape.type_key == "runtime.RPCObjectRef" + assert get_elem(shape, 0) == 2 + assert get_elem(shape, 1) == 3 + assert get_size(shape) == 2 + + # start server + server = rpc.Server(key="x1") + client = rpc.connect("127.0.0.1", server.port, key="x1") + check(rpc.LocalSession(), True) + check(client, False) + + def check_minrpc(): + if tvm.get_global_func("rpc.CreatePipeClient", allow_missing=True) is None: + return + # Test minrpc server. + temp = utils.tempdir() + minrpc_exec = temp.relpath("minrpc") + tvm.rpc.with_minrpc(cc.create_executable)(minrpc_exec, []) + check(rpc.PopenSession(minrpc_exec), False) + + check_minrpc() + + @tvm.testing.requires_rpc def test_local_func(): client = rpc.LocalSession() From 196b413813ea6b5e85720118c9aea1fe043a81fb Mon Sep 17 00:00:00 2001 From: TaoMiao Date: Sat, 13 Jan 2024 03:22:18 +0800 Subject: [PATCH 04/16] [Relay][Frontend][Torch] fix a typo mistake in nonzero_numpy (#16390) fix a typo mistake in pytorch frontend nonzero_numpy --- python/tvm/relay/frontend/pytorch.py | 2 +- tests/python/frontend/pytorch/test_forward.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 0213dcc488fd..b9650e6e9a9c 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -2680,7 +2680,7 @@ def nonzero(self, inputs, input_types, is_numpy_style=False): return ret def nonzero_numpy(self, inputs, input_types): - return self.nonzero(inputs, input_types, is_numpy_style=False) + return self.nonzero(inputs, input_types, is_numpy_style=True) def scatter(self, inputs, input_types): assert len(inputs) == 4 or len(inputs) == 5, ( diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index 6178a58b6d13..9bf40cfcdd85 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -4445,6 +4445,7 @@ def forward(self, data): inp = torch.Tensor(np.array([[0, 1, 0], [2, 0, 9], [-1, -1, 0]]).astype("float32")) verify_trace_model(Nonzero(), [inp], ["llvm"]) + verify_trace_model(Nonzero(as_tuple=True), [inp], ["llvm"]) def test_forward_scatter(): From 3e52c3dba5dbcd9d248fca65bd432f1c339c5a1e Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Sun, 14 Jan 2024 19:34:31 -0800 Subject: [PATCH 05/16] [CI] Remove NVIDIA_DISABLE_REQUIRE (#16384) --- ci/jenkins/generated/arm_jenkinsfile.groovy | 4 ++-- ci/jenkins/generated/cortexm_jenkinsfile.groovy | 4 ++-- ci/jenkins/generated/cpu_jenkinsfile.groovy | 4 ++-- ci/jenkins/generated/docker_jenkinsfile.groovy | 4 ++-- ci/jenkins/generated/gpu_jenkinsfile.groovy | 4 ++-- ci/jenkins/generated/hexagon_jenkinsfile.groovy | 4 ++-- ci/jenkins/generated/i386_jenkinsfile.groovy | 4 ++-- ci/jenkins/generated/lint_jenkinsfile.groovy | 4 ++-- ci/jenkins/generated/minimal_cross_isa_jenkinsfile.groovy | 4 ++-- ci/jenkins/generated/minimal_jenkinsfile.groovy | 4 ++-- ci/jenkins/generated/riscv_jenkinsfile.groovy | 4 ++-- ci/jenkins/generated/wasm_jenkinsfile.groovy | 4 ++-- ci/jenkins/templates/utils/base.groovy.j2 | 2 +- 13 files changed, 25 insertions(+), 25 deletions(-) diff --git a/ci/jenkins/generated/arm_jenkinsfile.groovy b/ci/jenkins/generated/arm_jenkinsfile.groovy index eaa2efedabcc..80b2e59a812c 100644 --- a/ci/jenkins/generated/arm_jenkinsfile.groovy +++ b/ci/jenkins/generated/arm_jenkinsfile.groovy @@ -60,7 +60,7 @@ // 'python3 jenkins/generate.py' // Note: This timestamp is here to ensure that updates to the Jenkinsfile are // always rebased on main before merging: -// Generated at 2024-01-06T14:50:09.699639 +// Generated at 2024-01-10T13:15:25.226391 import org.jenkinsci.plugins.pipeline.modeldefinition.Utils // These are set at runtime from data in ci/jenkins/docker-images.yml, update @@ -112,7 +112,7 @@ properties([ upstream_revision = null // command to start a docker container -docker_run = 'docker/bash.sh --env CI --env PLATFORM --env TVM_SHARD_INDEX --env TVM_NUM_SHARDS --env RUN_DISPLAY_URL --env PLATFORM --env SKIP_SLOW_TESTS --env TEST_STEP_NAME --env NVIDIA_DISABLE_REQUIRE=true' +docker_run = 'docker/bash.sh --env CI --env PLATFORM --env TVM_SHARD_INDEX --env TVM_NUM_SHARDS --env RUN_DISPLAY_URL --env PLATFORM --env SKIP_SLOW_TESTS --env TEST_STEP_NAME' docker_build = 'docker/build.sh' // timeout in minutes max_time = 180 diff --git a/ci/jenkins/generated/cortexm_jenkinsfile.groovy b/ci/jenkins/generated/cortexm_jenkinsfile.groovy index 3ec30d18addc..c7b452723089 100644 --- a/ci/jenkins/generated/cortexm_jenkinsfile.groovy +++ b/ci/jenkins/generated/cortexm_jenkinsfile.groovy @@ -60,7 +60,7 @@ // 'python3 jenkins/generate.py' // Note: This timestamp is here to ensure that updates to the Jenkinsfile are // always rebased on main before merging: -// Generated at 2024-01-06T14:50:09.593229 +// Generated at 2024-01-10T13:15:25.121865 import org.jenkinsci.plugins.pipeline.modeldefinition.Utils // These are set at runtime from data in ci/jenkins/docker-images.yml, update @@ -112,7 +112,7 @@ properties([ upstream_revision = null // command to start a docker container -docker_run = 'docker/bash.sh --env CI --env PLATFORM --env TVM_SHARD_INDEX --env TVM_NUM_SHARDS --env RUN_DISPLAY_URL --env PLATFORM --env SKIP_SLOW_TESTS --env TEST_STEP_NAME --env NVIDIA_DISABLE_REQUIRE=true' +docker_run = 'docker/bash.sh --env CI --env PLATFORM --env TVM_SHARD_INDEX --env TVM_NUM_SHARDS --env RUN_DISPLAY_URL --env PLATFORM --env SKIP_SLOW_TESTS --env TEST_STEP_NAME' docker_build = 'docker/build.sh' // timeout in minutes max_time = 180 diff --git a/ci/jenkins/generated/cpu_jenkinsfile.groovy b/ci/jenkins/generated/cpu_jenkinsfile.groovy index 7fa7f84f22d0..f26a69109f63 100644 --- a/ci/jenkins/generated/cpu_jenkinsfile.groovy +++ b/ci/jenkins/generated/cpu_jenkinsfile.groovy @@ -60,7 +60,7 @@ // 'python3 jenkins/generate.py' // Note: This timestamp is here to ensure that updates to the Jenkinsfile are // always rebased on main before merging: -// Generated at 2024-01-06T14:50:09.574671 +// Generated at 2024-01-10T13:15:25.103852 import org.jenkinsci.plugins.pipeline.modeldefinition.Utils // These are set at runtime from data in ci/jenkins/docker-images.yml, update @@ -112,7 +112,7 @@ properties([ upstream_revision = null // command to start a docker container -docker_run = 'docker/bash.sh --env CI --env PLATFORM --env TVM_SHARD_INDEX --env TVM_NUM_SHARDS --env RUN_DISPLAY_URL --env PLATFORM --env SKIP_SLOW_TESTS --env TEST_STEP_NAME --env NVIDIA_DISABLE_REQUIRE=true' +docker_run = 'docker/bash.sh --env CI --env PLATFORM --env TVM_SHARD_INDEX --env TVM_NUM_SHARDS --env RUN_DISPLAY_URL --env PLATFORM --env SKIP_SLOW_TESTS --env TEST_STEP_NAME' docker_build = 'docker/build.sh' // timeout in minutes max_time = 180 diff --git a/ci/jenkins/generated/docker_jenkinsfile.groovy b/ci/jenkins/generated/docker_jenkinsfile.groovy index af0431324e0c..2aa6a173620e 100644 --- a/ci/jenkins/generated/docker_jenkinsfile.groovy +++ b/ci/jenkins/generated/docker_jenkinsfile.groovy @@ -60,7 +60,7 @@ // 'python3 jenkins/generate.py' // Note: This timestamp is here to ensure that updates to the Jenkinsfile are // always rebased on main before merging: -// Generated at 2024-01-06T14:50:09.681292 +// Generated at 2024-01-10T13:15:25.207618 import org.jenkinsci.plugins.pipeline.modeldefinition.Utils // These are set at runtime from data in ci/jenkins/docker-images.yml, update @@ -112,7 +112,7 @@ properties([ upstream_revision = null // command to start a docker container -docker_run = 'docker/bash.sh --env CI --env PLATFORM --env TVM_SHARD_INDEX --env TVM_NUM_SHARDS --env RUN_DISPLAY_URL --env PLATFORM --env SKIP_SLOW_TESTS --env TEST_STEP_NAME --env NVIDIA_DISABLE_REQUIRE=true' +docker_run = 'docker/bash.sh --env CI --env PLATFORM --env TVM_SHARD_INDEX --env TVM_NUM_SHARDS --env RUN_DISPLAY_URL --env PLATFORM --env SKIP_SLOW_TESTS --env TEST_STEP_NAME' docker_build = 'docker/build.sh' // timeout in minutes max_time = 180 diff --git a/ci/jenkins/generated/gpu_jenkinsfile.groovy b/ci/jenkins/generated/gpu_jenkinsfile.groovy index deb3631a69f5..ce0e2c5b0d86 100644 --- a/ci/jenkins/generated/gpu_jenkinsfile.groovy +++ b/ci/jenkins/generated/gpu_jenkinsfile.groovy @@ -60,7 +60,7 @@ // 'python3 jenkins/generate.py' // Note: This timestamp is here to ensure that updates to the Jenkinsfile are // always rebased on main before merging: -// Generated at 2024-01-06T14:50:09.659750 +// Generated at 2024-01-10T13:15:25.186261 import org.jenkinsci.plugins.pipeline.modeldefinition.Utils // These are set at runtime from data in ci/jenkins/docker-images.yml, update @@ -112,7 +112,7 @@ properties([ upstream_revision = null // command to start a docker container -docker_run = 'docker/bash.sh --env CI --env PLATFORM --env TVM_SHARD_INDEX --env TVM_NUM_SHARDS --env RUN_DISPLAY_URL --env PLATFORM --env SKIP_SLOW_TESTS --env TEST_STEP_NAME --env NVIDIA_DISABLE_REQUIRE=true' +docker_run = 'docker/bash.sh --env CI --env PLATFORM --env TVM_SHARD_INDEX --env TVM_NUM_SHARDS --env RUN_DISPLAY_URL --env PLATFORM --env SKIP_SLOW_TESTS --env TEST_STEP_NAME' docker_build = 'docker/build.sh' // timeout in minutes max_time = 180 diff --git a/ci/jenkins/generated/hexagon_jenkinsfile.groovy b/ci/jenkins/generated/hexagon_jenkinsfile.groovy index 22d7c26713d2..e63f6fb14b19 100644 --- a/ci/jenkins/generated/hexagon_jenkinsfile.groovy +++ b/ci/jenkins/generated/hexagon_jenkinsfile.groovy @@ -60,7 +60,7 @@ // 'python3 jenkins/generate.py' // Note: This timestamp is here to ensure that updates to the Jenkinsfile are // always rebased on main before merging: -// Generated at 2024-01-06T14:50:09.557404 +// Generated at 2024-01-10T13:15:25.087221 import org.jenkinsci.plugins.pipeline.modeldefinition.Utils // These are set at runtime from data in ci/jenkins/docker-images.yml, update @@ -112,7 +112,7 @@ properties([ upstream_revision = null // command to start a docker container -docker_run = 'docker/bash.sh --env CI --env PLATFORM --env TVM_SHARD_INDEX --env TVM_NUM_SHARDS --env RUN_DISPLAY_URL --env PLATFORM --env SKIP_SLOW_TESTS --env TEST_STEP_NAME --env NVIDIA_DISABLE_REQUIRE=true' +docker_run = 'docker/bash.sh --env CI --env PLATFORM --env TVM_SHARD_INDEX --env TVM_NUM_SHARDS --env RUN_DISPLAY_URL --env PLATFORM --env SKIP_SLOW_TESTS --env TEST_STEP_NAME' docker_build = 'docker/build.sh' // timeout in minutes max_time = 180 diff --git a/ci/jenkins/generated/i386_jenkinsfile.groovy b/ci/jenkins/generated/i386_jenkinsfile.groovy index 95a041d71307..c6eeba186115 100644 --- a/ci/jenkins/generated/i386_jenkinsfile.groovy +++ b/ci/jenkins/generated/i386_jenkinsfile.groovy @@ -60,7 +60,7 @@ // 'python3 jenkins/generate.py' // Note: This timestamp is here to ensure that updates to the Jenkinsfile are // always rebased on main before merging: -// Generated at 2024-01-06T14:50:09.642757 +// Generated at 2024-01-10T13:15:25.169799 import org.jenkinsci.plugins.pipeline.modeldefinition.Utils // These are set at runtime from data in ci/jenkins/docker-images.yml, update @@ -112,7 +112,7 @@ properties([ upstream_revision = null // command to start a docker container -docker_run = 'docker/bash.sh --env CI --env PLATFORM --env TVM_SHARD_INDEX --env TVM_NUM_SHARDS --env RUN_DISPLAY_URL --env PLATFORM --env SKIP_SLOW_TESTS --env TEST_STEP_NAME --env NVIDIA_DISABLE_REQUIRE=true' +docker_run = 'docker/bash.sh --env CI --env PLATFORM --env TVM_SHARD_INDEX --env TVM_NUM_SHARDS --env RUN_DISPLAY_URL --env PLATFORM --env SKIP_SLOW_TESTS --env TEST_STEP_NAME' docker_build = 'docker/build.sh' // timeout in minutes max_time = 180 diff --git a/ci/jenkins/generated/lint_jenkinsfile.groovy b/ci/jenkins/generated/lint_jenkinsfile.groovy index 20573d704341..64cda41994a3 100644 --- a/ci/jenkins/generated/lint_jenkinsfile.groovy +++ b/ci/jenkins/generated/lint_jenkinsfile.groovy @@ -60,7 +60,7 @@ // 'python3 jenkins/generate.py' // Note: This timestamp is here to ensure that updates to the Jenkinsfile are // always rebased on main before merging: -// Generated at 2024-01-06T14:50:09.718687 +// Generated at 2024-01-10T13:15:25.245060 import org.jenkinsci.plugins.pipeline.modeldefinition.Utils // These are set at runtime from data in ci/jenkins/docker-images.yml, update @@ -112,7 +112,7 @@ properties([ upstream_revision = null // command to start a docker container -docker_run = 'docker/bash.sh --env CI --env PLATFORM --env TVM_SHARD_INDEX --env TVM_NUM_SHARDS --env RUN_DISPLAY_URL --env PLATFORM --env SKIP_SLOW_TESTS --env TEST_STEP_NAME --env NVIDIA_DISABLE_REQUIRE=true' +docker_run = 'docker/bash.sh --env CI --env PLATFORM --env TVM_SHARD_INDEX --env TVM_NUM_SHARDS --env RUN_DISPLAY_URL --env PLATFORM --env SKIP_SLOW_TESTS --env TEST_STEP_NAME' docker_build = 'docker/build.sh' // timeout in minutes max_time = 180 diff --git a/ci/jenkins/generated/minimal_cross_isa_jenkinsfile.groovy b/ci/jenkins/generated/minimal_cross_isa_jenkinsfile.groovy index f4f1ca310a12..6a878b21ee50 100644 --- a/ci/jenkins/generated/minimal_cross_isa_jenkinsfile.groovy +++ b/ci/jenkins/generated/minimal_cross_isa_jenkinsfile.groovy @@ -60,7 +60,7 @@ // 'python3 jenkins/generate.py' // Note: This timestamp is here to ensure that updates to the Jenkinsfile are // always rebased on main before merging: -// Generated at 2024-01-06T14:50:09.520898 +// Generated at 2024-01-10T13:15:25.045689 import org.jenkinsci.plugins.pipeline.modeldefinition.Utils // These are set at runtime from data in ci/jenkins/docker-images.yml, update @@ -112,7 +112,7 @@ properties([ upstream_revision = null // command to start a docker container -docker_run = 'docker/bash.sh --env CI --env PLATFORM --env TVM_SHARD_INDEX --env TVM_NUM_SHARDS --env RUN_DISPLAY_URL --env PLATFORM --env SKIP_SLOW_TESTS --env TEST_STEP_NAME --env NVIDIA_DISABLE_REQUIRE=true' +docker_run = 'docker/bash.sh --env CI --env PLATFORM --env TVM_SHARD_INDEX --env TVM_NUM_SHARDS --env RUN_DISPLAY_URL --env PLATFORM --env SKIP_SLOW_TESTS --env TEST_STEP_NAME' docker_build = 'docker/build.sh' // timeout in minutes max_time = 180 diff --git a/ci/jenkins/generated/minimal_jenkinsfile.groovy b/ci/jenkins/generated/minimal_jenkinsfile.groovy index 55d717b7876d..cdca556bd038 100644 --- a/ci/jenkins/generated/minimal_jenkinsfile.groovy +++ b/ci/jenkins/generated/minimal_jenkinsfile.groovy @@ -60,7 +60,7 @@ // 'python3 jenkins/generate.py' // Note: This timestamp is here to ensure that updates to the Jenkinsfile are // always rebased on main before merging: -// Generated at 2024-01-06T14:50:09.611058 +// Generated at 2024-01-10T13:15:25.139264 import org.jenkinsci.plugins.pipeline.modeldefinition.Utils // These are set at runtime from data in ci/jenkins/docker-images.yml, update @@ -112,7 +112,7 @@ properties([ upstream_revision = null // command to start a docker container -docker_run = 'docker/bash.sh --env CI --env PLATFORM --env TVM_SHARD_INDEX --env TVM_NUM_SHARDS --env RUN_DISPLAY_URL --env PLATFORM --env SKIP_SLOW_TESTS --env TEST_STEP_NAME --env NVIDIA_DISABLE_REQUIRE=true' +docker_run = 'docker/bash.sh --env CI --env PLATFORM --env TVM_SHARD_INDEX --env TVM_NUM_SHARDS --env RUN_DISPLAY_URL --env PLATFORM --env SKIP_SLOW_TESTS --env TEST_STEP_NAME' docker_build = 'docker/build.sh' // timeout in minutes max_time = 180 diff --git a/ci/jenkins/generated/riscv_jenkinsfile.groovy b/ci/jenkins/generated/riscv_jenkinsfile.groovy index 68bf864f84c3..d378a7a5999a 100644 --- a/ci/jenkins/generated/riscv_jenkinsfile.groovy +++ b/ci/jenkins/generated/riscv_jenkinsfile.groovy @@ -60,7 +60,7 @@ // 'python3 jenkins/generate.py' // Note: This timestamp is here to ensure that updates to the Jenkinsfile are // always rebased on main before merging: -// Generated at 2024-01-06T14:50:09.540604 +// Generated at 2024-01-10T13:15:25.070888 import org.jenkinsci.plugins.pipeline.modeldefinition.Utils // These are set at runtime from data in ci/jenkins/docker-images.yml, update @@ -112,7 +112,7 @@ properties([ upstream_revision = null // command to start a docker container -docker_run = 'docker/bash.sh --env CI --env PLATFORM --env TVM_SHARD_INDEX --env TVM_NUM_SHARDS --env RUN_DISPLAY_URL --env PLATFORM --env SKIP_SLOW_TESTS --env TEST_STEP_NAME --env NVIDIA_DISABLE_REQUIRE=true' +docker_run = 'docker/bash.sh --env CI --env PLATFORM --env TVM_SHARD_INDEX --env TVM_NUM_SHARDS --env RUN_DISPLAY_URL --env PLATFORM --env SKIP_SLOW_TESTS --env TEST_STEP_NAME' docker_build = 'docker/build.sh' // timeout in minutes max_time = 180 diff --git a/ci/jenkins/generated/wasm_jenkinsfile.groovy b/ci/jenkins/generated/wasm_jenkinsfile.groovy index 910cb44b722f..1dea6599d0bc 100644 --- a/ci/jenkins/generated/wasm_jenkinsfile.groovy +++ b/ci/jenkins/generated/wasm_jenkinsfile.groovy @@ -60,7 +60,7 @@ // 'python3 jenkins/generate.py' // Note: This timestamp is here to ensure that updates to the Jenkinsfile are // always rebased on main before merging: -// Generated at 2024-01-06T14:50:09.628006 +// Generated at 2024-01-10T13:15:25.155555 import org.jenkinsci.plugins.pipeline.modeldefinition.Utils // These are set at runtime from data in ci/jenkins/docker-images.yml, update @@ -112,7 +112,7 @@ properties([ upstream_revision = null // command to start a docker container -docker_run = 'docker/bash.sh --env CI --env PLATFORM --env TVM_SHARD_INDEX --env TVM_NUM_SHARDS --env RUN_DISPLAY_URL --env PLATFORM --env SKIP_SLOW_TESTS --env TEST_STEP_NAME --env NVIDIA_DISABLE_REQUIRE=true' +docker_run = 'docker/bash.sh --env CI --env PLATFORM --env TVM_SHARD_INDEX --env TVM_NUM_SHARDS --env RUN_DISPLAY_URL --env PLATFORM --env SKIP_SLOW_TESTS --env TEST_STEP_NAME' docker_build = 'docker/build.sh' // timeout in minutes max_time = 180 diff --git a/ci/jenkins/templates/utils/base.groovy.j2 b/ci/jenkins/templates/utils/base.groovy.j2 index 88936598f6ab..68395d05a941 100644 --- a/ci/jenkins/templates/utils/base.groovy.j2 +++ b/ci/jenkins/templates/utils/base.groovy.j2 @@ -85,7 +85,7 @@ properties([ upstream_revision = null // command to start a docker container -docker_run = 'docker/bash.sh --env CI --env PLATFORM --env TVM_SHARD_INDEX --env TVM_NUM_SHARDS --env RUN_DISPLAY_URL --env PLATFORM --env SKIP_SLOW_TESTS --env TEST_STEP_NAME --env NVIDIA_DISABLE_REQUIRE=true' +docker_run = 'docker/bash.sh --env CI --env PLATFORM --env TVM_SHARD_INDEX --env TVM_NUM_SHARDS --env RUN_DISPLAY_URL --env PLATFORM --env SKIP_SLOW_TESTS --env TEST_STEP_NAME' docker_build = 'docker/build.sh' // timeout in minutes max_time = 180 From fe9814c73e24df3ecd031b216492fb555a1ab95a Mon Sep 17 00:00:00 2001 From: Egor Churaev Date: Mon, 15 Jan 2024 09:46:47 +0300 Subject: [PATCH 06/16] [OpenCL][CMake] Fix OpenCL tests compilation (#16394) [OpenCL] Fix OpenCL tests compilation Found a problem when you are in a different cmake project (not TVM) and you run TVM build with OpenCL tests, then `CMAKE_SOURCE_DIR` returns the path to the `CMakeList.txt` in the current project (not to the TVM) and in this case we will see the following error: `No SOURCES given to target: opencl-cpptest`. To be consistent with code style in `OpenCL.cmake`, I removed the usage of `CMAKE_SOURCE_DIR` variable. It also fixes the issue if TVM cmake was called from directory with another cmake project. --- cmake/modules/OpenCL.cmake | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cmake/modules/OpenCL.cmake b/cmake/modules/OpenCL.cmake index 2dc1fc18f36c..ddcd1e4190d1 100644 --- a/cmake/modules/OpenCL.cmake +++ b/cmake/modules/OpenCL.cmake @@ -81,7 +81,7 @@ if(USE_OPENCL) if(Build_OpenCL_GTests) message(STATUS "Building OpenCL-Gtests") tvm_file_glob(GLOB_RECURSE OPENCL_TEST_SRCS - "${CMAKE_SOURCE_DIR}/tests/cpp-runtime/opencl/*.cc" + "tests/cpp-runtime/opencl/*.cc" ) add_executable(opencl-cpptest ${OPENCL_TEST_SRCS}) target_link_libraries(opencl-cpptest PRIVATE gtest_main tvm_runtime) From a7dd32cc168b434b591bc4bfe1f446e42c07e9de Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Mon, 15 Jan 2024 18:43:27 -0500 Subject: [PATCH 07/16] [DeviceAPI] Support querying total global memory (#16398) This PR introduces a new attribute for device backends: `total_global_memory`. This attributes returns the total available global memory on a device in bytes. Tested locally on CUDA/ROCm/Metal/OpenCL: ```python >>> import tvm >>> tvm.metal().total_global_memory 154618822656 ``` --- include/tvm/runtime/device_api.h | 1 + python/tvm/_ffi/runtime_ctypes.py | 14 ++++++++++++++ src/runtime/cuda/cuda_device_api.cc | 10 +++++++++- src/runtime/metal/metal_device_api.mm | 4 ++++ src/runtime/opencl/opencl_device_api.cc | 10 +++++++++- src/runtime/rocm/rocm_device_api.cc | 11 ++++++++++- src/runtime/vulkan/vulkan_device_api.cc | 4 ++++ 7 files changed, 51 insertions(+), 3 deletions(-) diff --git a/include/tvm/runtime/device_api.h b/include/tvm/runtime/device_api.h index e33539daddb7..9ff469b7c837 100644 --- a/include/tvm/runtime/device_api.h +++ b/include/tvm/runtime/device_api.h @@ -50,6 +50,7 @@ enum DeviceAttrKind : int { kApiVersion = 11, kDriverVersion = 12, kL2CacheSizeBytes = 13, + kTotalGlobalMemory = 14, }; #ifdef TVM_KALLOC_ALIGNMENT diff --git a/python/tvm/_ffi/runtime_ctypes.py b/python/tvm/_ffi/runtime_ctypes.py index 7836f4224769..54e4d8f205a1 100644 --- a/python/tvm/_ffi/runtime_ctypes.py +++ b/python/tvm/_ffi/runtime_ctypes.py @@ -506,6 +506,20 @@ def l2_cache_size_bytes(self): """ return self._GetDeviceAttr(self.device_type, self.device_id, 13) + @property + def total_global_memory(self): + """Return size of the total global memory. + + Supported devices include CUDA/ROCm/Metal/OpenCL. + + Returns + ------- + total_global_memory : int or None + Return the global memory available on device in bytes. + Return None if the device does not support this feature. + """ + return self._GetDeviceAttr(self.device_type, self.device_id, 14) + def texture_spatial_limit(self): """Returns limits for textures by spatial dimensions diff --git a/src/runtime/cuda/cuda_device_api.cc b/src/runtime/cuda/cuda_device_api.cc index 769f01063ff2..f493865e0d3c 100644 --- a/src/runtime/cuda/cuda_device_api.cc +++ b/src/runtime/cuda/cuda_device_api.cc @@ -106,12 +106,20 @@ class CUDADeviceAPI final : public DeviceAPI { } case kDriverVersion: return; - case kL2CacheSizeBytes: + case kL2CacheSizeBytes: { // Get size of device l2 cache size in bytes. int l2_size = 0; CUDA_CALL(cudaDeviceGetAttribute(&l2_size, cudaDevAttrL2CacheSize, dev.device_id)); *rv = l2_size; return; + } + case kTotalGlobalMemory: { + cudaDeviceProp prop; + CUDA_CALL(cudaGetDeviceProperties(&prop, dev.device_id)); + int64_t total_global_memory = prop.totalGlobalMem; + *rv = total_global_memory; + return; + } } *rv = value; } diff --git a/src/runtime/metal/metal_device_api.mm b/src/runtime/metal/metal_device_api.mm index f7c2976d2240..c4ffc8943c01 100644 --- a/src/runtime/metal/metal_device_api.mm +++ b/src/runtime/metal/metal_device_api.mm @@ -89,6 +89,10 @@ return; case kL2CacheSizeBytes: return; + case kTotalGlobalMemory: { + *rv = static_cast([devices[dev.device_id] recommendedMaxWorkingSetSize]); + return; + } } }; } diff --git a/src/runtime/opencl/opencl_device_api.cc b/src/runtime/opencl/opencl_device_api.cc index fb9adc27573d..96ec8ed69f2c 100644 --- a/src/runtime/opencl/opencl_device_api.cc +++ b/src/runtime/opencl/opencl_device_api.cc @@ -199,13 +199,21 @@ void OpenCLWorkspace::GetAttr(Device dev, DeviceAttrKind kind, TVMRetValue* rv) *rv = std::string(value); break; } - case kL2CacheSizeBytes: + case kL2CacheSizeBytes: { // NOTE(Zihao): this API cannot reflect the real L2 cache size in both CUDA/AMD GPUs. cl_ulong value; OPENCL_CALL(clGetDeviceInfo(device_id, CL_DEVICE_GLOBAL_MEM_CACHE_SIZE, sizeof(value), &value, nullptr)); *rv = static_cast(value); break; + } + case kTotalGlobalMemory: { + cl_ulong total_global_memory; + OPENCL_CALL(clGetDeviceInfo(device_id, CL_DEVICE_GLOBAL_MEM_SIZE, sizeof(total_global_memory), + &total_global_memory, nullptr)); + *rv = static_cast(total_global_memory); + return; + } } } diff --git a/src/runtime/rocm/rocm_device_api.cc b/src/runtime/rocm/rocm_device_api.cc index c2fb42ee360a..72f17ede5257 100644 --- a/src/runtime/rocm/rocm_device_api.cc +++ b/src/runtime/rocm/rocm_device_api.cc @@ -122,11 +122,20 @@ class ROCMDeviceAPI final : public DeviceAPI { } case kDriverVersion: return; - case kL2CacheSizeBytes: + case kL2CacheSizeBytes: { // Get size of device l2 cache size in bytes. int l2_size; ROCM_CALL(hipDeviceGetAttribute(&l2_size, hipDeviceAttributeL2CacheSize, device.device_id)); *rv = l2_size; + return; + } + case kTotalGlobalMemory: { + hipDeviceProp_t prop; + ROCM_CALL(hipGetDeviceProperties(&prop, device.device_id)); + int64_t total_global_memory = prop.totalGlobalMem; + *rv = total_global_memory; + return; + } } *rv = value; } diff --git a/src/runtime/vulkan/vulkan_device_api.cc b/src/runtime/vulkan/vulkan_device_api.cc index d67746856cfc..e02c9304e126 100644 --- a/src/runtime/vulkan/vulkan_device_api.cc +++ b/src/runtime/vulkan/vulkan_device_api.cc @@ -163,6 +163,10 @@ void VulkanDeviceAPI::GetAttr(Device dev, DeviceAttrKind kind, TVMRetValue* rv) case kL2CacheSizeBytes: break; + + case kTotalGlobalMemory: { + return; + } } } From 68be158d35d83de46f45a4d86caa5400a9dad6a9 Mon Sep 17 00:00:00 2001 From: Bohan Hou Date: Tue, 16 Jan 2024 03:18:22 -0500 Subject: [PATCH 08/16] [ROCm] Some fixes of ROCm codegen (#16404) - Handle tvm_thread_invariant as no op. - `llvm.amdgcn.ds.bpermute` requires i32 as its input, but it can handle all 32 bit types - ocml intrinsics lead to incorrect codegen when used with vectorization, remove it and use llvm intrinsics instead --- src/target/llvm/codegen_llvm.cc | 2 + src/target/llvm/intrin_rule_rocm.cc | 87 +++++++++++-------- src/tir/transforms/lower_thread_allreduce.cc | 2 +- .../codegen/test_target_codegen_rocm.py | 53 +++++++++++ 4 files changed, 108 insertions(+), 36 deletions(-) diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index 3d4d3def2411..9701a299f1d1 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -1476,6 +1476,8 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) { } else if (op->op.same_as(builtin::assume())) { llvm::Value* cond = MakeValue(op->args[0]); return builder_->CreateAssumption(cond); + } else if (op->op.same_as(builtin::tvm_thread_invariant())) { + return MakeValue(op->args[0]); } else { LOG(FATAL) << "unknown intrinsic " << op->op; } diff --git a/src/target/llvm/intrin_rule_rocm.cc b/src/target/llvm/intrin_rule_rocm.cc index d25126f5d828..0fbfade3354a 100644 --- a/src/target/llvm/intrin_rule_rocm.cc +++ b/src/target/llvm/intrin_rule_rocm.cc @@ -89,8 +89,14 @@ inline PrimExpr DispatchShuffle(const PrimExpr& e) { index = self + delta; index = Select((self & (width - 1)) + delta >= width, self, index); } + // reinterprete var as int32 + bool is_int32 = var.dtype().is_int() && var.dtype().bits() == 32; + PrimExpr source = is_int32 ? var : reinterpret(DataType::Int(32), var); PrimExpr res = Call(DataType::Int(32), builtin::call_pure_extern(), - {StringImm("llvm.amdgcn.ds.bpermute"), index << 2, var}); + {StringImm("llvm.amdgcn.ds.bpermute"), index << 2, source}); + if (!is_int32) { + res = reinterpret(var.dtype(), res); + } return res; } @@ -114,73 +120,84 @@ TVM_REGISTER_OP("tir.tvm_warp_shuffle_down") .set_attr("rocm.FLowerIntrinsic", DispatchShuffle); TVM_REGISTER_OP("tir.floor") - .set_attr("rocm.FLowerIntrinsic", DispatchPureExternOCML); + .set_attr("rocm.FLowerIntrinsic", + DispatchLLVMPureIntrin<::llvm::Intrinsic::floor, 1>); TVM_REGISTER_OP("tir.ceil") - .set_attr("rocm.FLowerIntrinsic", DispatchPureExternOCML); + .set_attr("rocm.FLowerIntrinsic", + DispatchLLVMPureIntrin<::llvm::Intrinsic::ceil, 1>); TVM_REGISTER_OP("tir.round") - .set_attr("rocm.FLowerIntrinsic", DispatchPureExternOCML); + .set_attr("rocm.FLowerIntrinsic", + DispatchLLVMPureIntrin<::llvm::Intrinsic::round, 1>); TVM_REGISTER_OP("tir.nearbyint") - .set_attr("rocm.FLowerIntrinsic", DispatchPureExternOCML); + .set_attr("rocm.FLowerIntrinsic", + DispatchLLVMPureIntrin<::llvm::Intrinsic::nearbyint, 1>); TVM_REGISTER_OP("tir.trunc") - .set_attr("rocm.FLowerIntrinsic", DispatchPureExternOCML); + .set_attr("rocm.FLowerIntrinsic", + DispatchLLVMPureIntrin<::llvm::Intrinsic::trunc, 1>); TVM_REGISTER_OP("tir.fabs") - .set_attr("rocm.FLowerIntrinsic", DispatchPureExternOCML); + .set_attr("rocm.FLowerIntrinsic", + DispatchLLVMPureIntrin<::llvm::Intrinsic::fabs, 1>); -TVM_REGISTER_OP("tir.exp").set_attr("rocm.FLowerIntrinsic", - DispatchPureExternOCML); +TVM_REGISTER_OP("tir.exp").set_attr( + "rocm.FLowerIntrinsic", DispatchLLVMPureIntrin<::llvm::Intrinsic::exp, 1>); TVM_REGISTER_OP("tir.exp2") - .set_attr("rocm.FLowerIntrinsic", DispatchPureExternOCML); + .set_attr("rocm.FLowerIntrinsic", + DispatchLLVMPureIntrin<::llvm::Intrinsic::exp2, 1>); -TVM_REGISTER_OP("tir.exp10") - .set_attr("rocm.FLowerIntrinsic", DispatchPureExternOCML); +// TVM_REGISTER_OP("tir.exp10") +// .set_attr("rocm.FLowerIntrinsic", +// DispatchLLVMPureIntrin<::llvm::Intrinsic::exp10, 1>); -TVM_REGISTER_OP("tir.erf").set_attr("rocm.FLowerIntrinsic", - DispatchPureExternOCML); +// TVM_REGISTER_OP("tir.erf").set_attr("rocm.FLowerIntrinsic", +// DispatchPureExternOCML); TVM_REGISTER_OP("tir.fma").set_attr( "rocm.FLowerIntrinsic", DispatchLLVMPureIntrin<::llvm::Intrinsic::fmuladd, 3>); -TVM_REGISTER_OP("tir.log").set_attr("rocm.FLowerIntrinsic", - DispatchPureExternOCML); +TVM_REGISTER_OP("tir.log").set_attr( + "rocm.FLowerIntrinsic", DispatchLLVMPureIntrin<::llvm::Intrinsic::log, 1>); TVM_REGISTER_OP("tir.log2") - .set_attr("rocm.FLowerIntrinsic", DispatchPureExternOCML); + .set_attr("rocm.FLowerIntrinsic", + DispatchLLVMPureIntrin<::llvm::Intrinsic::log2, 1>); TVM_REGISTER_OP("tir.log10") - .set_attr("rocm.FLowerIntrinsic", DispatchPureExternOCML); + .set_attr("rocm.FLowerIntrinsic", + DispatchLLVMPureIntrin<::llvm::Intrinsic::log10, 1>); TVM_REGISTER_OP("tir.sqrt") - .set_attr("rocm.FLowerIntrinsic", DispatchPureExternOCML); + .set_attr("rocm.FLowerIntrinsic", + DispatchLLVMPureIntrin<::llvm::Intrinsic::sqrt, 1>); -TVM_REGISTER_OP("tir.pow").set_attr("rocm.FLowerIntrinsic", - DispatchPureExternOCML); +TVM_REGISTER_OP("tir.pow").set_attr( + "rocm.FLowerIntrinsic", DispatchLLVMPureIntrin<::llvm::Intrinsic::pow, 2>); -TVM_REGISTER_OP("tir.tanh") - .set_attr("rocm.FLowerIntrinsic", DispatchPureExternOCML); +// TVM_REGISTER_OP("tir.tanh") +// .set_attr("rocm.FLowerIntrinsic", DispatchPureExternOCML); -TVM_REGISTER_OP("tir.tan").set_attr("rocm.FLowerIntrinsic", - DispatchPureExternOCML); +// TVM_REGISTER_OP("tir.tan").set_attr("rocm.FLowerIntrinsic", +// DispatchPureExternOCML); -TVM_REGISTER_OP("tir.cos").set_attr("rocm.FLowerIntrinsic", - DispatchPureExternOCML); +TVM_REGISTER_OP("tir.cos").set_attr( + "rocm.FLowerIntrinsic", DispatchLLVMPureIntrin<::llvm::Intrinsic::cos, 1>); -TVM_REGISTER_OP("tir.cosh") - .set_attr("rocm.FLowerIntrinsic", DispatchPureExternOCML); +// TVM_REGISTER_OP("tir.cosh") +// .set_attr("rocm.FLowerIntrinsic", DispatchPureExternOCML); -TVM_REGISTER_OP("tir.sin").set_attr("rocm.FLowerIntrinsic", - DispatchPureExternOCML); +TVM_REGISTER_OP("tir.sin").set_attr( + "rocm.FLowerIntrinsic", DispatchLLVMPureIntrin<::llvm::Intrinsic::sin, 1>); -TVM_REGISTER_OP("tir.sinh") - .set_attr("rocm.FLowerIntrinsic", DispatchPureExternOCML); +// TVM_REGISTER_OP("tir.sinh") +// .set_attr("rocm.FLowerIntrinsic", DispatchPureExternOCML); -TVM_REGISTER_OP("tir.atan") - .set_attr("rocm.FLowerIntrinsic", DispatchPureExternOCML); +// TVM_REGISTER_OP("tir.atan") +// .set_attr("rocm.FLowerIntrinsic", DispatchPureExternOCML); } // namespace llvm } // namespace codegen diff --git a/src/tir/transforms/lower_thread_allreduce.cc b/src/tir/transforms/lower_thread_allreduce.cc index 1b2e8e9db04a..7094d6adaf3c 100644 --- a/src/tir/transforms/lower_thread_allreduce.cc +++ b/src/tir/transforms/lower_thread_allreduce.cc @@ -730,7 +730,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { // rocm only supports 32 bit operands for shuffling at the moment if ((target_->kind->name == "rocm") && (std::any_of(types.begin(), types.end(), [](DataType ty) { - if ((ty.is_vector()) || !ty.is_int()) return true; + if (ty.is_vector()) return ty.bits() * ty.lanes() != 32; return ty.bits() != 32; }))) { return false; diff --git a/tests/python/codegen/test_target_codegen_rocm.py b/tests/python/codegen/test_target_codegen_rocm.py index 3e286f6ebff2..a0990c330f03 100644 --- a/tests/python/codegen/test_target_codegen_rocm.py +++ b/tests/python/codegen/test_target_codegen_rocm.py @@ -19,6 +19,7 @@ from tvm import te import numpy as np import unittest +from tvm.script import tir as T tx = te.thread_axis("threadIdx.x") ty = te.thread_axis("threadIdx.y") @@ -130,9 +131,61 @@ def check_rocm(dtype, n, lanes): check_rocm("float16", 64, 2) +@tvm.testing.requires_rocm +def test_rocm_warp_shuffle(): + @T.prim_func + def func( + A_handle: T.handle, + ): + A = T.match_buffer(A_handle, (32,), dtype="float32") + + for bx in T.thread_binding(1, thread="blockIdx.x"): + for tx in T.thread_binding(32, thread="threadIdx.x"): + with T.block("test"): + A_local = T.alloc_buffer((1,), "float32", scope="local") + mask = T.alloc_buffer((1,), "uint32", scope="local") + t0 = T.alloc_buffer((1,), "float32", scope="local") + + A_local[0] = A[tx] + A_local[0] = T.tvm_warp_shuffle(mask[0], A_local[0], 0, 32, 32) + A[tx] = A_local[0] + + mod = tvm.build(func, target="rocm") + dev = tvm.rocm(0) + a = tvm.nd.array(np.random.uniform(size=(32,)).astype("float32"), dev) + mod(a) + tvm.testing.assert_allclose(a.numpy(), np.ones((32,)) * a.numpy()[0]) + + +@tvm.testing.requires_rocm +def test_rocm_vectorized_exp(): + @T.prim_func + def func( + A_handle: T.handle, + B_handle: T.handle, + ): + A = T.match_buffer(A_handle, (4,), dtype="float32") + B = T.match_buffer(B_handle, (4,), dtype="float32") + + for bx in T.thread_binding(1, thread="blockIdx.x"): + for tx in T.thread_binding(1, thread="threadIdx.x"): + with T.block("test"): + for i in T.vectorized(0, 4): + B[i] = T.exp2(A[i]) + + mod = tvm.build(func, target="rocm") + dev = tvm.rocm(0) + a = tvm.nd.array(np.ones((4,)).astype("float32"), dev) + b = tvm.nd.array(np.zeros((4,)).astype("float32"), dev) + mod(a, b) + tvm.testing.assert_allclose(b.numpy(), np.exp2(a.numpy())) + + if __name__ == "__main__": test_rocm_cross_thread_reduction() test_rocm_inf_nan() test_rocm_reduction_binding() test_rocm_copy() test_rocm_vectorize_add() + test_rocm_warp_shuffle() + test_rocm_vectorized_exp() From 3053f65da7b156763fdb3110835947113b327b39 Mon Sep 17 00:00:00 2001 From: Andrey Malyshev Date: Tue, 16 Jan 2024 17:00:30 +0200 Subject: [PATCH 09/16] Add NVIDIA Hopper H100 target tag (#16407) Co-authored-by: Andrey Malyshev --- src/target/tag.cc | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/target/tag.cc b/src/target/tag.cc index d8f341351d21..9caeec3b9205 100644 --- a/src/target/tag.cc +++ b/src/target/tag.cc @@ -155,6 +155,8 @@ TVM_REGISTER_CUDA_TAG("nvidia/tesla-c2050", "sm_20", 49152, 32768); TVM_REGISTER_CUDA_TAG("nvidia/tesla-c2070", "sm_20", 49152, 32768); TVM_REGISTER_CUDA_TAG("nvidia/nvidia-a100", "sm_80", 49152, 65536) .with_config("l2_cache_size_bytes", Integer(41943040)); +TVM_REGISTER_CUDA_TAG("nvidia/nvidia-h100", "sm_90", 49152, 65536) + .with_config("l2_cache_size_bytes", Integer(52428800)); TVM_REGISTER_CUDA_TAG("nvidia/nvidia-a40", "sm_86", 49152, 65536); TVM_REGISTER_CUDA_TAG("nvidia/nvidia-a30", "sm_80", 49152, 65536); TVM_REGISTER_CUDA_TAG("nvidia/nvidia-a10", "sm_86", 49152, 65536); From 12ad4fbcf43f3d73d757e69b1b9c02e45a291ffa Mon Sep 17 00:00:00 2001 From: TaoMiao Date: Wed, 17 Jan 2024 05:15:02 +0800 Subject: [PATCH 10/16] [Relay][Frontend][Torch] fix pytorch frontend not support logical or (#16400) add logical_or to relay pytorch frontend --- python/tvm/relay/frontend/pytorch.py | 7 +++++++ tests/python/frontend/pytorch/test_forward.py | 15 +++++++++++++++ 2 files changed, 22 insertions(+) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index b9650e6e9a9c..35f74544b833 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -2672,6 +2672,12 @@ def logical_and(self, inputs, input_types): return _op.logical_and(lhs, rhs) + def logical_or(self, inputs, input_types): + lhs = _op.cast(inputs[0], "bool") + rhs = _op.cast(inputs[1], "bool") + + return _op.logical_or(lhs, rhs) + def nonzero(self, inputs, input_types, is_numpy_style=False): data = inputs[0] ret = _op.transform.argwhere(data) @@ -4238,6 +4244,7 @@ def create_convert_map(self): "aten::unbind": self.unbind, "aten::__and__": self.logical_and, "aten::logical_and": self.logical_and, + "aten::logical_or": self.logical_or, "aten::_shape_as_tensor": self.shape_as_tensor, "aten::nonzero": self.nonzero, "aten::nonzero_numpy": self.nonzero_numpy, diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index 9bf40cfcdd85..bf96c21399f0 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -4882,6 +4882,21 @@ def test_fn(x, y): verify_model(test_fn, [a, b]) +def test_logical_or(): + """test_logical_or""" + + def test_fn(x, y): + return torch.logical_or(x, y) + + a = torch.tensor([0, 1, 10, 0], dtype=torch.int8) + b = torch.tensor([4, 0, 1, 0], dtype=torch.int8) + verify_model(test_fn, [a, b]) + + a = torch.tensor([True, False, True]) + b = torch.tensor([True, False, False]) + verify_model(test_fn, [a, b]) + + def test_masked_select(): """test_masked_select""" From 7ef521fad626a87b3ce5a5060865618618bcd454 Mon Sep 17 00:00:00 2001 From: ysh329 Date: Thu, 18 Jan 2024 00:19:14 +0800 Subject: [PATCH 11/16] [COMMUNITY] Add new key for release signing (#16419) Co-authored-by: Star Yuan --- KEYS | 59 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 59 insertions(+) diff --git a/KEYS b/KEYS index 3cee902acd32..32e451097045 100644 --- a/KEYS +++ b/KEYS @@ -642,3 +642,62 @@ WTgrESErlqNLN5ZTTW/1jBELJCfJKxgHUip+Yo6qNZoWwNLP1BaIcoA3miSG3DXf wS/UuN04NxDy7V6mPXE= =MTba -----END PGP PUBLIC KEY BLOCK----- +pub rsa4096 2024-01-15 [SC] + A4D9228E55761E665BF01CBB5CE869CB7DEC048C +uid [ultimate] Star Yuan (CODE SIGNING KEY) +sig 3 5CE869CB7DEC048C 2024-01-15 Star Yuan (CODE SIGNING KEY) +sub rsa4096 2024-01-15 [E] +sig 5CE869CB7DEC048C 2024-01-15 Star Yuan (CODE SIGNING KEY) + +-----BEGIN PGP PUBLIC KEY BLOCK----- + +mQINBGWlStcBEADaslyfbUNARhWftJoRAChoak0cFU6NxahhvyZfyTGtSuwuHNDD +2eyvhnDIaYXVClxoNgikiQ5Nkd1jtbA4rFCw6Pdbq+98fkpcr8N4o+jlbpu6Ff3j +dJ2Qu000MV5qe9FZ4QasdfglJElvizgfNbJv/Fz1ERl/BS1U0c7lyQF9jGGh7EY2 +1y+JFp5OMG6A9SpfaOd+iOw5/cfCQk8+sHQC4dp3hOJPK4NLvjotK+hlOhRsF7gU +goYYT2IP56kPQb6U/Uiv4/R6HbKugzqSMl6BMwAb9uG6UX0xUfAA8ciHoaITCJCQ +9e/jGWnDnqYlAMNqLkHEmW7THxJ3hHXcac/Z1C3PeLDJU0rpTxDcjuYkM5jFCu7H +TgT7lWBP/PyAAVSsLqMQbLJOWm0a14tb/oRoeYr/B2prIbJY5qJBM1nherKGMg0G +7Oqugo6A1VqgUxg7Chj73PledaNwvm5Lxpl6D+wPDSifhlz0vnwOCMoOon0pTjK4 +DXDEXnEXZtzkZgXI6g7AkVyt0gkqyUi+01ibmlBfcVHh3PVvU4oNdkaywQd5s29R +DsA4WOqt9cLv+iqIzM1juygfR6ooA1jHDIyIPmmC/kOrcxKXEFvIGXDDCbXAvdXc +uXgZeZqI3pbKjQaU3fF8HwJ956HTM8rywtVGH9BWRl/i6qn5sq9CcukcuQARAQAB +tDBTdGFyIFl1YW4gKENPREUgU0lHTklORyBLRVkpIDx5c2gzMjlAYXBhY2hlLm9y +Zz6JAk4EEwEKADgWIQSk2SKOVXYeZlvwHLtc6GnLfewEjAUCZaVK1wIbAwULCQgH +AgYVCgkICwIEFgIDAQIeAQIXgAAKCRBc6GnLfewEjBAiD/0cfaYfQ0DL7CPsP0lS +yezPDDTnDPIo//G1cuSYG0gnXQ1SpbJSzDE7deew+P506/sWFneOY5Kuv6DuSE8J +nM6vv1EYR4/9x/XstA4F04lQPngKKBV+UKrWj8zIA2Drn345Ece1150bWvrUD7mT ++ps1gfe8SGYpOmR/kRc8qra2zizcWBC1Dl4qd+RcY7Ac6Cu3G/JG2KvZnrUSVev9 +nzSl2V0JtFVIla2odSJqv0Zdj5E2vLvQd3Dxbf3BODCdL3iQqxrQhj+0T3QLEhPg +y2XOtqW7a96XosoQ44wUiHaS5LwFViG8LoiPADtSdXYb8m4FtMfB8t4mzXVqBjpz +2csMqOnNvo7bctfpJkjM14UKib39MR2wUv9fD6Qa+OAAIeXGTQH+wlXmlYjji9+A +4tgq/+d75qUC/tyHSgbZLNXobHF8v77g60cBvFXVL02W53xhVDZP4gwu5iSSN8BJ +a2hqwo4UO53mRUNkwFZONYxJE7MhLl22r08eu0xNYhoGtpHzDVoyHg26+2FUgFDd +TNsdqjMyJ+3GXEE3PdKVDTj9To+RoHLuCczk5uvtFYGhseRwIWbVhmTLKUL+wgSa ++b90slkv+CBJvLjvKbVCmCLXwiH8Cx+MZSu0oM5v8fbHuWOhkb7bJd1V+U7qV/OA +CCqBICt64F+ooQ0oEdC0oLvr2LkCDQRlpUrXARAA1DKsF2ZNUdPIn4VcsjRk/+qF +13VC9SaqMp+J+8m1XTIeXdr27uUa2vT4j8pAM4gwMVkpEqE0rmHK+S1SeEAlcizC +Bvp7vvso/glcOg9Sgt9PXvvEDPL/Hnsn1+3YX+Gye4cOTiDDgVW1RKcgGj9Xsir+ +5BS9Secj5CGo92cuaqIo/mMjxGlsuW/LvTU5qQhz7aOaBibe5EHPlGMqM6XJN0BZ +MHRfBiGDs2n/egMnTPL0JcTlAeird+yxDPULKzhQWkd8rfQKpwcRiY6IcYFHlWdM +VhZkXNRrxh6+q3rR7FKmxlvG/12YyT6Y1BocGLgROzKIeoEp+6vsU5LJ90jy82ig +oGSHwNjm2RRukjV3eebovl1dCo6IaI/j4idCv7NlcBnln/Unk4YOZbneMT5r+3Zy +Q4azLB8KHfHOrUwAxRAGPygdLtqbjs4mF45HDe6h3IOVoiOQlZNpesrwEumlK+Il +taU0T8hfxyMpIcTLUZpIddSxo0sVby2XZ+z00En3JvtqbpRcfA87thxpsE7uHxwT +YT8mPPDxo1R4I4LSzsDnekD8EB/7woz4n5I1RBoPB1LSoo0B2os+4vHGkiwZ0TN0 +ICcUYdM623Bv2wJQbVKEDvwjHZTkotjLx7R2lyqMRwFYrMXHxevOfbARJQCqrcY2 +ouLzQme9rE5MPQbKj2cAEQEAAYkCNgQYAQoAIBYhBKTZIo5Vdh5mW/Acu1zoact9 +7ASMBQJlpUrXAhsMAAoJEFzoact97ASMNsIP/3tlsvwUVfy19lUjxWT4rPw2GGz8 +lbPiaetgigK1F1rlzYnIVo32Fcj/GNNwWEdxxEzeaQR/AJmZLWB8sBDThoTGeSDK +fjKXeDjZh+ElpIKWyk7f3ddHN2TpBz698kZ7fYCciRE9T4d3xgbqx2rCfupxUFSj +lxLFRkasByJnLdAZI50NZjW838IHMaGsvgbWEqRuvKZOES6gFhrK1NTSxj5iuiHk +Uxj1KzMhOW+m1eZ0pQcCVXJDY6KYhmrZzw9q6kzSO9ukmS5yRf0EnD7Fsca4iIXP +Y28xs3zBxYHV4IGU1PtcIwNewmTnjnEy0apHPz0zDplHi1meXuhA7bBMjs/AouJg +6FIDNSQqDuFXufqvVQ6LZZgob+LklMAoGcka4/5ZLPjipj5SWNeZZunJujSqWK7f +KJaIfn7ILXqxjaTFrjBN3cm60rO1+zEektrjtWMmSBn0L76pY2ucenrqewruYYdD +12VQra/6QAS5R0HG8gzOfsZcrHaiIuLoTbsOgnqLVcdb9lO7f3oMbKPwejZ5yhyz +SraXHvmixlhf4uUYwsWyhw3UgHrv1psB8Z9NfdH9/T2BvRg0qy6ZmI0n0OagPNgz +v+SZrqrWkSjyPdl6j7x8EmePfNidqw/CnncYI2rEVSmP28W0Uhg5JLgroGYmycv6 +HeZaRpYvkV8UNmnE +=BtHq +-----END PGP PUBLIC KEY BLOCK----- From a5e883e8465e11221d3f22d6ef2f61a1bfa5d1f2 Mon Sep 17 00:00:00 2001 From: krishnaraj36 Date: Thu, 18 Jan 2024 12:38:57 +0530 Subject: [PATCH 12/16] [RUNTIME][CLML] Fix for Softmax op for 4D tensors (#16328) Fixed the softmax layer for 4D tensors to support for NCHW and NHWC layout types. Enabled relevant test cases for softmax layer --- python/tvm/relay/op/contrib/clml.py | 3 +- src/runtime/contrib/clml/clml_runtime.cc | 62 ++++++++++++---- tests/python/contrib/test_clml/test_ops.py | 86 ++++++++++++---------- 3 files changed, 98 insertions(+), 53 deletions(-) diff --git a/python/tvm/relay/op/contrib/clml.py b/python/tvm/relay/op/contrib/clml.py index 14dd35a3cb4b..53b022c347b4 100644 --- a/python/tvm/relay/op/contrib/clml.py +++ b/python/tvm/relay/op/contrib/clml.py @@ -437,7 +437,8 @@ def check_pad_op(extract): def check_softmax_op(extract): call = extract - if len(call.args[0].checked_type.shape) > 2: + # supports 2D and 4D tensors + if len(call.args[0].checked_type.shape) not in [2, 4]: return False return True diff --git a/src/runtime/contrib/clml/clml_runtime.cc b/src/runtime/contrib/clml/clml_runtime.cc index aa1e2b82b657..8e69cb8bd13b 100644 --- a/src/runtime/contrib/clml/clml_runtime.cc +++ b/src/runtime/contrib/clml/clml_runtime.cc @@ -511,6 +511,7 @@ class CLMLRuntime : public JSONRuntimeBase { /*! * \brief Create an CLML tensor from JSON node entry. Lookup storage map before creation. + * Update input placeholder for NHWC layout * * \param nid The node index of graph JSON. * \param shape shape information of tensor @@ -528,15 +529,22 @@ class CLMLRuntime : public JSONRuntimeBase { uint32_t eid = EntryID(nid, 0); node_data = data_entry_[eid]->data; } + auto clml_tensor = MakeCLMLTensorFromJSONNode(node, layout, dtype, node_data, shape); + this->layer_.storage_map.insert({nid, std::make_pair(clml_tensor, node)}); if ("input" == node.GetOpType()) { this->layer_.inputs.insert({nid, this->layer_.storage_map[nid].first}); // Input copy placeholder Tensor - this->layer_.in_placeholder.insert( - {nid, MakeCLMLTensorFromJSONNode(node, CL_TENSOR_LAYOUT_NCHW_QCOM, dtype, node_data, - shape)}); + if (layout == CL_TENSOR_LAYOUT_OPTIMAL_QCOM) { + this->layer_.in_placeholder.insert( + {nid, MakeCLMLTensorFromJSONNode(node, CL_TENSOR_LAYOUT_NCHW_QCOM, dtype, node_data, + shape)}); + } else { + this->layer_.in_placeholder.insert( + {nid, MakeCLMLTensorFromJSONNode(node, layout, dtype, node_data, shape)}); + } } return clml_tensor; @@ -559,6 +567,7 @@ class CLMLRuntime : public JSONRuntimeBase { const auto& node = nodes_[nid]; if ("nn.dense" == node.GetOpName()) CreateDenseLayerTensor(&layer_, node, nid); if ("nn.batch_matmul" == node.GetOpName()) CreateBatchMatmulLayerTensor(&layer_, node, nid); + if ("nn.softmax" == node.GetOpName()) CreateSoftmaxLayerTensor(&layer_, node, nid); } for (nid = 0; nid < nodes_.size(); ++nid) { @@ -1092,6 +1101,37 @@ class CLMLRuntime : public JSONRuntimeBase { return; } + /*! + * \brief Create a Softmax layer Tensors with supported layout. + * \param layer The CLML layer to build. Containing inputs, outputs and the CLML function. + * \param node The JSON representation of the operator. + * \param nid The node index of JSON graph node, which points to this operator. + */ + + void CreateSoftmaxLayerTensor(CachedLayer* layer, const JSONGraphNode& node, size_t nid) { + cl_ml_tensor_layout_qcom layout; + cl_int result = 0; + cl_ml_op_qcom op = nullptr; + DLDataType tvm_dtype = node.GetOpDataType()[0]; + cl_channel_type cl_dtype = MakeCLDataType(tvm_dtype); + auto out_dims = GetTensorDims(nodes_[node.GetInputs()[0].id_]); + int axis = std::stoi(node.GetAttr>("axis")[0]); + // enabling NHWC layout && NCHW layout for 4D, basis the axis value + if (out_dims.h >= 1 && out_dims.w >= 1) { + if (axis == 3 || axis == -1) { + layout = CL_TENSOR_LAYOUT_NHWC_QCOM; + } else { + layout = CL_TENSOR_LAYOUT_NCHW_QCOM; + } + } else { // default layout for 2D + layout = CL_TENSOR_LAYOUT_OPTIMAL_QCOM; + } + auto output = MakeCLMLTensorFromJSONEntry(nid, {}, layout, cl_dtype); + auto input = MakeCLMLTensorFromJSONEntry(node.GetInputs()[0].id_, {}, layout, cl_dtype); + + return; + } + /*! * \brief Create a SoftMax layer. * @@ -1100,24 +1140,20 @@ class CLMLRuntime : public JSONRuntimeBase { * \param nid The node index of JSON graph node, which points to this operator. */ void CreateSoftMaxLayer(CachedLayer* layer, const JSONGraphNode& node, size_t nid) { + cl_ml_tensor_layout_qcom layout; + cl_softmax_mode_qcom mode = CL_SOFTMAX_MODE_SPATIAL_QCOM; cl_int result = 0; cl_ml_op_qcom op = nullptr; DLDataType tvm_dtype = node.GetOpDataType()[0]; cl_channel_type cl_dtype = MakeCLDataType(tvm_dtype); cl_arithmetic_mode_qcom cl_arithmetic_mode = MakeCLArithMode(cl_dtype, cl_dtype); - auto input = MakeCLMLTensorFromJSONEntry(node.GetInputs()[0].id_, {}, - CL_TENSOR_LAYOUT_OPTIMAL_QCOM, cl_dtype); - auto out_dims = GetTensorDims(nodes_[node.GetInputs()[0].id_]); - auto output = MakeCLMLTensorFromJSONEntry(nid, {out_dims.n, out_dims.c, 1, 1}, - CL_TENSOR_LAYOUT_OPTIMAL_QCOM, cl_dtype); - - cl_ml_op_softmax_desc_qcom softmax_desc = {CL_SOFTMAX_ALGORITHM_ACCURATE_QCOM, - CL_SOFTMAX_MODE_INSTANCE_QCOM, cl_arithmetic_mode}; - + auto output = MakeCLMLTensorFromJSONEntry(nid, {}, layout, cl_dtype); + auto input = MakeCLMLTensorFromJSONEntry(node.GetInputs()[0].id_, {}, layout, cl_dtype); + cl_ml_op_softmax_desc_qcom softmax_desc = {CL_SOFTMAX_ALGORITHM_ACCURATE_QCOM, mode, + cl_arithmetic_mode}; result = CLML_INTF->clCreateMLOpSoftmaxQCOM(CLML_CTX, nullptr, &softmax_desc, input->tensor, output->tensor, &op, layer_.tuning_cache); ICHECK(op && result == CL_SUCCESS) << "SoftMax Error:" << result; - layer->function.push_back(op); return; } diff --git a/tests/python/contrib/test_clml/test_ops.py b/tests/python/contrib/test_clml/test_ops.py index 58365bf4291a..3d89994126af 100644 --- a/tests/python/contrib/test_clml/test_ops.py +++ b/tests/python/contrib/test_clml/test_ops.py @@ -280,9 +280,9 @@ def test_conv2d(remote, dtype, target, trials, executor_type): has_activation=composite[2], ) outputs = _build_and_run_network(remote, func, params, inputs, target, executor_type) - out_rtol = 1e-1 if dtype == "float16" else 1e-5 + out_tol = 1e-1 if dtype == "float16" else 1e-5 tvm.testing.assert_allclose( - outputs[0].asnumpy(), outputs[1].asnumpy(), rtol=out_rtol, atol=out_rtol + outputs[0].asnumpy(), outputs[1].asnumpy(), rtol=out_tol, atol=out_tol ) args = (shape, kernel_h, kernel_w, pad, stride, dilation, groups, dtype, out_channels) exp_codegen = _get_conv_expected_codegen( @@ -373,9 +373,9 @@ def test_conv2d_transpose(remote, dtype, target, trials, executor_type): func = relay.Function([x, w], y) mod = IRModule.from_expr(func) outputs = _build_and_run_network(remote, mod, params, inputs, target, executor_type) - out_rtol = 1e-1 if dtype == "float16" else 1e-5 + out_tol = 1e-1 if dtype == "float16" else 1e-5 tvm.testing.assert_allclose( - outputs[0].asnumpy(), outputs[1].asnumpy(), rtol=out_rtol, atol=out_rtol + outputs[0].asnumpy(), outputs[1].asnumpy(), rtol=out_tol, atol=out_tol ) args = ( dshape, @@ -425,9 +425,9 @@ def test_batchnorm(remote, dtype, target, trials, executor_type): "a": input_arr, } outputs = _build_and_run_network(remote, mod, params, inputs, target, executor_type) - out_rtol = 1e-3 if dtype == "float16" else 1e-5 + out_tol = 1e-3 if dtype == "float16" else 1e-5 tvm.testing.assert_allclose( - outputs[0].asnumpy(), outputs[1].asnumpy(), rtol=out_rtol, atol=out_rtol + outputs[0].asnumpy(), outputs[1].asnumpy(), rtol=out_tol, atol=out_tol ) exp_codegen = [ { @@ -485,9 +485,9 @@ def test_concat(remote, dtype, target, trials, executor_type): func = relay.concatenate((a, b), axis=1) outputs = _build_and_run_network(remote, func, params, inputs, target, executor_type) - out_rtol = 1e-2 if dtype == "float16" else 1e-5 + out_tol = 1e-2 if dtype == "float16" else 1e-5 tvm.testing.assert_allclose( - outputs[0].asnumpy(), outputs[1].asnumpy(), rtol=out_rtol, atol=out_rtol + outputs[0].asnumpy(), outputs[1].asnumpy(), rtol=out_tol, atol=out_tol ) exp_codegen = [ @@ -601,9 +601,9 @@ def test_pool(remote, dtype, target, trials, executor_type): func = relay.nn.avg_pool2d(a, pool_size=pool_size, strides=stride, padding=padding) outputs = _build_and_run_network(remote, func, params, inputs, target, executor_type) - out_rtol = 1e-2 if dtype == "float16" else 1e-5 + out_tol = 1e-2 if dtype == "float16" else 1e-5 tvm.testing.assert_allclose( - outputs[0].asnumpy(), outputs[1].asnumpy(), rtol=out_rtol, atol=out_rtol + outputs[0].asnumpy(), outputs[1].asnumpy(), rtol=out_tol, atol=out_tol ) args = (input_shape, pool_size, stride, padding, pooling_type, dtype) exp_codegen = _get_pool_expected_codegen(*args) @@ -690,9 +690,9 @@ def _get_model(x_shape, k_shape, has_bias=False): def _verify(out, params, inputs, exp_codegen): mod = IRModule.from_expr(out) outputs = _build_and_run_network(remote, mod, params, inputs, target, executor_type) - out_rtol = 1e-1 if dtype == "float16" else 1e-5 + out_tol = 1e-1 if dtype == "float16" else 1e-5 tvm.testing.assert_allclose( - outputs[0].asnumpy(), outputs[1].asnumpy(), rtol=out_rtol, atol=out_rtol + outputs[0].asnumpy(), outputs[1].asnumpy(), rtol=out_tol, atol=out_tol ) verify_codegen(remote, mod, params, exp_codegen, target) @@ -718,9 +718,9 @@ def _get_model(a_shape, b_shape, op_func): def _verify(out, params, inputs): mod = IRModule.from_expr(out) outputs = _build_and_run_network(remote, mod, params, inputs, target, executor_type) - out_rtol = 1e-2 if dtype == "float16" else 1e-5 + out_tol = 1e-2 if dtype == "float16" else 1e-5 tvm.testing.assert_allclose( - outputs[0].asnumpy(), outputs[1].asnumpy(), rtol=out_rtol, atol=out_rtol + outputs[0].asnumpy(), outputs[1].asnumpy(), rtol=out_tol, atol=out_tol ) exp_codegen = [ { @@ -776,9 +776,9 @@ def _get_model(a_shape, op): def _verify(out, params, inputs): mod = IRModule.from_expr(out) outputs = _build_and_run_network(remote, mod, params, inputs, target, executor_type) - out_rtol = 1e-2 if dtype == "float16" else 1e-5 + out_tol = 1e-2 if dtype == "float16" else 1e-5 tvm.testing.assert_allclose( - outputs[0].asnumpy(), outputs[1].asnumpy(), rtol=out_rtol, atol=out_rtol + outputs[0].asnumpy(), outputs[1].asnumpy(), rtol=out_tol, atol=out_tol ) exp_codegen = [ @@ -823,12 +823,11 @@ def _get_model(a_shape, block_size): def _verify(out, params, inputs): mod = IRModule.from_expr(out) outputs = _build_and_run_network(remote, mod, params, inputs, target, executor_type) - out_rtol = 1e-2 if dtype == "float16" else 1e-5 + out_tol = 1e-2 if dtype == "float16" else 1e-5 tvm.testing.assert_allclose( - outputs[0].asnumpy(), outputs[1].asnumpy(), rtol=out_rtol, atol=out_rtol + outputs[0].asnumpy(), outputs[1].asnumpy(), rtol=out_tol, atol=out_tol ) - # Check to make sure these ops are offloaded to CLML instead of TVM. exp_codegen = [ { "attrs": { @@ -877,12 +876,11 @@ def _get_model(a_shape, scale, align_corners): def _verify(out, params, inputs): mod = IRModule.from_expr(out) outputs = _build_and_run_network(remote, mod, params, inputs, target, executor_type) - out_rtol = 1e-2 if dtype == "float16" else 1e-5 + out_tol = 1e-2 if dtype == "float16" else 1e-5 tvm.testing.assert_allclose( - outputs[0].asnumpy(), outputs[1].asnumpy(), rtol=out_rtol, atol=out_rtol + outputs[0].asnumpy(), outputs[1].asnumpy(), rtol=out_tol, atol=out_tol ) - # Check to make sure these ops are offloaded to CLML instead of TVM. exp_codegen = [ { "attrs": { @@ -944,12 +942,11 @@ def _get_model(a_shape, b_shape, a_transpose, b_transpose): def _verify(out, params, inputs): mod = IRModule.from_expr(out) outputs = _build_and_run_network(remote, mod, params, inputs, target, executor_type) - out_rtol = 1e-1 if dtype == "float16" else 1e-5 + out_tol = 1e-1 if dtype == "float16" else 1e-5 tvm.testing.assert_allclose( - outputs[0].asnumpy(), outputs[1].asnumpy(), rtol=out_rtol, atol=out_rtol + outputs[0].asnumpy(), outputs[1].asnumpy(), rtol=out_tol, atol=out_tol ) - # Check to make sure these ops are offloaded to CLML instead of TVM. exp_codegen = [ { "attrs": { @@ -1026,20 +1023,30 @@ def _get_model(a_shape, axis): params = {} return out, params, inputs, axis - def _verify(out, params, inputs, axis): + def _verify(out, params, inputs, axis, out_tol): mod = IRModule.from_expr(out) outputs = _build_and_run_network(remote, mod, params, inputs, target, executor_type) - out_rtol = 1e-1 if dtype == "float16" else 1e-5 tvm.testing.assert_allclose( - outputs[0].asnumpy(), outputs[1].asnumpy(), rtol=out_rtol, atol=out_rtol + outputs[0].asnumpy(), outputs[1].numpy(), rtol=out_tol, atol=out_tol ) args = (inputs, dtype, outputs[0].shape, axis) exp_codegen = _get_softmax_exp_codegen(*args) verify_codegen(remote, mod, params, exp_codegen, target) - _verify(*(_get_model((1, 5), 1))) - _verify(*(_get_model((1, 1000), 1))) - _verify(*(_get_model((1, 3), 1))) + # 2D Tensor TEST CASES + _verify(*(_get_model((1, 5), 1)), 1e-3) + _verify(*(_get_model((1, 16), 1)), 1e-3) + _verify(*(_get_model((1, 1000), -1)), 1e-3) + + # 4D Tensor TEST CASES layout = NCHW + _verify(*(_get_model((1, 100, 64, 100), 1)), 1e-3) + _verify(*(_get_model((1, 64, 64, 64), 1)), 1e-3) + _verify(*(_get_model((1, 5, 3, 4), 1)), 1e-3) + + # 4D Tensor TEST CASES layout = NHWC + _verify(*(_get_model((1, 64, 100, 100), 3)), 1e-1) + _verify(*(_get_model((1, 100, 100, 100), 3)), 1e-1) + _verify(*(_get_model((1, 64, 5, 32), -1)), 1e-1) @pytest.mark.parametrize("dtype", ["float32", "float16"]) @@ -1066,9 +1073,9 @@ def _verify(in_shape, scale_h, scale_w): ) mod = IRModule.from_expr(func) outputs = _build_and_run_network(remote, mod, params, inputs, target, executor_type) - out_rtol = 1e-2 if dtype == "float16" else 1e-5 + out_tol = 1e-2 if dtype == "float16" else 1e-5 tvm.testing.assert_allclose( - outputs[0].asnumpy(), outputs[1].asnumpy(), rtol=out_rtol, atol=out_rtol + outputs[0].asnumpy(), outputs[1].asnumpy(), rtol=out_tol, atol=out_tol ) exp_codegen = [ { @@ -1124,9 +1131,9 @@ def _verify(shape, newshape): params = {} mod = IRModule.from_expr(out) outputs = _build_and_run_network(remote, mod, params, inputs, target, executor_type) - out_rtol = 1e-3 if dtype == "float16" else 1e-5 + out_tol = 1e-3 if dtype == "float16" else 1e-5 tvm.testing.assert_allclose( - outputs[0].asnumpy(), outputs[1].asnumpy(), rtol=out_rtol, atol=out_rtol + outputs[0].asnumpy(), outputs[1].asnumpy(), rtol=out_tol, atol=out_tol ) exp_codegen = [ { @@ -1223,9 +1230,9 @@ def test_pool_global(remote, dtype, target, executor_type, trials): func = relay.nn.global_avg_pool2d(a) mod = IRModule.from_expr(func) outputs = _build_and_run_network(remote, mod, params, inputs, target, executor_type) - out_rtol = 1e-3 if dtype == "float16" else 1e-5 + out_tol = 1e-3 if dtype == "float16" else 1e-5 tvm.testing.assert_allclose( - outputs[0].asnumpy(), outputs[1].asnumpy(), rtol=out_rtol, atol=out_rtol + outputs[0].asnumpy(), outputs[1].asnumpy(), rtol=out_tol, atol=out_tol ) args = (input_shape, pooling_type, dtype, outputs[0].shape) exp_codegen = _get_pool_global_expected_codegen(*args) @@ -1241,6 +1248,7 @@ def _get_model(a_shape): # Defined the test case with unary operator # Single batch_flatten op is failing in native OpenCL # Empty TVM mod in VM doesn't pick appropriate cross compiler + np.random.seed(0) out = relay.nn.relu(a) out = relay.nn.batch_flatten(out) inputs = {"a": tvm.nd.array(np.random.uniform(-1, 1, a_shape).astype(dtype))} @@ -1250,9 +1258,9 @@ def _get_model(a_shape): def _verify(out, params, inputs): mod = IRModule.from_expr(out) outputs = _build_and_run_network(remote, mod, params, inputs, target, executor_type) - out_rtol = 1e-3 if dtype == "float16" else 1e-5 + out_tol = 1e-3 if dtype == "float16" else 1e-5 tvm.testing.assert_allclose( - outputs[0].asnumpy(), outputs[1].asnumpy(), rtol=out_rtol, atol=out_rtol + outputs[0].asnumpy(), outputs[1].asnumpy(), rtol=out_tol, atol=out_tol ) exp_codegen = [ { From e1c430c7e3180b65e234cf39b3f1de6e71825f55 Mon Sep 17 00:00:00 2001 From: TaoMiao Date: Fri, 19 Jan 2024 02:40:36 +0800 Subject: [PATCH 13/16] [Relay][Frontend][Torch] fix pytorch frontend linspace op (#16417) fix pytorch frontend linspace op --- python/tvm/relay/frontend/pytorch.py | 2 +- tests/python/frontend/pytorch/test_forward.py | 5 +++++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 35f74544b833..8594ee0e0614 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -918,7 +918,7 @@ def linspace(self, inputs, input_types): # Find the spacing between values as step if step != 1: step = (stop - start) / (step - 1) - stop = stop + step + stop = stop + (step / 2) else: stop = start + step diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index bf96c21399f0..6d07f081e9ac 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -3632,6 +3632,10 @@ class Linspace8(Module): def forward(self, *args): return torch.linspace(1, 2, 1, dtype=torch.int16) + class Linspace9(Module): + def forward(self, *args): + return torch.linspace(0, 8, 10) + verify_model(Linspace1().float().eval()) verify_model(Linspace2().float().eval()) verify_model(Linspace3().float().eval()) @@ -3640,6 +3644,7 @@ def forward(self, *args): verify_model(Linspace6().float().eval()) verify_model(Linspace7().float().eval()) verify_model(Linspace8().float().eval()) + verify_model(Linspace9().float().eval()) @tvm.testing.uses_gpu From 827beed0d6e130c3b3854ee27cbce632f7917867 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Thu, 18 Jan 2024 16:35:32 -0800 Subject: [PATCH 14/16] [CMake] Enable cuda lang if USE_CUDA is on (#16426) --- cmake/modules/CUDA.cmake | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/cmake/modules/CUDA.cmake b/cmake/modules/CUDA.cmake index bbbf6b89ba2e..e66d3ba9fcd6 100644 --- a/cmake/modules/CUDA.cmake +++ b/cmake/modules/CUDA.cmake @@ -29,6 +29,7 @@ if(USE_CUDA) message(FATAL_ERROR "Cannot find CUDA, USE_CUDA=" ${USE_CUDA}) endif() message(STATUS "Build with CUDA ${CUDA_VERSION} support") + enable_language(CUDA) tvm_file_glob(GLOB RUNTIME_CUDA_SRCS src/runtime/cuda/*.cc) list(APPEND RUNTIME_SRCS ${RUNTIME_CUDA_SRCS}) list(APPEND COMPILER_SRCS src/target/opt/build_cuda_on.cc) @@ -62,8 +63,6 @@ if(USE_CUDA) if(USE_THRUST) message(STATUS "Build with Thrust support") - cmake_minimum_required(VERSION 3.13) # to compile CUDA code - enable_language(CUDA) set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-extended-lambda") tvm_file_glob(GLOB CONTRIB_THRUST_SRC src/runtime/contrib/thrust/*.cu) list(APPEND RUNTIME_SRCS ${CONTRIB_THRUST_SRC}) @@ -72,8 +71,6 @@ if(USE_CUDA) if(USE_CURAND) message(STATUS "Build with cuRAND support") message(STATUS "${CUDA_CURAND_LIBRARY}") - cmake_minimum_required(VERSION 3.13) # to compile CUDA code - enable_language(CUDA) tvm_file_glob(GLOB CONTRIB_CURAND_SRC_CC src/runtime/contrib/curand/*.cc) tvm_file_glob(GLOB CONTRIB_CURAND_SRC_CU src/runtime/contrib/curand/*.cu) list(APPEND TVM_RUNTIME_LINKER_LIBS ${CUDA_CURAND_LIBRARY}) From 614a7a9e31e00ddd9442b218a5b7042f3a49e9b1 Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Thu, 18 Jan 2024 19:36:16 -0500 Subject: [PATCH 15/16] [CI][WASM] Update emsdk and nodejs version (#16420) This PR updates the emsdk and nodejs version of docker. --- docker/install/ubuntu_install_emscripten.sh | 4 ++-- docker/install/ubuntu_install_nodejs.sh | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/docker/install/ubuntu_install_emscripten.sh b/docker/install/ubuntu_install_emscripten.sh index 87c95f2936bf..98331d2cbdc2 100755 --- a/docker/install/ubuntu_install_emscripten.sh +++ b/docker/install/ubuntu_install_emscripten.sh @@ -23,5 +23,5 @@ set -o pipefail cd / git clone https://github.com/emscripten-core/emsdk.git cd emsdk -./emsdk install 3.1.30 -./emsdk activate 3.1.30 +./emsdk install 3.1.51 +./emsdk activate 3.1.51 diff --git a/docker/install/ubuntu_install_nodejs.sh b/docker/install/ubuntu_install_nodejs.sh index b295d9e3e41d..6d9ef3f5ded8 100755 --- a/docker/install/ubuntu_install_nodejs.sh +++ b/docker/install/ubuntu_install_nodejs.sh @@ -28,5 +28,5 @@ apt-install-and-clear -y curl # The node install script fetched and executed here will update the # apt source list, hence the second apt-get update --fix-missing is necessary. -curl -s -S -L https://deb.nodesource.com/setup_14.x | bash - +curl -s -S -L https://deb.nodesource.com/setup_16.x | bash - apt-install-and-clear -y nodejs From 6e8115453891ad5caa74977dba0945b72bfd94fd Mon Sep 17 00:00:00 2001 From: quic_rutkoor Date: Thu, 18 Jan 2024 20:47:54 -0800 Subject: [PATCH 16/16] Loop-Partition Scheduling primitive --- include/tvm/tir/schedule/schedule.h | 11 + python/tvm/tir/schedule/schedule.py | 108 ++++ src/tir/schedule/concrete_schedule.cc | 137 ++++++ src/tir/schedule/concrete_schedule.h | 2 + src/tir/schedule/primitive.h | 12 + src/tir/schedule/primitive/get_block_loop.cc | 2 +- .../schedule/primitive/loop_transformation.cc | 293 +++++++++++ src/tir/schedule/schedule.cc | 2 + src/tir/schedule/traced_schedule.cc | 21 + src/tir/schedule/traced_schedule.h | 2 + .../test_tir_schedule_partition.py | 460 ++++++++++++++++++ 11 files changed, 1049 insertions(+), 1 deletion(-) create mode 100644 tests/python/tir-schedule/test_tir_schedule_partition.py diff --git a/include/tvm/tir/schedule/schedule.h b/include/tvm/tir/schedule/schedule.h index 273912ed1f8f..457e6f28951d 100644 --- a/include/tvm/tir/schedule/schedule.h +++ b/include/tvm/tir/schedule/schedule.h @@ -354,6 +354,17 @@ class ScheduleNode : public runtime::Object { */ virtual Array Split(const LoopRV& loop_rv, const Array>& factors, bool preserve_unit_iters = true) = 0; + /*! + * \brief Partition the loops into sequence of multiple loops + * 1) The loop can't have annotation or thread binding. + * \param loop_rv The loop to be partition + * \param factors The positive integers, and at most one of which is `NullOpt`, which means + * that factor is inferred. + * \param preserve_unit_iters Whether or not to preserve unit iterators in block bindings + * \return The new loops after partition + */ + virtual Array LoopPartition(const LoopRV& loop_rv, const Array>& factors, + bool preserve_unit_iters = true) = 0; /*! * \brief Reorder a list of loops. It doesn't require the loops to be consecutive. * It requires: diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py index 23b000c09015..b871c91987df 100644 --- a/python/tvm/tir/schedule/schedule.py +++ b/python/tvm/tir/schedule/schedule.py @@ -813,6 +813,114 @@ def after_split(a: T.handle, b: T.handle) -> None: ) ) + @type_checked + def loop_partition( + self, + loop: LoopRV, + factors: List[Union[int, ExprRV, None]], + preserve_unit_iters: bool = True, + ) -> List[LoopRV]: + """Partition a loop into a list of consecutive loops. It requires: + 1) The loop can't have annotation or thread binding. + Predicates may be added to ensure the total loop numbers keeps unchanged. + In `factors`, at most one of the factors can be None, + which will be automatically inferred. + + Parameters + ---------- + loop : LoopRV + The loop to be partition + + factors: List[Union[int, ExprRV, None]] + The partitioning factors + Potential inputs are: + - None + - ExprRV + - Positive constant integers + + preserve_unit_iters : bool + Whether or not to preserve unit iterators in block bindings + + Returns + ------- + partition_loops : List[LoopRV] + The new loops after partition + + Examples + -------- + + Before partition, in TensorIR, the IR is: + + .. code-block:: python + + @T.prim_func + def before_partition(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, (128, 128)) + B = T.match_buffer(b, (128, 128)) + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] * 2.0 + + Create the schedule and do partition: + + .. code-block:: python + + sch = tir.Schedule(before_partition) + i, j = sch.get_loops(sch.get_block("B")) + sch.partition(i, factors=[2, 64]) + print(sch.mod["main"].script()) + + After applying partition, the IR becomes: + + .. code-block:: python + + def after_partition(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, (128, 128)) + B = T.match_buffer(b, (128, 128)) + # the original loop is partition into 3 loops + with T.block("root"): + T.reads() + T.writes() + with T.block("B_i_common"): + T.reads() + T.writes() + with T.block("B_i0_partition"): + T.reads() + T.writes() + for i0, j in T.grid(2, 128): + with T.block("B_i0"): + vi, vj = T.axis.remap("SS", [i0, j]) + T.reads(A[0:2, 0:128]) + T.writes(B[0:2, 0:128]) + B[vi, vj] = A[vi, vj] * T.float32(2) + with T.block("B_i1_partition"): + T.reads() + T.writes() + for i1 in range(2, 66): + for j in range(128): + with T.block("B_i1"): + vi, vj = T.axis.remap("SS", [i1, j]) + T.reads(A[2:66, 0:128]) + T.writes(B[2:66, 0:128]) + B[vi, vj] = A[vi, vj] * T.float32(2) + with T.block("B_partition_2"): + T.reads() + T.writes() + for i2 in range(66, 128): + for j in range(128): + with T.block("B_i2"): + vi, vj = T.axis.remap("SS", [i2, j]) + T.reads(A[66:128, 0:128]) + T.writes(B[66:128, 0:128]) + B[vi, vj] = A[vi, vj] * T.float32(2) + """ + return list( + _ffi_api.ScheduleLoopPartition( # type: ignore # pylint: disable=no-member + self, loop, factors, preserve_unit_iters + ) + ) + @type_checked def reorder(self, *ordered_loops: List[LoopRV]) -> None: """ diff --git a/src/tir/schedule/concrete_schedule.cc b/src/tir/schedule/concrete_schedule.cc index 343fb7617886..88438b0108d3 100644 --- a/src/tir/schedule/concrete_schedule.cc +++ b/src/tir/schedule/concrete_schedule.cc @@ -500,6 +500,143 @@ Array ConcreteScheduleNode::Split(const LoopRV& loop_rv, return CreateRV(results); } +Array ConcreteScheduleNode::LoopPartition(const LoopRV& loop_rv, + const Array>& factor_rvs, + bool preserve_unit_iters) { + class SymbolicShapeError : public ScheduleError { + public: + explicit SymbolicShapeError(IRModule mod, For loop) : mod_(mod), loop_(std::move(loop)) {} + + String FastErrorString() const final { + return "ScheduleError: The min and extent values of the loop are required to be known at " + "compile time. However, dynamic shape has been detected."; + } + + String DetailRenderTemplate() const final { + return "Detected dynamic shape in either min or extent of a loop {0}"; + } + + IRModule mod() const final { return mod_; } + Array LocationsOfInterest() const final { return {loop_}; } + + IRModule mod_; + For loop_; + }; + + class NotSingleInferFactorError : public ScheduleError { + public: + explicit NotSingleInferFactorError(IRModule mod) : mod_(mod) {} + + String FastErrorString() const final { + return "ScheduleError: only one factor can be specified as -1 or none"; + } + + String DetailRenderTemplate() const final { + return "Only one factor can be specified as -1 or none"; + } + + IRModule mod() const final { return mod_; } + Array LocationsOfInterest() const final { return {}; } + + IRModule mod_; + }; + + class WrongFactorSumError : public ScheduleError { + public: + explicit WrongFactorSumError(IRModule mod, For loop) : mod_(mod), loop_(std::move(loop)) {} + + String FastErrorString() const final { + return "ScheduleError: The sum of factors is larger than or equal to the extent of " + "loop"; + } + + String DetailRenderTemplate() const final { + return "The sum of factors is not larger than or equal to the extent of loop {0}"; + } + + IRModule mod() const final { return mod_; } + Array LocationsOfInterest() const final { return {loop_}; } + + IRModule mod_; + For loop_; + }; + + class NonPositiveFactorError : public ScheduleError { + public: + explicit NonPositiveFactorError(IRModule mod, int64_t factor, size_t idx) + : mod_(std::move(mod)), factor_(factor), idx_(idx) {} + + String FastErrorString() const final { + return "ScheduleError: All the constant factors are required to be positive. However, some " + "constant input factor is zero or negative."; + } + String DetailRenderTemplate() const final { + std::ostringstream os; + os << "All the constant factors are required to be positive. However, the factor at position " + << idx_ << " is " << factor_; + return os.str(); + } + IRModule mod() const final { return mod_; } + Array LocationsOfInterest() const final { return {}; } + + private: + IRModule mod_; + int64_t factor_; + size_t idx_; + }; + + // Prepare for the loop_partitioning + StmtSRef loop_sref = this->GetSRef(loop_rv); + const ForNode* loop = TVM_SREF_TO_FOR(loop_sref); + Array factors; + factors.reserve(factor_rvs.size()); + int infer_index = -1; + PrimExpr tot_length = 0; + Array results; + TVM_TIR_SCHEDULE_BEGIN(); + if (!is_const_number(loop->min) || !is_const_number(loop->extent)) { + throw SymbolicShapeError(state_->mod, GetRef(loop)); + } + // infer factor if needed and check validity of factors + for (size_t i = 0; i < factor_rvs.size(); i++) { + if (!factor_rvs[i].defined()) { + factors.push_back(Integer(-1)); + if (infer_index != -1) { + throw NotSingleInferFactorError(state_->mod); + } + infer_index = i; + } else { + PrimExpr factor = this->Get(factor_rvs[i].value()); + if (is_const_int(factor) && !is_positive_const(factor)) { + throw NonPositiveFactorError(state_->mod, factor.as()->value, i); + } + if (factor.dtype().bits() > loop->extent.dtype().bits()) { + factor = cast(loop->extent.dtype(), factor); + } + factors.push_back(factor); + tot_length += factor; + } + } + if (this->analyzer_->CanProve(tot_length >= loop->extent)) { + throw WrongFactorSumError(state_->mod, GetRef(loop)); + } + if (infer_index != -1) { + // if there is a 'None' in the factor list, 'None' becomes the difference between the extent and + // the sum of the factors excluding 'None' specified in the partition directive. + factors.Set(infer_index, loop->extent - tot_length); + } + for (size_t i = 1; i < factor_rvs.size(); i++) { + factors.Set(i, factors[i] + factors[i - 1]); + } + if (infer_index == -1) { + factors.push_back(loop->extent); + } + results = tir::LoopPartition(state_, loop_sref, factors, preserve_unit_iters); + TVM_TIR_SCHEDULE_END("loop_partition", this->error_render_level_); + this->state_->DebugVerify(); + return CreateRV(results); +} + void ConcreteScheduleNode::Reorder(const Array& ordered_loop_rvs) { TVM_TIR_SCHEDULE_BEGIN(); tir::Reorder(state_, GetSRefs(ordered_loop_rvs)); diff --git a/src/tir/schedule/concrete_schedule.h b/src/tir/schedule/concrete_schedule.h index a6c47070c8df..a510b0bc8683 100644 --- a/src/tir/schedule/concrete_schedule.h +++ b/src/tir/schedule/concrete_schedule.h @@ -109,6 +109,8 @@ class ConcreteScheduleNode : public ScheduleNode { LoopRV Merge(const Array& loop_rvs) override; Array Split(const LoopRV& loop_rv, const Array>& factors, bool preserve_unit_iters) override; + Array LoopPartition(const LoopRV& loop_rv, const Array>& factors, + bool preserve_unit_iters) override; void Reorder(const Array& ordered_loop_rvs) override; void ReorderBlockIterVar(const BlockRV& block_rv, const Array new_order) override; LoopRV AddUnitLoop(const BlockRV& block_rv) override; diff --git a/src/tir/schedule/primitive.h b/src/tir/schedule/primitive.h index 02fb982f5ed9..dc4bfdc1a97d 100644 --- a/src/tir/schedule/primitive.h +++ b/src/tir/schedule/primitive.h @@ -209,6 +209,18 @@ Array GetOutputBlocks(const ScheduleState& self, const StmtSRef& scope TVM_DLL Array Split(ScheduleState self, const StmtSRef& loop_sref, const Array& factors, bool preserve_unit_iters); +/*! + * Partition a loop into a list of consecutive loops. It requires: + * 1) The loop can't have annotation or thread binding. + * \param self The state of the schedule + * \param loop_sref The sref to the loop being partition + * \param factors The partitioning factors + * \param preserve_unit_iters Whether or not to preserve unit iterators in block bindings + * \return An array of srefs to the loops after partitioning + */ +TVM_DLL Array LoopPartition(ScheduleState self, const StmtSRef& loop_sref, + const Array& factors, bool preserve_unit_iters); + /*! * \brief Merge a list of loops into one. The loops under their LCA requires: * 1) Under the same scope diff --git a/src/tir/schedule/primitive/get_block_loop.cc b/src/tir/schedule/primitive/get_block_loop.cc index 588770d968ef..f6afc71218e7 100644 --- a/src/tir/schedule/primitive/get_block_loop.cc +++ b/src/tir/schedule/primitive/get_block_loop.cc @@ -238,7 +238,7 @@ struct GetOutputBlocksTraits : public UnpackedInstTraits static String UnpackedAsPython(Array outputs, String block_rv) { PythonAPICall py("get_output_blocks"); - py.Input("block", block_rv); + py.Input("scope_block", block_rv); py.OutputList(outputs); return py.Str(); } diff --git a/src/tir/schedule/primitive/loop_transformation.cc b/src/tir/schedule/primitive/loop_transformation.cc index a6b97bf17906..9d50860661e4 100644 --- a/src/tir/schedule/primitive/loop_transformation.cc +++ b/src/tir/schedule/primitive/loop_transformation.cc @@ -454,6 +454,258 @@ Array Split(ScheduleState self, const StmtSRef& loop_sref, const Array return result_srefs; } +class BufferIndicesMapExtractor : public StmtExprVisitor { + public: + explicit BufferIndicesMapExtractor(Var loop_var) : loop_var_(loop_var) {} + + static Map> Extract(Var loop_var, Block& block) { + BufferIndicesMapExtractor extractor(loop_var); + extractor(std::move(block->body)); + return extractor.buffer_indices_map; + } + + private: + void VisitStmt_(const BufferStoreNode* store) final { + Array indices; + bool check_ = false; + for (int i = 0; i < store->indices.size(); i++) { + const VarNode* var_node = store->indices[i].as(); + if (var_node == nullptr) { + check_ = true; + break; + } + indices.push_back(var_node->name_hint); + } + if (buffer_indices_map.find(store->buffer->name) == buffer_indices_map.end() && !check_) + buffer_indices_map.Set(store->buffer->name, indices); + StmtExprVisitor::VisitStmt_(store); + } + + void VisitExpr_(const BufferLoadNode* load) final { + Array indices; + bool check_ = false; + for (int i = 0; i < load->indices.size(); i++) { + const VarNode* var_node = load->indices[i].as(); + if (var_node == nullptr) { + check_ = true; + break; + } + indices.push_back(var_node->name_hint); + } + if (buffer_indices_map.find(load->buffer->name) == buffer_indices_map.end() && !check_) + buffer_indices_map.Set(load->buffer->name, indices); + StmtExprVisitor::VisitExpr_(load); + } + + void VisitStmt_(const BlockNode* op) final { StmtVisitor::VisitStmt_(op); } + + Var loop_var_; + Map> buffer_indices_map; +}; + +Array MutateBufferRegion(Map>& buffer_indices_map, + Map& index_range_map, + Array region_arr) { + // Update the region with new Ranges and return new BufferRegion + Array new_region_arr = + MutateArray(region_arr, [&buffer_indices_map, &index_range_map](const BufferRegion& region) { + BufferRegion new_region = region; + auto it = buffer_indices_map.find(new_region->buffer->name); + if (it == buffer_indices_map.end()) return new_region; + + Array old_indices = buffer_indices_map[new_region->buffer->name]; + Array new_ranges; + for (int i = 0; i < old_indices.size(); i++) { + new_ranges.push_back(index_range_map[old_indices[i]]); + } + new_region.CopyOnWrite()->region = std::move(new_ranges); + return new_region; + }); + return new_region_arr; +} + +class BlockMutator : public StmtExprMutator { + public: + explicit BlockMutator(Var new_loop_var, PrimExpr min, PrimExpr extent) + : new_loop_var_(new_loop_var), min_(min), extent_(extent) {} + + private: + Stmt VisitStmt_(const BlockNode* _op) final { + Block new_block = Downcast(StmtMutator::VisitStmt_(_op)); + + // If iter_vars.size() is 0, then the block most probably be an Opaque block + if (new_block->iter_vars.size() == 0 || inner_iter_var_index == -1) { + new_block.CopyOnWrite()->name_hint = + new_block.CopyOnWrite()->name_hint + "_" + new_loop_var_->name_hint; + return std::move(new_block); + } + + Var iter_var_ = new_block->iter_vars[inner_iter_var_index]->var; + inner_iter_var_index = -1; + // As we are working on cloned block, we need to create new instances of iter_var + Array new_iter_vars = + MutateArray(new_block->iter_vars, [this, &iter_var_](const IterVar& iter) { + auto dtype = iter->var.dtype(); + // Create new Var instance for each IterVar + Var new_var = Var(iter->var->name_hint, iter->var.dtype()); + IterVar new_iter = iter; + new_iter.CopyOnWrite()->var = new_var; + // Change the domain of IterVar corresponding to partitioned loop_var + if (iter_var_.same_as(iter->var)) { + new_iter.CopyOnWrite()->dom = Range(tvm::cast(dtype, min_), tvm::cast(dtype, extent_)); + } + return new_iter; + }); + + // Update the IterVars of new_block + if (!new_block->iter_vars.same_as(new_iter_vars)) { + new_block.CopyOnWrite()->iter_vars = std::move(new_iter_vars); + new_block.CopyOnWrite()->name_hint = + new_block.CopyOnWrite()->name_hint + "_" + new_loop_var_->name_hint; + } + + // Get the (iter_var, new Range) map + Map index_range_map; + for (int i = 0; i < new_block->iter_vars.size(); i++) { + IterVar iter = new_block->iter_vars[i]; + index_range_map.Set(iter->var->name_hint, iter->dom); + } + + // Get the (Buffer, indices) map + Map> buffer_indices_map = + BufferIndicesMapExtractor::Extract(new_loop_var_, new_block); + Array new_writes = + MutateBufferRegion(buffer_indices_map, index_range_map, new_block->writes); + if (!new_block->writes.same_as(new_writes)) { + // Update the writes with new_writes + new_block.CopyOnWrite()->writes = std::move(new_writes); + } + Array new_reads = + MutateBufferRegion(buffer_indices_map, index_range_map, new_block->reads); + if (!new_block->reads.same_as(new_reads)) { + // Update the reads with new_reads + new_block.CopyOnWrite()->reads = std::move(new_reads); + } + + Map var_map; + for (int i = 0; i < new_block->iter_vars.size(); i++) { + var_map.Set(_op->iter_vars[i]->var, new_block->iter_vars[i]->var); + } + + // Update all instances of old iter_vars in the block with new iter_vars + auto block_stmt = tir::Substitute(new_block, var_map); + return std::move(block_stmt); + } + + Stmt VisitStmt_(const BlockRealizeNode* realize) final { + Array iter_values = realize->iter_values; + for (int i = 0; i < iter_values.size(); i++) { + if (new_loop_var_.same_as(iter_values[i])) { + // Get the iter_var index corresponding to loop_var iter_value index + inner_iter_var_index = i; + break; + } + } + BlockRealize stmt = Downcast(StmtExprMutator::VisitStmt_(realize)); + return std::move(stmt); + } + + Stmt VisitStmt_(const ForNode* op) final { + For res = Downcast(StmtMutator::VisitStmt_(op)); + Var new_var = Var(op->loop_var->name_hint, op->loop_var.dtype()); + + if (!op->loop_var.same_as(new_var)) { + // If the partioned loop contains nested for loop, then create new iteration variable instance + res.CopyOnWrite()->body = std::move(tir::Substitute(res->body, {{op->loop_var, new_var}})); + res.CopyOnWrite()->loop_var = new_var; + } + return res; + } + + Var new_loop_var_; + PrimExpr min_, extent_; + int inner_iter_var_index = -1; +}; + +const String get_block_name(Stmt loop_body) { + const BlockRealizeNode* blk_realize = loop_body.as(); + if (blk_realize == nullptr) { + return get_block_name(loop_body.as()->body); + } + return blk_realize->block->name_hint; +} + +Array LoopPartition(ScheduleState self, const StmtSRef& loop_sref, + const Array& factors, bool preserve_unit_iters) { + const ForNode* loop = TVM_SREF_TO_FOR(loop_sref); + if (!loop->annotations.empty() || loop->thread_binding.defined()) { + throw HasAnnotationOrThreadBindingError(self->mod, GetRef(loop)); + } + + arith::Analyzer analyzer; + // Find the most common dtype + DataType dtype; + { + int bits = loop->loop_var.dtype().bits(); + for (const PrimExpr& factor : factors) { + bits = std::max(bits, factor.dtype().bits()); + } + dtype = DataType::Int(bits); + } + + String block_name = get_block_name(loop->body) + "_" + loop->loop_var->name_hint; + int n = factors.size(); + PrimExpr min_value = loop->min; + PrimExpr extent_value; + + Array block_partitions; + block_partitions.reserve(n); + + // Iterate over each pair of factors and create partition + for (int i = 0; i < n; i++) { + extent_value = analyzer.Simplify(factors[i]); + Var new_loop_var = loop->loop_var.copy_with_suffix(std::to_string(i)).copy_with_dtype(dtype); + Stmt loop_body = tir::Substitute(loop->body, {{loop->loop_var, new_loop_var}}); + + // Create new block with new reference to each variable/stmt/expr in the existing block + loop_body = BlockMutator(new_loop_var, min_value, extent_value)(std::move(loop_body)); + // Create new for loop with appropriate range + auto for_node = + For(new_loop_var, min_value, extent_value - min_value, ForKind::kSerial, loop_body); + + const auto& partition_block_name = block_name + std::to_string(i) + "_partition"; + // Create partition_block for the partitioned for loop + BlockRealize partition_block({}, extent_value > 0, + Block({}, {}, {}, partition_block_name, for_node)); + block_partitions.push_back(partition_block); + + min_value = extent_value; + } + + // Create common block with all the partitioned blocks as its children blocks + BlockRealize common({}, make_const(DataType::Bool(), 1), + Block({}, {}, {}, block_name + "_common", tir::SeqStmt(block_partitions))); + + // Replace existing loop with the newly created common block + self->Replace(loop_sref, common, {}); + StmtSRef scope_sref = self->stmt2ref.at(common->block.get()); + StmtSRef scope_root = tir::GetScopeRoot(self, scope_sref, /*require_stage_pipeline=*/false); + bool scope_block_affine_binding = self->IsAffineBlockBinding(scope_root); + // Update the SRefTree for the newly created common block + self->UpdateScopeBlockInfo(tir::GetBlockRealize(self, scope_root)); + self->block_info[scope_root].affine_binding = scope_block_affine_binding; + + // Collect the SRef for each partitioned loop and return + Array partition_srefs; + partition_srefs.reserve(n); + for (int i = 0; i < n; i++) { + StmtSRef partition_loop_sref = + self->stmt2ref.at(block_partitions[i].as()->block->body.get()); + partition_srefs.push_back(partition_loop_sref); + } + return partition_srefs; +} + class LoopReconstructor : private StmtMutator { public: explicit LoopReconstructor(Block scope_root, const std::vector>& loops) @@ -954,6 +1206,46 @@ struct SplitTraits : public UnpackedInstTraits { friend struct ::tvm::tir::UnpackedInstTraits; }; +struct LoopPartitionTraits : public UnpackedInstTraits { + static constexpr const char* kName = "LoopPartition"; + static constexpr bool kIsPure = false; + + private: + static constexpr size_t kNumInputs = 2; + static constexpr size_t kNumAttrs = 1; + static constexpr size_t kNumDecisions = 0; + + template + static TVM_ALWAYS_INLINE void _SetInputs(const runtime::TVMArgsSetter& setter, + const Array& inputs) { + thread_local ObjectRef loop_rv{nullptr}; + thread_local Array factors{nullptr}; + loop_rv = inputs[0]; + factors = Array{inputs.begin() + 1, inputs.end()}; + setter(delta, loop_rv); + setter(delta + 1, factors); + } + + static Array UnpackedApplyToSchedule(Schedule sch, LoopRV loop_rv, + Array> factors, + Bool preserve_unit_iters) { + return sch->LoopPartition(loop_rv, factors, preserve_unit_iters.operator bool()); + } + + static String UnpackedAsPython(Array outputs, String loop_rv, Array factors, + Bool preserve_unit_iters) { + PythonAPICall py("loop_partition"); + py.Input("loop", loop_rv); + py.Input("factors", factors); + py.Input("preserve_unit_iters", preserve_unit_iters.operator bool()); + py.OutputList(outputs); + return py.Str(); + } + + template + friend struct ::tvm::tir::UnpackedInstTraits; +}; + struct MergeTraits : public UnpackedInstTraits { static constexpr const char* kName = "Merge"; static constexpr bool kIsPure = false; @@ -1084,6 +1376,7 @@ struct AddUnitLoopTraits : public UnpackedInstTraits { }; TVM_REGISTER_INST_KIND_TRAITS(SplitTraits); +TVM_REGISTER_INST_KIND_TRAITS(LoopPartitionTraits); TVM_REGISTER_INST_KIND_TRAITS(MergeTraits); TVM_REGISTER_INST_KIND_TRAITS(FuseTraits); TVM_REGISTER_INST_KIND_TRAITS(ReorderTraits); diff --git a/src/tir/schedule/schedule.cc b/src/tir/schedule/schedule.cc index cdca53fa8a71..44f9b8f42c68 100644 --- a/src/tir/schedule/schedule.cc +++ b/src/tir/schedule/schedule.cc @@ -162,6 +162,8 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGetOutputBlocks") TVM_REGISTER_GLOBAL("tir.schedule.ScheduleMerge").set_body_method(&ScheduleNode::Merge); TVM_REGISTER_GLOBAL("tir.schedule.ScheduleFuse").set_body_method(&ScheduleNode::Fuse); TVM_REGISTER_GLOBAL("tir.schedule.ScheduleSplit").set_body_method(&ScheduleNode::Split); +TVM_REGISTER_GLOBAL("tir.schedule.ScheduleLoopPartition") + .set_body_method(&ScheduleNode::LoopPartition); TVM_REGISTER_GLOBAL("tir.schedule.ScheduleReorder") .set_body_method(&ScheduleNode::Reorder); TVM_REGISTER_GLOBAL("tir.schedule.ScheduleReorderBlockIterVar") diff --git a/src/tir/schedule/traced_schedule.cc b/src/tir/schedule/traced_schedule.cc index e55a5cf8078c..3b66112ac9ce 100644 --- a/src/tir/schedule/traced_schedule.cc +++ b/src/tir/schedule/traced_schedule.cc @@ -244,6 +244,27 @@ Array TracedScheduleNode::Split(const LoopRV& loop_rv, return results; } +Array TracedScheduleNode::LoopPartition(const LoopRV& loop_rv, + const Array>& factor_rvs, + bool preserve_unit_iters) { + Array results = + ConcreteScheduleNode::LoopPartition(loop_rv, factor_rvs, preserve_unit_iters); + + std::vector inputs; + inputs.reserve(1 + factor_rvs.size()); + inputs.push_back(loop_rv); + for (const ObjectRef& obj : factor_rvs) { + inputs.push_back(obj); + } + + static const InstructionKind& kind = InstructionKind::Get("LoopPartition"); + trace_->Append(/*inst=*/Instruction(/*kind=*/kind, + /*inputs=*/inputs, + /*attrs=*/{Integer(preserve_unit_iters)}, + /*outputs=*/{results.begin(), results.end()})); + return results; +} + void TracedScheduleNode::Reorder(const Array& ordered_loop_rvs) { ConcreteScheduleNode::Reorder(ordered_loop_rvs); diff --git a/src/tir/schedule/traced_schedule.h b/src/tir/schedule/traced_schedule.h index d7a42f63d4dc..1586c15a439c 100644 --- a/src/tir/schedule/traced_schedule.h +++ b/src/tir/schedule/traced_schedule.h @@ -68,6 +68,8 @@ class TracedScheduleNode : public ConcreteScheduleNode { LoopRV Merge(const Array& loop_rvs) final; Array Split(const LoopRV& loop_rv, const Array>& factor_rvs, bool preserve_unit_iters) final; + Array LoopPartition(const LoopRV& loop_rv, const Array>& factor_rvs, + bool preserve_unit_iters) final; void Reorder(const Array& ordered_loop_rvs) final; void ReorderBlockIterVar(const BlockRV& block_rv, const Array new_order) final; LoopRV AddUnitLoop(const BlockRV& block_rv) final; diff --git a/tests/python/tir-schedule/test_tir_schedule_partition.py b/tests/python/tir-schedule/test_tir_schedule_partition.py new file mode 100644 index 000000000000..08595843e3aa --- /dev/null +++ b/tests/python/tir-schedule/test_tir_schedule_partition.py @@ -0,0 +1,460 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-function-docstring,missing-module-docstring +import pytest +import tvm +import tvm.testing +from tvm import te, tir +from tvm.script import tir as T +from tvm.tir.expr import IntImm +from tvm.tir.schedule.testing import ( + assert_structural_equal_ignore_global_symbol, + verify_trace_roundtrip, +) + +# pylint: disable=no-member,invalid-name,unused-variable + + +@T.prim_func +def elementwise(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, (128, 128, 128)) + B = T.match_buffer(b, (128, 128, 128)) + for i, j, k in T.grid(128, 128, 128): + with T.block("B"): + vi, vj, vk = T.axis.remap("SSS", [i, j, k]) + B[vi, vj, vk] = A[vi, vj, vk] * 2.0 + + +@T.prim_func +def elementwise_symbolic(a: T.handle, b: T.handle, n: T.int32) -> None: + A = T.match_buffer(a, (128, 128, n)) + B = T.match_buffer(b, (128, 128, n)) + for i, j, k in T.grid(128, 128, n): + with T.block("B"): + vi, vj, vk = T.axis.remap("SSS", [i, j, k]) + B[vi, vj, vk] = A[vi, vj, vk] * 2.0 + + +@T.prim_func +def elementwise_with_anno(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, (128, 128, 128)) + B = T.match_buffer(b, (128, 128, 128)) + for i, j in T.grid(128, 128): + for k in T.serial(0, 128, annotations={"useless_annotation": True}): + with T.block("B"): + vi, vj, vk = T.axis.remap("SSS", [i, j, k]) + T.reads([A[vi, vj, vk]]) + T.writes([B[vi, vj, vk]]) + B[vi, vj, vk] = A[vi, vj, vk] * 2.0 + + +@T.prim_func +def elementwise_with_thread_binding(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, (128, 128, 128)) + B = T.match_buffer(b, (128, 128, 128)) + for i, j in T.grid(128, 128): + for k in T.thread_binding(0, 128, thread="threadIdx.x"): + with T.block("B"): + vi, vj, vk = T.axis.remap("SSS", [i, j, k]) + T.reads([A[vi, vj, vk]]) + T.writes([B[vi, vj, vk]]) + B[vi, vj, vk] = A[vi, vj, vk] * 2.0 + + +@T.prim_func +def elementwise_with_opaque_block(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, (128, 128, 128)) + B = T.match_buffer(b, (128, 128, 128)) + for i, j, k in T.grid(128, 128, 128): + with T.block("opaque"): + T.reads([A[i, j, k]]) + T.writes([B[i, j, k]]) + with T.block("B"): + vi, vj, vk = T.axis.remap("SSS", [i, j, k]) + T.reads([A[vi, vj, vk]]) + T.writes([B[vi, vj, vk]]) + B[vi, vj, vk] = A[vi, vj, vk] * 2.0 + + +@T.prim_func +def elementwise_partition_with_opaque_block(a: T.handle, b: T.handle) -> None: + B = T.match_buffer(b, [128, 128, 128]) + A = T.match_buffer(a, [128, 128, 128]) + with T.block("root"): + T.reads() + T.writes() + with T.block("opaque_i_common"): + T.reads() + T.writes() + with T.block("opaque_i0_partition"): + T.reads() + T.writes() + for i0, j, k in T.grid(112, 128, 128): + with T.block("opaque_i0"): + T.reads(A[i0, j, k]) + T.writes(B[i0, j, k]) + with T.block("B_i0"): + vi, vj, vk = T.axis.remap("SSS", [i0, j, k]) + T.reads(A[0:112, 0:128, 0:128]) + T.writes(B[0:112, 0:128, 0:128]) + B[vi, vj, vk] = A[vi, vj, vk] * T.float32(2) + with T.block("opaque_i1_partition"): + T.reads() + T.writes() + for i1 in range(112, 128): + for j, k in T.grid(128, 128): + with T.block("opaque_i1"): + T.reads(A[i1, j, k]) + T.writes(B[i1, j, k]) + with T.block("B_i1"): + vi, vj, vk = T.axis.remap("SSS", [i1, j, k]) + T.reads(A[112:128, 0:128, 0:128]) + T.writes(B[112:128, 0:128, 0:128]) + B[vi, vj, vk] = A[vi, vj, vk] * T.float32(2) + + +@T.prim_func +def elementwise_loop_partition_case0(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, [128, 128, 128]) + B = T.match_buffer(b, [128, 128, 128]) + with T.block("root"): + T.reads() + T.writes() + with T.block("B_i_common"): + T.reads() + T.writes() + with T.block("B_i0_partition"): + T.reads() + T.writes() + for i0 in range(2): + with T.block("B_i0_j_common"): + T.reads() + T.writes() + with T.block("B_i0_j0_partition"): + T.reads() + T.writes() + for j0, k in T.grid(4, 128): + with T.block("B_i0_j0"): + vi, vj, vk = T.axis.remap("SSS", [i0, j0, k]) + T.reads(A[0:2, 0:4, 0:128]) + T.writes(B[0:2, 0:4, 0:128]) + B[vi, vj, vk] = A[vi, vj, vk] * T.float32(2) + with T.block("B_i0_j1_partition"): + T.reads() + T.writes() + for j1 in range(4, 36): + for k in range(128): + with T.block("B_i0_j1"): + vi, vj, vk = T.axis.remap("SSS", [i0, j1, k]) + T.reads(A[0:2, 4:36, 0:128]) + T.writes(B[0:2, 4:36, 0:128]) + B[vi, vj, vk] = A[vi, vj, vk] * T.float32(2) + with T.block("B_i0_j2_partition"): + T.reads() + T.writes() + for j2 in range(36, 128): + for k in range(128): + with T.block("B_i0_j2"): + vi, vj, vk = T.axis.remap("SSS", [i0, j2, k]) + T.reads(A[0:2, 36:128, 0:128]) + T.writes(B[0:2, 36:128, 0:128]) + B[vi, vj, vk] = A[vi, vj, vk] * T.float32(2) + with T.block("B_i1_partition"): + T.reads() + T.writes() + for i1 in range(2, 3): + for j, k in T.grid(128, 128): + with T.block("B_i1"): + vi, vj, vk = T.axis.remap("SSS", [i1, j, k]) + T.reads(A[2, 0:128, 0:128]) + T.writes(B[2, 0:128, 0:128]) + B[vi, vj, vk] = A[vi, vj, vk] * T.float32(2) + with T.block("B_i2_partition"): + T.reads() + T.writes() + for i2 in range(3, 67): + for j, k in T.grid(128, 128): + with T.block("B_i2"): + vi, vj, vk = T.axis.remap("SSS", [i2, j, k]) + T.reads(A[3:67, 0:128, 0:128]) + T.writes(B[3:67, 0:128, 0:128]) + B[vi, vj, vk] = A[vi, vj, vk] * T.float32(2) + with T.block("B_i3_partition"): + T.reads() + T.writes() + for i3 in range(67, 128): + for j, k in T.grid(128, 128): + with T.block("B_i3"): + vi, vj, vk = T.axis.remap("SSS", [i3, j, k]) + T.reads(A[67:128, 0:128, 0:128]) + T.writes(B[67:128, 0:128, 0:128]) + B[vi, vj, vk] = A[vi, vj, vk] * T.float32(2) + + +@T.prim_func +def elementwise_loop_partition_case1(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, [128, 128, 128]) + B = T.match_buffer(b, [128, 128, 128]) + with T.block("root"): + T.reads() + T.writes() + with T.block("B_i_common"): + T.reads() + T.writes() + with T.block("B_i0_partition"): + T.reads() + T.writes() + for i0, j, k in T.grid(63, 128, 128): + with T.block("B_i0"): + vi, vj, vk = T.axis.remap("SSS", [i0, j, k]) + T.reads(A[0:63, 0:128, 0:128]) + T.writes(B[0:63, 0:128, 0:128]) + B[vi, vj, vk] = A[vi, vj, vk] * T.float32(2) + with T.block("B_i1_partition"): + T.reads() + T.writes() + for i1 in range(63, 64): + for j in range(128): + with T.block("B_i1_k_common"): + T.reads() + T.writes() + with T.block("B_i1_k0_partition"): + T.reads() + T.writes() + for k0 in range(1): + with T.block("B_i1_k0"): + vi, vj, vk = T.axis.remap("SSS", [i1, j, k0]) + T.reads(A[63, 0:128, 0]) + T.writes(B[63, 0:128, 0]) + B[vi, vj, vk] = A[vi, vj, vk] * T.float32(2) + with T.block("B_i1_k1_partition"): + T.reads() + T.writes() + for k1 in range(1, 65): + with T.block("B_i1_k1"): + vi, vj, vk = T.axis.remap("SSS", [i1, j, k1]) + T.reads(A[63, 0:128, 1:65]) + T.writes(B[63, 0:128, 1:65]) + B[vi, vj, vk] = A[vi, vj, vk] * T.float32(2) + with T.block("B_i1_k2_partition"): + T.reads() + T.writes() + for k2 in range(65, 128): + with T.block("B_i1_k2"): + vi, vj, vk = T.axis.remap("SSS", [i1, j, k2]) + T.reads(A[63, 0:128, 65:128]) + T.writes(B[63, 0:128, 65:128]) + B[vi, vj, vk] = A[vi, vj, vk] * T.float32(2) + with T.block("B_i2_partition"): + T.reads() + T.writes() + for i2 in range(64, 128): + for j, k in T.grid(128, 128): + with T.block("B_i2"): + vi, vj, vk = T.axis.remap("SSS", [i2, j, k]) + T.reads(A[64:128, 0:128, 0:128]) + T.writes(B[64:128, 0:128, 0:128]) + B[vi, vj, vk] = A[vi, vj, vk] * T.float32(2) + + +@T.prim_func +def opaque_access(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, [16, 16], "float32") + B = T.match_buffer(b, [16, 16], "float32") + for i, j in T.grid(16, 16): + with T.block("A"): + vi, vj = T.axis.remap("SS", [i, j]) + T.reads([]) + T.writes([A[0:16, 0:16]]) + A[vi, vj] = 1 + for i, j in T.grid(16, 16): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + T.reads([]) + T.writes([B[0:16, 0:16]]) + T.evaluate(T.tvm_fill_fragment(B.data, 16, 16, 16, 0, vi * 16 + vj, dtype="handle")) + + +@T.prim_func +def opaque_access_loop_partition(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, (16, 16)) + B = T.match_buffer(b, (16, 16)) + for i in range(16): + with T.block("A_j_common"): + T.reads() + T.writes() + with T.block("A_j0_partition"): + T.reads() + T.writes() + for j0 in range(12): + with T.block("A_j0"): + vi, vj = T.axis.remap("SS", [i, j0]) + T.reads() + T.writes(A[0:16, 0:12]) + A[vi, vj] = T.float32(1) + with T.block("A_j1_partition"): + T.reads() + T.writes() + for j1 in range(12, 16): + with T.block("A_j1"): + vi, vj = T.axis.remap("SS", [i, j1]) + T.reads() + T.writes(A[0:16, 12:16]) + A[vi, vj] = T.float32(1) + for i in range(16): + with T.block("B_j_common"): + T.reads() + T.writes() + with T.block("B_j0_partition"): + T.reads() + T.writes() + for j0 in range(12): + with T.block("B_j0"): + vi, vj = T.axis.remap("SS", [i, j0]) + T.reads() + T.writes(B[0:16, 0:16]) + T.tvm_fill_fragment(B.data, 16, 16, 16, 0, vi * 16 + vj) + with T.block("B_j1_partition"): + T.reads() + T.writes() + for j1 in range(12, 16): + with T.block("B_j1"): + vi, vj = T.axis.remap("SS", [i, j1]) + T.reads() + T.writes(B[0:16, 0:16]) + T.tvm_fill_fragment(B.data, 16, 16, 16, 0, vi * 16 + vj) + + +# pylint: enable=no-member,invalid-name,unused-variable + + +def test_loop_partition(): + sch = tir.Schedule(elementwise, debug_mask="all") + block_b = sch.get_block("B") + i, j, k = sch.get_loops(block_b) + sch.loop_partition(i, factors=[2, 1, 64]) + + block_b_partition = sch.get_block("B_i0") + i, j, k = sch.get_loops(block_b_partition) + loops = sch.loop_partition(j, factors=[4, 32]) + + assert_structural_equal_ignore_global_symbol(elementwise_loop_partition_case0, sch.mod["main"]) + verify_trace_roundtrip(sch=sch, mod=elementwise) + + +def test_partition_with_inferred_factor(): + sch = tir.Schedule(elementwise, debug_mask="all") + block_b = sch.get_block("B") + i, j, k = sch.get_loops(block_b) + sch.loop_partition(i, factors=[None, 1, 64]) + + block_b_partition = sch.get_block("B_i1") + i, j, k = sch.get_loops(block_b_partition) + sch.loop_partition(k, factors=[1, 64, None]) + + assert_structural_equal_ignore_global_symbol(elementwise_loop_partition_case1, sch.mod["main"]) + verify_trace_roundtrip(sch=sch, mod=elementwise) + + +def test_partition_with_opaque_block(): + sch = tir.Schedule(elementwise_with_opaque_block, debug_mask="all") + block_opaque = sch.get_block("opaque") + i, _, _ = sch.get_loops(block_opaque) + sch.loop_partition(i, factors=[None, 16]) + assert_structural_equal_ignore_global_symbol( + elementwise_partition_with_opaque_block, sch.mod["main"] + ) + verify_trace_roundtrip(sch=sch, mod=elementwise_with_opaque_block) + + +def test_partition_with_opaque_access(): + sch = tir.Schedule(opaque_access, debug_mask="all") + block_a = sch.get_block("A") + _, j = sch.get_loops(block_a) + sch.loop_partition(j, factors=[None, 4]) + block_b = sch.get_block("B") + _, j = sch.get_loops(block_b) + sch.loop_partition(j, factors=[None, 4]) + assert_structural_equal_ignore_global_symbol(opaque_access_loop_partition, sch.mod["main"]) + verify_trace_roundtrip(sch=sch, mod=opaque_access) + + +def test_partition_int64_extent_with_mixed_factors(): + def _create_prim_func(): + m = te.const(384, "int64") + A = te.placeholder((m,), name="A", dtype="float32") + B = te.compute((m,), lambda i: A[i] + 1, name="B") + return te.create_prim_func([A, B]) + + mod = _create_prim_func() + sch = tir.Schedule(mod, debug_mask="all") + (i,) = sch.get_loops(sch.get_block("B")) + sch.loop_partition( + i, + factors=[ + te.const(1, "int64"), + te.const(51, "int32"), + ], + ) + + +def test_partition_fail_symbolic(): + sch = tir.Schedule(elementwise_symbolic, debug_mask="all") + block_b = sch.get_block("B") + _, _, k = sch.get_loops(block_b) + with pytest.raises(tvm.tir.ScheduleError): + sch.loop_partition(k, factors=[10, None]) + + +def test_partition_fail_out_of_bound(): + sch = tir.Schedule(elementwise, debug_mask="all") + block_b = sch.get_block("B") + i, j, k = sch.get_loops(block_b) + with pytest.raises(tvm.tir.ScheduleError): + sch.loop_partition(i, factors=[1000, 2, 3]) + + +def test_partition_with_non_positive_factors(): + sch = tir.Schedule(elementwise, debug_mask="all") + block_b = sch.get_block("B") + i, j, k = sch.get_loops(block_b) + with pytest.raises(tvm.tir.ScheduleError): + sch.loop_partition(i, factors=[-2, -64]) + with pytest.raises(tvm.tir.ScheduleError): + sch.loop_partition(j, factors=[0, None]) + with pytest.raises(tvm.tir.ScheduleError): + sch.loop_partition(k, factors=[None, -16]) + + +def test_partition_fail_with_annotation(): + sch = tir.Schedule(elementwise_with_anno, debug_mask="all") + block_b = sch.get_block("B") + _, j, k = sch.get_loops(block_b) + with pytest.raises(tvm.tir.ScheduleError): + sch.loop_partition(k, factors=[None, 10]) + + +def test_partition_fail_with_thread_binding(): + sch = tir.Schedule(elementwise_with_thread_binding, debug_mask="all") + block_b = sch.get_block("B") + _, j, k = sch.get_loops(block_b) + with pytest.raises(tvm.tir.ScheduleError): + sch.loop_partition(k, factors=[None, 10]) + + +if __name__ == "__main__": + tvm.testing.main()