Skip to content

Commit

Permalink
remove overly restrictive checks for cudagraph (pytorch#80881)
Browse files Browse the repository at this point in the history
Finish fixing pytorch#80809
Pull Request resolved: pytorch#80881
Approved by: https://github.com/jbschlosser
  • Loading branch information
albanD authored and atalman committed Jul 21, 2022
1 parent 67ece03 commit 6090a66
Showing 1 changed file with 0 additions and 5 deletions.
5 changes: 0 additions & 5 deletions torch/optim/adamw.py
Expand Up @@ -255,8 +255,6 @@ def _single_tensor_adamw(params: List[Tensor],

if capturable:
assert param.is_cuda and step_t.is_cuda, "If capturable=True, params and state_steps must be CUDA tensors."
else:
assert not step_t.is_cuda, "If capturable=False, state_steps should not be CUDA tensors."

# update step
step_t += 1
Expand Down Expand Up @@ -334,9 +332,6 @@ def _multi_tensor_adamw(params: List[Tensor],
if capturable:
assert all(p.is_cuda and step.is_cuda for p, step in zip(params, state_steps)), \
"If capturable=True, params and state_steps must be CUDA tensors."
else:
assert all(not step.is_cuda for step in state_steps), \
"If capturable=False, state_steps should not be CUDA tensors."

if maximize:
grads = torch._foreach_neg(tuple(grads)) # type: ignore[assignment]
Expand Down

0 comments on commit 6090a66

Please sign in to comment.