Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add CachedGISTEmbedLoss #2592

Merged
merged 11 commits into from
Apr 16, 2024
Merged

Add CachedGISTEmbedLoss #2592

merged 11 commits into from
Apr 16, 2024

Conversation

JacksonCakes
Copy link
Contributor

As per discussed in #2583, this is the implementation of GradCache version of GISTEmbedLoss to reduce memory usage while maintaining performance levels comparable to those of GISTEmbedLoss.

@tomaarsen
Copy link
Collaborator

Hello!

Thanks a bunch for this! I've been testing this alongside MNRL, CMNRL and normal GIST yesterday and today. It seems to roughly match their performance, though I'm using some simple training & testing data.
I'll assist with finishing up this PR today, which mostly involves making sure that CachedGISTEmbedLoss is mentioned in the documentation.

  • Tom Aarsen

@JacksonCakes
Copy link
Contributor Author

JacksonCakes commented Apr 15, 2024

Hi! Sorry, I just realized that calculating similarity for the entire batch in the guide, instead of using mini-batches, also adds extra memory usage. That's not ideal since I can handle up to a batch size of 4096 with CMNR loss, but only 1024 for this in my own test. I've made some adjustments to the guided part to address this.

@tomaarsen
Copy link
Collaborator

tomaarsen commented Apr 15, 2024

I think there's a small issue with 3208e61, the loss is always 0 it seems. These are some of my logs:
image

(Green is 3208e61, Salmon is 5c054da)

Update: The evaluation performance does go up over time, so I suspect that the loss is not actually 0, it's just VERY small (such that it rounds to 0.00 in my logs). That said, being so small likely results in underflow/inaccuracies, and the evaluation loss is notably worse than before:
image

Additionally, the memory usage is actually a tad higher:
image

But that might also be somehow related to the 0 loss.

  • Tom Aarsen

@JacksonCakes
Copy link
Contributor Author

Ah my bad, i think it's because I did not properly offset the diagonal part resulting in the guide mask always select diagonally from the first element, which easily causing -inf in the scores and potentially leading to weird loss.

@tomaarsen
Copy link
Collaborator

After further experimentation, I can confirm that 3215c06 matches the performance (losses, evaluations) of 5c054da exactly, but the former allows for much higher batch sizes. E.g. I was able to set my batch size to an absurd 50k.

Great job!

I'll work on the documentation things that I had mentioned.

  • Tom Aarsen

@JacksonCakes
Copy link
Contributor Author

Good to know! Thank you for your effort! Happy to help :)

@tomaarsen
Copy link
Collaborator

I think this is ready! I'll merge it now, so it can be included in tomorrow/Thursday's release. Do feel free to let me know if there were things that you think are missing/suboptimal.

Thanks for your time/work on this!

cc @avsolatorio, @kwang2049

  • Tom Aarsen

@tomaarsen tomaarsen merged commit 38ab549 into UKPLab:master Apr 16, 2024
9 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants