Skip to content

Commit

Permalink
feat(//py)!: Porting forward the API to use kwargs
Browse files Browse the repository at this point in the history
BREAKING CHANGE: This changes the API for compile settings
from a dictionary of settings to a set of kwargs for the various
compilation functions. This will break existing code. However
there is simple guidance to port forward your code:

Given a dict of valid TRTorch CompileSpec settings

```py
spec = {
	"inputs": ...
	...
}
```

You can use this same dict with the new APIs by changing your code from:

```py
trtorch.compile(mod, spec)
```

to:

```py
trtorch.compile(mod, **spec)
```
which will unpack the dictionary as arguments to the function

Signed-off-by: Naren Dasan <naren@narendasan.com>
Signed-off-by: Naren Dasan <narens@nvidia.com>
  • Loading branch information
narendasan committed Oct 19, 2021
1 parent 4d95b04 commit 17e0e8a
Show file tree
Hide file tree
Showing 10 changed files with 208 additions and 118 deletions.
3 changes: 3 additions & 0 deletions docsrc/py_api/trtorch.rst
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
.. _trtorch_py:

.. automodule trtorch
:undoc-members:
trtorch
===============

Expand Down
4 changes: 3 additions & 1 deletion py/trtorch/_compile_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def _parse_op_precision(precision: Any) -> _types.dtype:
raise TypeError("Provided an unsupported dtype as operating precision (support: int8, half, float), got: " +
str(precision))

elif isinstance(precision, _types.DataTypes):
elif isinstance(precision, _types.dtype):
return precision

else:
Expand Down Expand Up @@ -170,6 +170,8 @@ def _parse_compile_spec(compile_spec: Dict[str, Any]) -> trtorch._C.CompileSpec:
inputs = [trtorch.Input._from_tensor(i) if isinstance(i, torch.Tensor) else i for i in compile_spec["inputs"]]
info.inputs = [i._to_internal() for i in inputs]

assert (len(info.inputs) > 0), "Require at least one input definition to compile model"

if "op_precision" in compile_spec and "enabled_precisions" in compile_spec:
raise KeyError(
"Found both key \"op_precision\", and \"enabled_precisions\" in compile spec, please port forward to using only \"enabled_precisions\""
Expand Down
248 changes: 155 additions & 93 deletions py/trtorch/_compiler.py

Large diffs are not rendered by default.

7 changes: 7 additions & 0 deletions py/trtorch/csrc/tensorrt_classes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,13 @@ nvinfer1::DataType toTRTDataType(DataType value) {
}
}

Device::Device(const core::runtime::CudaDevice& internal_dev) {
device_type = DeviceType::kGPU;
gpu_id = internal_dev.id;
dla_core = -1;
allow_gpu_fallback = false;
}

nvinfer1::TensorFormat toTRTTensorFormat(TensorFormat value) {
switch (value) {
case TensorFormat::kChannelLast:
Expand Down
2 changes: 2 additions & 0 deletions py/trtorch/csrc/tensorrt_classes.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ struct Device : torch::CustomClassHolder {
allow_gpu_fallback(false) // allow_gpu_fallback
{}

Device(const core::runtime::CudaDevice& internal_dev);

ADD_ENUM_GET_SET(device_type, DeviceType, static_cast<int64_t>(DeviceType::kDLA));
ADD_FIELD_GET_SET(gpu_id, int64_t);
ADD_FIELD_GET_SET(dla_core, int64_t);
Expand Down
6 changes: 6 additions & 0 deletions py/trtorch/csrc/trtorch_py.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,10 @@ void set_device(const int device_id) {
core::set_device(device_id);
}

Device get_current_device() {
return Device(core::runtime::get_current_device());
}

torch::jit::Module CompileGraph(const torch::jit::Module& mod, CompileSpec& info) {
py::gil_scoped_acquire gil;
auto trt_mod = core::CompileGraph(mod, info.toInternalCompileSpec());
Expand Down Expand Up @@ -315,6 +319,8 @@ PYBIND11_MODULE(_C, m) {
m.def("_set_is_colored_output_on", &logging::set_is_colored_output_on, "Set if the logging output should be colored");
m.def("_log", &logging::log, "Add a message to the logger");
m.def("set_device", &trtorch::pyapi::set_device, "Set CUDA device id");
m.def("_get_current_device", &trtorch::pyapi::get_current_device, "Get the current active CUDA device");


py::enum_<core::util::logging::LogLevel>(m, "LogLevel", py::arithmetic())
.value("INTERNAL_ERROR", core::util::logging::LogLevel::kINTERNAL_ERROR)
Expand Down
42 changes: 25 additions & 17 deletions tests/py/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,44 +23,52 @@ def test_compile_traced(self):
"enabled_precisions": {torch.float}
}

trt_mod = trtorch.compile(self.traced_model, compile_spec)
trt_mod = trtorch.compile(self.traced_model, **compile_spec)
same = (trt_mod(self.input) - self.traced_model(self.input)).abs().max()
self.assertTrue(same < 2e-2)

def test_compile_script(self):
trt_mod = trtorch.compile(self.scripted_model, inputs=[self.input], device=trtorch.Device(gpu_id=0), enabled_precisions={torch.float})
same = (trt_mod(self.input) - self.scripted_model(self.input)).abs().max()
self.assertTrue(same < 2e-2)

def test_from_torch_tensor(self):
compile_spec = {
"inputs": [trtorch.Input(shape=self.input.shape)],
"inputs": [self.input],
"device": {
"device_type": trtorch.DeviceType.GPU,
"gpu_id": 0,
},
"enabled_precisions": {torch.float}
}

trt_mod = trtorch.compile(self.scripted_model, compile_spec)
trt_mod = trtorch.compile(self.scripted_model, **compile_spec)
same = (trt_mod(self.input) - self.scripted_model(self.input)).abs().max()
self.assertTrue(same < 2e-2)

def test_from_torch_tensor(self):
def test_device(self):
compile_spec = {"inputs": [self.input], "device": trtorch.Device("gpu:0"), "enabled_precisions": {torch.float}}

trt_mod = trtorch.compile(self.scripted_model, **compile_spec)
same = (trt_mod(self.input) - self.scripted_model(self.input)).abs().max()
self.assertTrue(same < 2e-2)


def test_compile_script_from_dict(self):
compile_spec = {
"inputs": [self.input],
"inputs": [trtorch.Input(shape=self.input.shape)],
"device": {
"device_type": trtorch.DeviceType.GPU,
"gpu_id": 0,
},
"enabled_precisions": {torch.float}
}

trt_mod = trtorch.compile(self.scripted_model, compile_spec)
same = (trt_mod(self.input) - self.scripted_model(self.input)).abs().max()
trt_mod = trtorch.compile(self.traced_model, **compile_spec)
same = (trt_mod(self.input) - self.traced_model(self.input)).abs().max()
self.assertTrue(same < 2e-2)

def test_device(self):
compile_spec = {"inputs": [self.input], "device": trtorch.Device("gpu:0"), "enabled_precisions": {torch.float}}

trt_mod = trtorch.compile(self.scripted_model, compile_spec)
same = (trt_mod(self.input) - self.scripted_model(self.input)).abs().max()
self.assertTrue(same < 2e-2)


class TestCompileHalf(ModelTestCase):
Expand All @@ -80,7 +88,7 @@ def test_compile_script_half(self):
"enabled_precisions": {torch.half}
}

trt_mod = trtorch.compile(self.scripted_model, compile_spec)
trt_mod = trtorch.compile(self.scripted_model, **compile_spec)
same = (trt_mod(self.input.half()) - self.scripted_model(self.input.half())).abs().max()
trtorch.logging.log(trtorch.logging.Level.Debug, "Max diff: " + str(same))
self.assertTrue(same < 3e-2)
Expand All @@ -103,7 +111,7 @@ def test_compile_script_half_by_default(self):
"enabled_precisions": {torch.float, torch.half}
}

trt_mod = trtorch.compile(self.scripted_model, compile_spec)
trt_mod = trtorch.compile(self.scripted_model, **compile_spec)
same = (trt_mod(self.input.half()) - self.scripted_model(self.input.half())).abs().max()
trtorch.logging.log(trtorch.logging.Level.Debug, "Max diff: " + str(same))
self.assertTrue(same < 3e-2)
Expand Down Expand Up @@ -132,7 +140,7 @@ def test_compile_script(self):
}
}

trt_mod = trtorch.compile(self.scripted_model, compile_spec)
trt_mod = trtorch.compile(self.scripted_model, **compile_spec)
same = (trt_mod(self.input) - self.scripted_model(self.input)).abs().max()
self.assertTrue(same < 2e-3)

Expand Down Expand Up @@ -160,7 +168,7 @@ def test_compile_script(self):
}
}

