Skip to content

Commit

Permalink
feat: added user level API for fallback
Browse files Browse the repository at this point in the history
Signed-off-by: Bo Wang <wangbo1995ee@163.com>
  • Loading branch information
bowang007 committed Mar 10, 2021
1 parent 55e0510 commit f4c29b4
Show file tree
Hide file tree
Showing 6 changed files with 74 additions and 30 deletions.
54 changes: 28 additions & 26 deletions core/compiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -156,29 +156,6 @@ std::string ConvertGraphToTRTEngine(const torch::jit::script::Module& mod, std::
return std::move(engine);
}

//torch::jit::script::Module CompileGraph(const torch::jit::script::Module& mod, CompileSpec cfg) {
// // TODO: Should be doing a functional transform but need PR #31978
// // [jit] More robust mangling
// // torch::jit::script::Module new_mod = mod.clone();
// torch::jit::script::Module new_mod(mod._ivalue()->name() + "_trt");
// std::vector<std::shared_ptr<torch::jit::Graph>> graphs;
// for (const torch::jit::script::Method& method : mod.get_methods()) {
// // Don't convert hidden methods
// if (method.name().rfind("_", 0)) {
// auto engine = ConvertGraphToTRTEngine(mod, method.name(), cfg);
// auto new_g = std::make_shared<torch::jit::Graph>();
// AddEngineToGraph(new_mod, new_g, engine);
// auto new_method = new_mod._ivalue()->compilation_unit()->create_function(method.name(), new_g);
// auto schema = GenerateGraphSchema(new_mod, new_method->name(), new_g);
// new_mod.type()->addMethod(new_method);
// new_method->setSchema(schema);
// }
// }
//
// return new_mod;
//}



void AddSegmentedBlockToGraph(std::shared_ptr<torch::jit::Graph>& g, partitioning::SegmentedBlock &seg,
std::unordered_map<torch::jit::Value*, torch::jit::Value*> &old_to_new_g) {
Expand All @@ -198,7 +175,6 @@ void AddSegmentedBlockToGraph(std::shared_ptr<torch::jit::Graph>& g, partitionin
}
}

torch::jit::Node *node;
for (const auto n : seg.nodes()) {
partitioning::cloneNode(n, g, old_to_new_g);
}
Expand All @@ -212,8 +188,7 @@ void AddSegmentedBlockToGraph(std::shared_ptr<torch::jit::Graph>& g, partitionin
return;
}


