diff --git a/nemo/collections/nlp/parts/nlp_overrides.py b/nemo/collections/nlp/parts/nlp_overrides.py index 0c81f47ac83b..71714acf6be9 100644 --- a/nemo/collections/nlp/parts/nlp_overrides.py +++ b/nemo/collections/nlp/parts/nlp_overrides.py @@ -79,6 +79,7 @@ from nemo.core.optim.distributed_adam import MegatronDistributedFusedAdam from nemo.core.optim.mcore_optim import McoreDistributedOptimizer + HAVE_APEX = True except (ImportError, ModuleNotFoundError): diff --git a/nemo/core/optim/mcore_optim.py b/nemo/core/optim/mcore_optim.py index 5103cce2dd84..caad2ba17c6e 100644 --- a/nemo/core/optim/mcore_optim.py +++ b/nemo/core/optim/mcore_optim.py @@ -59,7 +59,8 @@ def load_state_dict(self, state_dict): def sharded_state_dict(self, model_sharded_state_dict, optimizer_state_dict=None): return self.mcore_optimizer.sharded_state_dict( - model_sharded_state_dict, is_loading=False, sharding_type='dp_zero_gather_scatter') + model_sharded_state_dict, is_loading=False, sharding_type='dp_zero_gather_scatter' + ) def step(self, closure): """Clip gradients (if needed) and step the base optimizer.