Skip to content

Commit

Permalink
Bring back global gradient clipping and improve speed of collecting m…
Browse files Browse the repository at this point in the history
…etrics (#326)
  • Loading branch information
epwalsh committed Oct 12, 2023
1 parent 54572d3 commit d4744d0
Show file tree
Hide file tree
Showing 2 changed files with 152 additions and 92 deletions.
211 changes: 136 additions & 75 deletions olmo/optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ def clip_grads_and_collect_metrics(
per_param_avg_metric_names: List[str] = []
per_param_norm_metric_names: List[str] = []

# Collect metrics locally.
for group in self.param_groups:
if is_distributed():
# TODO (epwalsh): handle non-sharded params. We don't have any right now but we would
Expand All @@ -90,38 +91,27 @@ def clip_grads_and_collect_metrics(
for x, prefix in zip(tensors, prefixes):
# grad or state tensors could be none for params that have their shards completely on
# other ranks.
x = (
x.to(device="cpu")
if x is not None
else torch.tensor([], device="cpu", dtype=torch.float32)
)
if x.numel() > 0:
if x is not None and x.numel() > 0:
if collect_param_metrics:
x_abs = x.abs()
per_param_min_metrics.append(
x_abs.min().unsqueeze(0).to(device="cpu", dtype=torch.float32)
)
per_param_max_metrics.append(
x_abs.max().unsqueeze(0).to(device="cpu", dtype=torch.float32)
)
per_param_sum_metrics.append(
x.sum().unsqueeze(0).to(device="cpu", dtype=torch.float32)
)
per_param_min_metrics.append(x_abs.min().unsqueeze(0).to(dtype=torch.float32))
per_param_max_metrics.append(x_abs.max().unsqueeze(0).to(dtype=torch.float32))
per_param_sum_metrics.append(x.sum().unsqueeze(0).to(dtype=torch.float32))
per_param_numel_metrics.append(
torch.tensor([x.numel()], device="cpu", dtype=torch.float32)
torch.tensor([x.numel()], device=device, dtype=torch.float32)
)
per_param_norm_metrics.append(
torch.linalg.vector_norm(x, 2.0, dtype=torch.float32).unsqueeze(0).to(device="cpu")
torch.linalg.vector_norm(x, 2.0, dtype=torch.float32).unsqueeze(0)
)
else:
if collect_param_metrics:
per_param_min_metrics.append(
torch.tensor([float("inf")], device="cpu", dtype=torch.float32)
torch.tensor([float("inf")], device=device, dtype=torch.float32)
)
per_param_max_metrics.append(torch.tensor([0.0], device="cpu", dtype=torch.float32))
per_param_sum_metrics.append(torch.tensor([0.0], device="cpu", dtype=torch.float32))
per_param_numel_metrics.append(torch.tensor([0.0], device="cpu", dtype=torch.float32))
per_param_norm_metrics.append(torch.tensor([0.0], device="cpu", dtype=torch.float32))
per_param_max_metrics.append(torch.tensor([0.0], device=device, dtype=torch.float32))
per_param_sum_metrics.append(torch.tensor([0.0], device=device, dtype=torch.float32))
per_param_numel_metrics.append(torch.tensor([0.0], device=device, dtype=torch.float32))
per_param_norm_metrics.append(torch.tensor([0.0], device=device, dtype=torch.float32))
if collect_param_metrics:
per_param_min_metric_names.append(f"{prefix}.min")
per_param_max_metric_names.append(f"{prefix}.max")
Expand All @@ -139,6 +129,11 @@ def clip_grads_and_collect_metrics(
)
assert len(per_param_norm_metrics) == len(per_param_norm_metric_names)

def is_grad_norm_metric(metric_name: str) -> bool:
return metric_name.startswith("grad/") and metric_name.endswith(".norm")

# Now reduce metrics over all ranks.
total_grad_norm: torch.Tensor
per_param_avg_metrics: List[torch.Tensor] = []
if is_distributed(): # TODO (epwalsh): skip for non-sharded params
# Reduce metrics across all ranks. Note that we can use a `reduce` for most cases
Expand All @@ -148,12 +143,12 @@ def clip_grads_and_collect_metrics(
if per_param_min_metrics:
all_mins = torch.cat(per_param_min_metrics).to(device)
dist.reduce(all_mins, 0, op=dist.ReduceOp.MIN)
per_param_min_metrics = all_mins.to(device="cpu").split(1)
per_param_min_metrics = all_mins.split(1)
# Reduce maxs.
if per_param_max_metrics:
all_maxs = torch.cat(per_param_max_metrics).to(device)
dist.reduce(all_maxs, 0, op=dist.ReduceOp.MAX)
per_param_max_metrics = all_maxs.to(device="cpu").split(1)
per_param_max_metrics = all_maxs.split(1)
# Reduce sums or just norms.
all_norms = torch.cat(per_param_norm_metrics).to(device) ** 2.0
if per_param_sum_metrics and per_param_numel_metrics:
Expand All @@ -163,14 +158,28 @@ def clip_grads_and_collect_metrics(
[all_sums.unsqueeze(0), all_norms.unsqueeze(0), all_numels.unsqueeze(0)], dim=0
)
dist.all_reduce(all_sums_norms_numels, op=dist.ReduceOp.SUM)
all_sums, all_norms, all_numels = all_sums_norms_numels.to(device="cpu").split(1)
all_sums, all_norms, all_numels = all_sums_norms_numels.split(1)
# Get averages.
# NOTE: could get infs for non-rank0 processes but that's okay.
per_param_avg_metrics = (all_sums / all_numels).squeeze(0).split(1)
else:
dist.all_reduce(all_norms, op=dist.ReduceOp.SUM)
grad_norm_metric_mask = torch.tensor(
[float(is_grad_norm_metric(n)) for n in per_param_norm_metric_names], device=all_norms.device
)
total_grad_norm = (all_norms * grad_norm_metric_mask).sum() ** 0.5
per_param_norm_metrics = (all_norms ** (0.5)).squeeze(0).split(1)
else:
total_grad_norm = (
torch.cat(
[
m
for m, n in zip(per_param_norm_metrics, per_param_norm_metric_names)
if is_grad_norm_metric(n)
]
)
** 2.0
).sum() ** 0.5
per_param_avg_metrics = [x / n for x, n in zip(per_param_sum_metrics, per_param_numel_metrics)]

assert len(per_param_avg_metrics) == len(per_param_avg_metric_names)
Expand All @@ -185,69 +194,121 @@ def clip_grads_and_collect_metrics(
all_metrics[metric_name] = metric.squeeze(0)
for metric_name, metric in zip(per_param_norm_metric_names, per_param_norm_metrics):
all_metrics[metric_name] = metric.squeeze(0)
all_metrics["total_grad_norm"] = total_grad_norm

# Clip gradients.
num_grads_clipped = 0
num_eligible_grads = 0
for group in self.param_groups:
# We'll use the bigger of beta1 and beta2 to update the exponential average of the norm of
# the gradient (a scalar), not to be confused with the exponential average of the gradient.
# TODO (epwalsh): handle optimizers that don't have betas.
beta1, beta2 = group["betas"]
beta = max(beta1, beta2)
max_norm = group.get("max_grad_norm")
max_norm_ratio = group.get("max_grad_norm_ratio")
if max_norm is None and max_norm_ratio is None:
if (max_norm_ratio := group.get("max_grad_norm_ratio")) is not None:
num_clipped = self._do_adaptive_clipping(
group, max_norm_ratio, global_step, all_metrics, collect_param_metrics=collect_param_metrics
)
elif (max_norm := group.get("max_grad_norm")) is not None:
num_clipped = self._do_global_fixed_clipping(
group, max_norm, all_metrics, collect_param_metrics=collect_param_metrics
)
else:
# No clipping needed.
continue
num_eligible_grads += len(group["params"])
if num_clipped is not None:
num_grads_clipped += num_clipped

for name, p in zip(group["param_names"], group["params"]):
name = self._clean_param_name(name)
grad_norm = all_metrics.get(f"grad/{name}.norm")
if grad_norm is None:
continue
if collect_param_metrics:
clipping_rate = torch.tensor(num_grads_clipped / num_eligible_grads, device="cpu")
all_metrics["clipping_rate"] = clipping_rate
return all_metrics
else:
return {}

@torch.no_grad()
def _do_adaptive_clipping(
self,
group: Dict[str, Any],
max_norm_ratio: float,
global_step: int,
all_metrics: Dict[str, torch.Tensor],
collect_param_metrics: bool = True,
) -> Optional[int]:
"""
Do adaptive gradient clipping on a param group.
num_eligible_grads += 1
If ``collect_param_metrics`` is ``True`` this will return the total number of gradients clipped.
"""
device = get_default_device()
num_grads_clipped = 0
# We'll use the bigger of beta1 and beta2 to update the exponential average of the norm of
# the gradient (a scalar), not to be confused with the exponential average of the gradient.
# TODO (epwalsh): handle optimizers that don't have betas.
beta1, beta2 = group["betas"]
beta = max(beta1, beta2)
for name, p in zip(group["param_names"], group["params"]):
name = self._clean_param_name(name)
grad_norm = all_metrics.get(f"grad/{name}.norm")
if grad_norm is None:
continue

# Get or initialize the exponential average of grad norm.
state = self.state[p]
grad_norm_exp_avg = state.get("grad_norm_exp_avg")
if grad_norm_exp_avg is None:
grad_norm_exp_avg = grad_norm.clone().to(device)
# We don't want to add anything to `state` until `state` has been initialized, otherwise
# this will crash some optimizers which rely on checking `len(state)`. The downside here
# is that we won't start tracking `grad_norm_exp_avg` until the 2nd training step.
if global_step > 1:
state["grad_norm_exp_avg"] = grad_norm_exp_avg

# Determine the clipping coefficient based on the clipping strategy.
if max_norm_ratio is not None:
# Adaptive clipping.
clipped_norm = max_norm_ratio * grad_norm_exp_avg
clip_coef = clipped_norm / (grad_norm + 1e-6)
else:
# Fixed clipping.
clipped_norm = torch.tensor(max_norm, device=device)
clip_coef = clipped_norm / (grad_norm + 1e-6)

# Clip the gradients and update the exponential average.
clip_coef_clamped = torch.clamp(clip_coef, max=1.0)
# Get or initialize the exponential average of grad norm.
state = self.state[p]
grad_norm_exp_avg = state.get("grad_norm_exp_avg")
if grad_norm_exp_avg is None:
grad_norm_exp_avg = grad_norm.clone().to(device)
# We don't want to add anything to `state` until `state` has been initialized, otherwise
# this will crash some optimizers which rely on checking `len(state)`. The downside here
# is that we won't start tracking `grad_norm_exp_avg` until the 2nd training step.
if global_step > 1:
state["grad_norm_exp_avg"] = grad_norm_exp_avg

clipped_norm = max_norm_ratio * grad_norm_exp_avg
clip_coef = clipped_norm / (grad_norm + 1e-6)

# Clip the gradients and update the exponential average.
# Note that multiplying by the clamped coefficient is meaningless when it is
# equal to 1, but it avoids the host-device sync that would result from `if clip_coef_clamped < 1`.
clip_coef_clamped = torch.clamp(clip_coef, max=1.0)
if p.grad is not None:
# p.grad could be none for some ranks when using FSDP.
p.grad.detach().mul_(clip_coef_clamped.to(p.grad.device, p.grad.dtype))
grad_norm_exp_avg.lerp_(clipped_norm.to(grad_norm_exp_avg.device), 1 - beta)

if collect_param_metrics:
# Can't avoid host-device sync here.
if clip_coef_clamped < 1.0:
num_grads_clipped += 1
if p.grad is not None:
# p.grad could be none for some ranks when using FSDP.
p.grad.detach().mul_(clip_coef_clamped.to(p.grad.device, p.grad.dtype))
grad_norm_exp_avg.lerp_(clipped_norm.to(grad_norm_exp_avg.device), 1 - beta)
else:
grad_norm_exp_avg.lerp_(grad_norm.to(grad_norm_exp_avg.device), 1 - beta)
all_metrics[f"grad_norm_exp_avg/{name}"] = grad_norm_exp_avg.to(device="cpu")

clipping_rate = torch.tensor(num_grads_clipped / num_eligible_grads, device="cpu")
all_metrics[f"grad_norm_exp_avg/{name}"] = grad_norm_exp_avg
return num_grads_clipped if collect_param_metrics else None

@torch.no_grad()
def _do_global_fixed_clipping(
self,
group: Dict[str, Any],
max_norm: float,
all_metrics: Dict[str, torch.Tensor],
collect_param_metrics: bool = True,
) -> Optional[int]:
"""
Do global fixed gradient clipping on a param group.
If ``collect_param_metrics`` is ``True`` this will return the total number of gradients clipped.
"""
device = get_default_device()
total_grad_norm = all_metrics["total_grad_norm"]
clip_coef = max_norm / (total_grad_norm.to(device) + 1e-6)
clip_coef_clamped = torch.clamp(clip_coef, max=1.0)
num_grads_clipped: Optional[int] = None
if collect_param_metrics:
all_metrics["clipping_rate"] = clipping_rate
return all_metrics
else:
return {"clipping_rate": clipping_rate}
# Can't avoid host-device sync here.
if clip_coef_clamped < 1.0:
num_grads_clipped = len(group["params"])
for p in group["params"]:
# Clip the gradients.
# Note that multiplying by the clamped coefficient is meaningless when it is
# equal to 1, but it avoids the host-device sync that would result from `if clip_coef_clamped < 1`.
if p.grad is not None:
# p.grad could be none for some ranks when using FSDP.
p.grad.detach().mul_(clip_coef_clamped.to(p.grad.device, p.grad.dtype))
return num_grads_clipped

def get_post_step_metrics(self, module: nn.Module) -> Dict[str, torch.Tensor]:
del module
Expand Down
33 changes: 16 additions & 17 deletions olmo/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -762,12 +762,6 @@ def train_batch(self, batch: Dict[str, Any]) -> Tuple[torch.Tensor, Optional[tor
# Run backward pass.
loss.backward()

# Check for nan.
if torch.isnan(ce_batch_loss):
raise ValueError("nan loss encountered")
if z_batch_loss is not None and torch.isnan(z_batch_loss):
raise ValueError("nan loss encountered")

return ce_batch_loss, z_batch_loss

def train_step(self, batch: Dict[str, Any], reduce_global_loss: bool = True) -> Dict[str, float]:
Expand All @@ -787,13 +781,19 @@ def train_step(self, batch: Dict[str, Any], reduce_global_loss: bool = True) ->
# Run forward-backward pass.
ce_batch_loss, z_batch_loss = self.train_batch(batch)

# Collect loss, potentially reducing over all ranks.
if reduce_global_loss:
dist.reduce(ce_batch_loss, 0)
ce_batch_loss.div_(get_world_size())
if z_batch_loss is not None:
dist.reduce(z_batch_loss, 0)
z_batch_loss.div_(get_world_size())

# Clip gradient norms and collect param/gradient/optim metrics.
should_log_optim_metrics_this_step = self.should_log_optim_metrics_this_step()
optim_metrics = self.optim.clip_grads_and_collect_metrics(
self.global_step, collect_param_metrics=should_log_optim_metrics_this_step
)
for key, value in optim_metrics.items():
metrics[f"optim/{key}"] = value.item()

# Adjust the learning rate.
for group in self.optim.param_groups:
Expand All @@ -804,20 +804,19 @@ def train_step(self, batch: Dict[str, Any], reduce_global_loss: bool = True) ->
# Optimizer step.
self.optim.step()

# Collect loss, potentially reducing over all ranks.
if reduce_global_loss:
dist.reduce(ce_batch_loss, 0)
ce_batch_loss.div_(get_world_size())
# TODO (dirkgr): If we did this much earlier, like, right after the forwards step, but then didn't
# call `.item()` for a long time, would it use laziness to interleave this reduce call with the backward step?
# Collect metrics and check for NaN loss.
# NOTE: this involves a bunch of host-device syncs so we wait until the last moment to do this.
if torch.isnan(ce_batch_loss):
raise ValueError("nan loss encountered")
if z_batch_loss is not None and torch.isnan(z_batch_loss):
raise ValueError("nan loss encountered")
for key, value in optim_metrics.items():
metrics[f"optim/{key}"] = value.item()
self.cur_train_loss = ce_batch_loss.item()
self.min_train_loss = min(self.min_train_loss, self.cur_train_loss)
metrics["train/CrossEntropyLoss"] = self.cur_train_loss
metrics["train/Perplexity"] = math.exp(self.cur_train_loss)
if z_batch_loss is not None:
if reduce_global_loss:
dist.reduce(z_batch_loss, 0)
z_batch_loss.div_(get_world_size())
metrics["train/ZLoss"] = z_batch_loss.item()

# Maybe collect post-step optimizer-specific metrics.
Expand Down

0 comments on commit d4744d0

Please sign in to comment.