Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Memory leak issues with lax.scan #73

Closed
ericmjl opened this issue Sep 16, 2020 · 0 comments · Fixed by #74
Closed

Memory leak issues with lax.scan #73

ericmjl opened this issue Sep 16, 2020 · 0 comments · Fixed by #74

Comments

@ericmjl
Copy link
Collaborator

ericmjl commented Sep 16, 2020

Putting up this issue so that I can keep track of progress.

This is related to google/jax#3348, in which the OP followed a very similar pattern to what we have used in unirep:

        fx = partial(apply_fun_scan, p1)
        _, ht_new = lax.scan(fx, p2, inputs)  

I suspect we have to get rid of this pattern in order to avoid memory leaks with JAX. (I think this is what might have crashed a colleague's machine at work.) I thought of an idea in google/jax#3348 (comment), so maybe that could be a starting point for me to kickstart.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

1 participant