trt_mod = trtorch.compile(self.scripted_model, compile_spec)
trt_mod = trtorch.compile(self.scripted_model, **compile_spec)
same = (trt_mod(self.input) - self.scripted_model(self.input)).abs().max()
self.assertTrue(same < 2e-3)

Expand All @@ -183,7 +191,7 @@ def test_pt_to_trt_to_pt(self):
}
}

trt_engine = trtorch.convert_method_to_trt_engine(self.ts_model, "forward", compile_spec)
trt_engine = trtorch.convert_method_to_trt_engine(self.ts_model, "forward", **compile_spec)
trt_mod = trtorch.embed_engine_in_new_module(trt_engine, trtorch.Device("cuda:0"))
same = (trt_mod(self.input) - self.ts_model(self.input)).abs().max()
self.assertTrue(same < 2e-3)
Expand Down
4 changes: 2 additions & 2 deletions tests/py/test_api_dla.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def test_compile_traced(self):
"enabled_precisions": {torch.half}
}

trt_mod = trtorch.compile(self.traced_model, compile_spec)
trt_mod = trtorch.compile(self.traced_model, **compile_spec)
same = (trt_mod(self.input) - self.traced_model(self.input)).abs().max()
self.assertTrue(same < 2e-2)

Expand All @@ -56,7 +56,7 @@ def test_compile_script(self):
"enabled_precisions": {torch.half}
}

