Skip to content

Commit

Permalink
Add flag to return ranges on Unicode characters (#159)
Browse files Browse the repository at this point in the history
  • Loading branch information
guillaumekln committed Sep 2, 2020
1 parent bf6fb8b commit db2c3b7
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 5 deletions.
25 changes: 23 additions & 2 deletions bindings/python/Python.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>

#include <unicode/unistr.h>

#include <onmt/Tokenizer.h>
#include <onmt/BPE.h>
#include <onmt/SentencePiece.h>
Expand Down Expand Up @@ -217,7 +219,9 @@ class TokenizerWrapper
return to_py_list(tokens);
}

py::tuple detokenize_with_ranges(const py::list& words, bool merge_ranges) const
py::tuple detokenize_with_ranges(const py::list& words,
bool merge_ranges,
bool with_unicode_ranges) const
{
onmt::Ranges ranges;
std::string text;
Expand All @@ -231,6 +235,21 @@ class TokenizerWrapper
ranges, merge_ranges);
}

if (with_unicode_ranges)
{
onmt::Ranges unicode_ranges;
for (const auto& pair : ranges)
{
const size_t word_index = pair.first;
const onmt::Range& range = pair.second;
const icu::UnicodeString prefix(text.c_str(), range.first);
const icu::UnicodeString piece(text.c_str() + range.first, range.second - range.first + 1);
unicode_ranges.emplace(word_index,
onmt::Range(prefix.length(), prefix.length() + piece.length() - 1));
}
ranges = std::move(unicode_ranges);
}

py::list ranges_py(ranges.size());
size_t index = 0;
for (const auto& pair : ranges)
Expand Down Expand Up @@ -511,7 +530,9 @@ PYBIND11_MODULE(pyonmttok, m)
.def("detokenize", &TokenizerWrapper::detokenize,
py::arg("tokens"), py::arg("features")=py::none())
.def("detokenize_with_ranges", &TokenizerWrapper::detokenize_with_ranges,
py::arg("tokens"), py::arg("merge_ranges")=false)
py::arg("tokens"),
py::arg("merge_ranges")=false,
py::arg("unicode_ranges")=false)
.def("detokenize_file", &TokenizerWrapper::detokenize_file,
py::arg("input_path"),
py::arg("output_path"))
Expand Down
9 changes: 6 additions & 3 deletions bindings/python/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -72,11 +72,14 @@ tokenizer.detokenize(
) -> str

# The detokenize_with_ranges method also returns a dictionary mapping a token
# index to a range in the detokenized text. Set merge_ranges=True to merge
# consecutive ranges, e.g. subwords of the same token in case of subword tokenization.
# index to a range in the detokenized text.
# Set merge_ranges=True to merge consecutive ranges, e.g. subwords of the same
# token in case of subword tokenization.
# Set unicode_ranges=True to return ranges over Unicode characters instead of bytes.
tokenizer.detokenize_with_ranges(
tokens: Union[List[str], List[pyonmttok.Token]],
merge_ranges: bool = True
merge_ranges: bool = True,
unicode_ranges: bool = True
) -> Tuple[str, Dict[int, Pair[int, int]]]

# Detokenize a file.
Expand Down
12 changes: 12 additions & 0 deletions bindings/python/test/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,18 @@ def test_detok_with_ranges():
assert ranges[0] == (0, 0)
assert ranges[1] == (2, 2)

_, ranges = tokenizer.detokenize_with_ranges(
["测", "试"], unicode_ranges=True)
assert len(ranges) == 2
assert ranges[0] == (0, 0)
assert ranges[1] == (2, 2)

_, ranges = tokenizer.detokenize_with_ranges(
["测", "■试"], unicode_ranges=True, merge_ranges=True)
assert len(ranges) == 2
assert ranges[0] == (0, 1)
assert ranges[1] == (0, 1)

def test_bpe_case_insensitive_issue_147():
tokenizer = pyonmttok.Tokenizer(
"conservative",
Expand Down

0 comments on commit db2c3b7

Please sign in to comment.