Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make VICRegLoss use BaseMetricLossFunction #560

Closed
KevinMusgrave opened this issue Dec 24, 2022 · 2 comments
Closed

Make VICRegLoss use BaseMetricLossFunction #560

KevinMusgrave opened this issue Dec 24, 2022 · 2 comments
Labels
enhancement New feature or request
Milestone

Comments

@KevinMusgrave
Copy link
Owner

KevinMusgrave commented Dec 24, 2022

With the self-supervised wrapper planned for version 2, the loss function can be used like in the paper by simply using the wrapper. So from a usability point of view, there's no reason to not have it extend BaseMetricLossFunction.

@KevinMusgrave KevinMusgrave added the enhancement New feature or request label Dec 24, 2022
@KevinMusgrave KevinMusgrave added this to the v2.0 milestone Dec 24, 2022
@TKassis
Copy link

TKassis commented Dec 24, 2022

This is great! Yeah, there really is no clear reason why the VICRegLoss can't follow the convention of the other losses.

@KevinMusgrave
Copy link
Owner Author

KevinMusgrave commented Jan 29, 2023

In v2.0, for the sake of consistency I've converted it to have the same forward arguments (embeddings, labels, indices_tuple, ref_emb, ref_labels).

However, the only valid way to use it is still loss_fn(embeddings, ref_emb=ref_emb). Passing in the other arguments will raise an exception.

The reason is that I wasn't able to find a good way to compute the various sub-losses in a general way. For example, let's say instead of passing in ref_emb, we passed in labels = [0, 0, 0, 0, 1, 1, 1] for a batch of 7 embeddings.

  • The invariance loss is supposed to be between 2 "views" of the inputs. But how would that apply to the above labels? There are 4 views of the 0 label and 3 views of the 1 label.
  • The variance and covariance losses are supposed to be over 1 view, but how many views are there in the above set of labels?

If it turns out there's a clever way to generalize VICReg, I am open to implementing it. I've changed the forward arguments now so that changes in the future don't require a breaking change.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants