Skip to content

Commit

Permalink
feat(//core): New API to register arbitrary TRT engines in TorchScript
Browse files Browse the repository at this point in the history
Modules

Signed-off-by: Naren Dasan <naren@narendasan.com>
Signed-off-by: Naren Dasan <narens@nvidia.com>
  • Loading branch information
narendasan committed Apr 21, 2021
1 parent bbf997e commit 3ec836e
Show file tree
Hide file tree
Showing 5 changed files with 62 additions and 0 deletions.
14 changes: 14 additions & 0 deletions core/compiler.cpp
Expand Up @@ -173,6 +173,20 @@ torch::jit::script::Module CompileGraph(const torch::jit::script::Module& mod, C
return new_mod;
}

torch::jit::script::Module EmbedEngineInNewModule(std::string& engine) {
std::ostringstream engine_id;
engine_id << reinterpret_cast<int*>(&engine);
torch::jit::script::Module new_mod("tensorrt_engine_mod_" + engine_id.str());
auto new_g = std::make_shared<torch::jit::Graph>();
AddEngineToGraph(new_mod, new_g, engine);
auto new_method = new_mod._ivalue()->compilation_unit()->create_function("forward", new_g);
auto schema = GenerateGraphSchema(new_mod, new_method->name(), new_g);
new_mod.type()->addMethod(new_method);
new_method->setSchema(schema);

return new_mod;
}

void set_device(const int gpu_id) {
TRTORCH_ASSERT(cudaSetDevice(gpu_id) == cudaSuccess, "Unable to set CUDA device: " << gpu_id);
}
Expand Down
2 changes: 2 additions & 0 deletions core/compiler.h
Expand Up @@ -19,6 +19,8 @@ std::string ConvertGraphToTRTEngine(const torch::jit::script::Module& mod, std::

torch::jit::script::Module CompileGraph(const torch::jit::script::Module& module, CompileSpec cfg);

torch::jit::script::Module EmbedEngineInNewModule(std::string& engine);

void set_device(const int gpu_id);

} // namespace core
Expand Down
14 changes: 14 additions & 0 deletions cpp/api/include/trtorch/trtorch.h
Expand Up @@ -480,6 +480,20 @@ TRTORCH_API std::string ConvertGraphToTRTEngine(
const torch::jit::Module& module,
std::string method_name,
CompileSpec info);

/**
* @brief Take a previously created TensorRT engine and embed it in
* in a TorchScript module
*
* @param engine: std::string - Precompiled serialized TensorRT engine
*
* Takes a prebuilt serialized TensorRT engine and embeds it in a TorchScript
* graph. Registers the engine as the forward method of the module
*
* @return: A new module trageting a TensorRT engine
*/
TRTORCH_API torch::jit::Module EmbedEngineInNewModule(std::string& engine);

/**
* @brief Set gpu device id
*
Expand Down
4 changes: 4 additions & 0 deletions cpp/api/src/trtorch.cpp
Expand Up @@ -31,6 +31,10 @@ torch::jit::script::Module CompileGraph(const torch::jit::script::Module& module
return core::CompileGraph(module, to_internal_compile_spec(info));
}

torch::jit::Module EmbedEngineInNewModule(std::string& engine) {
return core::EmbedEngineInNewModule(engine);
}

std::string get_build_info() {
auto info = core::util::get_build_info();
return std::string("TRTorch Version: ") + TRTORCH_VERSION + '\n' + info;
Expand Down
28 changes: 28 additions & 0 deletions tests/modules/test_modules_as_engines.cpp
Expand Up @@ -16,6 +16,34 @@ TEST_P(ModuleTests, ModuleAsEngineIsClose) {
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-5));
}

TEST_P(ModuleTests, ModuleToModuleIsClose) {
std::vector<at::Tensor> inputs;
std::vector<torch::jit::IValue> inputs_ivalues;
for (auto in_shape : input_shapes) {
inputs.push_back(at::randint(5, in_shape, {at::kCUDA}));
inputs_ivalues.push_back(inputs[inputs.size() - 1].clone());
}

torch::jit::IValue jit_results_ivalues = trtorch::tests::util::RunModuleForward(mod, inputs_ivalues);
std::vector<at::Tensor> jit_results;
jit_results.push_back(jit_results_ivalues.toTensor());

auto forward_graph = mod.get_method("forward");
std::vector<c10::ArrayRef<int64_t>> input_ranges;
for (auto in : inputs) {
input_ranges.push_back(in.sizes());
}

auto engine = trtorch::ConvertGraphToTRTEngine(mod, "forward", input_ranges);
auto trt_mod = trtorch::EmbedEngineInNewModule(engine);

torch::jit::IValue trt_results_ivalues = trtorch::tests::util::RunModuleForward(mod, inputs_ivalues);
std::vector<at::Tensor> trt_results;
trt_results.push_back(trt_results_ivalues.toTensor());

ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-5));
}

INSTANTIATE_TEST_SUITE_P(
ModuleAsEngineForwardIsCloseSuite,
ModuleTests,
Expand Down

0 comments on commit 3ec836e

Please sign in to comment.