Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
4a82fe9
Use `lax.top_k` instead of `jnp.argsort` in Gumbel top-k trick for we…
mar-muel Mar 14, 2025
dadc68b
add experimental lax.optimization_barrier autodiff rules
mattjj Mar 14, 2025
14cb745
Add a C++ implementation of a toplogical sort.
hawkinsp Mar 14, 2025
7db59cd
Merge pull request #27174 from mattjj:opt-barrier-ad-rules
Google-ML-Automation Mar 15, 2025
3c0027a
mixing modes
yashk2810 Mar 15, 2025
d07d642
Merge pull request #27177 from jax-ml:mixing_modes
Google-ML-Automation Mar 15, 2025
9b0ace4
Support error checking in explicit mode
ayaka14732 Mar 15, 2025
f360e19
Update XLA dependency to use revision
Google-ML-Automation Mar 15, 2025
de8b056
Better docs for jax.lax add/sub/mul/div
jakevdp Mar 15, 2025
466ef6a
Change the way that batching.spec_types is updated.
jpuigcerver Mar 16, 2025
e8b683a
Update XLA dependency to use revision
Google-ML-Automation Mar 16, 2025
761b35c
Merge pull request #27176 from jakevdp:lax-docs
Google-ML-Automation Mar 16, 2025
2bdd9c8
[Mosaic GPU] Add support for fast WGMMA layout changes after 8- to 16…
apaszke Mar 17, 2025
89b21de
[Mosaic GPU] Add support for changing the layout before the upcast
apaszke Mar 17, 2025
a7e5eae
[pallas:mosaic_gpu] `jnp.reduce_sum` now works for >1D arrays
superbobry Mar 17, 2025
55812c5
Update XLA dependency to use revision
Google-ML-Automation Mar 17, 2025
0ff2340
Removed trivial docstrings from JAX tests
superbobry Mar 17, 2025
3649da5
[Mosaic GPU] Make the s4 -> bf16 upcast more flexible when it comes t…
apaszke Mar 17, 2025
031614c
Pin numpy~=2.1.0 in workflow file instead of test-requirements.txt
nitins17 Mar 17, 2025
de9ad6b
Merge pull request #27157 from mar-muel:improve-random-choice-perform…
Google-ML-Automation Mar 17, 2025
3f59fa6
Add replace option to random.categorical to enable sampling without r…
carlosgmartin Mar 13, 2025
9a686e0
[Mosaic GPU] Add initial transform inference rules for `vector.{load,…
bchetioui Mar 17, 2025
be5d13a
Remove code that preserved _original_py_fns on C++ classes.
hawkinsp Mar 17, 2025
ebcae0d
Merge pull request #26980 from carlosgmartin:categorical_replace
Google-ML-Automation Mar 17, 2025
20658fa
Replace cached function get_replicated_hlo_sharding() with a constant.
hawkinsp Mar 17, 2025
4f70471
Fix error in pallas tutorial
Google-ML-Automation Mar 17, 2025
ecf7fde
Add B200 testing to continuous workflow
MichaelHudgins Mar 17, 2025
b74b16f
Merge pull request #27164 from MichaelHudgins:a4-testing
Google-ML-Automation Mar 17, 2025
b496613
Compute tile index using tile-based coordinates
Google-ML-Automation Mar 17, 2025
051687d
[pallas] `pallas_call_p` is now parameterized by a mesh
superbobry Mar 17, 2025
8c35191
Enable `jax.device_put` to a sharding with no local devices.
emilyfertig Mar 17, 2025
f174b00
Replace the uses of `PjRtClient::Compile()` with `PjRtClient::Compile…
changhuilin Mar 18, 2025
549973d
Allow pspec to be passed to device_put if there is a mesh in the surr…
yashk2810 Mar 18, 2025
34cd5b0
[Mosaic GPU] Remove sub-byte conversion restriction
apaszke Mar 18, 2025
38d52a1
[mosaic_gpu] Force flush all cupti activity, then unsubscribe.
chr1sj0nes Mar 18, 2025
d4bd257
[Mosaic GPU] Add a specialized layout for loading 4-bit inputs in WGM…
apaszke Mar 18, 2025
ba2f7c9
[Mosaic GPU] Add transform inference rule for `mgpu.slice_smem`.
bchetioui Mar 18, 2025
7a459f0
Update XLA dependency to use revision
Google-ML-Automation Mar 18, 2025
8da9324
[Mosaic GPU] Fuse slicing into s4 -> bf16 upcasts
apaszke Mar 18, 2025
1e36cbe
[Mosaic GPU] Raise a `NotImplementedError` if `swizzle=16`.
bchetioui Mar 18, 2025
c7b407c
Merge branch 'rocm-main' into ci-upstream-sync-151_1
charleshofer Mar 18, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .github/workflows/pytest_cpu.yml
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,11 @@ jobs:
run: |
$JAXCI_PYTHON -m pip install uv~=0.5.30
$JAXCI_PYTHON -m uv pip install -r build/test-requirements.txt
# CPU Pytests crash with NumPy 2.2+ on Linux Aarch64; b/399168632
if [[ $OS == "linux" && $ARCH == "aarch64" ]]; then
$JAXCI_PYTHON -m uv pip install numpy~=2.1.0
fi
# Halt for testing
- name: Wait For Connection
uses: google-ml-infra/actions/ci_connection@main
Expand Down
3 changes: 2 additions & 1 deletion .github/workflows/pytest_cuda.yml
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,8 @@ jobs:
runs-on: ${{ inputs.runner }}
# TODO: Update to the generic ML ecosystem test containers when they are ready.
container: ${{ (contains(inputs.cuda, '12.3') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-cuda12.3-cudnn9.1:latest') ||
(contains(inputs.cuda, '12.1') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-cuda12.1-cudnn9.1:latest') }}
(contains(inputs.cuda, '12.1') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-cuda12.1-cudnn9.1:latest') ||
(contains(inputs.cuda, '12.8') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-cuda12.8-cudnn9.8:latest') }}
name: "Pytest CUDA (${{ inputs.runner }}, CUDA ${{ inputs.cuda }}, Python ${{ inputs.python }}, x64=${{ inputs.enable-x64 }})"

env:
Expand Down
24 changes: 18 additions & 6 deletions .github/workflows/wheel_tests_continuous.yml
Original file line number Diff line number Diff line change
Expand Up @@ -110,18 +110,30 @@ jobs:
fail-fast: false # don't cancel all jobs on failure
matrix:
# Python values need to match the matrix stategy in the artifact build jobs above
runner: ["linux-x86-g2-48-l4-4gpu", "linux-x86-a3-8g-h100-8gpu"]
# See exlusions for what is fully tested
runner: ["linux-x86-g2-48-l4-4gpu", "linux-x86-a3-8g-h100-8gpu","linux-x86-a4-224-b200-1gpu"]
python: ["3.10",]
cuda: ["12.3", "12.1"]
cuda: ["12.1","12.3","12.8"]
enable-x64: [1, 0]
exclude:
# Run only a single configuration on H100 to save resources
# L4 does not run on cuda 12.8 but tests other configs
- runner: "linux-x86-g2-48-l4-4gpu"
cuda: "12.8"
# H100 runs only a single config, CUDA 12.3 Enable x64 1
- runner: "linux-x86-a3-8g-h100-8gpu"
cuda: "12.8"
- runner: "linux-x86-a3-8g-h100-8gpu"
python: "3.10"
cuda: "12.1"
- runner: "linux-x86-a3-8g-h100-8gpu"
python: "3.10"
enable-x64: 0
enable-x64: "0"
# B200 runs only a single config, CUDA 12.8 Enable x64 1
- runner: "linux-x86-a4-224-b200-1gpu"
enable-x64: "0"
- runner: "linux-x86-a4-224-b200-1gpu"
cuda: "12.1"
- runner: "linux-x86-a4-224-b200-1gpu"
cuda: "12.3"

name: "Pytest CUDA (JAX artifacts version = ${{ format('{0}', 'head') }})"
with:
runner: ${{ matrix.runner }}
Expand Down
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
{func}`jax.lax.dynamic_update_slice` and related functions. The default is
true, matching the current behavior. If set to false, JAX does not need to
emit code clamping negative indices, which improves code size.
* Added a `replace` option to {func}`jax.random.categorical` to enable sampling
without replacement.

## jax 0.5.2 (Mar 4, 2025)

Expand Down
5 changes: 1 addition & 4 deletions build/test-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,4 @@ setuptools
matplotlib~=3.8.4; python_version=="3.10"
matplotlib; python_version>="3.11"
opt-einsum
auditwheel

# CPU Pytests crash with NumPy 2.2+ on Linux Aarch64; b/399168632
numpy~=2.1.0; platform_system == "Linux" and platform_machine == "aarch64"
auditwheel
138 changes: 104 additions & 34 deletions docs/notebooks/explicit-sharding.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -49,13 +49,9 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": 7,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "hVi6mApuVw3r",
"outputId": "a64bcbcb-27f8-4c57-8931-8091c9bb8ebf"
"id": "hVi6mApuVw3r"
},
"outputs": [],
"source": [
Expand Down Expand Up @@ -84,13 +80,13 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 8,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "mzDIDvj7Vw0k",
"outputId": "417b8453-9c86-4e76-a886-4fa9fdb16434"
"outputId": "09ef049b-461f-47db-bf58-dc10b42fe40a"
},
"outputs": [
{
Expand Down Expand Up @@ -119,13 +115,13 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 9,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "IyPx_-IBVwxr",
"outputId": "7d6e4fcb-f6a8-4ed8-ae41-61cf478fa499"
"outputId": "0cd3122f-e579-45d7-868d-e42bb0eacddb"
},
"outputs": [
{
Expand All @@ -141,7 +137,7 @@
"Array([ 0, 2, 4, 6, 8, 10, 12, 14], dtype=int32)"
]
},
"execution_count": 3,
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
Expand Down Expand Up @@ -172,13 +168,13 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 10,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "NO2ulM_QW7a8",
"outputId": "ea313610-146c-41f4-95b4-c5a5b2b407cb"
"outputId": "d888371b-080e-4bff-be5d-ea56beda3aac"
},
"outputs": [
{
Expand Down Expand Up @@ -208,13 +204,13 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 11,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "1-TzmA0AXCAf",
"outputId": "15b33b6d-3915-4725-da6d-4f31fb78fe71"
"outputId": "1c7cc3ac-4b0e-42b7-facc-c706af10d7d2"
},
"outputs": [
{
Expand Down Expand Up @@ -256,13 +252,13 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 12,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "Gy7ABds3XND3",
"outputId": "4ced73ed-5872-45f3-a4a6-2138f942e01b"
"outputId": "0d72dad2-381a-4e96-f771-40d705da1376"
},
"outputs": [
{
Expand Down Expand Up @@ -297,13 +293,13 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 13,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "grCcotr-XQjY",
"outputId": "9a9f381d-5111-4824-9bc0-cb2472cb8e6a"
"outputId": "c2db656c-809f-49a6-c948-629d6420360c"
},
"outputs": [
{
Expand All @@ -324,7 +320,7 @@
" [ 3, 4, 5, 6, 7, 8, 9, 10]], dtype=int32)"
]
},
"execution_count": 7,
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
Expand Down Expand Up @@ -460,13 +456,13 @@
},
{
"cell_type": "code",
"execution_count": 13,
"execution_count": 14,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "fpFEaMBcXsJG",
"outputId": "d28a69eb-260f-4fc5-8f19-2cc64cc70660"
"outputId": "5b84b1d1-d7b2-4e9a-ba98-3dd34a5465ef"
},
"outputs": [
{
Expand All @@ -479,13 +475,6 @@
"We're in auto-sharding mode here. This is the current mesh: AbstractMesh('X': 2, 'Y': 4, axis_types=(Auto, Auto))\n",
"Result type: ShapedArray(int32[4@X,4])\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Result type: ShapedArray(int32[4@X,4])\n"
]
}
],
"source": [
Expand Down Expand Up @@ -550,13 +539,13 @@
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": 15,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "geptWrdYX0OM",
"outputId": "c0e62eb1-9f79-4d1c-e708-526165ca680f"
"outputId": "b8c3813f-60bb-4ccf-9da7-73462c57963f"
},
"outputs": [
{
Expand Down Expand Up @@ -588,7 +577,88 @@
{
"cell_type": "markdown",
"metadata": {
"id": "AQQjzUeGX4P6"
"id": "LZWjgiMZ7uSS"
},
"source": [
"You can use the `auto_axes` API to be `Auto` over some mesh axes while being `Explicit` over other. For example:"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "IVzPSkp77uCF",
"outputId": "db80a604-98ac-4343-8677-23729adf7ffc"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"mesh inside f: AbstractMesh('X': 2, 'Y': 4, axis_types=(Explicit, Explicit))\n",
"x.sharding: ShapedArray(float32[4@X,4@Y])\n",
"\n",
"mesh inside g: AbstractMesh('X': 2, 'Y': 4, axis_types=(Auto, Explicit))\n",
"y.sharding inside g: jax.typeof(y) = ShapedArray(float32[4,4@Y])\n",
"\n",
"z.sharding: ShapedArray(float32[4@X,4@Y])\n",
"\n"
]
},
{
"data": {
"text/plain": [
"Array([[ 1. , 2.682942 , 2.818595 , 1.28224 ],\n",
" [-0.513605 , -0.9178486 , 0.44116902, 2.3139732 ],\n",
" [ 2.9787164 , 1.824237 , -0.08804226, -0.99998045],\n",
" [-0.07314587, 1.840334 , 2.9812148 , 2.3005757 ]], dtype=float32)"
]
},
"execution_count": 27,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import functools\n",
"\n",
"@functools.partial(auto_axes, axes='X')\n",
"def g(y):\n",
" print(f'mesh inside g: {get_abstract_mesh()}')\n",
" print(f'y.sharding inside g: {jax.typeof(y) = }', end='\\n\\n')\n",
" return y * 2\n",
"\n",
"@jax.jit\n",
"def f(arr1):\n",
" print(f'mesh inside f: {get_abstract_mesh()}')\n",
" x = jnp.sin(arr1)\n",
" print(f'x.sharding: {jax.typeof(x)}', end='\\n\\n')\n",
"\n",
" z = g(x, out_shardings=P(\"X\", \"Y\"))\n",
"\n",
" print(f'z.sharding: {jax.typeof(z)}', end=\"\\n\\n\")\n",
" return z + 1\n",
"\n",
"some_x = reshard(np.arange(16).reshape(4, 4), P(\"X\", \"Y\"))\n",
"f(some_x)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "_3sfJjRq8w9f"
},
"source": [
"As you can see, inside `g`, the type of `arr1` is `ShapedArray(float32[4,4@Y])` which indicates it's Explicit over `Y` mesh axis while auto over `X`."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "sJcWbfAh7UcO"
},
"source": [
"## Concrete array shardings can mention `Auto` mesh axis\n",
Expand All @@ -606,7 +676,7 @@
},
{
"cell_type": "code",
"execution_count": 25,
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
Expand Down Expand Up @@ -708,5 +778,5 @@
}
},
"nbformat": 4,
"nbformat_minor": 4
"nbformat_minor": 0
}
Loading
Loading