Skip to content

Commit

Permalink
[optim/dynamo] shortcut adagrad with has_complex (pytorch#112722)
Browse files Browse the repository at this point in the history
Follow up to pytorch#110706, it was missed as depended on another fix

Pull Request resolved: pytorch#112722
Approved by: https://github.com/albanD
  • Loading branch information
jon-chuang authored and Skylion007 committed Nov 14, 2023
1 parent 8155646 commit f868397
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 14 deletions.
7 changes: 4 additions & 3 deletions torch/distributed/optim/functional_adagrad.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,11 +76,11 @@ def step(self, gradients: List[Optional[Tensor]]):
+ f"Gradients length: {len(gradients)}"
)

has_sparse_grad = False
has_sparse_grad, has_complex = False, False
for param, gradient in zip(self.param_group["params"], gradients):
if gradient is not None:
if gradient.is_sparse:
has_sparse_grad = True
has_sparse_grad |= gradient.is_sparse
has_complex |= torch.is_complex(param)
params_with_grad.append(param)
grads.append(gradient)
state = self.state[param]
Expand All @@ -100,4 +100,5 @@ def step(self, gradients: List[Optional[Tensor]]):
has_sparse_grad=has_sparse_grad,
foreach=self.foreach,
maximize=self.maximize,
has_complex=has_complex,
)
25 changes: 14 additions & 11 deletions torch/optim/adagrad.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch
from torch import Tensor

from .optimizer import (Optimizer, _use_grad_for_differentiable, _get_value,
from .optimizer import (Optimizer, _use_grad_for_differentiable, _get_value, _view_as_real,
_default_to_fused_or_foreach, _differentiable_doc, _foreach_doc, _maximize_doc)
from typing import List, Optional

Expand Down Expand Up @@ -82,18 +82,18 @@ def share_memory(self):
state["sum"].share_memory_()

def _init_group(self, group, params_with_grad, grads, state_sums, state_steps):
has_sparse_grad = False
has_sparse_grad, has_complex = False, False
for p in group["params"]:
if p.grad is not None:
if p.grad.is_sparse:
has_sparse_grad = True
has_sparse_grad |= p.grad.is_sparse
has_complex |= torch.is_complex(p)
params_with_grad.append(p)
grads.append(p.grad)
state = self.state[p]
state_sums.append(state["sum"])
state_steps.append(state["step"])

return has_sparse_grad
return has_sparse_grad, has_complex

@_use_grad_for_differentiable
def step(self, closure=None):
Expand All @@ -115,7 +115,7 @@ def step(self, closure=None):
state_sums = []
state_steps = []

has_sparse_grad = self._init_group(group, params_with_grad, grads, state_sums, state_steps)
has_sparse_grad, has_complex = self._init_group(group, params_with_grad, grads, state_sums, state_steps)

adagrad(
params_with_grad,
Expand All @@ -130,6 +130,7 @@ def step(self, closure=None):
foreach=group["foreach"],
maximize=group["maximize"],
differentiable=group["differentiable"],
has_complex=has_complex,
)

return loss
Expand Down Expand Up @@ -189,6 +190,7 @@ def adagrad(
has_sparse_grad: bool = None,
foreach: Optional[bool] = None,
differentiable: bool = False,
has_complex: bool = False,
*,
lr: float,
weight_decay: float,
Expand Down Expand Up @@ -229,6 +231,7 @@ def adagrad(
has_sparse_grad=has_sparse_grad,
maximize=maximize,
differentiable=differentiable,
has_complex=has_complex,
)


Expand All @@ -252,6 +255,7 @@ def _single_tensor_adagrad(
has_sparse_grad: bool,
maximize: bool,
differentiable: bool,
has_complex: bool,
):

for (param, grad, state_sum, step_t) in zip(params, grads, state_sums, state_steps):
Expand Down Expand Up @@ -310,6 +314,7 @@ def _multi_tensor_adagrad(
has_sparse_grad: bool,
maximize: bool,
differentiable: bool,
has_complex: bool,
):

assert not differentiable, "_foreach ops don't support autograd"
Expand All @@ -335,18 +340,16 @@ def _multi_tensor_adagrad(
has_sparse_grad=True,
maximize=False,
differentiable=differentiable,
has_complex=has_complex,
)
continue

if maximize:
device_grads = torch._foreach_neg(device_grads)

# Handle complex parameters
for i in range(len(device_params)):
if torch.is_complex(device_params[i]):
device_params[i] = torch.view_as_real(device_params[i])
device_grads[i] = torch.view_as_real(device_grads[i])
device_state_sums[i] = torch.view_as_real(device_state_sums[i])
if has_complex:
_view_as_real(device_params, device_grads, device_state_sums)

# Update steps
# If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over
Expand Down

0 comments on commit f868397

Please sign in to comment.