Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
101 commits
Select commit Hold shift + click to select a range
73fa0f4
[Pallas] Deprecate dictionary compiler_params in favor of dataclass.
justinjfu Nov 22, 2024
34a2f0c
Add a jaxlib at head build to the cloud-tpu-ci-nightly workflow
nitins17 Nov 22, 2024
c0811c9
Adds coverage for spmd-axisname-filtering in shard_map transpose.
jkr26 Nov 22, 2024
7635605
Use with_spec where possible to clean up the code a bit
yashk2810 Nov 22, 2024
21f8885
[sharding_in_types] Make argmax and argmin work with sharding_in_type…
yashk2810 Nov 22, 2024
b1d1dcf
Add linearization rule for pjit_p
dougalm Nov 22, 2024
9f6dbef
Update XLA dependency to use revision
Google-ML-Automation Nov 22, 2024
030ee4a
Merge pull request #25070 from jax-ml:pjit-lin-rule
Google-ML-Automation Nov 22, 2024
8699f5d
When host local inputs on all hosts are the same, use `_DeferredShard…
yashk2810 Nov 23, 2024
b259fde
Fix member access to xla backend. The correct member is `client` inst…
Google-ML-Automation Nov 23, 2024
e53ff2c
[Mosaic][Easy] - Wire up kernel names to MLIR dump
Google-ML-Automation Nov 23, 2024
4d8751b
Update XLA dependency to use revision
Google-ML-Automation Nov 23, 2024
b372ce4
Update XLA dependency to use revision
Google-ML-Automation Nov 24, 2024
69e3f0d
[pallas:mosaic_gpu] Add test for FragmentedArray.bitcast.
petebu Nov 25, 2024
84a9cba
Refactor FFI examples to consolidate several examples into one submod…
dfm Nov 21, 2024
914600a
[Mosaic GPU] Simplify logic for pointwise splat operands
apaszke Nov 25, 2024
e8934b9
[ROCm] Add rocm version information
Ruturaj4 Nov 25, 2024
aa05dc0
Automated Code Change
Google-ML-Automation Nov 25, 2024
c35f8b2
Add abstract mesh context manager to trace_context in the fallback pa…
yashk2810 Nov 25, 2024
9866372
[cuda] Bump nvidia-cuda-nvcc-cu12 dependency to 12.6.85
gspschmid Nov 25, 2024
bb1024f
[SDY] enable `cpu_shardy` for JAX shard_alike test.
Varcho Nov 25, 2024
066859e
[SDY] Enable `test_pjit_array_multi_input_multi_output` since Shardy …
Varcho Nov 25, 2024
84dc9ba
Update ROCm scripts to match new build.py usage
nitins17 Nov 25, 2024
deab6fb
Remove _pjit_lower_cached cache. We can simplify the caching of jit a…
yashk2810 Nov 25, 2024
107bc96
[Mosaic GPU] Support batch dimensions in FA3 MGPU kernel.
justinjfu Nov 25, 2024
95029ab
drop compute capability check
Google-ML-Automation Nov 25, 2024
f22bafa
[SDY] remove TODO for enabling Layouts for Shardy post cl/697715276.
Varcho Nov 25, 2024
6761512
Re-factor build CLI to a subcommand based approach
nitins17 Nov 25, 2024
788f493
Merge pull request #25041 from dfm:ffi-example-refactor
Google-ML-Automation Nov 25, 2024
f7e9f62
Add new CI scripts for building JAX artifacts
nitins17 Nov 25, 2024
ebea435
Update XLA dependency to use revision
Google-ML-Automation Nov 25, 2024
ef7df1a
[pallas_mgpu] Allow trees (eg tuples) to be returned from cond_p expr…
cperivol Nov 26, 2024
c5dc980
[mgpu/pallas_mgpu] Pointwise tanh support
cperivol Nov 26, 2024
59e13f8
Add sharding argument to reshape since it also takes a `shape` argume…
yashk2810 Nov 26, 2024
627debc
Create a `null_mesh_context` internal context manager to handle null …
yashk2810 Nov 26, 2024
f828f2d
[mgpu] Pointwise min
cperivol Nov 26, 2024
024e331
Merge pull request #25084 from ROCm:ci_rocm_version
Google-ML-Automation Nov 26, 2024
16a5607
Use xla_extension_version instead of jaxlib_version
Google-ML-Automation Nov 26, 2024
b6566c8
[mosaic_gpu] Fixed unbounded recursion in `FragmentedArray._pointwise`
superbobry Nov 26, 2024
231967f
[AutoPGLE] Explicitly ignore host callback pointers
Google-ML-Automation Nov 26, 2024
dc11d40
[Pallas TPU] Better error message for lowering `sp.broadcast_to_p`
ayaka14732 Nov 26, 2024
92e18e6
[AutoPGLE] Fix pgle test after removing pjit cache.
Google-ML-Automation Nov 26, 2024
e453fa1
Update XLA dependency to use revision
belitskiy Nov 26, 2024
6763fcf
Fix a weird interaction with `set_local` and empty tuples passed to it.
yashk2810 Nov 26, 2024
d30ec2b
[ROCm] fix jax and wheelhouse relative paths
Ruturaj4 Oct 22, 2024
694de6b
[ROCm] Change run_multi_gpu set opts
Ruturaj4 Nov 26, 2024
bbaec6e
[JAX] Add Python binding for building a colocated Python program
hyeontaek Nov 26, 2024
8df2766
Add argument to override base docker in dockerfile
Ruturaj4 Nov 26, 2024
3d80632
Update http to https in amd artifactory url.
Ruturaj4 Nov 26, 2024
9c42379
Update XLA dependency to use revision
Google-ML-Automation Nov 26, 2024
10fdee3
Move `tsl/platform/{build_config,build_config_root,rules_cc}.bzl` to …
ddunl Nov 26, 2024
afcef67
Install git before actions/checkout
nitins17 Nov 27, 2024
c6866d0
Add a check for return codes of `executor.run` so that we propagate e…
nitins17 Nov 27, 2024
1372669
Add new CI script to run Bazel GPU (non-RBE) jobs
nitins17 Nov 27, 2024
0d2dfea
Add a private `set_mesh` API to enter into sharding_in_types mode. Th…
yashk2810 Nov 27, 2024
47d1960
Update the render documentation job to use the new self-hosted runners
nitins17 Nov 27, 2024
7a2070e
[Mosaic:TPU] Enable broadcast from 1-D vectors
tlongeri Nov 27, 2024
7f14de0
[mosaic_gpu] Warmup before measuring the running time in `profiler.me…
superbobry Nov 27, 2024
03b6945
Integrate LLVM at llvm/llvm-project@b214ca82daee
d0k Nov 27, 2024
f3acfa9
[mgpu] FragentedArray.foreach() can now optionally return a new array
cperivol Nov 27, 2024
8477580
[mgpu pallas] Layout iota operation.
cperivol Nov 27, 2024
d449f12
Fix early exiting when building multiple wheels
nitins17 Nov 27, 2024
df8ecb9
[mgpu] Debug print for mlir vectors.
cperivol Nov 27, 2024
04a4f9b
Merge pull request #25096 from nitins17:update-rocm-ci-scripts
Google-ML-Automation Nov 27, 2024
df6758f
Update XLA dependency to use revision
belitskiy Nov 27, 2024
c2c177e
[AutoPGLE] Update fdo_profile comment.
Google-ML-Automation Nov 27, 2024
6e72592
[Pallas] Fix float -> int casting on Triton backend.
justinjfu Nov 27, 2024
cc5036c
Raise a better error message if anything other than a sequence of int…
yashk2810 Nov 27, 2024
a212a29
Update XLA dependency to use revision
Google-ML-Automation Nov 27, 2024
8c52154
Add experimental JAX roofline API.
epiqueras Nov 27, 2024
132ad25
Update XLA dependency to use revision
Google-ML-Automation Nov 27, 2024
b62ca8b
Rework custom hermetic python instructions.
vam-google Nov 28, 2024
bdee4c3
Merge pull request #25153 from epiqueras:feature/typechecker
Google-ML-Automation Nov 28, 2024
34fe66b
[mgpu] foreach should not try to create an array if it didn't create …
cperivol Nov 28, 2024
a158e02
Reverts cc5036cc18bc585b0d92a4f606956da084effbad
Nov 28, 2024
b09b077
[Mosaic GPU] Add support for fast upcasts of s8 to bf16 for vectors o…
apaszke Nov 28, 2024
14ddb81
[Mosaic GPU] Avoid double-predication when async_copy predicate is sp…
apaszke Nov 28, 2024
d5bfafb
[mgpu] Added a missed case for debug_print types and raise a proper e…
cperivol Nov 28, 2024
b801539
[Pallas][Mosaic GPU] Add support for compressing squeezed dims in asy…
apaszke Nov 28, 2024
db158e6
[Mosaic GPU] Improve the implementation of max and exp
apaszke Nov 28, 2024
456dfeb
[Take 2] Raise a better error message if anything other than a sequen…
yashk2810 Nov 28, 2024
f73de23
Update XLA dependency to use revision
Google-ML-Automation Nov 28, 2024
385e2f4
Merge pull request #25137 from ROCm:ci_enable_https-upstream
Google-ML-Automation Nov 29, 2024
6d4278d
Merge pull request #25091 from gspschmid:gschmid/nvidia-cuda-nvcc-cu1…
Google-ML-Automation Nov 29, 2024
b0df405
Merge pull request #25130 from ROCm:ci_fix_set_options-upstream
Google-ML-Automation Nov 29, 2024
ab79066
Merge pull request #25128 from ROCm:ci_fix_wheelhouse_relative_paths-…
Google-ML-Automation Nov 29, 2024
f10d3eb
[Mosaic GPU] Allow contracting ops into FMAs
apaszke Nov 29, 2024
ea69401
[mgpu] Fixed off-by-one issue in pointwise argument shuffling when le…
cperivol Nov 29, 2024
c3c21c7
[mgpu_pallas] Better support for unsigned integers and floats in iota.
cperivol Nov 29, 2024
031c0ac
Add new CI scripts for running Pytests
nitins17 Nov 29, 2024
47858c4
Update XLA dependency to use revision
Google-ML-Automation Nov 30, 2024
db4b3f2
Update XLA dependency to use revision
Google-ML-Automation Nov 30, 2024
a1dfdc1
C++ tree with path API
IvyZX Dec 1, 2024
e124c05
Update XLA dependency to use revision
Google-ML-Automation Dec 1, 2024
bd66f52
[Mosaic GPU] Add a bank-conflict checker to tiled transfer + transfer…
apaszke Dec 2, 2024
7b32d88
Merge pull request #25136 from ROCm:ci_dockerfile_arg_changes-upstream
Google-ML-Automation Dec 2, 2024
5d5b06c
[jax] Canonicalize dtypes when checking if dtypes present in target d…
chr1sj0nes Dec 2, 2024
aff7714
[Pallas:MGPU] Fix an overly strict precision requirement in tests
apaszke Dec 2, 2024
58e3045
Merge branch 'rocm-main' into ci-daily-sync-02-12-2024
charleshofer Dec 2, 2024
97d201e
Update ci-build.yaml to use specific image
charleshofer Dec 2, 2024
f8b753c
Update ci-build.yaml
charleshofer Dec 2, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .bazelrc
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,7 @@ build:macos_cache_push --config=macos_cache --remote_upload_local_results=true -
build:ci_linux_x86_64 --config=avx_linux --config=avx_posix
build:ci_linux_x86_64 --config=mkl_open_source_only
build:ci_linux_x86_64 --config=clang --verbose_failures=true
build:ci_linux_x86_64 --color=yes

# TODO(b/356695103): We do not have a CPU only toolchain so we use the CUDA
# toolchain for both CPU and GPU builds.
Expand All @@ -203,6 +204,7 @@ build:ci_linux_x86_64_cuda --config=ci_linux_x86_64
# Linux Aarch64 CI configs
build:ci_linux_aarch64_base --config=clang --verbose_failures=true
build:ci_linux_aarch64_base --action_env=TF_SYSROOT="/dt10"
build:ci_linux_aarch64_base --color=yes

build:ci_linux_aarch64 --config=ci_linux_aarch64_base
build:ci_linux_aarch64 --host_crosstool_top="@ml2014_clang_aarch64_config_aarch64//crosstool:toolchain"
Expand All @@ -221,18 +223,21 @@ build:ci_linux_aarch64_cuda --action_env=CLANG_CUDA_COMPILER_PATH="/usr/lib/llvm
build:ci_darwin_x86_64 --macos_minimum_os=10.14
build:ci_darwin_x86_64 --config=macos_cache_push
build:ci_darwin_x86_64 --verbose_failures=true
build:ci_darwin_x86_64 --color=yes

# Mac Arm64 CI configs
build:ci_darwin_arm64 --macos_minimum_os=11.0
build:ci_darwin_arm64 --config=macos_cache_push
build:ci_darwin_arm64 --verbose_failures=true
build:ci_darwin_arm64 --color=yes

# Windows x86 CI configs
build:ci_windows_amd64 --config=avx_windows
build:ci_windows_amd64 --compiler=clang-cl --config=clang --verbose_failures=true
build:ci_windows_amd64 --crosstool_top="@xla//tools/toolchains/win/20240424:toolchain"
build:ci_windows_amd64 --extra_toolchains="@xla//tools/toolchains/win/20240424:cc-toolchain-x64_windows-clang-cl"
build:ci_windows_amd64 --host_linkopt=/FORCE:MULTIPLE --linkopt=/FORCE:MULTIPLE
build:ci_windows_amd64 --color=yes

# #############################################################################
# RBE config options below. These inherit the CI configs above and set the
Expand Down
20 changes: 11 additions & 9 deletions .github/workflows/asan.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,8 @@ jobs:
run:
shell: bash -l {0}
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
with:
path: jax
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
with:
repository: python/cpython
path: cpython
ref: v3.13.0
# Install git before actions/checkout as otherwise it will download the code with the GitHub
# REST API and therefore any subsequent git commands will fail.
- name: Install clang 18
env:
DEBIAN_FRONTEND: noninteractive
Expand All @@ -42,6 +36,14 @@ jobs:
zlib1g-dev libbz2-dev libreadline-dev libsqlite3-dev curl git \
libncursesw5-dev xz-utils tk-dev libxml2-dev libxmlsec1-dev \
libffi-dev liblzma-dev
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
with:
path: jax
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
with:
repository: python/cpython
path: cpython
ref: v3.13.0
- name: Build CPython with ASAN enabled
env:
ASAN_OPTIONS: detect_leaks=0
Expand All @@ -65,7 +67,7 @@ jobs:
run: |
source ${GITHUB_WORKSPACE}/venv/bin/activate
cd jax
python build/build.py \
python build/build.py build --wheels=jaxlib --verbose \
--bazel_options=--color=yes \
--bazel_options=--copt=-fsanitize=address \
--clang_path=/usr/bin/clang-18
Expand Down
5 changes: 2 additions & 3 deletions .github/workflows/ci-build.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ jobs:

documentation:
name: Documentation - test code snippets
runs-on: ubuntu-latest
runs-on: ROCM-Ubuntu
timeout-minutes: 10
strategy:
matrix:
Expand Down Expand Up @@ -164,8 +164,7 @@ jobs:
pip install -r docs/requirements.txt
- name: Render documentation
run: |
sphinx-build --color -W --keep-going -b html -D nb_execution_mode=off docs docs/build/html

sphinx-build -j auto --color -W --keep-going -b html -D nb_execution_mode=off docs docs/build/html

jax2tf_test:
name: "jax2tf_test (py ${{ matrix.python-version }} on ${{ matrix.os }}, x64=${{ matrix.enable-x64}})"
Expand Down
28 changes: 24 additions & 4 deletions .github/workflows/cloud-tpu-ci-nightly.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@ jobs:
strategy:
fail-fast: false # don't cancel all jobs on failure
matrix:
jaxlib-version: ["pypi_latest", "nightly", "nightly+oldest_supported_libtpu"]
jaxlib-version: ["head", "pypi_latest", "nightly", "nightly+oldest_supported_libtpu"]
tpu: [
# {type: "v3-8", cores: "4"}, # Enable when we have the v3/v4 type available
# {type: "v4-8", cores: "4", runner: "linux-x86-ct4p-240-4tpu"},
# {type: "v3-8", cores: "4"}, # Enable when we have the v3 type available
{type: "v4-8", cores: "4", runner: "linux-x86-ct4p-240-4tpu"},
{type: "v5e-8", cores: "8", runner: "linux-x86-ct5lp-224-8tpu"}
]
python-version: ["3.10"]
Expand All @@ -47,14 +47,34 @@ jobs:
# mandates using a specific commit for non-Google actions. We use
# https://github.com/sethvargo/ratchet to pin specific versions.
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
# Checkout XLA at head, if we're building jaxlib at head.
- name: Checkout XLA at head
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
if: ${{ matrix.jaxlib-version == 'head' }}
with:
repository: openxla/xla
path: xla
- name: Install JAX test requirements
run: |
$PYTHON -m pip install -U -r build/test-requirements.txt
$PYTHON -m pip install -U -r build/collect-profile-requirements.txt
- name: Install JAX
run: |
$PYTHON -m pip uninstall -y jax jaxlib libtpu
if [ "${{ matrix.jaxlib-version }}" == "pypi_latest" ]; then
if [ "${{ matrix.jaxlib-version }}" == "head" ]; then
# Build and install jaxlib at head
$PYTHON build/build.py --bazel_options=--config=rbe_linux_x86_64 \
--bazel_options="--override_repository=xla=$(pwd)/xla" \
--bazel_options=--color=yes
$PYTHON -m pip install dist/*.whl
# Install "jax" at head
$PYTHON -m pip install -U -e .
# Install libtpu
$PYTHON -m pip install --pre libtpu \
-f https://storage.googleapis.com/jax-releases/libtpu_releases.html
elif [ "${{ matrix.jaxlib-version }}" == "pypi_latest" ]; then
$PYTHON -m pip install .[tpu] \
-f https://storage.googleapis.com/jax-releases/libtpu_releases.html
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/wheel_win_x64.yml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ jobs:
python -m pip install -r build/test-requirements.txt
python -m pip install --upgrade numpy==2.0.0 scipy==1.13.1
"C:\\msys64\\;C:\\msys64\\usr\\bin\\;" >> $env:GITHUB_PATH
python.exe build\build.py `
python.exe build\build.py build --wheels=jaxlib `
--bazel_options=--color=yes `
--bazel_options=--config=win_clang `
--verbose
Expand Down
5 changes: 3 additions & 2 deletions .github/workflows/windows_ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,10 @@ jobs:
python -m pip install -r build/test-requirements.txt
python -m pip install --upgrade numpy==2.0.0 scipy==1.13.1
"C:\\msys64\\;C:\\msys64\\usr\\bin\\;" >> $env:GITHUB_PATH
python.exe build\build.py `
python.exe build\build.py build --wheels=jaxlib `
--bazel_options=--color=yes `
--bazel_options=--config=win_clang
--bazel_options=--config=win_clang `
--verbose

- uses: actions/upload-artifact@b4b15b8c7c6ac21ea08fcf65892d2ee8f75cf882 # v4.4.3
with:
Expand Down
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,11 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
`platforms` instead.
* Hashing of tracers, which has been deprecated since version 0.4.30, now
results in a `TypeError`.
* Refactor: JAX build CLI (build/build.py) now uses a subcommand structure and
replaces previous build.py usage. Run `python build/build.py --help` for
more details. Brief overview of the new subcommand options:
* `build`: Builds JAX wheel packages. For e.g., `python build/build.py build --wheels=jaxlib,jax-cuda-pjrt`
* `requirements_update`: Updates requirements_lock.txt files.
* {func}`jax.scipy.linalg.toeplitz` now does implicit batching on multi-dimensional
inputs. To recover the previous behavior, you can call {func}`jax.numpy.ravel`
on the function inputs.
Expand Down
Loading
Loading