Skip to content

Commit

Permalink
fix(//core/conversion/conversionctx): In the case of strict types and
Browse files Browse the repository at this point in the history
int8 do not enable fp16 kernels

Signed-off-by: Naren Dasan <naren@narendasan.com>
Signed-off-by: Naren Dasan <narens@nvidia.com>
  • Loading branch information
narendasan committed Jun 11, 2020
1 parent 26709cc commit 3611778
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions core/conversion/conversionctx/ConversionCtx.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ std::ostream& operator<<(std::ostream& os, const BuilderSettings& s) {
<< "\n Operating Precision: " << s.op_precision \
<< "\n Make Refittable Engine: " << s.refit \
<< "\n Debuggable Engine: " << s.debug \
<< "\n Strict Type: " << s.strict_types \
<< "\n Strict Types: " << s.strict_types \
<< "\n Allow GPU Fallback (if running on DLA): " << s.allow_gpu_fallback \
<< "\n Min Timing Iterations: " << s.num_min_timing_iters \
<< "\n Avg Timing Iterations: " << s.num_avg_timing_iters \
Expand Down Expand Up @@ -51,7 +51,9 @@ ConversionCtx::ConversionCtx(BuilderSettings build_settings)
case nvinfer1::DataType::kINT8:
TRTORCH_CHECK(builder->platformHasFastInt8(), "Requested inference in INT8 but platform does support INT8");
cfg->setFlag(nvinfer1::BuilderFlag::kINT8);
cfg->setFlag(nvinfer1::BuilderFlag::kFP16);
if (!settings.strict_types) {
cfg->setFlag(nvinfer1::BuilderFlag::kFP16);
}
input_type = nvinfer1::DataType::kFLOAT;
TRTORCH_CHECK(settings.calibrator != nullptr, "Requested inference in INT8 but no calibrator provided, set the ptq_calibrator field in the ExtraInfo struct with your calibrator");
cfg->setInt8Calibrator(settings.calibrator);
Expand Down

0 comments on commit 3611778

Please sign in to comment.