Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Commit

Permalink
Seperate start/end token check for source and target tokenizer (#308)
Browse files Browse the repository at this point in the history
* Seperate start/end token check for  source/target

* Update changelog

* Dont repeat check if source == target tokenizers

* Add brackets to enforce order of logical operators

Co-authored-by: Pete <petew@allenai.org>
  • Loading branch information
JohnGiorgi and epwalsh committed Nov 1, 2021
1 parent 84ba7cf commit 5ff0f79
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 16 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## Unreleased

### Changed

- Seperate start/end token check in `Seq2SeqDatasetReader` for source and target tokenizers.

## [v2.7.0](https://github.com/allenai/allennlp-models/releases/tag/v2.7.0) - 2021-09-01

### Added
Expand Down
43 changes: 27 additions & 16 deletions allennlp_models/generation/dataset_readers/seq2seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,22 +103,14 @@ def __init__(
or target_add_start_token
or target_add_end_token
):
# Check that the tokenizer correctly appends the start and end tokens to
# the sequence without splitting them.
tokens = self._source_tokenizer.tokenize(start_symbol + " " + end_symbol)
err_msg = (
f"Bad start or end symbol ('{start_symbol}', '{end_symbol}') "
f"for tokenizer {self._source_tokenizer}"
)
try:
start_token, end_token = tokens[0], tokens[-1]
except IndexError:
raise ValueError(err_msg)
if start_token.text != start_symbol or end_token.text != end_symbol:
raise ValueError(err_msg)

self._start_token = start_token
self._end_token = end_token
if source_add_start_token or source_add_end_token:
self._check_start_end_tokens(start_symbol, end_symbol, self._source_tokenizer)
if (
target_add_start_token or target_add_end_token
) and self._target_tokenizer != self._source_tokenizer:
self._check_start_end_tokens(start_symbol, end_symbol, self._target_tokenizer)
self._start_token = Token(start_symbol)
self._end_token = Token(end_symbol)

self._delimiter = delimiter
self._source_max_tokens = source_max_tokens
Expand Down Expand Up @@ -190,3 +182,22 @@ def apply_token_indexers(self, instance: Instance) -> None:
instance.fields["source_tokens"]._token_indexers = self._source_token_indexers # type: ignore
if "target_tokens" in instance.fields:
instance.fields["target_tokens"]._token_indexers = self._target_token_indexers # type: ignore

def _check_start_end_tokens(
self, start_symbol: str, end_symbol: str, tokenizer: Tokenizer
) -> None:
"""Check that `tokenizer` correctly appends `start_symbol` and `end_symbol` to the
sequence without splitting them. Raises a `ValueError` if this is not the case.
"""

tokens = tokenizer.tokenize(start_symbol + " " + end_symbol)
err_msg = (
f"Bad start or end symbol ('{start_symbol}', '{end_symbol}') "
f"for tokenizer {self._source_tokenizer}"
)
try:
start_token, end_token = tokens[0], tokens[-1]
except IndexError:
raise ValueError(err_msg)
if start_token.text != start_symbol or end_token.text != end_symbol:
raise ValueError(err_msg)

0 comments on commit 5ff0f79

Please sign in to comment.