Skip to content

[MRG] Adding greenkhorn #66

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Sep 24, 2018
Merged

[MRG] Adding greenkhorn #66

merged 14 commits into from
Sep 24, 2018

Conversation

arakotom
Copy link
Collaborator

  • added greenkhorn algorithm in bregman.py
  • added novel option for resolution
  • modified unit test in test_variants

Copy link
Collaborator

@agramfort agramfort left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hope this helps

ot/bregman.py Outdated
>>> a=[.5,.5]
>>> b=[.5,.5]
>>> M=[[0.,1.],[1.,0.]]
>>> ot.sinkhorn(a,b,M,1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ot.greenkhorn

ot/bregman.py Outdated
m = b.shape[0]

# Next 3 lines equivalent to K= np.exp(-M/reg), but faster to compute
K = np.empty(M.shape, dtype=M.dtype)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

K = np.empty_like(M)

ot/bregman.py Outdated
np.divide(M, -reg, out=K)
np.exp(K, out=K)

u = np.ones(n)/n
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

np.full(n, 1. / n)

ot/bregman.py Outdated
np.exp(K, out=K)

u = np.ones(n)/n
v = np.ones(m)/m
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

np.full(m, 1. / m)

ot/bregman.py Outdated

u = np.ones(n)/n
v = np.ones(m)/m
G = np.diag(u)@K@np.diag(v)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use broadcasting to avoid filling diagonal matrices

G = u[:, np.newaxis] * K * v[, np.newaxis]

ot/bregman.py Outdated
G[:,i_2] = u*K[:,i_2]*v[i_2]
#aviol = (G@one_m - a)
#aviol_2 = (G.T@one_n - b)
viol = viol + ( -old_v + v[i_2])*K[:,i_2]*u
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

viol += ...


# check values
np.testing.assert_allclose(G0, Gs, atol=1e-05)
np.testing.assert_allclose(G0, Ges, atol=1e-05)
np.testing.assert_allclose(G0, Gerr)

np.testing.assert_allclose(G0, G_green, atol = 1e-32)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a pep8 checker would tell you but you should not put spaces around = in function signatures. It's to visually distinguish what is a function parameter from a variable assignment.


# check values
np.testing.assert_allclose(G0, Gs, atol=1e-05)
np.testing.assert_allclose(G0, Ges, atol=1e-05)
np.testing.assert_allclose(G0, Gerr)

np.testing.assert_allclose(G0, G_green, atol = 1e-32)
print(G0,G_green)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and you should always put a space after a ,

ot/bregman.py Outdated
one_n = np.ones(n)
one_m = np.ones(m)
viol = G@one_m - a
viol_2 = G.T@one_n - b
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here to allocate arrays of ones to compute sum of rows and columns. I would just use np.sum(..., axis=)

ot/bregman.py Outdated
log['u'] = u
log['v'] = v

while i < numItermax and stopThr_val > stopThr:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rather than using a while you could use a for loop. For optim solvers I tend to do:

for i in range(numItermax):
      ...
      if stopping condition satisfied do:
              break
else:
     print("Solver did not converge")

so you can easily print a message when you did not converge.

@rflamary rflamary changed the title adding greenkhorn Adding greenkhorn Sep 24, 2018
@rflamary rflamary changed the title Adding greenkhorn [MRG] Adding greenkhorn Sep 24, 2018
ot/bregman.py Outdated

"""

i = 0
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not needed

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

indeed, thougt I got rid of all of them...

@rflamary rflamary merged commit 22d310d into master Sep 24, 2018
@rflamary rflamary deleted the greenkhorn branch December 5, 2018 11:49
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants