New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
3D embedding tensor #565
Comments
Interesting. For now I think this is the best workaround: from pytorch_metric_learning.utils import common_functions as c_f
c_f.check_shapes = lambda x,y: None |
Thank you, @KevinMusgrave. I am using the class RelevanceLoss(nn.Module):
def __init__(self, params):
super(RelevanceLoss, self).__init__()
self.miner = RelevanceMiner(params.miner)
# Let Q and D be two batches of matrices (MxN), MaxSimDistance return a score matrix S,
# where S[i,j] represents the score between Q[i] and D[j]
self.criterion = losses.NTXentLoss(temperature=params.criterion.temperature, distance=MaxSimDistance())
def forward(self, query_idx, query_rpr, doc_idx, doc_rpr):
# query_rpr and doc_rpr is a tensor of shape (batch_size, sequence_length, hidden_size)
miner_outs = self.miner.mine(text_ids=query_idx, label_ids=doc_idx)
return self.criterion(query_rpr, None, miner_outs, doc_rpr, None) Even navigating over the NTXentLoss code, I couldn't figure out the best place to overwrite the |
It should work as long as you overwrite it before you use the loss function. See this notebook for an example: https://colab.research.google.com/drive/1jHw9dlFBCAd46CKSaAL4O7vtKQvzxgeG?usp=sharing (It crashes but it gets passed |
It seems to be working, thank you. |
Hello again @KevinMusgrave To clarify, does the suggested change have any impact on RAM usage? The loss is converging, but the RAM consumption grows with each training step, even though I'm not storing anything and all training step is done on GPU. |
No I don't see why it would affect RAM usage. |
In version 1.7.2, I have moved the class CustomDistance(BaseDistance):
def check_shapes(self, query_emb, ref_emb):
pass |
Hello,
I'm getting the
ValueError: embeddings must be a 2D tensor of shape (batch_size, embedding_size)
message because, indeed, my embedding is a 3D tensor. However, I've provided the distance function. So, is there a workaround for this shapes verification?The text was updated successfully, but these errors were encountered: