Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

TRT INT8 unsupported hardware fix #19349

Merged
merged 1 commit into from
Oct 19, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/operator/subgraph/tensorrt/onnx_to_tensorrt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ std::tuple<unique_ptr<nvinfer1::ICudaEngine>,
builder_config->setInt8Calibrator(calibrator);
} else {
LOG(WARNING) << "TensorRT can't use int8 on this platform";
calibrator->setDone();
calibrator = nullptr;
}
}
Expand Down Expand Up @@ -177,6 +178,7 @@ std::tuple<unique_ptr<nvinfer1::ICudaEngine>,
trt_builder->setInt8Calibrator(calibrator);
} else {
LOG(WARNING) << "TensorRT can't use int8 on this platform";
calibrator->setDone();
calibrator = nullptr;
}
}
Expand Down
4 changes: 4 additions & 0 deletions src/operator/subgraph/tensorrt/tensorrt_int8_calibrator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,10 @@ void TRTInt8Calibrator::writeCalibrationCache(const void* ptr,
<< " length=" << length;
}

void TRTInt8Calibrator::setDone() {
done_ = true;
}

void TRTInt8Calibrator::waitAndSetDone() {
std::unique_lock<std::mutex> lk(mutex_);
cv_.wait(lk, [&]{ return (!batch_is_set_ && !calib_running_) || done_; });
Expand Down
2 changes: 2 additions & 0 deletions src/operator/subgraph/tensorrt/tensorrt_int8_calibrator.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ struct TRTInt8Calibrator : public nvinfer1::IInt8EntropyCalibrator2 {
// TODO(spanev): determine if we need to serialize it
const std::string& getCalibrationTableAsString() { return calibration_table_; }

void setDone();

void waitAndSetDone();

bool isCacheEmpty();
Expand Down
6 changes: 0 additions & 6 deletions tests/python/tensorrt/test_tensorrt.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,12 +143,6 @@ def get_top1(logits):

def test_tensorrt_symbol_int8():
ctx = mx.gpu(0)
cuda_arch = get_cuda_compute_capability(ctx)
cuda_arch_min = 70
if cuda_arch < cuda_arch_min:
print('Bypassing test_tensorrt_symbol_int8 on cuda arch {}, need arch >= {}).'.format(
cuda_arch, cuda_arch_min))
return

# INT8 engine output are not lossless, so we don't expect numerical uniformity,
# but we have to compare the TOP1 metric
Expand Down