Skip to content

Commit

Permalink
fix: change counts to scaler tensors, fixes metaopt#70
Browse files Browse the repository at this point in the history
  • Loading branch information
XuehaiPan committed Sep 9, 2022
1 parent da00100 commit ac54272
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions torchopt/_src/transform.py
Expand Up @@ -161,7 +161,7 @@ def _scale_by_schedule(

def init_fn(params):
zero = tree_map( # count init
lambda t: torch.zeros(1, dtype=torch.int32, device=t.device), params
lambda t: torch.zeros(1, dtype=torch.int32, device=t.device).squeeze_(), params
)
return ScaleByScheduleState(count=zero)

Expand Down Expand Up @@ -299,7 +299,7 @@ def _scale_by_adam(

def init_fn(params):
zero = tree_map( # count init
lambda t: torch.zeros(1, dtype=torch.int32, device=t.device), params
lambda t: torch.zeros(1, dtype=torch.int32, device=t.device).squeeze_(), params
)
mu = tree_map( # first moment
lambda t: torch.zeros_like(t, requires_grad=moment_requires_grad), params
Expand Down Expand Up @@ -426,7 +426,7 @@ def update_fn(updates, state, *, params=None, inplace=True):

def init_fn(params):
zero = tree_map( # count init
lambda t: torch.zeros(1, dtype=torch.int32, device=t.device), params
lambda t: torch.zeros(1, dtype=torch.int32, device=t.device).squeeze_(), params
)
mu = tree_map( # first moment
lambda t: torch.zeros_like(t, requires_grad=moment_requires_grad), params
Expand Down

0 comments on commit ac54272

Please sign in to comment.