forked from jax-ml/jax
-
Notifications
You must be signed in to change notification settings - Fork 5
CI: 12/02/24 upstream sync #170
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
PiperOrigin-RevId: 699057658
This will allow us to test TPU compatibility with jaxlib at head. Also, enable v4 runners as they are now online. PiperOrigin-RevId: 699155667
PiperOrigin-RevId: 699193349
PiperOrigin-RevId: 699226058
…s. This also requires adding `reduce_p` sharding rule PiperOrigin-RevId: 699244204
http://github.com/openxla/xla/commit/0564969ba385bfc895baad8f64879236bfbc717b. PiperOrigin-RevId: 699295115
PiperOrigin-RevId: 699304829
…Arg` to do the transfers instead of `jit` to avoid blocking. PiperOrigin-RevId: 699336402
…ead of `backend` PiperOrigin-RevId: 699338495
PiperOrigin-RevId: 699408419
http://github.com/openxla/xla/commit/90af2896ab4992ff14a1cd2a75ce02e43f46c090. PiperOrigin-RevId: 699545393
http://github.com/openxla/xla/commit/40d457a268baf95e42cd95709dedef70c0ea2994. PiperOrigin-RevId: 699768724
PiperOrigin-RevId: 699919048
The previous version of the code was too complicated and failed to account for the fact that in an op that broadcasts there does not necessarily exist and operand that has the output shape. Reading through the code now, it's a bit weird that we allow implicit broadcasting of operands with splat layouts, but not any other operands. But I guess that's a thing to implement later. PiperOrigin-RevId: 699983045
PiperOrigin-RevId: 699991540
…th too (which will be deleted after jax 0.4.36 release) PiperOrigin-RevId: 700006186
PiperOrigin-RevId: 700029576
…conflict resolution is now complete. PiperOrigin-RevId: 700042542
…s we have downstream caches and a cpp cache too. If you drop out of cpp cache, things are going to be slow anyways. PiperOrigin-RevId: 700052522
PiperOrigin-RevId: 700052530
PiperOrigin-RevId: 700052796
PiperOrigin-RevId: 700053383
This commit reworks the JAX build CLI to a subcommand based approach where CLI use cases are now defined as subcommands. Two subcommands are defined: build and requirements_update. "build" is to be used when wanting to build a JAX wheel package. "requirements_update" is to be used when wanting to update the requirements_lock.txt files. The new structure offers a clear and organized CLI that enables users to execute specific build tasks without having to navigate through a monolithic script. Each subcommand has specific arguments that apply to its respective build process. In addition, arguments are separated into groups to achieve a cleaner separation and improves the readability when the CLI subcommands are run with `--help`. It also makes it clear as to which parts of the build they affect. E.g: CUDA arguments only apply to CUDA builds, ROCM arguments only apply to ROCM builds, etc. This reduces the complexity and the potential for errors during the build process. Segregating functionalities into distinct subcommands also simplifies the code which should help with the maintenance and future extensions. There is also a transition from using `subprocess.check_output` to `asyncio.create_subprocess_shell` for executing the build commands which allows for streaming logs and helps in showing the build progress in real time. Usage: * Building `jaxlib`: ``` python build/build.py build --wheels=jaxlib --python_version=3.10 ``` * Building `jax-cuda-plugin`: ``` python build/build.py build --wheels=jax-cuda-plugin --cuda_version=12.3.2 --cudnn_version=9.1.1 --python_version=3.10 ``` * Building multiple packages: ``` python build/build.py build --wheels=jaxlib,jax-cuda-plugin,jax-cuda-pjrt --cuda_version=12.3.2 --cudnn_version=9.1.1 --python_version=3.10 ``` * Building `jax-rocm-pjrt`: ``` python build/build.py build --wheels=jax-rocm-pjrt --rocm_version=60 --rocm_path=/path/to/rocm ``` * Using a local XLA path: ``` python build/build.py build --wheels=jaxlib --local_xla_path=/path/to/xla ``` * Updating requirements_lock.txt files: ``` python build/build.py requirements_update --python_version=3.10 ``` For more details on each argument and to see available options, run: ``` python build/build.py build --help ``` or ``` python build/build.py requirements_update --help ``` PiperOrigin-RevId: 700075411
PiperOrigin-RevId: 700093685
This commit introduces new CI scripts and environment files for building JAX artifacts. It makes use of the artifact envs inside the "ci/envs/build_artifacts" folder to control the build behavior. For e.g: for building jaxlib, we will need to run `./ci/build_artifacts.sh ./ci/envs/build_artifacts/jaxlib.env` from the JAX GitHub root. PiperOrigin-RevId: 700104283
…nc_copy + grid fixes This change removes the need to flatten the batch dimension into sequence dimensions in the flash attention kernel. The critical thing here is the observation that we can in fact collapse all squeezed dimension into a single one in the TMA descriptor, letting us reduce its rank when necessary. Doing this also uncovered some issues with how we were handling the grid in Pallas:MGPU lowering, which I've fixed. PiperOrigin-RevId: 701035277
Both are very important for FlashAttention and both were poorly mapped to PTX. For exp, we really do not care about denormals when running in approximate mode, since they would produce results so close to 1 that it really doesn't matter. For max, LLVM ended up generating a while bunch of comparisons and selects and failed to take advantage of the max instructions present in GPUs. Both of those changes _significantly_ improve the performance of Mosaic attention kernels for heads smaller than 256 (when the pointwise part dominates the execution time). In one example I looked at, the utilization jumps from 55% to 64%. PiperOrigin-RevId: 701042779
…ce of ints is passed to make_mesh or create_device_mesh Reverts a158e02 PiperOrigin-RevId: 701045239
http://github.com/openxla/xla/commit/fc80c5576b71c986fbd4505a59826f7d433878bc. PiperOrigin-RevId: 701110365
PiperOrigin-RevId: 701142951
…cc-cu12_12-6-85 PiperOrigin-RevId: 701143135
PiperOrigin-RevId: 701143242
…paths-upstream PiperOrigin-RevId: 701143534
Using FMAs can significantly increase the ALU throughput and only increases the precision. We use this capability to reduce the number of operations needed to evaluate the softmax part of attention. PiperOrigin-RevId: 701226007
…ading argument is splat. Also adapted the test to catch a possible regression. The issue appeared in >2 operands. PiperOrigin-RevId: 701306731
PiperOrigin-RevId: 701307324
This commit adds the new CI scripts for running Pytests. It makes use of the pytest envs inside the "ci/envs/run_tests" folder to control the build behavior. For e.g: for running the GPU tests with Pytest, we will need to run `./ci/run_pytest.sh ./ci/envs/run_tests/pytest_gpu.env`. Note that Pytests need JAX wheels to be installed on the system to work. The `install_wheels_locally.sh` script installs these wheels in CI builds. PiperOrigin-RevId: 701331411
http://github.com/openxla/xla/commit/479fb21237319d091ee93e86619c8d4d88bda079. PiperOrigin-RevId: 701368225
http://github.com/openxla/xla/commit/20d4636c743e53f070612d6b4c6ebd03b2b28bf5. PiperOrigin-RevId: 701562320
* Make tree_util.tree_flatten_with_path and tree_map_with_path APIs to be C++-based, to speed up the pytree flattening. * Moves all the key classes down to C++ level, while keeping the APIs unchanged. * Known small caveats: they are no longer Python dataclasses, and pattern matching might make pytype unhappy. * Registered defaultdict and ordereddict via the keypath API now. PiperOrigin-RevId: 701613257
http://github.com/openxla/xla/commit/41e12cc0247edf4ffb1569f2a25c61cec924c755. PiperOrigin-RevId: 701766499
… planner Instead of only allowing a fixed set of layouts that we've hand verified as bank-conflict free, we now simulate the transactions performed within each warp and verify that no bank conflicts happen. If we detect that the simple schedule does not work out, we attempt to partition the threads in a warp into two groups and stagger the transfers in a way that lets us avoid conflicts. This allows us to match the hand-designed transfer schedule I wrote for 32-bit types, and even generalizes it to more cases automatically (e.g. swizzle=32). PiperOrigin-RevId: 701919158
…pstream PiperOrigin-RevId: 701959495
…types list. PiperOrigin-RevId: 701961663
They started failing after we allowed LLVM to perform contractions of adds and muls, but the difference is tiny. PiperOrigin-RevId: 701961845
JehandadKhan
approved these changes
Dec 2, 2024
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Daily sync with upstream