From b1df810de5b99ad8b85d5f7c8c3843369a4b35bd Mon Sep 17 00:00:00 2001 From: Alexandros Koumparoulis Date: Wed, 15 May 2024 07:13:36 +0000 Subject: [PATCH] pass dp_zero_gather_scatter to starded-state-dict Signed-off-by: Alexandros Koumparoulis --- nemo/core/optim/mcore_optim.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/nemo/core/optim/mcore_optim.py b/nemo/core/optim/mcore_optim.py index 0d4b524049ca..5103cce2dd84 100644 --- a/nemo/core/optim/mcore_optim.py +++ b/nemo/core/optim/mcore_optim.py @@ -37,6 +37,8 @@ def __init__(self, optim): self.mcore_optimizer = optim self.param_groups = self.mcore_optimizer.param_groups self.state = self.mcore_optimizer.state + self.sharding_type = 'dp_zero_gather_scatter' + # 'fully_sharded_bucket_space' if args.ckpt_fully_parallel_save else 'dp_zero_gather_scatter' def zero_grad(self, set_to_none: bool = True): """We only need to zero the model related parameters, i.e., @@ -55,8 +57,9 @@ def state_dict(self): def load_state_dict(self, state_dict): self.mcore_optimizer.load_state_dict(state_dict) - def sharded_state_dict(self, model_sharded_state_dict, is_loading: bool = False, **kwargs): - return self.mcore_optimizer.sharded_state_dict(model_sharded_state_dict, is_loading, **kwargs) + def sharded_state_dict(self, model_sharded_state_dict, optimizer_state_dict=None): + return self.mcore_optimizer.sharded_state_dict( + model_sharded_state_dict, is_loading=False, sharding_type='dp_zero_gather_scatter') def step(self, closure): """Clip gradients (if needed) and step the base optimizer.