From 942e50700c373cef9c5e9aafe08d1a8cd14eb521 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Fri, 19 Nov 2021 15:37:52 +0100 Subject: [PATCH] Adding support for `hidden_states` and `attentions` in unbatching (#14420) support. --- src/transformers/pipelines/base.py | 11 ++++++++--- tests/test_pipelines_common.py | 14 ++++++++++++++ 2 files changed, 22 insertions(+), 3 deletions(-) diff --git a/src/transformers/pipelines/base.py b/src/transformers/pipelines/base.py index f59da6a87074d9..e3eefec6c8c63b 100644 --- a/src/transformers/pipelines/base.py +++ b/src/transformers/pipelines/base.py @@ -747,9 +747,14 @@ def loader_batch_item(self): else: loader_batched = {} for k, element in self._loader_batch_data.items(): - if k == "past_key_values": - continue - if isinstance(element[self._loader_batch_index], torch.Tensor): + if k in {"hidden_states", "past_key_values", "attentions"} and isinstance(element, tuple): + if isinstance(element[0], torch.Tensor): + loader_batched[k] = tuple(el[self._loader_batch_index].unsqueeze(0) for el in element) + elif isinstance(element[0], np.ndarray): + loader_batched[k] = tuple( + np.expand_dims(el[self._loader_batch_index], 0) for el in element + ) + elif isinstance(element[self._loader_batch_index], torch.Tensor): loader_batched[k] = element[self._loader_batch_index].unsqueeze(0) elif isinstance(element[self._loader_batch_index], np.ndarray): loader_batched[k] = np.expand_dims(element[self._loader_batch_index], 0) diff --git a/tests/test_pipelines_common.py b/tests/test_pipelines_common.py index b027429318521a..88d7b2b184bd5a 100644 --- a/tests/test_pipelines_common.py +++ b/tests/test_pipelines_common.py @@ -27,6 +27,7 @@ TOKENIZER_MAPPING, AutoFeatureExtractor, AutoTokenizer, + DistilBertForSequenceClassification, IBertConfig, RobertaConfig, TextClassificationPipeline, @@ -322,6 +323,19 @@ def data(n: int): results.append(out) self.assertEqual(len(results), 10) + @require_torch + def test_unbatch_attentions_hidden_states(self): + model = DistilBertForSequenceClassification.from_pretrained( + "Narsil/tiny-distilbert-sequence-classification", output_hidden_states=True, output_attentions=True + ) + tokenizer = AutoTokenizer.from_pretrained("Narsil/tiny-distilbert-sequence-classification") + text_classifier = TextClassificationPipeline(model=model, tokenizer=tokenizer) + + # Used to throw an error because `hidden_states` are a tuple of tensors + # instead of the expected tensor. + outputs = text_classifier(["This is great !"] * 20, batch_size=32) + self.assertEqual(len(outputs), 20) + @is_pipeline_test class PipelinePadTest(unittest.TestCase):