Skip to content

Commit

Permalink
Let PS_weights with same device as parameter
Browse files Browse the repository at this point in the history
  • Loading branch information
bichengying committed May 31, 2020
1 parent 6358ba9 commit 27edc87
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion bluefog/torch/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,7 +476,7 @@ def _register_window(self):
raise KeyError(
"Cannot find parameter {} in the _parameter_names dictionary".format(name))

ps_weights = torch.Tensor([1.0]).to(p.data.dtype)
ps_weights = torch.Tensor([1.0]).to(p.data.dtype).to(p.data.device)
self._named_ps_weights[name] = ps_weights
# If do not modify in the C level, it is inevitable to copy
# the parameter once in the cat ops.
Expand Down

0 comments on commit 27edc87

Please sign in to comment.