forked from zerollzeng/tiny-tensorrt
-
Notifications
You must be signed in to change notification settings - Fork 0
/
PyTrt.cpp
88 lines (83 loc) · 3.66 KB
/
PyTrt.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
/*
* @Author: zerollzeng
* @Date: 2019-08-29 15:45:15
* @LastEditors: zerollzeng
* @LastEditTime: 2019-12-16 13:45:57
*/
#include "pybind11/pybind11.h"
#include "pybind11/numpy.h"
#include "pybind11/stl.h"
#include "spdlog/spdlog.h"
namespace py = pybind11;
#include "Trt.h"
PYBIND11_MODULE(pytrt, m) {
m.doc() = "python interface of tiny-tensorrt";
py::class_<TrtPluginParams>(m, "TrtPluginParams")
.def(py::init<>());
py::class_<Trt>(m, "Trt")
.def(py::init([]() {
return std::unique_ptr<Trt>(new Trt());
}))
.def(py::init([](TrtPluginParams params) {
return std::unique_ptr<Trt>(new Trt(params));
}))
.def("CreateEngine", (void (Trt::*)(const std::string&,
const std::string&,
const std::string&,
const std::vector<std::string>&,
const std::vector<std::vector<float>>&,
int,
int)) &Trt::CreateEngine, "create engine with caffe model")
.def("CreateEngine", (void (Trt::*)(const std::string&,
const std::string&,
const std::vector<std::string>&,
int)) &Trt::CreateEngine, "create engine with onnx model")
.def("CreateEngine", (void (Trt::*)(const std::string&,
const std::string&,
const std::vector<std::string>&,
const std::vector<std::vector<int>>&,
const std::vector<std::string>&,
int)) &Trt::CreateEngine, "create engine with tensorflow model")
.def("DoInference", [](Trt& self, py::array_t<float, py::array::c_style | py::array::forcecast> array) {
std::vector<float> input;
input.resize(array.size());
std::memcpy(input.data(), array.data(), array.size()*sizeof(float));
self.DataTransfer(input, 0, 1);
self.Forward();
})
.def("GetOutput", [](Trt& self, std::string& bindName) {
std::vector<std::string>::iterator it = std::find(self.mBindingName.begin(), self.mBindingName.end(), bindName);
int outputIndex;
if(it != self.mBindingName.end()) {
outputIndex = std::distance(self.mBindingName.begin(), it);
} else {
spdlog::error("invalid output binding name: {}", bindName);
return py::array();
}
std::vector<float> output;
self.DataTransfer(output, outputIndex, 0);
nvinfer1::Dims dims = self.GetBindingDims(outputIndex);
ssize_t nbDims= dims.nbDims;
std::vector<ssize_t> shape;
for(int i=0;i<nbDims;i++){
shape.push_back(dims.d[i]);
}
std::vector<ssize_t> strides;
for(int i=0;i<nbDims;i++){
ssize_t stride = sizeof(float);
for(int j=i+1;j<nbDims;j++) {
stride = stride * shape[j];
}
strides.push_back(stride);
}
return py::array(py::buffer_info(
output.data(),
sizeof(float),
py::format_descriptor<float>::format(),
nbDims,
shape,
strides
));
})
;
}