Skip to content

Commit

Permalink
Fix broken function args in troubleshooting
Browse files Browse the repository at this point in the history
  • Loading branch information
JonasGeiping committed Aug 7, 2023
1 parent b8ac4df commit 1397b8c
Showing 1 changed file with 23 additions and 24 deletions.
47 changes: 23 additions & 24 deletions pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,30 +151,29 @@ def collect_stats(step, loss_vals, train_time, stats, model_engine, dataloader,
return loss_vals, train_time


def engage_troubleshooting(model_engine, step, training_allowed, cfg):
training_allowed, no_recovery_necessary = training_allowed, no_recovery_necessary
# log.info(f"Non-finite loss in step {step} on device {cfg.impl.local_rank}.")
#
# is_finite_grad = [torch.isfinite(p.grad).all() for p in model_engine.model.parameters() if p.grad is not None]
# has_finite_gradients = torch.stack(is_finite_grad).all() if len(is_finite_grad) > 0 else True
# if not has_finite_gradients:
# if "dump_nan_grads" in cfg.impl.troubleshoot_strategy:
# log.info(f"Non-finite gradients in step {step} on device {cfg.impl.local_rank}, dumping...")
# model_engine.optimizer.zero_grad()
# else:
# if "recover_checkpoint" in cfg.impl.troubleshoot_strategy:
# no_recovery_necessary = False
# else:
# training_allowed = False
# log.info(f"Stopping training due to non-finite grads in step {step} on device {cfg.impl.local_rank}.")
#
# has_finite_parameters = torch.stack([torch.isfinite(p).all() for p in model_engine.model.parameters()]).all()
# if not has_finite_parameters:
# if "recover_checkpoint" in cfg.impl.troubleshoot_strategy:
# no_recovery_necessary = False
# else:
# training_allowed = False
# log.info(f"Stopping training due to non-finite parameters in step {step} on device {cfg.impl.local_rank}.")
def engage_troubleshooting(model_engine, step, training_allowed, no_recovery_necessary, cfg):
log.info(f"Non-finite loss in step {step} on device {cfg.impl.local_rank}.")

is_finite_grad = [torch.isfinite(p.grad).all() for p in model_engine.model.parameters() if p.grad is not None]
has_finite_gradients = torch.stack(is_finite_grad).all() if len(is_finite_grad) > 0 else True
if not has_finite_gradients:
if "dump_nan_grads" in cfg.impl.troubleshoot_strategy:
log.info(f"Non-finite gradients in step {step} on device {cfg.impl.local_rank}, dumping...")
model_engine.optimizer.zero_grad()
else:
if "recover_checkpoint" in cfg.impl.troubleshoot_strategy:
no_recovery_necessary = False
else:
training_allowed = False
log.info(f"Stopping training due to non-finite grads in step {step} on device {cfg.impl.local_rank}.")

has_finite_parameters = torch.stack([torch.isfinite(p).all() for p in model_engine.model.parameters()]).all()
if not has_finite_parameters:
if "recover_checkpoint" in cfg.impl.troubleshoot_strategy:
no_recovery_necessary = False
else:
training_allowed = False
log.info(f"Stopping training due to non-finite parameters in step {step} on device {cfg.impl.local_rank}.")
return training_allowed, no_recovery_necessary


Expand Down

0 comments on commit 1397b8c

Please sign in to comment.