Skip to content

Conversation

@charleshofer
Copy link
Collaborator

Daily sync with upstream

justinjfu and others added 30 commits November 21, 2024 23:34
This will allow us to test TPU compatibility with jaxlib at head. Also, enable v4 runners as they are now online.

PiperOrigin-RevId: 699155667
…s. This also requires adding `reduce_p` sharding rule

PiperOrigin-RevId: 699244204
…Arg` to do the transfers instead of `jit` to avoid blocking.

PiperOrigin-RevId: 699336402
…ead of `backend`

PiperOrigin-RevId: 699338495
The previous version of the code was too complicated and failed to account
for the fact that in an op that broadcasts there does not necessarily exist
and operand that has the output shape.

Reading through the code now, it's a bit weird that we allow implicit
broadcasting of operands with splat layouts, but not any other operands.
But I guess that's a thing to implement later.

PiperOrigin-RevId: 699983045
PiperOrigin-RevId: 699991540
…th too (which will be deleted after jax 0.4.36 release)

PiperOrigin-RevId: 700006186
…conflict resolution is now complete.

PiperOrigin-RevId: 700042542
…s we have downstream caches and a cpp cache too.

If you drop out of cpp cache, things are going to be slow anyways.

PiperOrigin-RevId: 700052522
PiperOrigin-RevId: 700052796
This commit reworks the JAX build CLI to a subcommand based approach where CLI use cases are now defined as subcommands. Two subcommands are defined: build and requirements_update. "build" is to be used when wanting to build a JAX wheel package. "requirements_update" is to be used when wanting to update the requirements_lock.txt files. The new structure offers a clear and organized CLI that enables users to execute specific build tasks without having to navigate through a monolithic script.

Each subcommand has specific arguments that apply to its respective build process. In addition, arguments are separated into groups to achieve a cleaner separation and improves the readability when the CLI subcommands are run with `--help`. It also makes it clear as to which parts of the build they affect. E.g: CUDA arguments only apply to CUDA builds, ROCM arguments only apply to ROCM builds, etc. This reduces the complexity and the potential for errors during the build process. Segregating functionalities into distinct subcommands also simplifies the code which should help with the maintenance and future extensions.

There is also a transition from using `subprocess.check_output` to `asyncio.create_subprocess_shell` for executing the build commands which allows for streaming logs and helps in showing the build progress in real time.

Usage:
* Building `jaxlib`:
```
python build/build.py build --wheels=jaxlib --python_version=3.10
```
* Building `jax-cuda-plugin`:
```
python build/build.py build --wheels=jax-cuda-plugin --cuda_version=12.3.2 --cudnn_version=9.1.1 --python_version=3.10
```
* Building multiple packages:
```
python build/build.py build --wheels=jaxlib,jax-cuda-plugin,jax-cuda-pjrt --cuda_version=12.3.2 --cudnn_version=9.1.1 --python_version=3.10
```
* Building `jax-rocm-pjrt`:
```
python build/build.py build --wheels=jax-rocm-pjrt --rocm_version=60 --rocm_path=/path/to/rocm
```
* Using a local XLA path:
```
python build/build.py build --wheels=jaxlib --local_xla_path=/path/to/xla
```
* Updating requirements_lock.txt files:
```
python build/build.py requirements_update --python_version=3.10
```

For more details on each argument and to see available options, run:
```
python build/build.py build --help
```
or
```
python build/build.py requirements_update --help
```

PiperOrigin-RevId: 700075411
This commit introduces new CI scripts and environment files for building JAX artifacts. It makes use of the artifact envs inside the "ci/envs/build_artifacts" folder to control the build behavior. For e.g: for building jaxlib, we will need to run `./ci/build_artifacts.sh ./ci/envs/build_artifacts/jaxlib.env` from the JAX GitHub root.

PiperOrigin-RevId: 700104283
apaszke and others added 23 commits November 28, 2024 08:35
…nc_copy + grid fixes

This change removes the need to flatten the batch dimension into sequence dimensions
in the flash attention kernel. The critical thing here is the observation that we can
in fact collapse all squeezed dimension into a single one in the TMA descriptor, letting
us reduce its rank when necessary.

Doing this also uncovered some issues with how we were handling the grid in Pallas:MGPU
lowering, which I've fixed.

PiperOrigin-RevId: 701035277
Both are very important for FlashAttention and both were poorly mapped to PTX.
For exp, we really do not care about denormals when running in approximate mode,
since they would produce results so close to 1 that it really doesn't matter.
For max, LLVM ended up generating a while bunch of comparisons and selects and
failed to take advantage of the max instructions present in GPUs.

Both of those changes _significantly_ improve the performance of Mosaic attention
kernels for heads smaller than 256 (when the pointwise part dominates the execution
time). In one example I looked at, the utilization jumps from 55% to 64%.

PiperOrigin-RevId: 701042779
…ce of ints is passed to make_mesh or create_device_mesh

Reverts a158e02

PiperOrigin-RevId: 701045239
…cc-cu12_12-6-85

PiperOrigin-RevId: 701143135
…paths-upstream

PiperOrigin-RevId: 701143534
Using FMAs can significantly increase the ALU throughput and only increases
the precision. We use this capability to reduce the number of operations
needed to evaluate the softmax part of attention.

PiperOrigin-RevId: 701226007
…ading argument is splat.

Also adapted the test to catch a possible regression. The issue appeared in >2 operands.

PiperOrigin-RevId: 701306731
This commit adds the new CI scripts for running Pytests. It makes use of the pytest envs inside the "ci/envs/run_tests" folder to control the build behavior. For e.g: for running the GPU tests with Pytest, we will need to run `./ci/run_pytest.sh ./ci/envs/run_tests/pytest_gpu.env`. Note that Pytests need JAX wheels to be installed on the system to work. The `install_wheels_locally.sh` script installs these wheels in CI builds.

PiperOrigin-RevId: 701331411
* Make tree_util.tree_flatten_with_path and tree_map_with_path APIs to be C++-based, to speed up the pytree flattening.

* Moves all the key classes down to C++ level, while keeping the APIs unchanged.
  * Known small caveats: they are no longer Python dataclasses, and pattern matching might make pytype unhappy.

* Registered defaultdict and ordereddict via the keypath API now.

PiperOrigin-RevId: 701613257
… planner

Instead of only allowing a fixed set of layouts that we've hand verified as
bank-conflict free, we now simulate the transactions performed within each
warp and verify that no bank conflicts happen. If we detect that the simple
schedule does not work out, we attempt to partition the threads in a warp
into two groups and stagger the transfers in a way that lets us avoid conflicts.

This allows us to match the hand-designed transfer schedule I wrote for 32-bit
types, and even generalizes it to more cases automatically (e.g. swizzle=32).

PiperOrigin-RevId: 701919158
They started failing after we allowed LLVM to perform contractions of
adds and muls, but the difference is tiny.

PiperOrigin-RevId: 701961845
@charleshofer charleshofer merged commit 0b2038e into rocm-main Dec 2, 2024
7 checks passed
@charleshofer charleshofer deleted the ci-daily-sync-02-12-2024 branch December 2, 2024 22:56
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.