-
Notifications
You must be signed in to change notification settings - Fork 6
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feature/enable l2 dense solver #29
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good overall. Just gave a few comments.
src/fugw/solvers/utils.py
Outdated
@@ -938,18 +1077,96 @@ def elementwise_prod_fact_sparse(a, b, p): | |||
) | |||
|
|||
|
|||
def compute_approx_kl(p, q): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not really clear what approx_kl is supposed to be (ref ?)
Again, for publica functions add at elast 1-line docstring.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Indeed, this naming is misleading.
The idea is that what we are computing here is not exactly a KL divergence, because p and q are not necessarily probability measures.
@6Ulm and I suggest we replace _approx_
with unnormalized
or generalized
. What do you think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To clarify, the true KL divergence between any 2 positive vectors is KL(p | q) = < p, log (p/q) > - mass(p) + mass(q)
. That's why at the beginning, we denoted the first term approx_kl
for approximation of the KL.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1 for unnormalized
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yaaay, unnormalized it is!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, but this PR makes me think that we should refactor the global structure of the codebase at some point.
loss_wasserstein = csr_sum( | ||
elementwise_prod_fact_sparse(K1, K2, pi + gamma) | ||
loss_wasserstein = ( | ||
csr_sum(elementwise_prod_fact_sparse(K1, K2, pi + gamma)) / 2 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍
This PR: