Skip to content

[MRG] Update pymanopt requirement and API for ot.dr #443

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Mar 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions .github/requirements_test_windows.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@ scipy>=1.3
cython
matplotlib
autograd
pymanopt==0.2.4; python_version <'3'
pymanopt==0.2.6rc1; python_version >= '3'
pymanopt
cvxopt
scikit-learn
pytest
2 changes: 2 additions & 0 deletions RELEASES.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
- Backend version of `ot.partial` and `ot.smooth` (PR #388)
- Added argument for warmstart of dual potentials in Sinkhorn-based methods in `ot.bregman` (PR #437)
- Add parameters method in `ot.da.SinkhornTransport` (PR #440)
- `ot.dr` now uses the new Pymanopt API and POT is compatible with current
Pymanopt (PR #443)

#### Closed issues

Expand Down
3 changes: 1 addition & 2 deletions docs/requirements_rtd.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ scipy>=1.0
cython
matplotlib
autograd
pymanopt==0.2.4; python_version <'3'
pymanopt; python_version >= '3'
pymanopt
cvxopt
scikit-learn
30 changes: 16 additions & 14 deletions ot/dr.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@

from scipy import linalg
import autograd.numpy as np
from pymanopt.function import Autograd
from pymanopt.manifolds import Stiefel
from pymanopt import Problem
from pymanopt.solvers import SteepestDescent, TrustRegions

import pymanopt
import pymanopt.manifolds
import pymanopt.optimizers


def dist(x1, x2):
Expand All @@ -38,8 +38,8 @@ def sinkhorn(w1, w2, M, reg, k):
ui = np.ones((M.shape[0],))
vi = np.ones((M.shape[1],))
for i in range(k):
vi = w2 / (np.dot(K.T, ui))
ui = w1 / (np.dot(K, vi))
vi = w2 / (np.dot(K.T, ui) + 1e-50)
ui = w1 / (np.dot(K, vi) + 1e-50)
G = ui.reshape((M.shape[0], 1)) * K * vi.reshape((1, M.shape[1]))
return G

Expand Down Expand Up @@ -222,7 +222,9 @@ def wda(X, y, p=2, reg=1, k=10, solver=None, sinkhorn_method='sinkhorn', maxiter
else:
regmean = np.ones((len(xc), len(xc)))

@Autograd
manifold = pymanopt.manifolds.Stiefel(d, p)

@pymanopt.function.autograd(manifold)
def cost(P):
# wda loss
loss_b = 0
Expand All @@ -243,21 +245,21 @@ def cost(P):
return loss_w / loss_b

# declare manifold and problem
manifold = Stiefel(d, p)
problem = Problem(manifold=manifold, cost=cost)

problem = pymanopt.Problem(manifold=manifold, cost=cost)

# declare solver and solve
if solver is None:
solver = SteepestDescent(maxiter=maxiter, logverbosity=verbose)
solver = pymanopt.optimizers.SteepestDescent(max_iterations=maxiter, log_verbosity=verbose)
elif solver in ['tr', 'TrustRegions']:
solver = TrustRegions(maxiter=maxiter, logverbosity=verbose)
solver = pymanopt.optimizers.TrustRegions(max_iterations=maxiter, log_verbosity=verbose)

Popt = solver.solve(problem, x=P0)
Popt = solver.run(problem, initial_point=P0)

def proj(X):
return (X - mx.reshape((1, -1))).dot(Popt)
return (X - mx.reshape((1, -1))).dot(Popt.point)

return Popt, proj
return Popt.point, proj


def projection_robust_wasserstein(X, Y, a, b, tau, U0=None, reg=0.1, k=2, stopThr=1e-3, maxiter=100, verbose=0):
Expand Down
3 changes: 1 addition & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@ numpy>=1.20
scipy>=1.3
matplotlib
autograd
pymanopt==0.2.4; python_version <'3'
pymanopt==0.2.6rc1; python_version >= '3'
pymanopt
cvxopt
scikit-learn
torch
Expand Down