Skip to content

Commit

Permalink
Check vocabulary size when converting OpenNMT-tf models (#57)
Browse files Browse the repository at this point in the history
  • Loading branch information
guillaumekln committed Dec 10, 2019
1 parent 667b074 commit 974635e
Show file tree
Hide file tree
Showing 5 changed files with 71 additions and 5 deletions.
22 changes: 22 additions & 0 deletions python/ctranslate2/converters/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ def convert(self, output_dir, model_spec, vmap=None, quantization=None, force=Fa
except NotImplementedError:
raise NotImplementedError("This converter does not support the model %s" % model_spec)
model_spec.validate()
self._check_vocabulary_size("source", src_vocab, model_spec.source_vocabulary_size)
self._check_vocabulary_size("target", tgt_vocab, model_spec.target_vocabulary_size)
if quantization is not None:
model_spec.quantize(quantization)
model_spec.serialize(os.path.join(output_dir, "model.bin"))
Expand All @@ -76,3 +78,23 @@ def _load(self, model_spec):
@abc.abstractmethod
def _save_vocabulary(self, vocab, destination):
raise NotImplementedError()

def _vocabulary_size(self, vocab):
"""Returns the vocabulary size.
When defined, this enables additional error checking when converting models.
"""
return None

def _check_vocabulary_size(self, name, vocab, expected_size):
"""Raises an exception if expected and actual vocabulary sizes are known but
do not match.
"""
if expected_size is None:
return
vocab_size = self._vocabulary_size(vocab)
if vocab_size is None:
return
if vocab_size != expected_size:
raise ValueError("%s vocabulary has size %d but the model expected a vocabulary "
"of size %d" % (name.capitalize(), vocab_size, expected_size))
6 changes: 6 additions & 0 deletions python/ctranslate2/converters/opennmt_tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,12 @@ def _load(self, model_spec):
def _save_vocabulary(self, vocab, destination):
shutil.copy(vocab, destination)

def _vocabulary_size(self, vocab):
with open(vocab, "rb") as vocab_file:
num_tokens = 0
for _ in vocab_file:
num_tokens += 1
return num_tokens + 1 # Add OOV token.

def set_transformer_spec_v2(spec, variables):
set_embeddings(
Expand Down
27 changes: 23 additions & 4 deletions python/ctranslate2/specs/model_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,6 @@ def index_spec(spec, index):
class LayerSpec(object):
"""Layer specification."""

@property
def revision(self):
return 1

def validate(self):
"""Checks that required variables are set to a valid value."""
def _check(spec, name, value):
Expand Down Expand Up @@ -154,3 +150,26 @@ def _write_string(string):
for alias, variable_name in aliases:
_write_string(alias)
_write_string(variable_name)


class ModelSpec(LayerSpec):
"""The top level layer specification."""

@property
def revision(self):
"""The model specification revision.
This value is incremented each time the weights layout of the model is
changed (e.g. a weight is renamed).
"""
return 1

@property
def source_vocabulary_size(self):
"""Source vocabulary size based on the model weights."""
return None

@property
def target_vocabulary_size(self):
"""Target vocabulary size based on the model weights."""
return None
10 changes: 9 additions & 1 deletion python/ctranslate2/specs/transformer_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from ctranslate2.specs import model_spec


class TransformerSpec(model_spec.LayerSpec):
class TransformerSpec(model_spec.ModelSpec):
"""Describes a Transformer model.
The specification is invariant to hidden dimensions but requires to
Expand All @@ -20,6 +20,14 @@ def __init__(self, num_layers, num_heads):
def revision(self):
return 3

@property
def source_vocabulary_size(self):
return self.encoder.embeddings.weight.shape[0]

@property
def target_vocabulary_size(self):
return self.decoder.embeddings.weight.shape[0]

class TransformerEncoderSpec(model_spec.LayerSpec):
def __init__(self, num_layers):
self.embeddings = EmbeddingsSpec()
Expand Down
11 changes: 11 additions & 0 deletions python/tests/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,17 @@ def test_opennmt_tf_model_conversion(tmpdir, model_path, src_vocab, tgt_vocab, m
output = translator.translate_batch([["آ" ,"ت" ,"ز" ,"م" ,"و" ,"ن"]])
assert output[0][0]["tokens"] == ["a", "t", "z", "m", "o", "n"]

def test_opennmt_tf_model_conversion_invalid_vocab(tmpdir):
model_path = os.path.join(
_TEST_DATA_DIR, "models", "transliteration-aren-all", "opennmt_tf", "v2", "checkpoint")
# Swap source and target vocabularies.
converter = ctranslate2.converters.OpenNMTTFConverter(
model_path,
src_vocab=os.path.join(model_path, "en.vocab"),
tgt_vocab=os.path.join(model_path, "ar.vocab"))
output_dir = str(tmpdir.join("ctranslate2_model"))
with pytest.raises(ValueError):
converter.convert(output_dir, ctranslate2.specs.TransformerBase())

try:
import onmt
Expand Down

0 comments on commit 974635e

Please sign in to comment.