-
Notifications
You must be signed in to change notification settings - Fork 528
[MRG] Adding greenkhorn #66
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
arakotom
commented
Sep 24, 2018
- added greenkhorn algorithm in bregman.py
- added novel option for resolution
- modified unit test in test_variants
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hope this helps
ot/bregman.py
Outdated
>>> a=[.5,.5] | ||
>>> b=[.5,.5] | ||
>>> M=[[0.,1.],[1.,0.]] | ||
>>> ot.sinkhorn(a,b,M,1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ot.greenkhorn
ot/bregman.py
Outdated
m = b.shape[0] | ||
|
||
# Next 3 lines equivalent to K= np.exp(-M/reg), but faster to compute | ||
K = np.empty(M.shape, dtype=M.dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
K = np.empty_like(M)
ot/bregman.py
Outdated
np.divide(M, -reg, out=K) | ||
np.exp(K, out=K) | ||
|
||
u = np.ones(n)/n |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
np.full(n, 1. / n)
ot/bregman.py
Outdated
np.exp(K, out=K) | ||
|
||
u = np.ones(n)/n | ||
v = np.ones(m)/m |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
np.full(m, 1. / m)
ot/bregman.py
Outdated
|
||
u = np.ones(n)/n | ||
v = np.ones(m)/m | ||
G = np.diag(u)@K@np.diag(v) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use broadcasting to avoid filling diagonal matrices
G = u[:, np.newaxis] * K * v[, np.newaxis]
ot/bregman.py
Outdated
G[:,i_2] = u*K[:,i_2]*v[i_2] | ||
#aviol = (G@one_m - a) | ||
#aviol_2 = (G.T@one_n - b) | ||
viol = viol + ( -old_v + v[i_2])*K[:,i_2]*u |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
viol += ...
test/test_bregman.py
Outdated
|
||
# check values | ||
np.testing.assert_allclose(G0, Gs, atol=1e-05) | ||
np.testing.assert_allclose(G0, Ges, atol=1e-05) | ||
np.testing.assert_allclose(G0, Gerr) | ||
|
||
np.testing.assert_allclose(G0, G_green, atol = 1e-32) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
a pep8 checker would tell you but you should not put spaces around = in function signatures. It's to visually distinguish what is a function parameter from a variable assignment.
test/test_bregman.py
Outdated
|
||
# check values | ||
np.testing.assert_allclose(G0, Gs, atol=1e-05) | ||
np.testing.assert_allclose(G0, Ges, atol=1e-05) | ||
np.testing.assert_allclose(G0, Gerr) | ||
|
||
np.testing.assert_allclose(G0, G_green, atol = 1e-32) | ||
print(G0,G_green) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
and you should always put a space after a ,
ot/bregman.py
Outdated
one_n = np.ones(n) | ||
one_m = np.ones(m) | ||
viol = G@one_m - a | ||
viol_2 = G.T@one_n - b |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
here to allocate arrays of ones to compute sum of rows and columns. I would just use np.sum(..., axis=)
ot/bregman.py
Outdated
log['u'] = u | ||
log['v'] = v | ||
|
||
while i < numItermax and stopThr_val > stopThr: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
rather than using a while you could use a for loop. For optim solvers I tend to do:
for i in range(numItermax):
...
if stopping condition satisfied do:
break
else:
print("Solver did not converge")
so you can easily print a message when you did not converge.
ot/bregman.py
Outdated
|
||
""" | ||
|
||
i = 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not needed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
indeed, thougt I got rid of all of them...