Skip to content

Commit

Permalink
fix(//py): Build system issues
Browse files Browse the repository at this point in the history
Signed-off-by: Naren Dasan <naren@narendasan.com>
Signed-off-by: Naren Dasan <narens@nvidia.com>
  • Loading branch information
narendasan committed May 14, 2020
1 parent 7088245 commit c1de126
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 12 deletions.
12 changes: 7 additions & 5 deletions py/setup.py
@@ -1,8 +1,12 @@
from setuptools import setup, Extension, find_packages
from setuptools.command.build_ext import build_ext
import os
import sys
import setuptools
import os
from setuptools import setup, Extension, find_packages
from setuptools.command.build_ext import build_ext
from setuptools.command.develop import develop
from setuptools.command.install import install
from distutils.cmd import Command

from torch.utils import cpp_extension
from shutil import copyfile

Expand All @@ -27,7 +31,6 @@ def copy_libtrtorch():

class DevelopCommand(develop):
description = "Builds the package and symlinks it into the PYTHONPATH"
user_options = develop.user_options + plugins_user_options

def initialize_options(self):
develop.initialize_options(self)
Expand All @@ -43,7 +46,6 @@ def run(self):

class InstallCommand(install):
description = "Builds the package"
user_options = install.user_options + plugins_user_options

def initialize_options(self):
install.initialize_options(self)
Expand Down
2 changes: 1 addition & 1 deletion py/trtorch/__init__.py
Expand Up @@ -16,7 +16,7 @@ def _load_trtorch_lib():
_load_trtorch_lib()

from .version import __version__
#from trtorch import _C
from trtorch import _C
from trtorch.compiler import *
from trtorch.types import *

Expand Down
12 changes: 6 additions & 6 deletions py/trtorch/csrc/trtorch_py.cpp
Expand Up @@ -108,7 +108,7 @@ struct ExtraInfo {

torch::jit::Module CompileGraph(const torch::jit::Module& mod, ExtraInfo& info) {
py::gil_scoped_acquire gil;
auto trt_mod = trtorch::CompileGraph(mod, info.toInternalExtraInfo());
auto trt_mod = core::CompileGraph(mod, info.toInternalExtraInfo());
return trt_mod;
}

Expand Down Expand Up @@ -139,8 +139,8 @@ PYBIND11_MODULE(_C, m) {
.def_readwrite("max", &InputRange::max)
.def("_to_internal_input_range", &InputRange::toInternalInputRange);

py::class_<core::conversion::InputRange>(m, "_InternalInputRange")
.def(py::init<>());
//py::class_<core::conversion::InputRange>(m, "_InternalInputRange")
// .def(py::init<>());

py::enum_<DataType>(m, "dtype")
.value("float", DataType::kFloat)
Expand Down Expand Up @@ -176,10 +176,10 @@ PYBIND11_MODULE(_C, m) {
.def_readwrite("max_batch_size", &ExtraInfo::max_batch_size);

m.doc() = "TRTorch Internal C Bindings: Ahead of Time compilation for PyTorch JIT. A tool to convert PyTorch JIT to TensorRT";
m.def("_compile_graph", &trtorch::pyapi::CompileGraph, "Ingest a PyTorch JIT module and convert supported subgraphs to TensorRT engines, returns a JIT module with the engines embedded");
m.def("_compile_graph", &trtorch::pyapi::CompileGraph, "Ingest a PyTorch JIT module and convert supported subgraphs to TensorRT engines, returns a JIT module with the engines embedded");
m.def("_convert_graph_to_trt_engine", &trtorch::pyapi::ConvertGraphToTRTEngine, "Given a PyTorch JIT Module, convert forward into a TensorRT engine and return a serialized engine");
m.def("_check_method_op_support", &trtorch::pyapi::CheckMethodOperatorSupport, "Takes a module and a method name and checks if the method graph contains purely convertable operators");
m.def("_get_build_info", &get_build_info, "Returns build info about the compiler as a string");
m.def("_check_method_op_support", &trtorch::pyapi::CheckMethodOperatorSupport, "Takes a module and a method name and checks if the method graph contains purely convertable operators");
m.def("_get_build_info", &get_build_info, "Returns build info about the compiler as a string");
m.def("_test", &test);
}

Expand Down

0 comments on commit c1de126

Please sign in to comment.