Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions include/tvm/runtime/module.h
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,14 @@ class TVM_DLL ModuleNode : public Object {
* \param format The format of the file.
*/
virtual void SaveToFile(const std::string& file_name, const std::string& format);

/*!
* \brief Save the module with separate funciton files.
* \param file_name The prefix of the files to be saved to.
* \param format The format of the file.
*/
virtual std::string SaveToFileSeparateFuncs(const std::string& prefix, const std::string& format);

/*!
* \brief Save the module to binary stream.
* \param stream The binary stream to save to.
Expand Down
63 changes: 62 additions & 1 deletion python/tvm/runtime/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from .packed_func import PackedFunc, PackedFuncHandle, _set_class_module

from . import _ffi_api

import logging

class BenchmarkResult:
"""Runtimes from benchmarking"""
Expand Down Expand Up @@ -251,6 +251,10 @@ def is_dso_exportable(self):
"""
return _ffi_api.ModuleIsDSOExportable(self)

def save_separate_funcs(self, file_name, fmt=""):
return _ffi_api.ModuleSaveToFileSeparateFuncs(self, file_name, fmt)


def save(self, file_name, fmt=""):
"""Save the module to file.

Expand Down Expand Up @@ -394,6 +398,63 @@ def _collect_from_import_tree(self, filter_func):
def _collect_dso_modules(self):
return self._collect_from_import_tree(lambda m: m.is_dso_exportable)

# Most of the following code is copied from `export_library`
def export_library_separate_funcs(self, output_dir, prefix, fcompile=None, addons=None, workspace_dir=None, **kwargs):
# NOTE: this function depends on contrib library features
# which are only available in when TVM function is available.
if _RUNTIME_ONLY:
raise RuntimeError("Cannot call export_library in runtime only mode")
# Extra dependencies during runtime.
from pathlib import Path
from tvm.contrib import cc as _cc, tar as _tar, utils as _utils

if self.type_key == "stackvm":
raise NotImplementedError()

modules = self._collect_dso_modules()
if workspace_dir is None:
temp = _utils.tempdir()
workspace_dir = temp.temp_dir
is_system_lib = False
llvm_target_string = None
assert len(modules) == 1
for index, module in enumerate(modules):

object_format = "o"
path_prefix = os.path.join(workspace_dir, f"lib{index}_{prefix}")
#build_files_str=""
build_files_str=module.save_separate_funcs(path_prefix, fmt=object_format)
build_files=build_files_str.split(",")
logging.info(f'Build files: {build_files}')


if module.type_key == "llvm":
is_system_lib = module.get_function("__tvm_is_system_module")()
llvm_target_string = module.get_function("_get_target_string")()
if not fcompile:
fcompile = _cc.create_shared

if llvm_target_string is None and hasattr(fcompile, "get_target_triple"):
triple = fcompile.get_target_triple()
assert triple, "Target triple should not be empty"
llvm_target_string = "llvm -mtriple " + triple

if getattr(fcompile, "need_system_lib", False) and not is_system_lib:
raise ValueError("%s need --system-lib option" % str(fcompile))

if self.imported_modules:
raise NotImplementedError()

for file in build_files:
file=file.strip()
if file=='':
continue
compiled_file_name=file.split("/")[-1][:-2]+".so"
compiled_file_name=os.path.join(output_dir, compiled_file_name)
logging.info(f'compiled_file_name: {compiled_file_name}')
fcompile(compiled_file_name, [file], **kwargs)


def export_library(self, file_name, fcompile=None, addons=None, workspace_dir=None, **kwargs):
"""
Export the module and all imported modules into a single device library.
Expand Down
9 changes: 9 additions & 0 deletions src/runtime/library_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ class LibraryModuleNode final : public ModuleNode {
const char* type_key() const final { return "library"; }

PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>& sptr_to_self) final {
LOG(INFO)<<"Library module load func: "<<name;
TVMBackendPackedCFunc faddr;
if (name == runtime::symbol::tvm_module_main) {
const char* entry_name =
Expand Down Expand Up @@ -198,6 +199,7 @@ void ProcessModuleBlob(const char* mblob, ObjectPtr<Library> lib,
}

Module CreateModuleFromLibrary(ObjectPtr<Library> lib, PackedFuncWrapper packed_func_wrapper) {
LOG(INFO)<<"Create Module from lib: ";
InitContextFunctions([lib](const char* fname) { return lib->GetSymbol(fname); });
auto n = make_object<LibraryModuleNode>(lib, packed_func_wrapper);
// Load the imported modules
Expand All @@ -210,6 +212,7 @@ Module CreateModuleFromLibrary(ObjectPtr<Library> lib, PackedFuncWrapper packed_
ProcessModuleBlob(dev_mblob, lib, packed_func_wrapper, &root_mod, &dso_ctx_addr);
} else {
// Only have one single DSO Module
LOG(INFO)<<"Only have one dso";
root_mod = Module(n);
dso_ctx_addr = root_mod.operator->();
}
Expand All @@ -223,8 +226,14 @@ Module CreateModuleFromLibrary(ObjectPtr<Library> lib, PackedFuncWrapper packed_
}

TVM_REGISTER_GLOBAL("runtime.module.loadfile_so").set_body([](TVMArgs args, TVMRetValue* rv) {
LOG(INFO)<<"loadfile_so";
ObjectPtr<Library> n = CreateDSOLibraryObject(args[0]);
*rv = CreateModuleFromLibrary(n);
});

TVM_REGISTER_GLOBAL("runtime.module.loadfile_func_so").set_body([](TVMArgs args, TVMRetValue* rv) {
LOG(INFO)<<"loadfile_func_so";
});

} // namespace runtime
} // namespace tvm
30 changes: 30 additions & 0 deletions src/runtime/module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ Module Module::LoadFromFile(const std::string& file_name, const std::string& for
fmt = "so";
}
std::string load_f_name = "runtime.module.loadfile_" + fmt;
LOG(INFO)<<"load from file"<<load_f_name;
VLOG(1) << "Loading module from '" << file_name << "' of format '" << fmt << "'";
const PackedFunc* f = Registry::Get(load_f_name);
ICHECK(f != nullptr) << "Loader for `." << format << "` files is not registered,"
Expand All @@ -93,10 +94,35 @@ Module Module::LoadFromFile(const std::string& file_name, const std::string& for
return m;
}

Module Module::LoadFromSeparateFuncFiles(const std::string& dest_dir,
const std::string& prefix,
const std::string& format = "") {
std::string fmt = format;
ICHECK(fmt.length() != 0) << "Cannot deduce format of file " << fmt;
if (fmt == "dll" || fmt == "dylib" || fmt == "dso") {
fmt = "so";
}
std::string load_f_name = "runtime.module.loadfile_func_" + fmt;
LOG(INFO)<<"load from file"<<load_f_name;
const PackedFunc* f = Registry::Get(load_f_name);
ICHECK(f != nullptr) << "Loader for `." << format << "` files is not registered,"
<< " resolved to (" << load_f_name << ") in the global registry."
<< "Ensure that you have loaded the correct runtime code, and"
<< "that you are on the correct hardware architecture.";
// Module m = (*f)(file_name, format);
// return m;
}



void ModuleNode::SaveToFile(const std::string& file_name, const std::string& format) {
LOG(FATAL) << "Module[" << type_key() << "] does not support SaveToFile";
}

std::string ModuleNode::SaveToFileSeparateFuncs(const std::string& prefix, const std::string& format) {
LOG(FATAL) << "Module[" << type_key() << "] does not support SaveToFileSeparateFuncs";
}

void ModuleNode::SaveToBinary(dmlc::Stream* stream) {
LOG(FATAL) << "Module[" << type_key() << "] does not support SaveToBinary";
}
Expand Down Expand Up @@ -199,6 +225,10 @@ TVM_REGISTER_GLOBAL("runtime.ModuleLoadFromFile").set_body_typed(Module::LoadFro
TVM_REGISTER_GLOBAL("runtime.ModuleSaveToFile")
.set_body_typed([](Module mod, String name, tvm::String fmt) { mod->SaveToFile(name, fmt); });

TVM_REGISTER_GLOBAL("runtime.ModuleSaveToFileSeparateFuncs")
.set_body_typed([](Module mod, String name, tvm::String fmt) { return mod->SaveToFileSeparateFuncs(name, fmt); });


TVM_REGISTER_GLOBAL("runtime.ModuleIsDSOExportable").set_body_typed([](Module mod) {
return mod->IsDSOExportable();
});
Expand Down
136 changes: 135 additions & 1 deletion src/target/llvm/llvm_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ class LLVMModuleNode final : public runtime::ModuleNode {

PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>& sptr_to_self) final;

std::string SaveToFileSeparateFuncs(const std::string& prefix, const std::string& format) final;
void SaveToFile(const std::string& file_name, const std::string& format) final;
void SaveToBinary(dmlc::Stream* stream) final;
std::string GetSource(const std::string& format) final;
Expand All @@ -109,15 +110,20 @@ class LLVMModuleNode final : public runtime::ModuleNode {
bool IsCompatibleWithHost(const llvm::TargetMachine* tm) const;
void* GetGlobalAddr(const std::string& name, const LLVMTarget& llvm_target) const;
void* GetFunctionAddr(const std::string& name, const LLVMTarget& llvm_target) const;
void AddFuncModule(const IRModule& mod, const Target& target, PrimFunc func);

// The LLVM scope object.
std::unique_ptr<LLVMInstance> llvm_instance_;
// llvm instance for each function
std::vector<LLVMInstance> func_llvm_instances_;
// JIT lock
std::mutex mutex_;
// execution engine
llvm::ExecutionEngine* ee_{nullptr};
// The raw pointer to the module.
llvm::Module* module_{nullptr};
// llvm module for each function
std::vector<llvm::Module*> func_modules_;
// The unique_ptr owning the module. This becomes empty once JIT has been initialized
// (EngineBuilder takes ownership of the module).
std::unique_ptr<llvm::Module> module_owning_ptr_;
Expand Down Expand Up @@ -168,6 +174,86 @@ PackedFunc LLVMModuleNode::GetFunction(const std::string& name,
return WrapPackedFunc(faddr, sptr_to_self);
}

std::string LLVMModuleNode::SaveToFileSeparateFuncs(const std::string& file_prefix, const std::string& format) {
std::string files="";
for(int i=0;i<func_modules_.size();i++){
LLVMInstance func_llvm_instance = func_llvm_instances_[i];
llvm::Module* func_llvm_module = func_modules_[i];
String func_name=function_names_[i];

std::string fmt = format;

std::string func_file_name = file_prefix + "_" + func_name + "." + fmt;
LOG(INFO)<<"Save file: "<<func_file_name;
files=files + "," + func_file_name;

std::error_code ecode;
#if TVM_LLVM_VERSION <= 70
llvm::raw_fd_ostream dest(func_file_name, ecode, llvm::sys::fs::F_None);
#else
llvm::raw_fd_ostream dest(func_file_name, ecode, llvm::sys::fs::OF_None);
#endif
ICHECK_EQ(ecode.value(), 0) << "Cannot open file: " << func_file_name << " " << ecode.message();
if (fmt == "o" || fmt == "obj") {
With<LLVMTarget> llvm_target(func_llvm_instance, LLVMTarget::GetTargetMetadata(*func_llvm_module));
#if TVM_LLVM_VERSION <= 60
std::unique_ptr<llvm::Module> m = llvm::CloneModule(func_llvm_module);
#else
std::unique_ptr<llvm::Module> m = llvm::CloneModule(*func_llvm_module);
#endif
llvm::legacy::PassManager pass;
llvm::TargetMachine* tm = llvm_target->GetOrCreateTargetMachine();
#if TVM_LLVM_VERSION <= 60
ICHECK(tm->addPassesToEmitFile(pass, dest, llvm::TargetMachine::CGFT_ObjectFile) == 0)
<< "Cannot emit target CGFT_ObjectFile";
#elif TVM_LLVM_VERSION <= 90
ICHECK(tm->addPassesToEmitFile(pass, dest, nullptr, llvm::TargetMachine::CGFT_ObjectFile) == 0)
<< "Cannot emit target CGFT_ObjectFile";
#else
ICHECK(tm->addPassesToEmitFile(pass, dest, nullptr, llvm::CGFT_ObjectFile) == 0)
<< "Cannot emit target CGFT_ObjectFile";
#endif
pass.run(*m);
} else if (fmt == "s" || fmt == "asm") {
With<LLVMTarget> llvm_target(func_llvm_instance, LLVMTarget::GetTargetMetadata(*func_llvm_module));
#if TVM_LLVM_VERSION <= 60
std::unique_ptr<llvm::Module> m = llvm::CloneModule(func_llvm_module);
#else
std::unique_ptr<llvm::Module> m = llvm::CloneModule(*func_llvm_module);
#endif
llvm::legacy::PassManager pass;
llvm::TargetMachine* tm = llvm_target->GetOrCreateTargetMachine();
#if TVM_LLVM_VERSION <= 60
ICHECK(tm->addPassesToEmitFile(pass, dest, llvm::TargetMachine::CGFT_AssemblyFile) == 0)
<< "Cannot emit target CGFT_AssemblyFile";
#elif TVM_LLVM_VERSION <= 90
ICHECK(tm->addPassesToEmitFile(pass, dest, nullptr, llvm::TargetMachine::CGFT_AssemblyFile) ==
0)
<< "Cannot emit target CGFT_AssemblyFile";
#else
ICHECK(tm->addPassesToEmitFile(pass, dest, nullptr, llvm::CGFT_AssemblyFile) == 0)
<< "Cannot emit target CGFT_AssemblyFile";
#endif
pass.run(*m);
} else if (fmt == "ll") {
func_llvm_module->print(dest, nullptr);
} else if (fmt == "bc") {
#if TVM_LLVM_VERSION <= 60
llvm::WriteBitcodeToFile(func_llvm_module, dest);
#else
llvm::WriteBitcodeToFile(*func_llvm_module, dest);
#endif
} else {
LOG(FATAL) << "Do not know how to save file " << func_file_name << " with format=\'" << format
<< "\'";
}
dest.close();
}
return files;

}


void LLVMModuleNode::SaveToFile(const std::string& file_name, const std::string& format) {
std::string fmt = runtime::GetFileFormat(file_name, format);
std::error_code ecode;
Expand Down Expand Up @@ -276,6 +362,54 @@ std::string LLVMModuleNode::GetSource(const std::string& format) {
return "";
}

void LLVMModuleNode::AddFuncModule(const IRModule& mod, const Target& target, PrimFunc func) {
LLVMInstance func_llvm_instance_ = LLVMInstance();
With<LLVMTarget> llvm_target(func_llvm_instance_, target);
llvm::TargetMachine* tm = llvm_target->GetOrCreateTargetMachine();
std::unique_ptr<CodeGenLLVM> cg = CodeGenLLVM::Create(llvm_target.get());

std::vector<PrimFunc> funcs;
std::string entry_func;
relay::Runtime runtime =
mod->GetAttr<relay::Runtime>(tvm::attr::kRuntime).value_or(relay::Runtime::Create("cpp"));
bool system_lib = runtime->GetAttr<Bool>("system-lib").value_or(Bool(false));
bool target_c_runtime = runtime->name == "crt";

funcs.push_back(func);

cg->Init("TVMFuncMod", llvm_target.get(), system_lib, system_lib, target_c_runtime);
cg->SetFastMathFlags(llvm_target->GetFastMathFlags());

cg->AddFunctionsOrdered(funcs.begin(), funcs.end());

// (jzh18) assume no entry function
// if (entry_func.length() != 0) {
// cg->AddMainFunction(entry_func);
// }

// // (@jzh18) what's the usage of module_owning_ptr_
std::unique_ptr<llvm::Module> func_module_owning_ptr_ = cg->Finish();
llvm::Module* func_module_ = func_module_owning_ptr_.get();
func_module_owning_ptr_.release();
llvm_target->SetTargetMetadata(func_module_);
func_module_->addModuleFlag(llvm::Module::Override, "Debug Info Version",
llvm::DEBUG_METADATA_VERSION);

if (tm->getTargetTriple().isOSDarwin()) {
func_module_->addModuleFlag(llvm::Module::Override, "Dwarf Version", 2);
}

func_llvm_instances_.push_back(func_llvm_instance_);
func_modules_.push_back(func_module_);
std::string verify_errors_storage;
llvm::raw_string_ostream verify_errors(verify_errors_storage);
LOG_IF(FATAL, llvm::verifyModule(*func_module_, &verify_errors))
<< "LLVM module verification failed with the following errors: \n"
<< verify_errors.str();

}


void LLVMModuleNode::Init(const IRModule& mod, const Target& target) {
llvm_instance_ = std::make_unique<LLVMInstance>();
With<LLVMTarget> llvm_target(*llvm_instance_, target);
Expand Down Expand Up @@ -303,6 +437,7 @@ void LLVMModuleNode::Init(const IRModule& mod, const Target& target) {
entry_func = global_symbol.value();
}
funcs.push_back(f);
AddFuncModule(mod, target, f);
}
// TODO(@jroesch): follow up on this condition.
// ICHECK(funcs.size() > 0);
Expand All @@ -325,7 +460,6 @@ void LLVMModuleNode::Init(const IRModule& mod, const Target& target) {
if (tm->getTargetTriple().isOSDarwin()) {
module_->addModuleFlag(llvm::Module::Override, "Dwarf Version", 2);
}

std::string verify_errors_storage;
llvm::raw_string_ostream verify_errors(verify_errors_storage);
LOG_IF(FATAL, llvm::verifyModule(*module_, &verify_errors))
Expand Down