trt_mod = trtorch.compile(self.scripted_model, compile_spec)
trt_mod = trtorch.compile(self.scripted_model, **compile_spec)
same = (trt_mod(self.input) - self.scripted_model(self.input)).abs().max()
self.assertTrue(same < 2e-2)

Expand Down
8 changes: 4 additions & 4 deletions tests/py/test_multi_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def test_compile_traced(self):
}
}

trt_mod = trtorch.compile(self.traced_model, compile_spec)
trt_mod = trtorch.compile(self.traced_model, **compile_spec)
trtorch.set_device(self.target_gpu)
same = (trt_mod(self.input) - self.traced_model(self.input)).abs().max()
trtorch.set_device(0)
Expand All @@ -51,7 +51,7 @@ def test_compile_script(self):
}
}

trt_mod = trtorch.compile(self.scripted_model, compile_spec)
trt_mod = trtorch.compile(self.scripted_model, **compile_spec)
trtorch.set_device(self.target_gpu)
same = (trt_mod(self.input) - self.scripted_model(self.input)).abs().max()
trtorch.set_device(0)
Expand Down Expand Up @@ -84,7 +84,7 @@ def test_compile_traced(self):
}
}

trt_mod = trtorch.compile(self.traced_model, compile_spec)
trt_mod = trtorch.compile(self.traced_model, **compile_spec)
# Changing the device ID deliberately. It should still run on correct device ID by context switching
trtorch.set_device(1)
same = (trt_mod(self.input) - self.traced_model(self.input)).abs().max()
Expand All @@ -103,7 +103,7 @@ def test_compile_script(self):
}
}

trt_mod = trtorch.compile(self.scripted_model, compile_spec)
trt_mod = trtorch.compile(self.scripted_model, **compile_spec)
# Changing the device ID deliberately. It should still run on correct device ID by context switching
trtorch.set_device(1)
same = (trt_mod(self.input) - self.scripted_model(self.input)).abs().max()
Expand Down
2 changes: 1 addition & 1 deletion tests/py/test_ptq_dataloader_calibrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def test_compile_script(self):
}
}

trt_mod = trtorch.compile(self.model, compile_spec)
trt_mod = trtorch.compile(self.model, **compile_spec)
int8_test_acc = self.compute_accuracy(self.testing_dataloader, trt_mod)
log(Level.Info, "[TRT INT8] Test Acc: {:.2f}%".format(100 * int8_test_acc))
acc_diff = fp32_test_acc - int8_test_acc
Expand Down

0 comments on commit 17e0e8a

Please sign in to comment.