diff --git a/src/lightning/pytorch/callbacks/prediction_writer.py b/src/lightning/pytorch/callbacks/prediction_writer.py index b27697fa55080..f1f19f07719b6 100644 --- a/src/lightning/pytorch/callbacks/prediction_writer.py +++ b/src/lightning/pytorch/callbacks/prediction_writer.py @@ -58,7 +58,7 @@ def __init__(self, output_dir, write_interval): self.output_dir = output_dir def write_on_batch_end( - self, trainer, pl_module', prediction, batch_indices, batch, batch_idx, dataloader_idx + self, trainer, pl_module, prediction, batch_indices, batch, batch_idx, dataloader_idx ): torch.save(prediction, os.path.join(self.output_dir, dataloader_idx, f"{batch_idx}.pt"))