You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
With the self-supervised wrapper planned for version 2, the loss function can be used like in the paper by simply using the wrapper. So from a usability point of view, there's no reason to not have it extend BaseMetricLossFunction.
The text was updated successfully, but these errors were encountered:
In v2.0, for the sake of consistency I've converted it to have the same forward arguments (embeddings, labels, indices_tuple, ref_emb, ref_labels).
However, the only valid way to use it is still loss_fn(embeddings, ref_emb=ref_emb). Passing in the other arguments will raise an exception.
The reason is that I wasn't able to find a good way to compute the various sub-losses in a general way. For example, let's say instead of passing in ref_emb, we passed in labels = [0, 0, 0, 0, 1, 1, 1] for a batch of 7 embeddings.
The invariance loss is supposed to be between 2 "views" of the inputs. But how would that apply to the above labels? There are 4 views of the 0 label and 3 views of the 1 label.
The variance and covariance losses are supposed to be over 1 view, but how many views are there in the above set of labels?
If it turns out there's a clever way to generalize VICReg, I am open to implementing it. I've changed the forward arguments now so that changes in the future don't require a breaking change.
With the self-supervised wrapper planned for version 2, the loss function can be used like in the paper by simply using the wrapper. So from a usability point of view, there's no reason to not have it extend BaseMetricLossFunction.
The text was updated successfully, but these errors were encountered: