Skip to content

Commit

Permalink
feat(//core/execution): Type checking for the executor, now is the
Browse files Browse the repository at this point in the history
responsibility of the user to transfer data to GPU and ensure types are
correct.

Signed-off-by: Naren Dasan <naren@narendasan.com>
Signed-off-by: Naren Dasan <narens@nvidia.com>
  • Loading branch information
narendasan committed Apr 23, 2020
1 parent 2f86f84 commit 2dd1ba3
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 6 deletions.
7 changes: 6 additions & 1 deletion README.md
Expand Up @@ -17,14 +17,19 @@ More Information / System Architecture:
...
auto compile_settings = trtorch::ExtraInfo(dims);
// FP16 execution
compile_settings.op_precision = torch::kHalf;
compile_settings.op_precision = torch::kFloat;
// Compile module
auto trt_mod = trtorch::CompileGraph(ts_mod, compile_settings);
// Run like normal
auto results = trt_mod.forward({in_tensor});
...
```
> Notes on running in lower precisions:
> - Set precision with extra_info.op_precision
> - The module should be left in FP32 before compilation
> - In FP16 only input tensors should be converted to FP16, other precisions use FP32
## Platform Support
| Platform | Support |
Expand Down
8 changes: 4 additions & 4 deletions core/execution/TRTEngine.cpp
Expand Up @@ -40,7 +40,7 @@ c10::FunctionSchema GenerateEngineFunctionSchema(EngineID id, nvinfer1::ICudaEng
}
}
}

ss << in_ss.str();
ss << ") -> (";
ss << out_ss.str();
Expand All @@ -56,15 +56,15 @@ TRTEngine::TRTEngine(nvinfer1::ILogger& logger, std::string& serialized_engine)
: schema(torch::jit::parseSchema("trt::noop() -> ()")) { // Need a better default

rt = nvinfer1::createInferRuntime(logger);

cuda_engine = rt->deserializeCudaEngine(serialized_engine.c_str(), serialized_engine.size());
// Easy way to get a unique name for each engine, maybe there is a more descriptive way (using something associated with the graph maybe)
id = reinterpret_cast<EngineID>(cuda_engine);
exec_ctx = cuda_engine->createExecutionContext();

uint64_t inputs = 0;
uint64_t outputs = 0;

for (int64_t x = 0; x < cuda_engine->getNbBindings(); x++) {
if(cuda_engine->bindingIsInput(x)) {
inputs++;
Expand Down
2 changes: 1 addition & 1 deletion core/execution/execution.h
@@ -1,5 +1,5 @@
#pragma once
#include <utility>
#include <utility>
#include "NvInfer.h"
#include "ATen/core/function_schema.h"

Expand Down
13 changes: 13 additions & 0 deletions core/execution/register_trt_op.cpp
@@ -1,5 +1,6 @@
#include "c10/cuda/CUDAStream.h"

#include "torch/torch.h"
#include "torch/csrc/jit/custom_operator.h"

#include "core/util/prelude.h"
Expand All @@ -15,6 +16,18 @@ std::vector<at::Tensor> RunCudaEngine(nvinfer1::IExecutionContext* ctx, std::pai
std::vector<at::Tensor> contig_inputs{};
contig_inputs.reserve(inputs.size());
for (size_t i = 0; i < inputs.size(); i++) {
TRTORCH_CHECK(inputs[i].is_cuda(), "Expected input tensors to have device cuda, found device " << inputs[i].device());
auto expected_type = torch::kF32;
switch (ctx->getEngine().getBindingDataType(i)) {
case nvinfer1::DataType::kHALF:
expected_type = torch::kF16;
break;
case nvinfer1::DataType::kFLOAT:
case nvinfer1::DataType::kINT8:
default:
expected_type = torch::kF32;
}
TRTORCH_CHECK(inputs[i].dtype() == expected_type, "Expected input tensors to have type " << expected_type << ", found type " << inputs[i].dtype());
auto dims = core::util::toDimsPad(inputs[i].sizes(), 1);
auto shape = core::util::toVec(dims);
contig_inputs.push_back(inputs[i].to(at::kCUDA).view(shape).contiguous());
Expand Down

0 comments on commit 2dd1ba3

Please sign in to comment.