From 8ba6b88c2850b4e674491a27e35db6c0a5dea41d Mon Sep 17 00:00:00 2001 From: Maxim Kochurov Date: Tue, 21 May 2019 12:29:36 +0300 Subject: [PATCH] Resolve inconsistency with tensor strides and optimizer updates (#71) * initial commit * fix typo * more efficient workaround * fix some typos * spacing in docs * update changelog * more intuitive name --- CHANGELOG.rst | 1 + geoopt/optim/radam.py | 6 ++++-- geoopt/optim/rsgd.py | 9 +++++---- geoopt/samplers/rhmc.py | 4 ++-- geoopt/samplers/rsgld.py | 9 ++++----- geoopt/samplers/sgrhmc.py | 7 +++---- geoopt/tensor.py | 6 +++--- geoopt/utils.py | 24 ++++++++++++++++++++++++ tests/test_adam.py | 3 +-- tests/test_rhmc.py | 2 +- tests/test_rsgd.py | 3 ++- 11 files changed, 50 insertions(+), 24 deletions(-) create mode 100644 geoopt/utils.py diff --git a/CHANGELOG.rst b/CHANGELOG.rst index a672ac9d..d0f54e99 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -28,3 +28,4 @@ Deprecations Bug Fixes --------- * Make pickle work with ManifoldTensors (#47) +* Resolve inconsistency with tensor strides and optimizer updates (#71) diff --git a/geoopt/optim/radam.py b/geoopt/optim/radam.py index 9740bd8d..7d497b02 100644 --- a/geoopt/optim/radam.py +++ b/geoopt/optim/radam.py @@ -4,6 +4,7 @@ from .tracing import create_traced_update from ..tensor import ManifoldParameter, ManifoldTensor from ..manifolds import Euclidean +from ..utils import copy_or_set_ class RiemannianAdam(OptimMixin, torch.optim.Adam): @@ -189,7 +190,8 @@ def perform_step( new_point, exp_avg_new = manifold.retr_transp( point, exp_avg, u=direction, t=-step_size ) - point.set_(new_point) + # use copy only for user facing point + copy_or_set_(point, new_point) exp_avg.set_(exp_avg_new) def stabilize_group(self, group): @@ -202,7 +204,7 @@ def stabilize_group(self, group): continue manifold = p.manifold exp_avg = state["exp_avg"] - p.set_(manifold.projx(p)) + copy_or_set_(p, manifold.projx(p)) exp_avg.set_(manifold.proju(p, exp_avg)) def _sanitize_group(self, group): diff --git a/geoopt/optim/rsgd.py b/geoopt/optim/rsgd.py index 20c83fb4..6b106f08 100644 --- a/geoopt/optim/rsgd.py +++ b/geoopt/optim/rsgd.py @@ -3,7 +3,7 @@ from ..tensor import ManifoldParameter, ManifoldTensor from .mixin import OptimMixin from .tracing import create_traced_update - +from ..utils import copy_or_set_ __all__ = ["RiemannianSGD"] @@ -168,10 +168,11 @@ def perform_step( point, momentum_buffer, u=grad, t=-lr ) momentum_buffer.set_(new_momentum_buffer) - point.set_(new_point) + # use copy only for user facing point + copy_or_set_(point, new_point) else: new_point = manifold.retr(point, grad, -lr) - point.set_(new_point) + copy_or_set_(point, new_point) def stabilize_group(self, group): with torch.no_grad(): @@ -180,7 +181,7 @@ def stabilize_group(self, group): continue manifold = p.manifold momentum = group["momentum"] - p.set_(manifold.projx(p)) + copy_or_set_(p, manifold.projx(p)) if momentum > 0: param_state = self.state[p] if not param_state: # due to None grads diff --git a/geoopt/samplers/rhmc.py b/geoopt/samplers/rhmc.py index 5f9d33d1..58cd2b1d 100644 --- a/geoopt/samplers/rhmc.py +++ b/geoopt/samplers/rhmc.py @@ -6,7 +6,7 @@ from geoopt.tensor import ManifoldParameter, ManifoldTensor from geoopt.manifolds import Euclidean from geoopt.samplers.base import Sampler - +from ..utils import copy_or_set_ __all__ = ["RHMC"] @@ -40,7 +40,7 @@ def _step(self, p, r, epsilon): r.add_(epsilon * egrad2rgrad(p, p.grad)) p_, r_ = retr_transp(p, r, u=r, t=epsilon) - p.set_(p_) + copy_or_set_(p, p_) r.set_(r_) def step(self, closure): diff --git a/geoopt/samplers/rsgld.py b/geoopt/samplers/rsgld.py index f85be87e..68e82ae3 100644 --- a/geoopt/samplers/rsgld.py +++ b/geoopt/samplers/rsgld.py @@ -5,7 +5,7 @@ from geoopt.tensor import ManifoldParameter, ManifoldTensor from geoopt.manifolds import Euclidean from geoopt.samplers.base import Sampler - +from ..utils import copy_or_set_ __all__ = ["RSGLD"] @@ -50,8 +50,8 @@ def step(self, closure): n = torch.randn_like(p).mul_(math.sqrt(epsilon)) r = egrad2rgrad(p, 0.5 * epsilon * p.grad + n) - - p.set_(retr(p, r, 1.0)) + # use copy only for user facing point + copy_or_set_(p, retr(p, r, 1.0)) p.grad.zero_() if not self.burnin: @@ -66,5 +66,4 @@ def stabilize(self): for p in group["params"]: if not isinstance(p, (ManifoldParameter, ManifoldTensor)): continue - - p.set_(p.manifold.projx(p)) + copy_or_set_(p, p.manifold.projx(p)) diff --git a/geoopt/samplers/sgrhmc.py b/geoopt/samplers/sgrhmc.py index 8b8d1f07..0d9ff238 100644 --- a/geoopt/samplers/sgrhmc.py +++ b/geoopt/samplers/sgrhmc.py @@ -5,7 +5,7 @@ from geoopt.tensor import ManifoldParameter, ManifoldTensor from geoopt.manifolds import Euclidean from geoopt.samplers.base import Sampler - +from ..utils import copy_or_set_ __all__ = ["SGRHMC"] @@ -76,7 +76,7 @@ def step(self, closure): v = self.state[p]["v"] p_, v_ = retr_transp(p, v, u=v, t=1.0) - p.set_(p_) + copy_or_set_(p, p_) v.set_(v_) n = egrad2rgrad(p, torch.randn_like(v)) @@ -103,7 +103,6 @@ def stabilize(self): manifold = p.manifold v = self.state[p]["v"] - - p.set_(manifold.projx(p)) + copy_or_set_(p, manifold.projx(p)) # proj here is ok v.set_(manifold.proju(p, v)) diff --git a/geoopt/tensor.py b/geoopt/tensor.py index 335b28df..d0e3d805 100644 --- a/geoopt/tensor.py +++ b/geoopt/tensor.py @@ -1,6 +1,7 @@ import torch.nn from .manifolds import Euclidean from .docutils import insert_docs +from .utils import copy_or_set_ __all__ = ["ManifoldTensor", "ManifoldParameter"] @@ -27,6 +28,7 @@ def __new__(cls, *args, manifold=Euclidean(), requires_grad=False, **kwargs): instance.manifold = manifold return instance + @torch.no_grad() def proj_(self): """ Inplace projection to the manifold @@ -36,9 +38,7 @@ def proj_(self): tensor same instance """ - with torch.no_grad(): - self.set_(self.manifold.projx(self.data)) - return self + return copy_or_set_(self, self.manifold.projx(self)) @insert_docs(Euclidean.retr.__doc__, r"\s+x : .+\n.+", "") def retr(self, u, t=1.0, order=None): diff --git a/geoopt/utils.py b/geoopt/utils.py new file mode 100644 index 00000000..aec919e4 --- /dev/null +++ b/geoopt/utils.py @@ -0,0 +1,24 @@ +__all__ = "copy_or_set_" + + +def copy_or_set_(dest, source): + """ + A workaround to respect strides of :code:`dest` when copying :code:`source` + (https://github.com/geoopt/geoopt/issues/70) + + Parameters + ---------- + dest : torch.Tensor + Destination tensor where to store new data + source : torch.Tensor + Source data to put in the new tensor + + Returns + ------- + dest + torch.Tensor, modified inplace + """ + if dest.stride() != source.stride(): + return dest.copy_(source) + else: + return dest.set_(source) diff --git a/tests/test_adam.py b/tests/test_adam.py index 96a3159e..daa9841f 100644 --- a/tests/test_adam.py +++ b/tests/test_adam.py @@ -26,7 +26,7 @@ def closure(): if (X - Xstar).norm() < 1e-5: break optim.step(closure) - + assert X.is_contiguous() np.testing.assert_allclose(X.data, Xstar, atol=1e-5, rtol=1e-5) optim.load_state_dict(optim.state_dict()) optim.step(closure) @@ -49,5 +49,4 @@ def closure(): for _ in range(2000): optim.step(closure) - np.testing.assert_allclose(start.data, ideal, atol=1e-5, rtol=1e-5) diff --git a/tests/test_rhmc.py b/tests/test_rhmc.py index 2ce01c8f..7c7a25d4 100644 --- a/tests/test_rhmc.py +++ b/tests/test_rhmc.py @@ -116,7 +116,7 @@ def forward(self): points = np.asarray(points) points = points[::20] - + assert nd.x.is_contiguous() np.testing.assert_allclose(mu.numpy(), points.mean(axis=0), atol=1e-1) np.testing.assert_allclose(sigma.numpy(), points.std(axis=0), atol=1e-1) diff --git a/tests/test_rsgd.py b/tests/test_rsgd.py index 6f93c3b7..36db1dda 100644 --- a/tests/test_rsgd.py +++ b/tests/test_rsgd.py @@ -34,7 +34,7 @@ def closure(): if (X - Xstar).norm() < 1e-5: break optim.step(closure) - + assert X.is_contiguous() np.testing.assert_allclose(X.data, Xstar, atol=1e-5) optim.load_state_dict(optim.state_dict()) optim.step(closure) @@ -56,5 +56,6 @@ def test_init_manifold(): opt.zero_grad() opt.step() assert not np.allclose(p0.data, p0old.data) + assert p0.is_contiguous() np.testing.assert_allclose(p1.data, p1old.data) np.testing.assert_allclose(p0.data, stiefel.projx(p0old.data), atol=1e-4)