Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Compilation takes way more time in new version #21250

Open
stergiosba opened this issue May 15, 2024 · 5 comments
Open

Compilation takes way more time in new version #21250

stergiosba opened this issue May 15, 2024 · 5 comments
Labels
bug Something isn't working needs info More information is required to diagnose & prioritize the issue. NVIDIA GPU Issues specific to NVIDIA GPUs XLA

Comments

@stergiosba
Copy link

stergiosba commented May 15, 2024

Description

Hey Jax team thanks for the amazing work you are all doing with this project.

I am working with Jax for quite sometime now and ever since the 0.4.27 update I have been getting x5 to x10 more compilation time on my code without changes.

Its quite a bit of code and I don't know how you guys can reproduce what I am seeing but I will just report some numbers on my machine using JAX_LOG_COMPILES :

Jax Version 0.4.26:

Finished jaxpr to MLIR module conversion jit(train) in 2.1959640979766846 sec
Finished XLA compilation of jit(train) in 21.94405436515808 sec

Jax Version 0.4.27 and 0.4.28:

Finished jaxpr to MLIR module conversion jit(train) in 2.4314699172973633 sec
Finished XLA compilation of jit(train) in 97.583487033844 sec

Again this is the same code with exactly the same dependencies only difference is Jax's version. Also noticed that recompilation is triggered for parts of my code in version 0.4.28.

What do you guys think?

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.28
jaxlib: 0.4.28
numpy:  1.26.4
python: 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0]
jax.devices (1 total, 1 local): [cuda(id=0)]
process_count: 1
platform: uname_result(system='Linux', node='apollo', release='6.5.0-26-generic', version='#26~22.04.1-Ubuntu SMP PREEMPT_DYNAMIC Tue Mar 12 10:22:43 UTC 2', machine='x86_64')


$ nvidia-smi
Wed May 15 15:53:39 2024       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA GeForce RTX 3070 Ti     Off |   00000000:01:00.0 Off |                  N/A |
| 33%   42C    P2             71W /  290W |     904MiB /   8192MiB |     19%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                                                         
+-----------------------------------------------------------------------------------------+
| Processes:                                                                              |
|  GPU   GI   CI        PID   Type   Process name                              GPU Memory |
|        ID   ID                                                               Usage      |
|=========================================================================================|
|    0   N/A  N/A      1850      G   /usr/lib/xorg/Xorg                            365MiB |
|    0   N/A  N/A      2091      G   /usr/bin/gnome-shell                           70MiB |
|    0   N/A  N/A    622325      G   ...irefox/4090/usr/lib/firefox/firefox          0MiB |
|    0   N/A  N/A    650168      G   ...erProcess --variations-seed-version         59MiB |
|    0   N/A  N/A    841800      C   python                                        160MiB |
+-----------------------------------------------------------------------------------------+
@stergiosba stergiosba added the bug Something isn't working label May 15, 2024
@hawkinsp
Copy link
Member

Well, we'd love to dig into why. However, we'll need a way to reproduce the problem.

Given it is apparently due to XLA compilation, it's possible an HLO dump will be enough for us. You can grab one by setting XLA_FLAGS=--xla_dump_to=/somewhere, zipping up somewhere and attaching it to this bug. Or you can give a self-contained runnable Python repro. Up to you.

@hawkinsp hawkinsp added XLA needs info More information is required to diagnose & prioritize the issue. NVIDIA GPU Issues specific to NVIDIA GPUs labels May 15, 2024
@stergiosba
Copy link
Author

stergiosba commented May 15, 2024

Thank you for taking a look into this.

I also checked with 0.4.27 and the compilation tile is the same as 0.4.28, maybe this is helpful to you.

I hope I did this correctly and I am not wasting your time, here is the zip with two folders for the two different versions: xla_dump.zip

I took a brief look and in the 0.4.26 version the last file that was generated was module_0042.jit_train.ptx. This file was significantly longer in the 0.4.28 version (~66000 vs ~99000 lines), however it is unclear to me what this means.

Moreover, in the 0.4.28 dump you can see all the files generated after the initial compilation and apparently this happens in a callback that I have for reporting training progress. (why would this be triggered in the new version..)

@stergiosba
Copy link
Author

Any updates on this? I appreciate any help thank you!

@stergiosba
Copy link
Author

Taking a look at different module_0042.jit_train.sm_8.6_gpu_after_optimizations.txt files it looks like the new version creates some fused_add_xor operations that do not exist in the old version.

They are approximately 15000 lines of operations. I don't know if that helps but nearly all of them refer to _threefry_split and uniform so they have to do with jax.random.

@stergiosba
Copy link
Author

stergiosba commented May 29, 2024

Checking all possible options for:

jax.config.update("jax_threefry_partitionable", True)
jax.config.update("jax_default_prng_impl", "threefry2x32")

Still the same long time.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working needs info More information is required to diagnose & prioritize the issue. NVIDIA GPU Issues specific to NVIDIA GPUs XLA
Projects
None yet
Development

No branches or pull requests

2 participants