Skip to content

Commit

Permalink
Fix a Bug, trainer_seq2seq.py, in the else branch at Line 172, genera…
Browse files Browse the repository at this point in the history
…tion_inputs should be a dict (huggingface#14546)

* fix bug, trainer_seq2seq.py, Line 172, generation_inputs must be a dict before feeding into self.model.generation()

* fix bug, trainer_seq2seq.py, Line 172, generation_inputs must be a dict before feeding into self.model.generation()
  • Loading branch information
TranSirius authored and Alberto Bégué committed Jan 27, 2022
1 parent 12d1dd6 commit ca2d57f
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion src/transformers/trainer_seq2seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ def prediction_step(
# very ugly hack to make it work
generation_inputs["input_ids"] = generation_inputs.pop(self.tokenizer.model_input_names[0])
else:
generation_inputs = inputs["input_ids"]
generation_inputs = {"input_ids": inputs["input_ids"]}

generated_tokens = self.model.generate(
**generation_inputs,
Expand Down

0 comments on commit ca2d57f

Please sign in to comment.