Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Commit

Permalink
Support splitting long sequences into multiple segments for transform…
Browse files Browse the repository at this point in the history
…ers (#3666)

* First pass at folding long sequences

* bug fix

* various fixes

* more fixes

* More fixes

* Add tests

* fix and flake8

* black

* mypy

* make language consistent

* misc improvments

* max_len -> max_length

* Default max_length to None rather than -1

* Clean tokens_to_indices

* Add documentation

* Factor out long sequence handling logic in embedder in separate functions

* Add tests

* black

* typo fix

* Make end token embeddings robust

* bug fix

* improve unfold test

* black

* Add documentation

* Rebase fixes

* More fixes for type_id

* More fixes from rebase

* black

* Minor doc fixes

* mypy

* Resolving comments

* minor fixes

* black
  • Loading branch information
ZhaofengWu committed Jan 27, 2020
1 parent 30d687c commit 5bec95c
Show file tree
Hide file tree
Showing 8 changed files with 621 additions and 97 deletions.
140 changes: 127 additions & 13 deletions allennlp/data/token_indexers/pretrained_transformer_indexer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict, List
from typing import Dict, List, Optional, Tuple
import logging
import torch
from allennlp.common.util import pad_sequence_to_length
Expand Down Expand Up @@ -31,19 +31,44 @@ class PretrainedTransformerIndexer(TokenIndexer):
We use a somewhat confusing default value of `tags` so that we do not add padding or UNK
tokens to this namespace, which would break on loading because we wouldn't find our default
OOV token.
max_length : `int`, optional (default = None)
If not None, split the document into segments of this many tokens (including special tokens)
before feeding into the embedder. The embedder embeds these segments independently and
concatenate the results to get the original document representation. Should be set to
the same value as the `max_length` option on the `PretrainedTransformerEmbedder`.
"""

def __init__(self, model_name: str, namespace: str = "tags", **kwargs) -> None:
def __init__(
self, model_name: str, namespace: str = "tags", max_length: int = None, **kwargs
) -> None:
super().__init__(**kwargs)
self._namespace = namespace
self._tokenizer = AutoTokenizer.from_pretrained(model_name)
self._added_to_vocabulary = False

def _add_encoding_to_vocabulary(self, vocab: Vocabulary) -> None:
(
self._num_added_start_tokens,
self._num_added_end_tokens,
) = self.__class__.determine_num_special_tokens_added(self._tokenizer)

self._max_length = max_length
if self._max_length is not None:
self._effective_max_length = ( # we need to take into account special tokens
self._max_length - self._tokenizer.num_added_tokens()
)
if self._effective_max_length <= 0:
raise ValueError(
"max_length needs to be greater than the number of special tokens inserted."
)

def _add_encoding_to_vocabulary_if_needed(self, vocab: Vocabulary) -> None:
"""
Copies tokens from ```transformers``` model to the specified namespace.
Transformers vocab is taken from the <vocab>/<encoder> keys of the tokenizer object.
"""
if self._added_to_vocabulary:
return

vocab_field_name = None
if hasattr(self._tokenizer, "vocab"):
vocab_field_name = "vocab"
Expand All @@ -61,17 +86,32 @@ def _add_encoding_to_vocabulary(self, vocab: Vocabulary) -> None:
vocab._token_to_index[self._namespace][word] = idx
vocab._index_to_token[self._namespace][idx] = word

self._added_to_vocabulary = True

@overrides
def count_vocab_items(self, token: Token, counter: Dict[str, Dict[str, int]]):
# If we only use pretrained models, we don't need to do anything here.
pass

@overrides
def tokens_to_indices(self, tokens: List[Token], vocabulary: Vocabulary) -> IndexedTokenList:
if not self._added_to_vocabulary:
self._add_encoding_to_vocabulary(vocabulary)
self._added_to_vocabulary = True
self._add_encoding_to_vocabulary_if_needed(vocabulary)

indices, type_ids = self._extract_token_and_type_ids(tokens)
# The mask has 1 for real tokens and 0 for padding tokens. Only real tokens are attended to.
output = {"token_ids": indices, "mask": [1] * len(indices)}
if type_ids is not None:
output["type_ids"] = type_ids

return self._postprocess_output(output)

def _extract_token_and_type_ids(
self, tokens: List[Token]
) -> Tuple[List[int], Optional[List[int]]]:
"""
Roughly equivalent to `zip(*[(token.text_id, token.type_id) for token in tokens])`,
with some checks.
"""
indices: List[int] = []
type_ids: List[int] = []
for token in tokens:
Expand All @@ -91,17 +131,54 @@ def tokens_to_indices(self, tokens: List[Token], vocabulary: Vocabulary) -> Inde
else:
type_ids = None

# The mask has 1 for real tokens and 0 for padding tokens. Only real tokens are attended to.
mask = [1] * len(indices)
return indices, type_ids

result = {"token_ids": indices, "mask": mask}
if type_ids is not None:
result["type_ids"] = type_ids
return result
def _postprocess_output(self, output: IndexedTokenList) -> IndexedTokenList:
"""
Takes an IndexedTokenList about to be returned by `tokens_to_indices()` and adds any
necessary postprocessing, e.g. long sequence splitting.
The input should have a `"token_ids"` key corresponding to the token indices. They should
have special tokens already inserted.
"""
if self._max_length is not None:
# We prepare long indices by converting them to (assuming max_length == 5)
# [CLS] A B C [SEP] [CLS] D E F [SEP] ...
# Embedder is responsible for folding this 1-d sequence to 2-d and feed to the
# transformer model.
# TODO(zhaofengw): we aren't respecting word boundaries when segmenting wordpieces.

indices = output["token_ids"]
# Strips original special tokens
indices = indices[self._num_added_start_tokens : -self._num_added_end_tokens]
# Folds indices
folded_indices = [
indices[i : i + self._effective_max_length]
for i in range(0, len(indices), self._effective_max_length)
]
# Adds special tokens to each segment
folded_indices = [
self._tokenizer.build_inputs_with_special_tokens(segment)
for segment in folded_indices
]
# Flattens
indices = [i for segment in folded_indices for i in segment]

output["token_ids"] = indices
# `create_token_type_ids_from_sequences()` inserts special tokens
output["type_ids"] = self._tokenizer.create_token_type_ids_from_sequences(
indices[self._num_added_start_tokens : -self._num_added_end_tokens]
)
output["segment_concat_mask"] = [1] * len(indices)

return output

@overrides
def get_empty_token_list(self) -> IndexedTokenList:
return {"token_ids": [], "mask": [], "type_ids": []}
output: IndexedTokenList = {"token_ids": [], "mask": [], "type_ids": []}
if self._max_length is not None:
output["segment_concat_mask"] = []
return output

@overrides
def as_padded_tensor_dict(
Expand Down Expand Up @@ -133,3 +210,40 @@ def __eq__(self, other):
return False
return True
return NotImplemented

@classmethod
def determine_num_special_tokens_added(cls, tokenizer) -> Tuple[int, int]:
"""
Determines the number of tokens `tokenizer` adds to a sequence (currently doesn't
consider sequence pairs) in the start & end.
# Parameters
tokenizer : `transformers.tokenization_utils.PretrainedTokenizer`, required.
We want to determine the number of added tokens by this tokenizer.
# Returns
The number of tokens (`int`) that are inserted in the start & end of a sequence.
"""
# Uses a slightly higher index to avoid tokenizer doing special things to lower-indexed
# tokens which might be special.
dummy = [1000]
inserted = tokenizer.build_inputs_with_special_tokens(dummy)

num_start = num_end = 0
seen_dummy = False
for idx in inserted:
if idx == dummy[0]:
if seen_dummy: # seeing it twice
raise ValueError("Cannot auto-determine the number of special tokens added.")
seen_dummy = True
continue

if not seen_dummy:
num_start += 1
else:
num_end += 1

assert num_start + num_end == tokenizer.num_added_tokens()
return num_start, num_end
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from allennlp.data.tokenizers.token import Token
from allennlp.data.token_indexers import PretrainedTransformerIndexer, TokenIndexer
from allennlp.data.token_indexers.token_indexer import IndexedTokenList
from allennlp.data.tokenizers import PretrainedTransformerTokenizer

logger = logging.getLogger(__name__)

Expand All @@ -33,49 +32,47 @@ class PretrainedTransformerMismatchedIndexer(TokenIndexer):
We use a somewhat confusing default value of `tags` so that we do not add padding or UNK
tokens to this namespace, which would break on loading because we wouldn't find our default
OOV token.
max_length : `int`, optional (default = None)
If positive, split the document into segments of this many tokens (including special tokens)
before feeding into the embedder. The embedder embeds these segments independently and
concatenate the results to get the original document representation. Should be set to
the same value as the `max_length` option on the `PretrainedTransformerMismatchedEmbedder`.
"""

def __init__(self, model_name: str, namespace: str = "tags", **kwargs) -> None:
def __init__(
self, model_name: str, namespace: str = "tags", max_length: int = None, **kwargs
) -> None:
super().__init__(**kwargs)
# The matched version v.s. mismatched
self._matched_indexer = PretrainedTransformerIndexer(model_name, namespace, **kwargs)

# add_special_tokens=False since we don't want wordpieces to be surrounded by special tokens
self._allennlp_tokenizer = PretrainedTransformerTokenizer(
model_name, add_special_tokens=False
self._matched_indexer = PretrainedTransformerIndexer(
model_name, namespace, max_length, **kwargs
)
self._tokenizer = self._allennlp_tokenizer.tokenizer

(
self._num_added_start_tokens,
self._num_added_end_tokens,
) = self._determine_num_special_tokens_added()
self._tokenizer = self._matched_indexer._tokenizer
self._num_added_start_tokens = self._matched_indexer._num_added_start_tokens
self._num_added_end_tokens = self._matched_indexer._num_added_end_tokens

@overrides
def count_vocab_items(self, token: Token, counter: Dict[str, Dict[str, int]]):
return self._matched_indexer.count_vocab_items(token, counter)

@overrides
def tokens_to_indices(self, tokens: List[Token], vocabulary: Vocabulary) -> IndexedTokenList:
orig_token_mask = [1] * len(tokens)
tokens, offsets = self._intra_word_tokenize(tokens)

# {"token_ids": ..., "mask": ...}
output = self._matched_indexer.tokens_to_indices(tokens, vocabulary)
self._matched_indexer._add_encoding_to_vocabulary_if_needed(vocabulary)

# Insert type ids for the special tokens.
output["type_ids"] = self._tokenizer.create_token_type_ids_from_sequences(
output["token_ids"]
indices, offsets = self._intra_word_tokenize(tokens)
# `create_token_type_ids_from_sequences()` inserts special tokens
type_ids = self._tokenizer.create_token_type_ids_from_sequences(
indices[self._num_added_start_tokens : -self._num_added_end_tokens]
)
# Insert the special tokens themselves.
output["token_ids"] = self._tokenizer.build_inputs_with_special_tokens(output["token_ids"])
output["mask"] = orig_token_mask
output["offsets"] = [
(start + self._num_added_start_tokens, end + self._num_added_start_tokens)
for start, end in offsets
]
output["wordpiece_mask"] = [1] * len(output["token_ids"])
return output
output: IndexedTokenList = {
"token_ids": indices,
"mask": [1] * len(tokens), # for original tokens (i.e. word-level)
"type_ids": type_ids,
"offsets": offsets,
"wordpiece_mask": [1] * len(indices), # for wordpieces (i.e. subword-level)
}

return self._matched_indexer._postprocess_output(output)

@overrides
def get_empty_token_list(self) -> IndexedTokenList:
Expand Down Expand Up @@ -114,54 +111,27 @@ def __eq__(self, other):
return True
return NotImplemented

def _intra_word_tokenize(
self, tokens: List[Token]
) -> Tuple[List[Token], List[Tuple[int, int]]]:
def _intra_word_tokenize(self, tokens: List[Token]) -> Tuple[List[int], List[Tuple[int, int]]]:
"""
Tokenizes each word into wordpieces separately. Also calculates offsets such that
wordpices[offsets[i][0]:offsets[i][1] + 1] corresponds to the original i-th token.
Does not insert special tokens.
Tokenizes each word into wordpieces separately and returns the wordpiece IDs.
Also calculates offsets such that wordpices[offsets[i][0]:offsets[i][1] + 1]
corresponds to the original i-th token.
This function inserts special tokens.
"""
wordpieces: List[Token] = []
wordpieces: List[int] = []
offsets = []
cumulative = 0
cumulative = self._num_added_start_tokens
for token in tokens:
subword_wordpieces = self._allennlp_tokenizer.tokenize(token.text)
subword_wordpieces = self._tokenizer.encode(token.text, add_special_tokens=False)
wordpieces.extend(subword_wordpieces)

start_offset = cumulative
cumulative += len(subword_wordpieces)
end_offset = cumulative - 1 # inclusive
offsets.append((start_offset, end_offset))

return wordpieces, offsets

def _determine_num_special_tokens_added(self) -> Tuple[int, int]:
"""
Determines the number of tokens self._tokenizer adds to a sequence (currently doesn't
consider sequence pairs) in the start & end.
wordpieces = self._tokenizer.build_inputs_with_special_tokens(wordpieces)
assert cumulative + self._num_added_end_tokens == len(wordpieces)

# Returns
The number of tokens (`int`) that are inserted in the start & end of a sequence.
"""
# Uses a slightly higher index to avoid tokenizer doing special things to lower-indexed
# tokens which might be special.
dummy = [1000]
inserted = self._tokenizer.build_inputs_with_special_tokens(dummy)

num_start = num_end = 0
seen_dummy = False
for idx in inserted:
if idx == dummy[0]:
if seen_dummy: # seeing it twice
raise ValueError("Cannot auto-determine the number of special tokens added.")
seen_dummy = True
continue

if not seen_dummy:
num_start += 1
else:
num_end += 1

assert num_start + num_end == self._tokenizer.num_added_tokens()
return num_start, num_end
return wordpieces, offsets

0 comments on commit 5bec95c

Please sign in to comment.