Skip to content

Commit

Permalink
Add quantized CTranslate2 exporters (#755)
Browse files Browse the repository at this point in the history
  • Loading branch information
guillaumekln committed Dec 14, 2020
1 parent bd50737 commit 3baf268
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 5 deletions.
5 changes: 3 additions & 2 deletions docs/configuration.md
Original file line number Diff line number Diff line change
Expand Up @@ -300,8 +300,9 @@ eval:

# (optional) Export a model when a metric has the best value so far (default: null).
export_on_best: bleu
# (optional) Format of the exported model (can be: "saved_model, "ctranslate2",
# "checkpoint", default: "saved_model").
# (optional) Format of the exported model (can be: "saved_model, "checkpoint",
# "ctranslate2", "ctranslate2_int8", "ctranslate2_int16", "ctranslate2_float16",
# default: "saved_model").
export_format: saved_model
# (optional) Maximum number of exports to keep on disk (default: 5).
max_exports_to_keep: 5
Expand Down
9 changes: 7 additions & 2 deletions opennmt/tests/runner_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,14 +345,19 @@ def testExport(self, export_vocabulary_assets):
tokens = result["tokens"][:result["length"]]
self.assertAllEqual(tokens, [b"a", b"t", b"z", b"m", b"o", b"n"])

def testCTranslate2Export(self):
@parameterized.expand([
("ctranslate2",),
("ctranslate2_int8",),
("ctranslate2_int16",),
])
def testCTranslate2Export(self, variant):
try:
import ctranslate2
except ImportError:
self.skipTest("ctranslate2 module is not available")
export_dir = os.path.join(self.get_temp_dir(), "export")
runner = self._getTransliterationRunner()
runner.export(export_dir, exporter=exporters.make_exporter("ctranslate2"))
runner.export(export_dir, exporter=exporters.make_exporter(variant))
self.assertTrue(ctranslate2.contains_model(export_dir))
translator = ctranslate2.Translator(export_dir)
output = translator.translate_batch([["آ" ,"ت" ,"ز" ,"م" ,"و" ,"ن"]])
Expand Down
4 changes: 4 additions & 0 deletions opennmt/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@
from opennmt.utils.decoding import dynamic_decode

from opennmt.utils.exporters import CTranslate2Exporter
from opennmt.utils.exporters import CTranslate2Float16Exporter
from opennmt.utils.exporters import CTranslate2Int16Exporter
from opennmt.utils.exporters import CTranslate2Int8Exporter
from opennmt.utils.exporters import CheckpointExporter
from opennmt.utils.exporters import Exporter
from opennmt.utils.exporters import SavedModelExporter
from opennmt.utils.exporters import register_exporter
Expand Down
31 changes: 30 additions & 1 deletion opennmt/utils/exporters.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,10 +85,18 @@ def __init__(self, quantization=None):
Args:
quantization: Quantize model weights to this type when exporting the model.
Can be "int16" or "int8". Default is no quantization.
Can be "int8", "int16", or "float16". Default is no quantization.
Raises:
ImportError: if the CTranslate2 package is missing.
ValueError: if :obj:`quantization` is invalid.
"""
# Fail now if ctranslate2 package is missing.
import ctranslate2 # pylint: disable=import-outside-toplevel,unused-import
accepted_quantization = ("int8", "int16", "float16")
if quantization is not None and quantization not in accepted_quantization:
raise ValueError("Invalid quantization '%s' for CTranslate2, accepted values are: %s" % (
quantization, ", ".join(accepted_quantization)))
self._quantization = quantization

def _export_model(self, model, export_dir):
Expand All @@ -107,3 +115,24 @@ def _export_model(self, model, export_dir):
tgt_vocab=model.labels_inputter.vocabulary_file,
variables=variables)
converter.convert(export_dir, model_spec, quantization=self._quantization, force=True)


@register_exporter(name="ctranslate2_int8")
class CTranslate2Int8Exporter(CTranslate2Exporter):
"""CTranslate2 exporter with int8 quantization."""
def __init__(self):
super().__init__(quantization="int8")


@register_exporter(name="ctranslate2_int16")
class CTranslate2Int16Exporter(CTranslate2Exporter):
"""CTranslate2 exporter with int16 quantization."""
def __init__(self):
super().__init__(quantization="int16")


@register_exporter(name="ctranslate2_float16")
class CTranslate2Float16Exporter(CTranslate2Exporter):
"""CTranslate2 exporter with float16 quantization."""
def __init__(self):
super().__init__(quantization="float16")

0 comments on commit 3baf268

Please sign in to comment.