diff --git a/src/transformers/pipelines/zero_shot_classification.py b/src/transformers/pipelines/zero_shot_classification.py index 1e2b88d7be170..f308c48b16dfb 100644 --- a/src/transformers/pipelines/zero_shot_classification.py +++ b/src/transformers/pipelines/zero_shot_classification.py @@ -191,10 +191,7 @@ def __call__( else: raise ValueError(f"Unable to understand extra arguments {args}") - result = super().__call__(sequences, **kwargs) - if len(result) == 1: - return result[0] - return result + return super().__call__(sequences, **kwargs) def preprocess(self, inputs, candidate_labels=None, hypothesis_template="This example is {}."): sequence_pairs, sequences = self._args_parser(inputs, candidate_labels, hypothesis_template) @@ -264,4 +261,6 @@ def postprocess(self, model_outputs, multi_label=False): "scores": scores[iseq, top_inds].tolist(), } ) + if len(result) == 1: + return result[0] return result diff --git a/tests/test_pipelines_zero_shot.py b/tests/test_pipelines_zero_shot.py index d22ce68621b70..ae47eb626cc42 100644 --- a/tests/test_pipelines_zero_shot.py +++ b/tests/test_pipelines_zero_shot.py @@ -61,6 +61,24 @@ def run_pipeline_test(self, model, tokenizer, feature_extractor): ) self.assertEqual(outputs, {"sequence": ANY(str), "labels": [ANY(str)], "scores": [ANY(float)]}) + # https://github.com/huggingface/transformers/issues/13846 + outputs = classifier(["I am happy"], ["positive", "negative"]) + self.assertEqual( + outputs, + [ + {"sequence": ANY(str), "labels": [ANY(str), ANY(str)], "scores": [ANY(float), ANY(float)]} + for i in range(1) + ], + ) + outputs = classifier(["I am happy", "I am sad"], ["positive", "negative"]) + self.assertEqual( + outputs, + [ + {"sequence": ANY(str), "labels": [ANY(str), ANY(str)], "scores": [ANY(float), ANY(float)]} + for i in range(2) + ], + ) + with self.assertRaises(ValueError): classifier("", candidate_labels="politics")