Skip to content

Commit

Permalink
Handle edge case for DeepSpeed optimization
Browse files Browse the repository at this point in the history
  • Loading branch information
amorehead committed Oct 8, 2023
1 parent 47707c6 commit fc265af
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion src/models/mnist_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,11 @@ def configure_optimizers(self) -> Dict[str, Any]:
:return: A dict containing the configured optimizers and learning-rate schedulers to be used for training.
"""
optimizer = self.hparams.optimizer(params=self.trainer.model.parameters())
try:
optimizer = self.hparams.optimizer(params=self.trainer.model.parameters())
except TypeError:
# NOTE: strategies such as DeepSpeed require `params` to instead be specified as `model_params`
optimizer = self.hparams.optimizer(model_params=self.trainer.model.parameters())
if self.hparams.scheduler is not None:
scheduler = self.hparams.scheduler(optimizer=optimizer)
return {
Expand Down

0 comments on commit fc265af

Please sign in to comment.