Skip to content

Commit

Permalink
[MRG] Update pymanopt requirement and API for ot.dr (#443)
Browse files Browse the repository at this point in the history
* updayte pymanopt API step 1

* add realease information

* update requireents for tests on windows
  • Loading branch information
rflamary committed Mar 9, 2023
1 parent a6d5d75 commit 263a36f
Show file tree
Hide file tree
Showing 5 changed files with 21 additions and 20 deletions.
3 changes: 1 addition & 2 deletions .github/requirements_test_windows.txt
Expand Up @@ -3,8 +3,7 @@ scipy>=1.3
cython
matplotlib
autograd
pymanopt==0.2.4; python_version <'3'
pymanopt==0.2.6rc1; python_version >= '3'
pymanopt
cvxopt
scikit-learn
pytest
2 changes: 2 additions & 0 deletions RELEASES.md
Expand Up @@ -15,6 +15,8 @@
- Backend version of `ot.partial` and `ot.smooth` (PR #388)
- Added argument for warmstart of dual potentials in Sinkhorn-based methods in `ot.bregman` (PR #437)
- Add parameters method in `ot.da.SinkhornTransport` (PR #440)
- `ot.dr` now uses the new Pymanopt API and POT is compatible with current
Pymanopt (PR #443)

#### Closed issues

Expand Down
3 changes: 1 addition & 2 deletions docs/requirements_rtd.txt
Expand Up @@ -9,7 +9,6 @@ scipy>=1.0
cython
matplotlib
autograd
pymanopt==0.2.4; python_version <'3'
pymanopt; python_version >= '3'
pymanopt
cvxopt
scikit-learn
30 changes: 16 additions & 14 deletions ot/dr.py
Expand Up @@ -17,10 +17,10 @@

from scipy import linalg
import autograd.numpy as np
from pymanopt.function import Autograd
from pymanopt.manifolds import Stiefel
from pymanopt import Problem
from pymanopt.solvers import SteepestDescent, TrustRegions

import pymanopt
import pymanopt.manifolds
import pymanopt.optimizers


def dist(x1, x2):
Expand All @@ -38,8 +38,8 @@ def sinkhorn(w1, w2, M, reg, k):
ui = np.ones((M.shape[0],))
vi = np.ones((M.shape[1],))
for i in range(k):
vi = w2 / (np.dot(K.T, ui))
ui = w1 / (np.dot(K, vi))
vi = w2 / (np.dot(K.T, ui) + 1e-50)
ui = w1 / (np.dot(K, vi) + 1e-50)
G = ui.reshape((M.shape[0], 1)) * K * vi.reshape((1, M.shape[1]))
return G

Expand Down Expand Up @@ -222,7 +222,9 @@ def wda(X, y, p=2, reg=1, k=10, solver=None, sinkhorn_method='sinkhorn', maxiter
else:
regmean = np.ones((len(xc), len(xc)))

@Autograd
manifold = pymanopt.manifolds.Stiefel(d, p)

@pymanopt.function.autograd(manifold)
def cost(P):
# wda loss
loss_b = 0
Expand All @@ -243,21 +245,21 @@ def cost(P):
return loss_w / loss_b

# declare manifold and problem
manifold = Stiefel(d, p)
problem = Problem(manifold=manifold, cost=cost)

problem = pymanopt.Problem(manifold=manifold, cost=cost)

# declare solver and solve
if solver is None:
solver = SteepestDescent(maxiter=maxiter, logverbosity=verbose)
solver = pymanopt.optimizers.SteepestDescent(max_iterations=maxiter, log_verbosity=verbose)
elif solver in ['tr', 'TrustRegions']:
solver = TrustRegions(maxiter=maxiter, logverbosity=verbose)
solver = pymanopt.optimizers.TrustRegions(max_iterations=maxiter, log_verbosity=verbose)

Popt = solver.solve(problem, x=P0)
Popt = solver.run(problem, initial_point=P0)

def proj(X):
return (X - mx.reshape((1, -1))).dot(Popt)
return (X - mx.reshape((1, -1))).dot(Popt.point)

return Popt, proj
return Popt.point, proj


def projection_robust_wasserstein(X, Y, a, b, tau, U0=None, reg=0.1, k=2, stopThr=1e-3, maxiter=100, verbose=0):
Expand Down
3 changes: 1 addition & 2 deletions requirements.txt
Expand Up @@ -2,8 +2,7 @@ numpy>=1.20
scipy>=1.3
matplotlib
autograd
pymanopt==0.2.4; python_version <'3'
pymanopt==0.2.6rc1; python_version >= '3'
pymanopt
cvxopt
scikit-learn
torch
Expand Down

0 comments on commit 263a36f

Please sign in to comment.