Skip to content

Commit

Permalink
minimal fixes to run DataCollatorForWholeWordMask with return_tensors…
Browse files Browse the repository at this point in the history
…="np" and return_tensors="tf" (huggingface#13891)

* minimal fixes to run DataCollatorForWholeWordMask with return_tensors="np" and return_tensors="tf"

* more consinstent implementation for numpy_mask_tokens
  • Loading branch information
dwyatte authored and Alberto Bégué committed Jan 27, 2022
1 parent e2a4647 commit fa8228f
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 8 deletions.
14 changes: 6 additions & 8 deletions src/transformers/data/data_collator.py
Original file line number Diff line number Diff line change
Expand Up @@ -883,14 +883,14 @@ def tf_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict
inputs, labels = self.tf_mask_tokens(batch_input, batch_mask)
return {"input_ids": inputs, "labels": labels}

def np_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
def numpy_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
if isinstance(examples[0], (dict, BatchEncoding)):
input_ids = [e["input_ids"] for e in examples]
else:
input_ids = examples
examples = [{"input_ids": e} for e in examples]

batch_input = _tf_collate_batch(input_ids, self.tokenizer, pad_to_multiple_of=self.pad_to_multiple_of)
batch_input = _numpy_collate_batch(input_ids, self.tokenizer, pad_to_multiple_of=self.pad_to_multiple_of)

mask_labels = []
for e in examples:
Expand Down Expand Up @@ -1009,15 +1009,15 @@ def tf_mask_tokens(self, inputs: Any, mask_labels: Any) -> Tuple[Any, Any]:
raise ValueError(
"This tokenizer does not have a mask token which is necessary for masked language modeling. Remove the --mlm flag if you want to use this tokenizer."
)
labels = inputs.clone()
labels = tf.identity(inputs)
# We sample a few tokens in each sequence for masked-LM training (with probability args.mlm_probability defaults to 0.15 in Bert/RoBERTa)

masked_indices = tf.cast(mask_labels, tf.bool)

special_tokens_mask = [
self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()
self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels
]
masked_indices = masked_indices & ~tf.convert_to_tensor(special_tokens_mask, dtype=tf.bool)
masked_indices = masked_indices & ~tf.cast(special_tokens_mask, dtype=tf.bool)
if self.tokenizer._pad_token is not None:
padding_mask = inputs == self.tokenizer.pad_token_id
masked_indices = masked_indices & ~padding_mask
Expand Down Expand Up @@ -1073,9 +1073,7 @@ def numpy_mask_tokens(self, inputs: Any, mask_labels: Any) -> Tuple[Any, Any]:
indices_random = (
np.random.binomial(1, 0.5, size=labels.shape).astype(np.bool) & masked_indices & ~indices_replaced
)
random_words = np.random.randint(
low=0, high=len(self.tokenizer), size=np.count_nonzero(indices_random), dtype=np.int64
)
random_words = np.random.randint(low=0, high=len(self.tokenizer), size=labels.shape, dtype=np.int64)
inputs[indices_random] = random_words[indices_random]

# The rest of the time (10% of the time) we keep the masked input tokens unchanged
Expand Down
31 changes: 31 additions & 0 deletions tests/test_data_collator.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
DataCollatorForLanguageModeling,
DataCollatorForPermutationLanguageModeling,
DataCollatorForTokenClassification,
DataCollatorForWholeWordMask,
DataCollatorWithPadding,
default_data_collator,
is_tf_available,
Expand Down Expand Up @@ -224,6 +225,16 @@ def test_data_collator_for_language_modeling(self):
pad_features = [list(range(5)), list(range(10))]
self._test_no_pad_and_pad(no_pad_features, pad_features)

def test_data_collator_for_whole_word_mask(self):
features = [{"input_ids": list(range(10))}, {"input_ids": list(range(10))}]

tokenizer = BertTokenizer(self.vocab_file)
data_collator = DataCollatorForWholeWordMask(tokenizer, return_tensors="pt")
batch = data_collator(features)

self.assertEqual(batch["input_ids"].shape, torch.Size((2, 10)))
self.assertEqual(batch["labels"].shape, torch.Size((2, 10)))

def test_plm(self):
tokenizer = BertTokenizer(self.vocab_file)
no_pad_features = [{"input_ids": list(range(10))}, {"input_ids": list(range(10))}]
Expand Down Expand Up @@ -488,6 +499,16 @@ def test_data_collator_for_language_modeling(self):
pad_features = [list(range(5)), list(range(10))]
self._test_no_pad_and_pad(no_pad_features, pad_features)

def test_data_collator_for_whole_word_mask(self):
features = [{"input_ids": list(range(10))}, {"input_ids": list(range(10))}]

tokenizer = BertTokenizer(self.vocab_file)
data_collator = DataCollatorForWholeWordMask(tokenizer, return_tensors="tf")
batch = data_collator(features)

self.assertEqual(batch["input_ids"].shape.as_list(), [2, 10])
self.assertEqual(batch["labels"].shape.as_list(), [2, 10])

def test_plm(self):
tokenizer = BertTokenizer(self.vocab_file)
no_pad_features = [{"input_ids": list(range(10))}, {"input_ids": list(range(10))}]
Expand Down Expand Up @@ -750,6 +771,16 @@ def test_data_collator_for_language_modeling(self):
pad_features = [list(range(5)), list(range(10))]
self._test_no_pad_and_pad(no_pad_features, pad_features)

def test_data_collator_for_whole_word_mask(self):
features = [{"input_ids": list(range(10))}, {"input_ids": list(range(10))}]

tokenizer = BertTokenizer(self.vocab_file)
data_collator = DataCollatorForWholeWordMask(tokenizer, return_tensors="np")
batch = data_collator(features)

self.assertEqual(batch["input_ids"].shape, (2, 10))
self.assertEqual(batch["labels"].shape, (2, 10))

def test_plm(self):
tokenizer = BertTokenizer(self.vocab_file)
no_pad_features = [{"input_ids": list(range(10))}, {"input_ids": list(range(10))}]
Expand Down

0 comments on commit fa8228f

Please sign in to comment.