Skip to content

Commit

Permalink
Use pybind11 for the Python bindings
Browse files Browse the repository at this point in the history
  • Loading branch information
guillaumekln committed Feb 25, 2019
1 parent 7c20d11 commit 1976f45
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 87 deletions.
9 changes: 3 additions & 6 deletions Dockerfile
Expand Up @@ -5,7 +5,7 @@ RUN apt-get update && \
build-essential \
cpio \
libboost-program-options-dev \
libboost-python-dev \
python-dev \
python-pip \
wget && \
apt-get clean && \
Expand Down Expand Up @@ -50,8 +50,8 @@ RUN mkdir build && \
COPY python python

WORKDIR /root/ctranslate2-dev/python
RUN pip --no-cache-dir install setuptools wheel
RUN CFLAGS="-DWITH_MKL=ON" CTRANSLATE_ROOT=/root/ctranslate2 \
RUN pip --no-cache-dir install setuptools wheel pybind11
RUN CFLAGS="-DWITH_MKL=ON" CTRANSLATE2_ROOT=/root/ctranslate2 \
python setup.py bdist_wheel

WORKDIR /root
Expand All @@ -63,14 +63,11 @@ FROM ubuntu:16.04
RUN apt-get update && \
apt-get install -y --no-install-recommends \
libboost-program-options1.58.0 \
libboost-python1.58.0 \
python-pip && \
apt-get clean && \
rm -rf /var/lib/apt/lists/*

COPY --from=builder /root/ctranslate2 /root/ctranslate2

RUN pip --no-cache-dir install setuptools
RUN pip --no-cache-dir install /root/ctranslate2/*.whl

WORKDIR /root
Expand Down
9 changes: 3 additions & 6 deletions Dockerfile.cuda
Expand Up @@ -7,10 +7,10 @@ RUN apt-get update && \
build-essential \
cpio \
libboost-program-options-dev \
libboost-python-dev \
libcudnn7=$CUDNN_VERSION-1+cuda10.0 \
libcudnn7-dev=$CUDNN_VERSION-1+cuda10.0 \
nvinfer-runtime-trt-repo-ubuntu1604-$TENSORRT_VERSION-ga-cuda10.0 \
python-dev \
python-pip \
wget && \
apt-get update && \
Expand Down Expand Up @@ -71,8 +71,8 @@ RUN mkdir build && \
COPY python python

WORKDIR /root/ctranslate2-dev/python
RUN pip --no-cache-dir install setuptools wheel
RUN CFLAGS="-DWITH_CUDA=ON -DWITH_MKL=ON" CTRANSLATE_ROOT=/root/ctranslate2 \
RUN pip --no-cache-dir install setuptools wheel pybind11
RUN CFLAGS="-DWITH_CUDA=ON -DWITH_MKL=ON" CTRANSLATE2_ROOT=/root/ctranslate2 \
python setup.py bdist_wheel

WORKDIR /root
Expand All @@ -87,7 +87,6 @@ RUN apt-get update && \
apt-get install -y --no-install-recommends \
cuda-cublas-10-0=10.0.130-1 \
libboost-program-options1.58.0 \
libboost-python1.58.0 \
libcudnn7=${CUDNN_VERSION}-1+cuda10.0 \
nvinfer-runtime-trt-repo-ubuntu1604-$TENSORRT_VERSION-ga-cuda10.0 \
python-pip && \
Expand All @@ -101,8 +100,6 @@ RUN apt-get update && \
rm -rf /var/lib/apt/lists/*

COPY --from=builder /root/ctranslate2 /root/ctranslate2

RUN pip --no-cache-dir install setuptools
RUN pip --no-cache-dir install /root/ctranslate2/*.whl

WORKDIR /root
Expand Down
20 changes: 5 additions & 15 deletions python/setup.py
@@ -1,27 +1,17 @@
import os
import pybind11

from setuptools import setup, find_packages, Extension


include_dirs = []
library_dirs = []

def _maybe_add_library_root(lib_name):
if "%s_ROOT" % lib_name in os.environ:
root = os.environ["%s_ROOT" % lib_name]
include_dirs.append("%s/include" % root)
library_dirs.append("%s/lib" % root)

_maybe_add_library_root("BOOST")
_maybe_add_library_root("CTRANSLATE")

ctranslate2_root = os.getenv("CTRANSLATE2_ROOT", "/usr/local")
ctranslate2_module = Extension(
"ctranslate2.translator",
sources=["translator.cc"],
extra_compile_args=["-std=c++11"],
include_dirs=include_dirs,
library_dirs=library_dirs,
libraries=[os.getenv("BOOST_PYTHON_LIBRARY", "boost_python"), "ctranslate2"])
include_dirs=["%s/include" % ctranslate2_root, pybind11.get_include()],
library_dirs=["%s/lib" % ctranslate2_root],
libraries=["ctranslate2"])

setup(
name="ctranslate2",
Expand Down
111 changes: 51 additions & 60 deletions python/translator.cc
@@ -1,30 +1,32 @@
#include <boost/python.hpp>
#include <boost/python/stl_iterator.hpp>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>

#include <ctranslate2/translator_pool.h>

namespace py = boost::python;
namespace py = pybind11;

class GILReleaser {
public:
GILReleaser()
: _save_state(PyEval_SaveThread()) {
}
~GILReleaser() {
PyEval_RestoreThread(_save_state);
}
private:
PyThreadState* _save_state;
};
#if PY_MAJOR_VERSION < 3
# define STR_TYPE py::bytes
#else
# define STR_TYPE py::str
#endif

template<class T>
template <typename T>
py::list std_vector_to_py_list(const std::vector<T>& v) {
py::list l;
for (const auto& x : v)
l.append(x);
return l;
}

template<>
py::list std_vector_to_py_list(const std::vector<std::string>& v) {
py::list l;
for (const auto& x : v)
l.append(STR_TYPE(x));
return l;
}

class TranslatorWrapper
{
public:
Expand Down Expand Up @@ -58,30 +60,21 @@ class TranslatorWrapper
options.num_hypotheses = num_hypotheses;
options.use_vmap = use_vmap;

GILReleaser releaser;
py::gil_scoped_release release;
_translator_pool.consume_text_file(in_file, out_file, max_batch_size, options, with_scores);
}

py::list translate_batch(const py::object& tokens,
py::list translate_batch(const std::vector<std::vector<std::string>>& tokens,
size_t beam_size,
size_t num_hypotheses,
float length_penalty,
size_t max_decoding_length,
size_t min_decoding_length,
bool use_vmap,
bool return_attention) {
if (tokens == py::object())
if (tokens.empty())
return py::list();

std::vector<std::vector<std::string>> tokens_vec;
tokens_vec.reserve(py::len(tokens));

for (auto it = py::stl_input_iterator<py::list>(tokens);
it != py::stl_input_iterator<py::list>(); it++) {
tokens_vec.emplace_back(py::stl_input_iterator<std::string>(*it),
py::stl_input_iterator<std::string>());
}

auto options = ctranslate2::TranslationOptions();
options.beam_size = beam_size;
options.length_penalty = length_penalty;
Expand All @@ -94,22 +87,22 @@ class TranslatorWrapper
std::vector<ctranslate2::TranslationResult> results;

{
GILReleaser releaser;
results = std::move(_translator_pool.post(tokens_vec, options).get());
py::gil_scoped_release release;
results = std::move(_translator_pool.post(tokens, options).get());
}

py::list py_results;
for (const auto& result : results) {
py::list batch;
for (size_t i = 0; i < result.num_hypotheses(); ++i) {
py::dict hyp;
hyp["score"] = result.scores()[i];
hyp["tokens"] = std_vector_to_py_list(result.hypotheses()[i]);
hyp[STR_TYPE("score")] = result.scores()[i];
hyp[STR_TYPE("tokens")] = std_vector_to_py_list(result.hypotheses()[i]);
if (result.has_attention()) {
py::list attn;
for (const auto& attn_vector : result.attention()[i])
attn.append(std_vector_to_py_list(attn_vector));
hyp["attention"] = attn;
hyp[STR_TYPE("attention")] = attn;
}
batch.append(hyp);
}
Expand All @@ -123,36 +116,34 @@ class TranslatorWrapper
ctranslate2::TranslatorPool _translator_pool;
};

BOOST_PYTHON_MODULE(translator)
PYBIND11_MODULE(translator, m)
{
PyEval_InitThreads();
py::class_<TranslatorWrapper, boost::noncopyable>(
"Translator",
py::init<std::string, std::string, int, size_t, size_t>(
(py::arg("model_path"),
py::arg("device")="cpu",
py::arg("device_index")=0,
py::arg("inter_threads")=1,
py::arg("intra_threads")=4)))
py::class_<TranslatorWrapper>(m, "Translator")
.def(py::init<std::string, std::string, int, size_t, size_t>(),
py::arg("model_path"),
py::arg("device")="cpu",
py::arg("device_index")=0,
py::arg("inter_threads")=1,
py::arg("intra_threads")=4)
.def("translate_batch", &TranslatorWrapper::translate_batch,
(py::arg("tokens"),
py::arg("beam_size")=4,
py::arg("num_hypotheses")=1,
py::arg("length_penalty")=0.6,
py::arg("max_decoding_length")=250,
py::arg("min_decoding_length")=1,
py::arg("use_vmap")=false,
py::arg("return_attention")=false))
py::arg("tokens"),
py::arg("beam_size")=4,
py::arg("num_hypotheses")=1,
py::arg("length_penalty")=0.6,
py::arg("max_decoding_length")=250,
py::arg("min_decoding_length")=1,
py::arg("use_vmap")=false,
py::arg("return_attention")=false)
.def("translate_file", &TranslatorWrapper::translate_file,
(py::arg("input_path"),
py::arg("output_path"),
py::arg("max_batch_size"),
py::arg("beam_size")=4,
py::arg("num_hypotheses")=1,
py::arg("length_penalty")=0.6,
py::arg("max_decoding_length")=250,
py::arg("min_decoding_length")=1,
py::arg("use_vmap")=false,
py::arg("with_scores")=false))
py::arg("input_path"),
py::arg("output_path"),
py::arg("max_batch_size"),
py::arg("beam_size")=4,
py::arg("num_hypotheses")=1,
py::arg("length_penalty")=0.6,
py::arg("max_decoding_length")=250,
py::arg("min_decoding_length")=1,
py::arg("use_vmap")=false,
py::arg("with_scores")=false)
;
}

0 comments on commit 1976f45

Please sign in to comment.