Skip to content
This repository has been archived by the owner on Oct 9, 2024. It is now read-only.

Reduce mem footprint of get_latent_representation #84

Merged
merged 3 commits into from
Feb 14, 2024
Merged

Conversation

martinkim0
Copy link
Member

@martinkim0 martinkim0 commented Feb 7, 2024

Helps with the memory footprint of this method since embeddings are not transferred to CPU until all batches have been processed. Technically leads to an increase in runtime due to having a call to device_get per batch instead of a larger one at the end, since I believe jax.device_get blocks computation. I think this is a better alternative than running out of GPU memory, and the increase in runtime can be mitigated by increasing the batch size

@martinkim0
Copy link
Member Author

Tests currently failing due to an issue in scvi-tools, which has been fixed. Just need to make sure the CI here installs from main.

@canergen
Copy link
Member

canergen commented Feb 7, 2024

For another solution: In resolVI, I have a second batch size, which are macrobatches and every macrobatch is copied (and median etc is computed). We could do it here, but I imagine the time difference to be very low here and wouldn't do it.

Copy link
Collaborator

@ebezzi ebezzi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This solved the OOM issue, at least for my specific case. Approving!

src/scvi_v2/_model.py Outdated Show resolved Hide resolved
@martinkim0 martinkim0 merged commit d43af68 into main Feb 14, 2024
1 of 8 checks passed
@martinkim0 martinkim0 deleted the latent-repi branch February 14, 2024 21:28
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants