Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/enable l2 dense solver #29

Merged
merged 12 commits into from
Jun 19, 2023
Merged

Conversation

alexisthual
Copy link
Owner

This PR:

  • enables fitting fugw problems with an L2 norm using solver for dense solutions
  • implements a sparse solver for fugw problems with an L2 norm and enables it

Copy link
Collaborator

@bthirion bthirion left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good overall. Just gave a few comments.

src/fugw/solvers/dense.py Show resolved Hide resolved
src/fugw/solvers/dense.py Show resolved Hide resolved
src/fugw/solvers/utils.py Show resolved Hide resolved
src/fugw/solvers/utils.py Show resolved Hide resolved
@@ -938,18 +1077,96 @@ def elementwise_prod_fact_sparse(a, b, p):
)


def compute_approx_kl(p, q):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not really clear what approx_kl is supposed to be (ref ?)

Again, for publica functions add at elast 1-line docstring.

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed, this naming is misleading.
The idea is that what we are computing here is not exactly a KL divergence, because p and q are not necessarily probability measures.
@6Ulm and I suggest we replace _approx_ with unnormalized or generalized. What do you think?

Copy link
Collaborator

@6Ulm 6Ulm Jun 13, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To clarify, the true KL divergence between any 2 positive vectors is KL(p | q) = < p, log (p/q) > - mass(p) + mass(q). That's why at the beginning, we denoted the first term approx_kl for approximation of the KL.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 for unnormalized

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yaaay, unnormalized it is!

src/fugw/solvers/utils.py Show resolved Hide resolved
src/fugw/solvers/utils.py Show resolved Hide resolved
src/fugw/solvers/utils.py Show resolved Hide resolved
Copy link
Collaborator

@bthirion bthirion left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

Copy link
Collaborator

@pbarbarant pbarbarant left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, but this PR makes me think that we should refactor the global structure of the codebase at some point.

loss_wasserstein = csr_sum(
elementwise_prod_fact_sparse(K1, K2, pi + gamma)
loss_wasserstein = (
csr_sum(elementwise_prod_fact_sparse(K1, K2, pi + gamma)) / 2
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

@alexisthual alexisthual merged commit 4285baa into main Jun 19, 2023
8 checks passed
@alexisthual alexisthual deleted the feature/enable_l2_dense_solver branch June 19, 2023 20:46
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants