Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix get_parameters when using main params optimizer #6764

Merged
merged 3 commits into from
Jun 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -240,14 +240,16 @@ def _vocab_size_with_padding(self, orig_vocab_size, make_vocab_size_divisible_by
)
return after

def _get_parameters(self):
def get_parameters_with_grad(self):
"""
private method to load all the trainable parameters from optimizer param groups
Get all parameters with grad from optimizer param groups
"""
params = []
for param_group in self._optimizer_param_groups:
for param in param_group['params']:
if param.requires_grad: # (@adithyare) adapter training with pp>1 can result in params with no grads
if (
param.grad is not None
): # (@adithyare) adapter training with pp>1 can result in params with no grads
params.append(param)
return params

Expand All @@ -272,9 +274,9 @@ def configure_gradient_clipping(self, *args, **kwargs):
else:
if self.megatron_amp_o2:
# grep fp32 master parameters for gradient clipping
parameters = self._optimizer.get_parameters()
parameters = self._optimizer.get_parameters_with_grad()
else:
parameters = self._get_parameters()
parameters = self.get_parameters_with_grad()
grad_norm = clip_grad_norm_fp32(parameters=parameters, max_norm=clip_val)

self.log('grad_norm', grad_norm, rank_zero_only=True, batch_size=1)
Expand Down
4 changes: 2 additions & 2 deletions nemo/core/optim/optimizer_with_main_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,11 +488,11 @@ def async_master_grads_allreudce(self):
def fp32_grad_accumulation(self):
return self._fp32_grad_accum

def get_parameters(self):
def get_parameters_with_grad(self):
params = []
for param_group in self.optimizer.param_groups:
for param in param_group['params']:
if param.requires_grad: # (@adithyare) added to enable pp>1 training for adapters
if param.grad is not None: # (@adithyare) added to enable pp>1 training for adapters
params.append(param)
return params

Expand Down