torch::jit::script::Module CompileGraph(const torch::jit::script::Module& mod, CompileSpec cfg) {
torch::jit::script::Module CompileGraphWithFallback(const torch::jit::script::Module& mod, CompileSpec cfg) {
// TODO: Should be doing a functional transform but need PR #31978
// [jit] More robust mangling
// torch::jit::script::Module new_mod = mod.clone();
Expand Down Expand Up @@ -270,6 +245,33 @@ torch::jit::script::Module CompileGraph(const torch::jit::script::Module& mod, C
return new_mod;
}


torch::jit::script::Module CompileGraph(const torch::jit::script::Module& mod, CompileSpec cfg) {
// TODO: not sure how to deal with duplicated code here, so just cut out a branch temporally
if (cfg.convert_info.engine_settings.torch_fallback.enabled) {
return CompileGraphWithFallback(mod, cfg);
}
// TODO: Should be doing a functional transform but need PR #31978
// [jit] More robust mangling
// torch::jit::script::Module new_mod = mod.clone();
torch::jit::script::Module new_mod(mod._ivalue()->name() + "_trt");
std::vector<std::shared_ptr<torch::jit::Graph>> graphs;
for (const torch::jit::script::Method& method : mod.get_methods()) {
// Don't convert hidden methods
if (method.name().rfind("_", 0)) {
auto engine = ConvertGraphToTRTEngine(mod, method.name(), cfg);
auto new_g = std::make_shared<torch::jit::Graph>();
AddEngineToGraph(new_mod, new_g, engine);
auto new_method = new_mod._ivalue()->compilation_unit()->create_function(method.name(), new_g);
auto schema = GenerateGraphSchema(new_mod, new_method->name(), new_g);
new_mod.type()->addMethod(new_method);
new_method->setSchema(schema);
}
}

return new_mod;
}

void set_device(const int gpu_id) {
TRTORCH_ASSERT(cudaSetDevice(gpu_id) == cudaSuccess, "Unable to set CUDA device: " << gpu_id);
}
Expand Down
5 changes: 5 additions & 0 deletions core/conversion/conversionctx/ConversionCtx.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,11 @@ std::ostream& operator<<(std::ostream& os, const BuilderSettings& s) {
}
os << "\n Engine Capability: " << s.capability \
<< "\n Calibrator Created: " << (s.calibrator != nullptr);

os << "\n Torch Fallback: " << s.torch_fallback.enabled;
if (s.torch_fallback.enabled) {
os << "\n Fallback min block size: " << s.torch_fallback.min_block_size;
}
return os;
}
// clang-format on
Expand Down
7 changes: 7 additions & 0 deletions core/conversion/conversionctx/ConversionCtx.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,20 @@ struct Device {
Device() : device_type(nvinfer1::DeviceType::kGPU), gpu_id(0), dla_core(0), allow_gpu_fallback(false) {}
};

struct TorchFallback {
bool enabled = false;
uint64_t min_block_size = 1;
std::vector<std::string> forced_fallback_operators;
};

struct BuilderSettings {
nvinfer1::DataType op_precision = nvinfer1::DataType::kFLOAT;
bool disable_tf32 = false;
bool refit = false;
bool debug = false;
bool strict_types = false;
Device device;
TorchFallback torch_fallback;
nvinfer1::EngineCapability capability = nvinfer1::EngineCapability::kDEFAULT;
nvinfer1::IInt8Calibrator* calibrator = nullptr;
uint64_t num_min_timing_iters = 2;
Expand Down
4 changes: 0 additions & 4 deletions core/partitioning/partitioning.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -124,10 +124,6 @@ void registerSegmentsInputsOutputs(std::vector<SegmentedBlock> &segmented_blocks
}
}

// for (auto &graph_input : g->inputs()) {
// input_values.erase(graph_input);
// }

for (auto &graph_output : g->outputs()) {
input_values.insert(graph_output);
}
Expand Down
31 changes: 31 additions & 0 deletions cpp/api/include/trtorch/trtorch.h
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,37 @@ struct TRTORCH_API CompileSpec {
*/
Device device;

/**
* @brief A struct to hold fallback info
*/
struct TRTORCH_API TorchFallback {
/// enable the automatic fallback feature
bool enabled = false;

/// minimum consecutive operation number that needs to be satisfied to convert to TensorRT
uint64_t min_block_size = 1;

/// A list of names of operations that will explicitly run in PyTorch
std::vector<std::string> forced_fallback_operators;

/**
* @brief Construct a default Torch Fallback object, fallback will be off
*/
TorchFallback() = default;

/**
* @brief Construct from a bool
*/
TorchFallback(bool enabled) : enabled(enabled) {}

/**
* @brief Constructor for setting min_block_size
*/
TorchFallback(bool enabled, uint64_t min_size) : enabled(enabled), min_block_size(min_size) {}
};

TorchFallback torch_fallback;

/**
* Sets the restrictions for the engine (CUDA Safety)
*/
Expand Down
3 changes: 3 additions & 0 deletions cpp/api/src/compile_spec.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,9 @@ core::CompileSpec to_internal_compile_spec(CompileSpec external) {
internal.convert_info.engine_settings.strict_types = external.strict_types;
internal.convert_info.engine_settings.device.allow_gpu_fallback = external.device.allow_gpu_fallback;
internal.convert_info.engine_settings.max_batch_size = external.max_batch_size;
internal.convert_info.engine_settings.torch_fallback.enabled = external.torch_fallback.enabled;
internal.convert_info.engine_settings.torch_fallback.min_block_size = external.torch_fallback.min_block_size;
internal.convert_info.engine_settings.torch_fallback.forced_fallback_operators = external.torch_fallback.forced_fallback_operators;

switch (external.device.device_type) {
case CompileSpec::Device::DeviceType::kDLA:
Expand Down

0 comments on commit f4c29b4

Please sign in to comment.