-
Notifications
You must be signed in to change notification settings - Fork 323
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adding an end of sentence symbol to the source side. #392
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks for making this backwards compatible!
sockeye/data_io.py
Outdated
yield sequence | ||
|
||
|
||
def create_sequence_readers(sources, target, vocab_sources, vocab_target): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You could add type annotations.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
sockeye/inference.py
Outdated
@@ -423,6 +427,11 @@ def load_models(context: mx.context.Context, | |||
utils.check_condition(vocab.are_identical(*[source_vocabs[i][fi] for i in range(len(source_vocabs))]), | |||
"Source vocabulary ids do not match. Factor %d" % fi) | |||
|
|||
source_with_eos = models[0].source_with_eos | |||
utils.check_condition(all(source_with_eos == m.source_with_eos for m in models), | |||
"All models must match either take an additional EOS symbol on the source side or not. " |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
grammar, maybe: "All models must agree on using source-side EOS symbols or not"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sounds good, changed.
sockeye/inference.py
Outdated
@@ -591,6 +600,13 @@ def chunks(self, chunk_size: int) -> Generator['TranslatorInput', None, None]: | |||
factors=factors, | |||
chunk_id=chunk_id) | |||
|
|||
def with_eos(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
missing docstring
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added
sockeye/train.py
Outdated
@@ -258,12 +260,14 @@ def create_data_iters_and_vocabs(args: argparse.Namespace, | |||
Create the data iterators and the vocabularies. | |||
|
|||
:param args: Arguments as returned by argparse. | |||
:param max_seq_len_source: Source maximum sequence length. | |||
:param max_seq_len_target: Target maximum sequence length. | |||
:param shared_vocab: Whether to create a shared vocabulary. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
double docstring for shared_vocab
1 unit test is failing (test_arguments.py) |
true, should be fixed now. |
they seem to work now :) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Lgtm, thanks!
Feel free to merge when travis or system tests are ready. |
With this change we add an additional EOS symbol on the source side. From the CLI/user perspective the
--max-seq-len
is the maximum length of the raw tokens without special symbols. Internally lengths are now based on the lengths including the special symbols like EOS. The change is backwards compatible in that the inference logic for existing models did not change.System tests are currently running.
Pull Request Checklist
until you can check this box.
pytest
)pytest test/system
)./style-check.sh
)sockeye/__init__.py
. Major version bump if this is a backwards incompatible change.By submitting this pull request, I confirm that my contribution is made under the terms of the Apache 2.0 license.