Skip to content

Commit

Permalink
Adding support for hidden_states and attentions in unbatching (hu…
Browse files Browse the repository at this point in the history
  • Loading branch information
Narsil authored and Alberto Bégué committed Jan 27, 2022
1 parent eefa5d4 commit 942e507
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 3 deletions.
11 changes: 8 additions & 3 deletions src/transformers/pipelines/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -747,9 +747,14 @@ def loader_batch_item(self):
else:
loader_batched = {}
for k, element in self._loader_batch_data.items():
if k == "past_key_values":
continue
if isinstance(element[self._loader_batch_index], torch.Tensor):
if k in {"hidden_states", "past_key_values", "attentions"} and isinstance(element, tuple):
if isinstance(element[0], torch.Tensor):
loader_batched[k] = tuple(el[self._loader_batch_index].unsqueeze(0) for el in element)
elif isinstance(element[0], np.ndarray):
loader_batched[k] = tuple(
np.expand_dims(el[self._loader_batch_index], 0) for el in element
)
elif isinstance(element[self._loader_batch_index], torch.Tensor):
loader_batched[k] = element[self._loader_batch_index].unsqueeze(0)
elif isinstance(element[self._loader_batch_index], np.ndarray):
loader_batched[k] = np.expand_dims(element[self._loader_batch_index], 0)
Expand Down
14 changes: 14 additions & 0 deletions tests/test_pipelines_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
TOKENIZER_MAPPING,
AutoFeatureExtractor,
AutoTokenizer,
DistilBertForSequenceClassification,
IBertConfig,
RobertaConfig,
TextClassificationPipeline,
Expand Down Expand Up @@ -322,6 +323,19 @@ def data(n: int):
results.append(out)
self.assertEqual(len(results), 10)

@require_torch
def test_unbatch_attentions_hidden_states(self):
model = DistilBertForSequenceClassification.from_pretrained(
"Narsil/tiny-distilbert-sequence-classification", output_hidden_states=True, output_attentions=True
)
tokenizer = AutoTokenizer.from_pretrained("Narsil/tiny-distilbert-sequence-classification")
text_classifier = TextClassificationPipeline(model=model, tokenizer=tokenizer)

# Used to throw an error because `hidden_states` are a tuple of tensors
# instead of the expected tensor.
outputs = text_classifier(["This is great !"] * 20, batch_size=32)
self.assertEqual(len(outputs), 20)


@is_pipeline_test
class PipelinePadTest(unittest.TestCase):
Expand Down

0 comments on commit 942e507

Please sign in to comment.