Skip to content

Commit

Permalink
feat(//cpp/api): Remove the extra includes in the API header
Browse files Browse the repository at this point in the history
Signed-off-by: Naren Dasan <naren@narendasan.com>
Signed-off-by: Naren Dasan <narens@nvidia.com>
  • Loading branch information
narendasan committed Apr 22, 2020
1 parent f022dfe commit 2f86f84
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 5 deletions.
7 changes: 7 additions & 0 deletions cpp/api/include/trtorch/ptq.h
Expand Up @@ -11,6 +11,13 @@ class IInt8Calibrator;
class IInt8EntropyCalibrator2;
}

namespace torch {
namespace data {
template<typename Example>
class Iterator;
}
}

namespace trtorch {
namespace ptq {

Expand Down
3 changes: 0 additions & 3 deletions cpp/api/include/trtorch/trtorch.h
Expand Up @@ -12,9 +12,6 @@
#include <vector>
#include <memory>

#include "torch/torch.h"
#include "NvInfer.h"

// Just include the .h?
namespace torch {
namespace jit {
Expand Down
4 changes: 2 additions & 2 deletions cpp/ptq/main.cpp
Expand Up @@ -41,8 +41,8 @@ int main(int argc, const char* argv[]) {

std::string calibration_cache_file = "/tmp/vgg16_TRT_ptq_calibration.cache";

//auto calibrator = trtorch::ptq::make_int8_calibrator(std::move(calibration_dataloader), calibration_cache_file, true);
auto calibrator = trtorch::ptq::make_int8_cache_calibrator(calibration_cache_file);
auto calibrator = trtorch::ptq::make_int8_calibrator(std::move(calibration_dataloader), calibration_cache_file, true);
//auto calibrator = trtorch::ptq::make_int8_cache_calibrator(calibration_cache_file);


std::vector<std::vector<int64_t>> input_shape = {{32, 3, 32, 32}};
Expand Down

0 comments on commit 2f86f84

Please sign in to comment.