Skip to content

Commit

Permalink
Resolve inconsistency with tensor strides and optimizer updates (geoo…
Browse files Browse the repository at this point in the history
…pt#71)

* initial commit

* fix typo

* more efficient workaround

* fix some typos

* spacing in docs

* update changelog

* more intuitive name
  • Loading branch information
ferrine committed May 21, 2019
1 parent 113d567 commit 8ba6b88
Show file tree
Hide file tree
Showing 11 changed files with 50 additions and 24 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,4 @@ Deprecations
Bug Fixes
---------
* Make pickle work with ManifoldTensors (#47)
* Resolve inconsistency with tensor strides and optimizer updates (#71)
6 changes: 4 additions & 2 deletions geoopt/optim/radam.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from .tracing import create_traced_update
from ..tensor import ManifoldParameter, ManifoldTensor
from ..manifolds import Euclidean
from ..utils import copy_or_set_


class RiemannianAdam(OptimMixin, torch.optim.Adam):
Expand Down Expand Up @@ -189,7 +190,8 @@ def perform_step(
new_point, exp_avg_new = manifold.retr_transp(
point, exp_avg, u=direction, t=-step_size
)
point.set_(new_point)
# use copy only for user facing point
copy_or_set_(point, new_point)
exp_avg.set_(exp_avg_new)

def stabilize_group(self, group):
Expand All @@ -202,7 +204,7 @@ def stabilize_group(self, group):
continue
manifold = p.manifold
exp_avg = state["exp_avg"]
p.set_(manifold.projx(p))
copy_or_set_(p, manifold.projx(p))
exp_avg.set_(manifold.proju(p, exp_avg))

def _sanitize_group(self, group):
Expand Down
9 changes: 5 additions & 4 deletions geoopt/optim/rsgd.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from ..tensor import ManifoldParameter, ManifoldTensor
from .mixin import OptimMixin
from .tracing import create_traced_update

from ..utils import copy_or_set_

__all__ = ["RiemannianSGD"]

Expand Down Expand Up @@ -168,10 +168,11 @@ def perform_step(
point, momentum_buffer, u=grad, t=-lr
)
momentum_buffer.set_(new_momentum_buffer)
point.set_(new_point)
# use copy only for user facing point
copy_or_set_(point, new_point)
else:
new_point = manifold.retr(point, grad, -lr)
point.set_(new_point)
copy_or_set_(point, new_point)

def stabilize_group(self, group):
with torch.no_grad():
Expand All @@ -180,7 +181,7 @@ def stabilize_group(self, group):
continue
manifold = p.manifold
momentum = group["momentum"]
p.set_(manifold.projx(p))
copy_or_set_(p, manifold.projx(p))
if momentum > 0:
param_state = self.state[p]
if not param_state: # due to None grads
Expand Down
4 changes: 2 additions & 2 deletions geoopt/samplers/rhmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from geoopt.tensor import ManifoldParameter, ManifoldTensor
from geoopt.manifolds import Euclidean
from geoopt.samplers.base import Sampler

from ..utils import copy_or_set_

__all__ = ["RHMC"]

Expand Down Expand Up @@ -40,7 +40,7 @@ def _step(self, p, r, epsilon):

r.add_(epsilon * egrad2rgrad(p, p.grad))
p_, r_ = retr_transp(p, r, u=r, t=epsilon)
p.set_(p_)
copy_or_set_(p, p_)
r.set_(r_)

def step(self, closure):
Expand Down
9 changes: 4 additions & 5 deletions geoopt/samplers/rsgld.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from geoopt.tensor import ManifoldParameter, ManifoldTensor
from geoopt.manifolds import Euclidean
from geoopt.samplers.base import Sampler

from ..utils import copy_or_set_

__all__ = ["RSGLD"]

Expand Down Expand Up @@ -50,8 +50,8 @@ def step(self, closure):

n = torch.randn_like(p).mul_(math.sqrt(epsilon))
r = egrad2rgrad(p, 0.5 * epsilon * p.grad + n)

p.set_(retr(p, r, 1.0))
# use copy only for user facing point
copy_or_set_(p, retr(p, r, 1.0))
p.grad.zero_()

if not self.burnin:
Expand All @@ -66,5 +66,4 @@ def stabilize(self):
for p in group["params"]:
if not isinstance(p, (ManifoldParameter, ManifoldTensor)):
continue

p.set_(p.manifold.projx(p))
copy_or_set_(p, p.manifold.projx(p))
7 changes: 3 additions & 4 deletions geoopt/samplers/sgrhmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from geoopt.tensor import ManifoldParameter, ManifoldTensor
from geoopt.manifolds import Euclidean
from geoopt.samplers.base import Sampler

from ..utils import copy_or_set_

__all__ = ["SGRHMC"]

Expand Down Expand Up @@ -76,7 +76,7 @@ def step(self, closure):
v = self.state[p]["v"]

p_, v_ = retr_transp(p, v, u=v, t=1.0)
p.set_(p_)
copy_or_set_(p, p_)
v.set_(v_)

n = egrad2rgrad(p, torch.randn_like(v))
Expand All @@ -103,7 +103,6 @@ def stabilize(self):

manifold = p.manifold
v = self.state[p]["v"]

p.set_(manifold.projx(p))
copy_or_set_(p, manifold.projx(p))
# proj here is ok
v.set_(manifold.proju(p, v))
6 changes: 3 additions & 3 deletions geoopt/tensor.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import torch.nn
from .manifolds import Euclidean
from .docutils import insert_docs
from .utils import copy_or_set_

__all__ = ["ManifoldTensor", "ManifoldParameter"]

Expand All @@ -27,6 +28,7 @@ def __new__(cls, *args, manifold=Euclidean(), requires_grad=False, **kwargs):
instance.manifold = manifold
return instance

@torch.no_grad()
def proj_(self):
"""
Inplace projection to the manifold
Expand All @@ -36,9 +38,7 @@ def proj_(self):
tensor
same instance
"""
with torch.no_grad():
self.set_(self.manifold.projx(self.data))
return self
return copy_or_set_(self, self.manifold.projx(self))

@insert_docs(Euclidean.retr.__doc__, r"\s+x : .+\n.+", "")
def retr(self, u, t=1.0, order=None):
Expand Down
24 changes: 24 additions & 0 deletions geoopt/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
__all__ = "copy_or_set_"


def copy_or_set_(dest, source):
"""
A workaround to respect strides of :code:`dest` when copying :code:`source`
(https://github.com/geoopt/geoopt/issues/70)
Parameters
----------
dest : torch.Tensor
Destination tensor where to store new data
source : torch.Tensor
Source data to put in the new tensor
Returns
-------
dest
torch.Tensor, modified inplace
"""
if dest.stride() != source.stride():
return dest.copy_(source)
else:
return dest.set_(source)
3 changes: 1 addition & 2 deletions tests/test_adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def closure():
if (X - Xstar).norm() < 1e-5:
break
optim.step(closure)

assert X.is_contiguous()
np.testing.assert_allclose(X.data, Xstar, atol=1e-5, rtol=1e-5)
optim.load_state_dict(optim.state_dict())
optim.step(closure)
Expand All @@ -49,5 +49,4 @@ def closure():

for _ in range(2000):
optim.step(closure)

np.testing.assert_allclose(start.data, ideal, atol=1e-5, rtol=1e-5)
2 changes: 1 addition & 1 deletion tests/test_rhmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def forward(self):

points = np.asarray(points)
points = points[::20]

assert nd.x.is_contiguous()
np.testing.assert_allclose(mu.numpy(), points.mean(axis=0), atol=1e-1)
np.testing.assert_allclose(sigma.numpy(), points.std(axis=0), atol=1e-1)

Expand Down
3 changes: 2 additions & 1 deletion tests/test_rsgd.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def closure():
if (X - Xstar).norm() < 1e-5:
break
optim.step(closure)

assert X.is_contiguous()
np.testing.assert_allclose(X.data, Xstar, atol=1e-5)
optim.load_state_dict(optim.state_dict())
optim.step(closure)
Expand All @@ -56,5 +56,6 @@ def test_init_manifold():
opt.zero_grad()
opt.step()
assert not np.allclose(p0.data, p0old.data)
assert p0.is_contiguous()
np.testing.assert_allclose(p1.data, p1old.data)
np.testing.assert_allclose(p0.data, stiefel.projx(p0old.data), atol=1e-4)

0 comments on commit 8ba6b88

Please sign in to comment.