From d4744d089e377334eb140eecd83b7061f97253ce Mon Sep 17 00:00:00 2001 From: Pete Date: Thu, 12 Oct 2023 08:35:20 -0700 Subject: [PATCH] Bring back global gradient clipping and improve speed of collecting metrics (#326) --- olmo/optim.py | 211 ++++++++++++++++++++++++++++++++------------------ olmo/train.py | 33 ++++---- 2 files changed, 152 insertions(+), 92 deletions(-) diff --git a/olmo/optim.py b/olmo/optim.py index b34c837ae..492cf041e 100644 --- a/olmo/optim.py +++ b/olmo/optim.py @@ -67,6 +67,7 @@ def clip_grads_and_collect_metrics( per_param_avg_metric_names: List[str] = [] per_param_norm_metric_names: List[str] = [] + # Collect metrics locally. for group in self.param_groups: if is_distributed(): # TODO (epwalsh): handle non-sharded params. We don't have any right now but we would @@ -90,38 +91,27 @@ def clip_grads_and_collect_metrics( for x, prefix in zip(tensors, prefixes): # grad or state tensors could be none for params that have their shards completely on # other ranks. - x = ( - x.to(device="cpu") - if x is not None - else torch.tensor([], device="cpu", dtype=torch.float32) - ) - if x.numel() > 0: + if x is not None and x.numel() > 0: if collect_param_metrics: x_abs = x.abs() - per_param_min_metrics.append( - x_abs.min().unsqueeze(0).to(device="cpu", dtype=torch.float32) - ) - per_param_max_metrics.append( - x_abs.max().unsqueeze(0).to(device="cpu", dtype=torch.float32) - ) - per_param_sum_metrics.append( - x.sum().unsqueeze(0).to(device="cpu", dtype=torch.float32) - ) + per_param_min_metrics.append(x_abs.min().unsqueeze(0).to(dtype=torch.float32)) + per_param_max_metrics.append(x_abs.max().unsqueeze(0).to(dtype=torch.float32)) + per_param_sum_metrics.append(x.sum().unsqueeze(0).to(dtype=torch.float32)) per_param_numel_metrics.append( - torch.tensor([x.numel()], device="cpu", dtype=torch.float32) + torch.tensor([x.numel()], device=device, dtype=torch.float32) ) per_param_norm_metrics.append( - torch.linalg.vector_norm(x, 2.0, dtype=torch.float32).unsqueeze(0).to(device="cpu") + torch.linalg.vector_norm(x, 2.0, dtype=torch.float32).unsqueeze(0) ) else: if collect_param_metrics: per_param_min_metrics.append( - torch.tensor([float("inf")], device="cpu", dtype=torch.float32) + torch.tensor([float("inf")], device=device, dtype=torch.float32) ) - per_param_max_metrics.append(torch.tensor([0.0], device="cpu", dtype=torch.float32)) - per_param_sum_metrics.append(torch.tensor([0.0], device="cpu", dtype=torch.float32)) - per_param_numel_metrics.append(torch.tensor([0.0], device="cpu", dtype=torch.float32)) - per_param_norm_metrics.append(torch.tensor([0.0], device="cpu", dtype=torch.float32)) + per_param_max_metrics.append(torch.tensor([0.0], device=device, dtype=torch.float32)) + per_param_sum_metrics.append(torch.tensor([0.0], device=device, dtype=torch.float32)) + per_param_numel_metrics.append(torch.tensor([0.0], device=device, dtype=torch.float32)) + per_param_norm_metrics.append(torch.tensor([0.0], device=device, dtype=torch.float32)) if collect_param_metrics: per_param_min_metric_names.append(f"{prefix}.min") per_param_max_metric_names.append(f"{prefix}.max") @@ -139,6 +129,11 @@ def clip_grads_and_collect_metrics( ) assert len(per_param_norm_metrics) == len(per_param_norm_metric_names) + def is_grad_norm_metric(metric_name: str) -> bool: + return metric_name.startswith("grad/") and metric_name.endswith(".norm") + + # Now reduce metrics over all ranks. + total_grad_norm: torch.Tensor per_param_avg_metrics: List[torch.Tensor] = [] if is_distributed(): # TODO (epwalsh): skip for non-sharded params # Reduce metrics across all ranks. Note that we can use a `reduce` for most cases @@ -148,12 +143,12 @@ def clip_grads_and_collect_metrics( if per_param_min_metrics: all_mins = torch.cat(per_param_min_metrics).to(device) dist.reduce(all_mins, 0, op=dist.ReduceOp.MIN) - per_param_min_metrics = all_mins.to(device="cpu").split(1) + per_param_min_metrics = all_mins.split(1) # Reduce maxs. if per_param_max_metrics: all_maxs = torch.cat(per_param_max_metrics).to(device) dist.reduce(all_maxs, 0, op=dist.ReduceOp.MAX) - per_param_max_metrics = all_maxs.to(device="cpu").split(1) + per_param_max_metrics = all_maxs.split(1) # Reduce sums or just norms. all_norms = torch.cat(per_param_norm_metrics).to(device) ** 2.0 if per_param_sum_metrics and per_param_numel_metrics: @@ -163,14 +158,28 @@ def clip_grads_and_collect_metrics( [all_sums.unsqueeze(0), all_norms.unsqueeze(0), all_numels.unsqueeze(0)], dim=0 ) dist.all_reduce(all_sums_norms_numels, op=dist.ReduceOp.SUM) - all_sums, all_norms, all_numels = all_sums_norms_numels.to(device="cpu").split(1) + all_sums, all_norms, all_numels = all_sums_norms_numels.split(1) # Get averages. # NOTE: could get infs for non-rank0 processes but that's okay. per_param_avg_metrics = (all_sums / all_numels).squeeze(0).split(1) else: dist.all_reduce(all_norms, op=dist.ReduceOp.SUM) + grad_norm_metric_mask = torch.tensor( + [float(is_grad_norm_metric(n)) for n in per_param_norm_metric_names], device=all_norms.device + ) + total_grad_norm = (all_norms * grad_norm_metric_mask).sum() ** 0.5 per_param_norm_metrics = (all_norms ** (0.5)).squeeze(0).split(1) else: + total_grad_norm = ( + torch.cat( + [ + m + for m, n in zip(per_param_norm_metrics, per_param_norm_metric_names) + if is_grad_norm_metric(n) + ] + ) + ** 2.0 + ).sum() ** 0.5 per_param_avg_metrics = [x / n for x, n in zip(per_param_sum_metrics, per_param_numel_metrics)] assert len(per_param_avg_metrics) == len(per_param_avg_metric_names) @@ -185,69 +194,121 @@ def clip_grads_and_collect_metrics( all_metrics[metric_name] = metric.squeeze(0) for metric_name, metric in zip(per_param_norm_metric_names, per_param_norm_metrics): all_metrics[metric_name] = metric.squeeze(0) + all_metrics["total_grad_norm"] = total_grad_norm # Clip gradients. num_grads_clipped = 0 num_eligible_grads = 0 for group in self.param_groups: - # We'll use the bigger of beta1 and beta2 to update the exponential average of the norm of - # the gradient (a scalar), not to be confused with the exponential average of the gradient. - # TODO (epwalsh): handle optimizers that don't have betas. - beta1, beta2 = group["betas"] - beta = max(beta1, beta2) - max_norm = group.get("max_grad_norm") - max_norm_ratio = group.get("max_grad_norm_ratio") - if max_norm is None and max_norm_ratio is None: + if (max_norm_ratio := group.get("max_grad_norm_ratio")) is not None: + num_clipped = self._do_adaptive_clipping( + group, max_norm_ratio, global_step, all_metrics, collect_param_metrics=collect_param_metrics + ) + elif (max_norm := group.get("max_grad_norm")) is not None: + num_clipped = self._do_global_fixed_clipping( + group, max_norm, all_metrics, collect_param_metrics=collect_param_metrics + ) + else: # No clipping needed. continue + num_eligible_grads += len(group["params"]) + if num_clipped is not None: + num_grads_clipped += num_clipped - for name, p in zip(group["param_names"], group["params"]): - name = self._clean_param_name(name) - grad_norm = all_metrics.get(f"grad/{name}.norm") - if grad_norm is None: - continue + if collect_param_metrics: + clipping_rate = torch.tensor(num_grads_clipped / num_eligible_grads, device="cpu") + all_metrics["clipping_rate"] = clipping_rate + return all_metrics + else: + return {} + + @torch.no_grad() + def _do_adaptive_clipping( + self, + group: Dict[str, Any], + max_norm_ratio: float, + global_step: int, + all_metrics: Dict[str, torch.Tensor], + collect_param_metrics: bool = True, + ) -> Optional[int]: + """ + Do adaptive gradient clipping on a param group. - num_eligible_grads += 1 + If ``collect_param_metrics`` is ``True`` this will return the total number of gradients clipped. + """ + device = get_default_device() + num_grads_clipped = 0 + # We'll use the bigger of beta1 and beta2 to update the exponential average of the norm of + # the gradient (a scalar), not to be confused with the exponential average of the gradient. + # TODO (epwalsh): handle optimizers that don't have betas. + beta1, beta2 = group["betas"] + beta = max(beta1, beta2) + for name, p in zip(group["param_names"], group["params"]): + name = self._clean_param_name(name) + grad_norm = all_metrics.get(f"grad/{name}.norm") + if grad_norm is None: + continue - # Get or initialize the exponential average of grad norm. - state = self.state[p] - grad_norm_exp_avg = state.get("grad_norm_exp_avg") - if grad_norm_exp_avg is None: - grad_norm_exp_avg = grad_norm.clone().to(device) - # We don't want to add anything to `state` until `state` has been initialized, otherwise - # this will crash some optimizers which rely on checking `len(state)`. The downside here - # is that we won't start tracking `grad_norm_exp_avg` until the 2nd training step. - if global_step > 1: - state["grad_norm_exp_avg"] = grad_norm_exp_avg - - # Determine the clipping coefficient based on the clipping strategy. - if max_norm_ratio is not None: - # Adaptive clipping. - clipped_norm = max_norm_ratio * grad_norm_exp_avg - clip_coef = clipped_norm / (grad_norm + 1e-6) - else: - # Fixed clipping. - clipped_norm = torch.tensor(max_norm, device=device) - clip_coef = clipped_norm / (grad_norm + 1e-6) - - # Clip the gradients and update the exponential average. - clip_coef_clamped = torch.clamp(clip_coef, max=1.0) + # Get or initialize the exponential average of grad norm. + state = self.state[p] + grad_norm_exp_avg = state.get("grad_norm_exp_avg") + if grad_norm_exp_avg is None: + grad_norm_exp_avg = grad_norm.clone().to(device) + # We don't want to add anything to `state` until `state` has been initialized, otherwise + # this will crash some optimizers which rely on checking `len(state)`. The downside here + # is that we won't start tracking `grad_norm_exp_avg` until the 2nd training step. + if global_step > 1: + state["grad_norm_exp_avg"] = grad_norm_exp_avg + + clipped_norm = max_norm_ratio * grad_norm_exp_avg + clip_coef = clipped_norm / (grad_norm + 1e-6) + + # Clip the gradients and update the exponential average. + # Note that multiplying by the clamped coefficient is meaningless when it is + # equal to 1, but it avoids the host-device sync that would result from `if clip_coef_clamped < 1`. + clip_coef_clamped = torch.clamp(clip_coef, max=1.0) + if p.grad is not None: + # p.grad could be none for some ranks when using FSDP. + p.grad.detach().mul_(clip_coef_clamped.to(p.grad.device, p.grad.dtype)) + grad_norm_exp_avg.lerp_(clipped_norm.to(grad_norm_exp_avg.device), 1 - beta) + + if collect_param_metrics: + # Can't avoid host-device sync here. if clip_coef_clamped < 1.0: num_grads_clipped += 1 - if p.grad is not None: - # p.grad could be none for some ranks when using FSDP. - p.grad.detach().mul_(clip_coef_clamped.to(p.grad.device, p.grad.dtype)) - grad_norm_exp_avg.lerp_(clipped_norm.to(grad_norm_exp_avg.device), 1 - beta) - else: - grad_norm_exp_avg.lerp_(grad_norm.to(grad_norm_exp_avg.device), 1 - beta) - all_metrics[f"grad_norm_exp_avg/{name}"] = grad_norm_exp_avg.to(device="cpu") - - clipping_rate = torch.tensor(num_grads_clipped / num_eligible_grads, device="cpu") + all_metrics[f"grad_norm_exp_avg/{name}"] = grad_norm_exp_avg + return num_grads_clipped if collect_param_metrics else None + + @torch.no_grad() + def _do_global_fixed_clipping( + self, + group: Dict[str, Any], + max_norm: float, + all_metrics: Dict[str, torch.Tensor], + collect_param_metrics: bool = True, + ) -> Optional[int]: + """ + Do global fixed gradient clipping on a param group. + + If ``collect_param_metrics`` is ``True`` this will return the total number of gradients clipped. + """ + device = get_default_device() + total_grad_norm = all_metrics["total_grad_norm"] + clip_coef = max_norm / (total_grad_norm.to(device) + 1e-6) + clip_coef_clamped = torch.clamp(clip_coef, max=1.0) + num_grads_clipped: Optional[int] = None if collect_param_metrics: - all_metrics["clipping_rate"] = clipping_rate - return all_metrics - else: - return {"clipping_rate": clipping_rate} + # Can't avoid host-device sync here. + if clip_coef_clamped < 1.0: + num_grads_clipped = len(group["params"]) + for p in group["params"]: + # Clip the gradients. + # Note that multiplying by the clamped coefficient is meaningless when it is + # equal to 1, but it avoids the host-device sync that would result from `if clip_coef_clamped < 1`. + if p.grad is not None: + # p.grad could be none for some ranks when using FSDP. + p.grad.detach().mul_(clip_coef_clamped.to(p.grad.device, p.grad.dtype)) + return num_grads_clipped def get_post_step_metrics(self, module: nn.Module) -> Dict[str, torch.Tensor]: del module diff --git a/olmo/train.py b/olmo/train.py index 144cb2428..85962f8f1 100644 --- a/olmo/train.py +++ b/olmo/train.py @@ -762,12 +762,6 @@ def train_batch(self, batch: Dict[str, Any]) -> Tuple[torch.Tensor, Optional[tor # Run backward pass. loss.backward() - # Check for nan. - if torch.isnan(ce_batch_loss): - raise ValueError("nan loss encountered") - if z_batch_loss is not None and torch.isnan(z_batch_loss): - raise ValueError("nan loss encountered") - return ce_batch_loss, z_batch_loss def train_step(self, batch: Dict[str, Any], reduce_global_loss: bool = True) -> Dict[str, float]: @@ -787,13 +781,19 @@ def train_step(self, batch: Dict[str, Any], reduce_global_loss: bool = True) -> # Run forward-backward pass. ce_batch_loss, z_batch_loss = self.train_batch(batch) + # Collect loss, potentially reducing over all ranks. + if reduce_global_loss: + dist.reduce(ce_batch_loss, 0) + ce_batch_loss.div_(get_world_size()) + if z_batch_loss is not None: + dist.reduce(z_batch_loss, 0) + z_batch_loss.div_(get_world_size()) + # Clip gradient norms and collect param/gradient/optim metrics. should_log_optim_metrics_this_step = self.should_log_optim_metrics_this_step() optim_metrics = self.optim.clip_grads_and_collect_metrics( self.global_step, collect_param_metrics=should_log_optim_metrics_this_step ) - for key, value in optim_metrics.items(): - metrics[f"optim/{key}"] = value.item() # Adjust the learning rate. for group in self.optim.param_groups: @@ -804,20 +804,19 @@ def train_step(self, batch: Dict[str, Any], reduce_global_loss: bool = True) -> # Optimizer step. self.optim.step() - # Collect loss, potentially reducing over all ranks. - if reduce_global_loss: - dist.reduce(ce_batch_loss, 0) - ce_batch_loss.div_(get_world_size()) - # TODO (dirkgr): If we did this much earlier, like, right after the forwards step, but then didn't - # call `.item()` for a long time, would it use laziness to interleave this reduce call with the backward step? + # Collect metrics and check for NaN loss. + # NOTE: this involves a bunch of host-device syncs so we wait until the last moment to do this. + if torch.isnan(ce_batch_loss): + raise ValueError("nan loss encountered") + if z_batch_loss is not None and torch.isnan(z_batch_loss): + raise ValueError("nan loss encountered") + for key, value in optim_metrics.items(): + metrics[f"optim/{key}"] = value.item() self.cur_train_loss = ce_batch_loss.item() self.min_train_loss = min(self.min_train_loss, self.cur_train_loss) metrics["train/CrossEntropyLoss"] = self.cur_train_loss metrics["train/Perplexity"] = math.exp(self.cur_train_loss) if z_batch_loss is not None: - if reduce_global_loss: - dist.reduce(z_batch_loss, 0) - z_batch_loss.div_(get_world_size()) metrics["train/ZLoss"] = z_batch_loss.item() # Maybe collect post-step optimizer-specific metrics.