From ca2d57f1cc29106475fb0bef7e51c31765cd7b5d Mon Sep 17 00:00:00 2001 From: TranSirius Date: Wed, 8 Dec 2021 01:09:18 +0800 Subject: [PATCH] Fix a Bug, trainer_seq2seq.py, in the else branch at Line 172, generation_inputs should be a dict (#14546) * fix bug, trainer_seq2seq.py, Line 172, generation_inputs must be a dict before feeding into self.model.generation() * fix bug, trainer_seq2seq.py, Line 172, generation_inputs must be a dict before feeding into self.model.generation() --- src/transformers/trainer_seq2seq.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/trainer_seq2seq.py b/src/transformers/trainer_seq2seq.py index 19296f51603528..e2ca5bd1a0d4d6 100644 --- a/src/transformers/trainer_seq2seq.py +++ b/src/transformers/trainer_seq2seq.py @@ -169,7 +169,7 @@ def prediction_step( # very ugly hack to make it work generation_inputs["input_ids"] = generation_inputs.pop(self.tokenizer.model_input_names[0]) else: - generation_inputs = inputs["input_ids"] + generation_inputs = {"input_ids": inputs["input_ids"]} generated_tokens = self.model.generate( **generation_inputs,