Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 20 additions & 11 deletions python/tvm/ir/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@
# under the License.
"""Common base structures."""

import tvm_ffi
from tvm_ffi import get_global_func, register_object

import tvm.error
from tvm.runtime import Object, _ffi_node_api

from . import _ffi_api, json_compact
Expand Down Expand Up @@ -205,9 +205,7 @@ def structural_equal(lhs, rhs, map_free_vars=False):
structural_hash
assert_strucural_equal
"""
lhs = tvm.runtime.convert(lhs)
rhs = tvm.runtime.convert(rhs)
return bool(_ffi_node_api.StructuralEqual(lhs, rhs, False, map_free_vars)) # type: ignore # pylint: disable=no-member
return tvm_ffi.structural_equal(lhs, rhs, map_free_vars)


def get_first_structural_mismatch(lhs, rhs, map_free_vars=False, skip_tensor_content=False):
Expand All @@ -234,9 +232,7 @@ def get_first_structural_mismatch(lhs, rhs, map_free_vars=False, skip_tensor_con
`None` if `lhs` and `rhs` are structurally equal.
Otherwise, a tuple of two AccessPath objects that point to the first detected mismtach.
"""
lhs = tvm.runtime.convert(lhs)
rhs = tvm.runtime.convert(rhs)
return _ffi_node_api.GetFirstStructuralMismatch(lhs, rhs, map_free_vars, skip_tensor_content) # type: ignore # pylint: disable=no-member
return tvm_ffi.get_first_structural_mismatch(lhs, rhs, map_free_vars, skip_tensor_content)


def assert_structural_equal(lhs, rhs, map_free_vars=False):
Expand All @@ -262,9 +258,22 @@ def assert_structural_equal(lhs, rhs, map_free_vars=False):
--------
structural_equal
"""
lhs = tvm.runtime.convert(lhs)
rhs = tvm.runtime.convert(rhs)
_ffi_node_api.StructuralEqual(lhs, rhs, True, map_free_vars) # type: ignore # pylint: disable=no-member
first_mismatch = tvm_ffi.get_first_structural_mismatch(lhs, rhs, map_free_vars)
if first_mismatch is not None:
from tvm.runtime.script_printer import ( # pylint: disable=import-outside-toplevel
PrinterConfig,
_script,
)

lhs_path, rhs_path = first_mismatch
lhs_script = _script(lhs, PrinterConfig(syntax_sugar=False, path_to_underline=[lhs_path]))
rhs_script = _script(rhs, PrinterConfig(syntax_sugar=False, path_to_underline=[rhs_path]))
raise ValueError(
f"StructuralEqual check failed, caused by lhs at {lhs_path}:\n"
f"{lhs_script}\n"
f"and rhs at {rhs_path}:\n"
f"{rhs_script}"
)


def structural_hash(node, map_free_vars=False):
Expand Down Expand Up @@ -306,7 +315,7 @@ def structural_hash(node, map_free_vars=False):
--------
structrual_equal
"""
return _ffi_node_api.StructuralHash(node, map_free_vars) # type: ignore # pylint: disable=no-member
return tvm_ffi.structural_hash(node, map_free_vars)


def deprecated(
Expand Down
15 changes: 15 additions & 0 deletions src/ir/module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,16 @@
*/
#include <tvm/ffi/cast.h>
#include <tvm/ffi/container/variant.h>
#include <tvm/ffi/extra/base64.h>
#include <tvm/ffi/extra/module.h>
#include <tvm/ffi/extra/structural_equal.h>
#include <tvm/ffi/function.h>
#include <tvm/ffi/reflection/registry.h>
#include <tvm/ffi/rvalue_ref.h>
#include <tvm/ir/global_var_supply.h>
#include <tvm/ir/module.h>
#include <tvm/ir/type_functor.h>
#include <tvm/target/codegen.h>

#include <algorithm>
#include <fstream>
Expand Down Expand Up @@ -230,6 +233,18 @@ IRModule IRModule::FromExpr(const RelaxExpr& expr,

TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::TypeAttrDef<ffi::ModuleObj>()
.def("__data_to_json__",
[](const ffi::ModuleObj* node) {
std::string bytes = codegen::SerializeModuleToBytes(ffi::GetRef<ffi::Module>(node),
/*export_dso*/ false);
return ffi::Base64Encode(ffi::Bytes(bytes));
})
.def("__data_from_json__", [](const ffi::String& base64_bytes) {
ffi::Bytes bytes = ffi::Base64Decode(base64_bytes);
ffi::Module rtmod = codegen::DeserializeModuleFromBytes(bytes.operator std::string());
return rtmod;
});
refl::GlobalDef()
.def("ir.IRModule",
[](tvm::ffi::Map<GlobalVar, BaseFunc> funcs, tvm::ffi::ObjectRef attrs,
Expand Down
83 changes: 0 additions & 83 deletions src/ir/structural_equal.cc

This file was deleted.

89 changes: 0 additions & 89 deletions src/ir/structural_hash.cc

This file was deleted.

21 changes: 21 additions & 0 deletions src/runtime/tensor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,15 @@
* \brief Tensor container infratructure.
*/
#include <tvm/ffi/error.h>
#include <tvm/ffi/extra/base64.h>
#include <tvm/ffi/function.h>
#include <tvm/ffi/reflection/registry.h>
#include <tvm/runtime/base.h>
#include <tvm/runtime/device_api.h>
#include <tvm/runtime/tensor.h>

#include "../support/base64.h"
#include "../support/bytes_io.h"
#include "tvm/runtime/data_type.h"

namespace tvm {
Expand Down Expand Up @@ -241,6 +244,24 @@ using namespace tvm::runtime;

TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::TypeAttrDef<tvm::ffi::TensorObj>()
.def("__data_to_json__",
[](const tvm::ffi::TensorObj* node) {
std::string result;
tvm::support::BytesOutStream mstrm(&result);
tvm::support::Base64OutStream b64strm(&mstrm);
tvm::runtime::SaveDLTensor(&b64strm, node);
b64strm.Finish();
return tvm::ffi::String(std::move(result));
})
.def("__data_from_json__", [](const std::string& blob) {
tvm::support::BytesInStream mstrm(blob);
tvm::support::Base64InStream b64strm(&mstrm);
b64strm.InitPosition();
tvm::runtime::Tensor temp;
TVM_FFI_ICHECK(temp.Load(&b64strm));
return temp;
});
refl::GlobalDef()
.def("runtime.TVMTensorAllocWithScope", Tensor::Empty)
.def_method("runtime.TVMTensorCreateView", &Tensor::CreateView)
Expand Down
Loading