-
Notifications
You must be signed in to change notification settings - Fork 2.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Compilation takes way more time in new version #21250
Comments
Well, we'd love to dig into why. However, we'll need a way to reproduce the problem. Given it is apparently due to XLA compilation, it's possible an HLO dump will be enough for us. You can grab one by setting |
Thank you for taking a look into this. I also checked with 0.4.27 and the compilation tile is the same as 0.4.28, maybe this is helpful to you. I hope I did this correctly and I am not wasting your time, here is the zip with two folders for the two different versions: xla_dump.zip I took a brief look and in the 0.4.26 version the last file that was generated was Moreover, in the 0.4.28 dump you can see all the files generated after the initial compilation and apparently this happens in a |
Any updates on this? I appreciate any help thank you! |
Taking a look at different They are approximately |
Checking all possible options for: jax.config.update("jax_threefry_partitionable", True)
jax.config.update("jax_default_prng_impl", "threefry2x32") Still the same long time. |
Description
Hey Jax team thanks for the amazing work you are all doing with this project.
I am working with Jax for quite sometime now and ever since the 0.4.27 update I have been getting x5 to x10 more compilation time on my code without changes.
Its quite a bit of code and I don't know how you guys can reproduce what I am seeing but I will just report some numbers on my machine using
JAX_LOG_COMPILES
:Jax Version 0.4.26:
Jax Version 0.4.27 and 0.4.28:
Again this is the same code with exactly the same dependencies only difference is Jax's version. Also noticed that recompilation is triggered for parts of my code in version 0.4.28.
What do you guys think?
System info (python version, jaxlib version, accelerator, etc.)
The text was updated successfully, but these errors were encountered: