diff --git a/nemo/core/optim/distributed_adam.py b/nemo/core/optim/distributed_adam.py index 32bd7e6c1154..f58fb7352c38 100644 --- a/nemo/core/optim/distributed_adam.py +++ b/nemo/core/optim/distributed_adam.py @@ -487,7 +487,7 @@ def _param_copy_fragments(self, fragments: Iterable[DistributedFusedAdam.Paramet for param in params: if is_float8tensor(param): param._reset_caches() - param.transpose(update_cache=True) + param.transpose_2d(cache=True) param._lazy_transpose_cache = True @torch.no_grad()