Skip to content

Commit

Permalink
Revert "ignore inf in loss tb summary, and report valid_rate"
Browse files Browse the repository at this point in the history
This reverts commit 72cf8fa.
  • Loading branch information
Le Horizon committed Jun 10, 2024
1 parent 305462e commit b75cdd7
Showing 1 changed file with 1 addition and 19 deletions.
20 changes: 1 addition & 19 deletions alf/algorithms/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1207,25 +1207,7 @@ def update_with_gradient(self,
loss_info.loss * weight)

loss_info = loss_info._replace(gns=gns)

new_fields_values = {}

def _mean_ignore_inf(path, loss):
inf_mask = torch.isinf(loss)
if inf_mask.any():
# populate the valid_rate field only when value can be inf.
new_fields_values[
path + '-valid_rate'] = 1.0 - inf_mask.float().mean()
if inf_mask.all():
return torch.tensor(np.float32(np.inf), device=loss.device)
else:
return loss[~inf_mask].mean()

loss_info = alf.nest.py_map_structure_with_path(
_mean_ignore_inf, loss_info)

for k, v in new_fields_values.items():
loss_info = alf.nest.set_field(loss_info, k, v)
loss_info = alf.nest.map_structure(torch.mean, loss_info)

return loss_info, all_params

Expand Down

0 comments on commit b75cdd7

Please sign in to comment.