Skip to content

Commit

Permalink
Disable grouping by dtype and device if compiling (pytorch#102771)
Browse files Browse the repository at this point in the history
Disable grouping if we are compiling, this happens during lowering
Pull Request resolved: pytorch#102771
Approved by: https://github.com/janeyx99
  • Loading branch information
mlazos authored and alimoezzi committed Jun 3, 2023
1 parent d32999c commit e626afe
Show file tree
Hide file tree
Showing 12 changed files with 29 additions and 25 deletions.
3 changes: 1 addition & 2 deletions torch/optim/adadelta.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

from .optimizer import (Optimizer, _use_grad_for_differentiable, _default_to_fused_or_foreach,
_differentiable_doc, _foreach_doc, _maximize_doc)
from torch.utils._foreach_utils import _group_tensors_by_device_and_dtype
from typing import List, Optional

__all__ = ["Adadelta", "adadelta"]
Expand Down Expand Up @@ -276,7 +275,7 @@ def _multi_tensor_adadelta(
if len(params) == 0:
return

grouped_tensors = _group_tensors_by_device_and_dtype([params, grads, square_avgs, acc_deltas])
grouped_tensors = Optimizer._group_tensors_by_device_and_dtype([params, grads, square_avgs, acc_deltas])
for device_params, device_grads, device_square_avgs, device_acc_deltas in grouped_tensors.values():
if maximize:
device_grads = torch._foreach_neg(device_grads)
Expand Down
3 changes: 1 addition & 2 deletions torch/optim/adagrad.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

from .optimizer import (Optimizer, _use_grad_for_differentiable, _get_value,
_default_to_fused_or_foreach, _differentiable_doc, _foreach_doc, _maximize_doc)
from torch.utils._foreach_utils import _group_tensors_by_device_and_dtype
from typing import List, Optional

__all__ = ["Adagrad", "adagrad"]
Expand Down Expand Up @@ -321,7 +320,7 @@ def _multi_tensor_adagrad(
if len(params) == 0:
return

grouped_tensorlists = _group_tensors_by_device_and_dtype([params, grads, state_sums, state_steps])
grouped_tensorlists = Optimizer._group_tensors_by_device_and_dtype([params, grads, state_sums, state_steps])
for device_params, device_grads, device_state_sums, device_state_steps in grouped_tensorlists.values():

if maximize:
Expand Down
7 changes: 4 additions & 3 deletions torch/optim/adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from .optimizer import (Optimizer, _use_grad_for_differentiable, _get_value, _stack_if_compiling,
_dispatch_sqrt, _default_to_fused_or_foreach, _capturable_doc,
_differentiable_doc, _foreach_doc, _fused_doc, _maximize_doc)
from torch.utils._foreach_utils import _group_tensors_by_device_and_dtype

__all__ = ['Adam', 'adam']

Expand Down Expand Up @@ -424,7 +423,8 @@ def _multi_tensor_adam(params: List[Tensor],

assert not differentiable, "_foreach ops don't support autograd"

grouped_tensors = _group_tensors_by_device_and_dtype([params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps])
grouped_tensors = Optimizer._group_tensors_by_device_and_dtype(
[params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps])
for (device_params, device_grads, device_exp_avgs, device_exp_avg_sqs,
device_max_exp_avg_sqs, device_state_steps) in grouped_tensors.values():

Expand Down Expand Up @@ -532,7 +532,8 @@ def _fused_adam(
) -> None:
grad_scale_dict = {grad_scale.device: grad_scale} if grad_scale is not None else None
found_inf_dict = {found_inf.device: found_inf} if found_inf is not None else None
grouped_tensors = _group_tensors_by_device_and_dtype([params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps])
grouped_tensors = Optimizer._group_tensors_by_device_and_dtype(
[params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps])
for (device, dtype) in grouped_tensors:
(
device_params,
Expand Down
3 changes: 1 addition & 2 deletions torch/optim/adamax.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from .optimizer import (Optimizer, _use_grad_for_differentiable, _get_value, _stack_if_compiling,
_default_to_fused_or_foreach, _differentiable_doc, _maximize_doc, _foreach_doc)
from typing import List, Optional
from torch.utils._foreach_utils import _group_tensors_by_device_and_dtype

__all__ = ["Adamax", "adamax"]

Expand Down Expand Up @@ -305,7 +304,7 @@ def _multi_tensor_adamax(
if len(params) == 0:
return

grouped_tensors = _group_tensors_by_device_and_dtype([params, grads, exp_avgs, exp_infs, state_steps])
grouped_tensors = Optimizer._group_tensors_by_device_and_dtype([params, grads, exp_avgs, exp_infs, state_steps])
for grouped_params, grouped_grads, grouped_exp_avgs, grouped_exp_infs, grouped_state_steps in grouped_tensors.values():
if maximize:
grouped_grads = torch._foreach_neg(grouped_grads)
Expand Down
6 changes: 3 additions & 3 deletions torch/optim/adamw.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
_stack_if_compiling, _capturable_doc, _differentiable_doc, _foreach_doc,
_fused_doc, _maximize_doc, _default_to_fused_or_foreach)
from typing import List, Optional
from torch.utils._foreach_utils import _group_tensors_by_device_and_dtype

__all__ = ["AdamW", "adamw"]

Expand Down Expand Up @@ -476,7 +475,7 @@ def _multi_tensor_adamw(

assert grad_scale is None and found_inf is None

grouped_tensors = _group_tensors_by_device_and_dtype([
grouped_tensors = Optimizer._group_tensors_by_device_and_dtype([
params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps])
for (device_params, device_grads, device_exp_avgs, device_exp_avg_sqs,
device_max_exp_avg_sqs, device_state_steps) in grouped_tensors.values():
Expand Down Expand Up @@ -593,7 +592,8 @@ def _fused_adamw(
raise RuntimeError("_fused_adamw is not differentiable")
grad_scale_dict = {grad_scale.device: grad_scale} if grad_scale is not None else None
found_inf_dict = {found_inf.device: found_inf} if found_inf is not None else None
grouped_tensors = _group_tensors_by_device_and_dtype([params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps])
grouped_tensors = Optimizer._group_tensors_by_device_and_dtype(
[params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps])
for (device, dtype) in grouped_tensors:
(
device_params,
Expand Down
3 changes: 1 addition & 2 deletions torch/optim/asgd.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from .optimizer import (Optimizer, _use_grad_for_differentiable, _get_value, _default_to_fused_or_foreach,
_differentiable_doc, _foreach_doc, _maximize_doc)
from torch._utils import is_compiling
from torch.utils._foreach_utils import _group_tensors_by_device_and_dtype
from typing import List, Optional

__all__ = ["ASGD", "asgd"]
Expand Down Expand Up @@ -294,7 +293,7 @@ def _multi_tensor_asgd(

assert not differentiable, "_foreach ops don't support autograd"

grouped_tensors = _group_tensors_by_device_and_dtype([params, grads, axs, mus, etas, state_steps])
grouped_tensors = Optimizer._group_tensors_by_device_and_dtype([params, grads, axs, mus, etas, state_steps])
for (grouped_params, grouped_grads, grouped_axs, grouped_mus,
grouped_etas, grouped_state_steps) in grouped_tensors.values():
if maximize:
Expand Down
4 changes: 1 addition & 3 deletions torch/optim/nadam.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from .optimizer import (Optimizer, _use_grad_for_differentiable, _get_value, _dispatch_sqrt, _stack_if_compiling,
_differentiable_doc, _foreach_doc, _default_to_fused_or_foreach)
from typing import List, Optional
from torch.utils._foreach_utils import _group_tensors_by_device_and_dtype

__all__ = ['NAdam', 'nadam']

Expand Down Expand Up @@ -291,8 +290,7 @@ def _multi_tensor_nadam(params: List[Tensor],

assert not differentiable, "_foreach ops don't support autograd"

grouped_tensors = _group_tensors_by_device_and_dtype([params, grads, exp_avgs, exp_avg_sqs,
mu_products, state_steps])
grouped_tensors = Optimizer._group_tensors_by_device_and_dtype([params, grads, exp_avgs, exp_avg_sqs, mu_products, state_steps])
for (grouped_params, grouped_grads, grouped_exp_avgs,
grouped_exp_avg_sqs, grouped_mu_products, grouped_state_steps) in grouped_tensors.values():

Expand Down
13 changes: 13 additions & 0 deletions torch/optim/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import torch.utils.hooks as hooks
from torch.utils.hooks import RemovableHandle
from torch._utils import is_compiling
from torch.utils._foreach_utils import _group_tensors_by_device_and_dtype

__all__ = ['Optimizer', 'register_optimizer_step_pre_hook', 'register_optimizer_step_post_hook']
_global_optimizer_pre_hooks: Dict[int, Callable] = OrderedDict()
Expand Down Expand Up @@ -288,6 +289,18 @@ def wrapper(*args, **kwargs):

return wrapper

@staticmethod
def _group_tensors_by_device_and_dtype(tensorlistlist, with_indices=False):
"""Groups a list of lists of tensors by device and dtype.
Skips this step if we are compiling since this will occur during inductor lowering."""
if is_compiling():
if with_indices:
indices = list(range(len(tensorlistlist[0])))
tensorlistlist.append(indices)
return {(None, None): tensorlistlist}
else:
return _group_tensors_by_device_and_dtype(tensorlistlist, with_indices)

def _patch_step_function(self):
self._zero_grad_profile_name = "Optimizer.zero_grad#{}.zero_grad".format(self.__class__.__name__)
hooked = getattr(self.__class__.step, "hooked", None)
Expand Down
3 changes: 1 addition & 2 deletions torch/optim/radam.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from .optimizer import (Optimizer, _use_grad_for_differentiable, _get_value, _dispatch_sqrt, _stack_if_compiling,
_default_to_fused_or_foreach, _differentiable_doc, _foreach_doc)
from typing import List, Optional
from torch.utils._foreach_utils import _group_tensors_by_device_and_dtype

__all__ = ["RAdam", "radam"]

Expand Down Expand Up @@ -315,7 +314,7 @@ def _multi_tensor_radam(

assert not differentiable, "_foreach ops don't support autograd"

grouped_tensors = _group_tensors_by_device_and_dtype([params, grads, exp_avgs, exp_avg_sqs, state_steps])
grouped_tensors = Optimizer._group_tensors_by_device_and_dtype([params, grads, exp_avgs, exp_avg_sqs, state_steps])
for grouped_params, grouped_grads, grouped_exp_avgs, grouped_exp_avg_sqs, grouped_state_steps in grouped_tensors.values():
# Update steps
torch._foreach_add_(grouped_state_steps, 1)
Expand Down
3 changes: 1 addition & 2 deletions torch/optim/rmsprop.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from .optimizer import (Optimizer, _default_to_fused_or_foreach, _use_grad_for_differentiable,
_differentiable_doc, _foreach_doc, _maximize_doc)
from typing import List, Optional
from torch.utils._foreach_utils import _group_tensors_by_device_and_dtype

__all__ = ["RMSprop", "rmsprop"]

Expand Down Expand Up @@ -326,7 +325,7 @@ def _multi_tensor_rmsprop(

assert not differentiable, "_foreach ops don't support autograd"

grouped_tensors = _group_tensors_by_device_and_dtype([params, grads, square_avgs, grad_avgs, momentum_buffer_list])
grouped_tensors = Optimizer._group_tensors_by_device_and_dtype([params, grads, square_avgs, grad_avgs, momentum_buffer_list])
for (grouped_params, grouped_grads, grouped_square_avgs, grouped_grad_avgs,
grouped_momentum_buffer_list) in grouped_tensors.values():
if maximize:
Expand Down
3 changes: 1 addition & 2 deletions torch/optim/rprop.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from .optimizer import (Optimizer, _use_grad_for_differentiable, _default_to_fused_or_foreach,
_differentiable_doc, _foreach_doc, _maximize_doc)
from typing import List, Optional
from torch.utils._foreach_utils import _group_tensors_by_device_and_dtype

__all__ = ["Rprop", "rprop"]

Expand Down Expand Up @@ -281,7 +280,7 @@ def _multi_tensor_rprop(

assert not differentiable, "_foreach ops don't support autograd"

grouped_tensors = _group_tensors_by_device_and_dtype([params, grads, prevs, step_sizes])
grouped_tensors = Optimizer._group_tensors_by_device_and_dtype([params, grads, prevs, step_sizes])
for grouped_params, grouped_grads, grouped_prevs, grouped_step_sizes in grouped_tensors.values():
# Handle complex params
def _view_complex_as_real(tensor_list):
Expand Down
3 changes: 1 addition & 2 deletions torch/optim/sgd.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from .optimizer import (Optimizer, required, _use_grad_for_differentiable, _default_to_fused_or_foreach,
_differentiable_doc, _foreach_doc, _maximize_doc)
from typing import List, Optional
from torch.utils._foreach_utils import _group_tensors_by_device_and_dtype

__all__ = ['SGD', 'sgd']

Expand Down Expand Up @@ -280,7 +279,7 @@ def _multi_tensor_sgd(params: List[Tensor],
if len(params) == 0:
return

grouped_tensors = _group_tensors_by_device_and_dtype([params, grads, momentum_buffer_list], with_indices=True)
grouped_tensors = Optimizer._group_tensors_by_device_and_dtype([params, grads, momentum_buffer_list], with_indices=True)
for device_params, device_grads, device_momentum_buffer_list, indices in grouped_tensors.values():
device_has_sparse_grad = any(grad.is_sparse for grad in device_grads)

Expand Down

0 comments on commit e626afe

Please sign in to comment.