{"payload":{"feedbackUrl":"https://github.com/orgs/community/discussions/53140","repo":{"id":154739597,"defaultBranch":"main","name":"jax","ownerLogin":"google","currentUserCanPush":false,"isFork":false,"isEmpty":false,"createdAt":"2018-10-25T21:25:02.000Z","ownerAvatar":"https://avatars.githubusercontent.com/u/1342004?v=4","public":true,"private":false,"isOrgOwned":true},"refInfo":{"name":"","listCacheKey":"v0:1718025682.0","currentOid":""},"activityList":{"items":[{"before":"cd93b46df4d05fedf8a2e9af7d715b84d6854cd2","after":null,"ref":"refs/heads/test_641698941","pushedAt":"2024-06-10T13:21:22.000Z","pushType":"branch_deletion","commitsCount":0,"pusher":{"login":"copybara-service[bot]","name":null,"path":"/apps/copybara-service","primaryAvatarUrl":"https://avatars.githubusercontent.com/in/44061?s=80&v=4"}},{"before":"991797a8a987f241074ccd2a94ad2aaec684537c","after":"cd93b46df4d05fedf8a2e9af7d715b84d6854cd2","ref":"refs/heads/main","pushedAt":"2024-06-10T13:21:21.000Z","pushType":"push","commitsCount":1,"pusher":{"login":"copybara-service[bot]","name":null,"path":"/apps/copybara-service","primaryAvatarUrl":"https://avatars.githubusercontent.com/in/44061?s=80&v=4"},"commit":{"message":"Add initialization annotations (for the benefit of MSAN) to variables that are initialized by external functions.\n\nPiperOrigin-RevId: 641879836","shortMessageHtmlLink":"Add initialization annotations (for the benefit of MSAN) to variables…"}},{"before":"ce4bb670c04efc0f648b15a4bed9de4a132c3b20","after":"cd93b46df4d05fedf8a2e9af7d715b84d6854cd2","ref":"refs/heads/test_641698941","pushedAt":"2024-06-10T13:21:20.000Z","pushType":"force_push","commitsCount":0,"pusher":{"login":"copybara-service[bot]","name":null,"path":"/apps/copybara-service","primaryAvatarUrl":"https://avatars.githubusercontent.com/in/44061?s=80&v=4"},"commit":{"message":"Add initialization annotations (for the benefit of MSAN) to variables that are initialized by external functions.\n\nPiperOrigin-RevId: 641879836","shortMessageHtmlLink":"Add initialization annotations (for the benefit of MSAN) to variables…"}},{"before":null,"after":"ce4bb670c04efc0f648b15a4bed9de4a132c3b20","ref":"refs/heads/test_641698941","pushedAt":"2024-06-10T13:06:30.000Z","pushType":"branch_creation","commitsCount":0,"pusher":{"login":"copybara-service[bot]","name":null,"path":"/apps/copybara-service","primaryAvatarUrl":"https://avatars.githubusercontent.com/in/44061?s=80&v=4"},"commit":{"message":"Add initialization annotations (for the benefit of MSAN) to variables that are initialized by external functions.\n\nPiperOrigin-RevId: 641698941","shortMessageHtmlLink":"Add initialization annotations (for the benefit of MSAN) to variables…"}},{"before":"5e7ad600e2fc541138896b5590bf7b43472deea6","after":"991797a8a987f241074ccd2a94ad2aaec684537c","ref":"refs/heads/main","pushedAt":"2024-06-10T13:04:03.000Z","pushType":"push","commitsCount":2,"pusher":{"login":"copybara-service[bot]","name":null,"path":"/apps/copybara-service","primaryAvatarUrl":"https://avatars.githubusercontent.com/in/44061?s=80&v=4"},"commit":{"message":"Merge pull request #21765 from hawkinsp:release\n\nPiperOrigin-RevId: 641876244","shortMessageHtmlLink":"Merge pull request #21765 from hawkinsp:release"}},{"before":"5e7ad600e2fc541138896b5590bf7b43472deea6","after":null,"ref":"refs/heads/test_641850715","pushedAt":"2024-06-10T12:59:15.000Z","pushType":"branch_deletion","commitsCount":0,"pusher":{"login":"copybara-service[bot]","name":null,"path":"/apps/copybara-service","primaryAvatarUrl":"https://avatars.githubusercontent.com/in/44061?s=80&v=4"}},{"before":"3b4039c850c23c154bcfcf1d9da88d672c577857","after":"5e7ad600e2fc541138896b5590bf7b43472deea6","ref":"refs/heads/main","pushedAt":"2024-06-10T12:59:13.000Z","pushType":"push","commitsCount":1,"pusher":{"login":"copybara-service[bot]","name":null,"path":"/apps/copybara-service","primaryAvatarUrl":"https://avatars.githubusercontent.com/in/44061?s=80&v=4"},"commit":{"message":"Removed the double re-exporting of Pallas GPU/TPU APIs\n\njax.experimental.pallas.{gpu,tpu} now import directly from the relevant\njax._src.pallas.{triton,mosaic} submodules.\n\nPiperOrigin-RevId: 641875127","shortMessageHtmlLink":"Removed the double re-exporting of Pallas GPU/TPU APIs"}},{"before":"bc29866ba83d0774b39ede55b208b407b1e12920","after":"5e7ad600e2fc541138896b5590bf7b43472deea6","ref":"refs/heads/test_641850715","pushedAt":"2024-06-10T12:59:12.000Z","pushType":"force_push","commitsCount":0,"pusher":{"login":"copybara-service[bot]","name":null,"path":"/apps/copybara-service","primaryAvatarUrl":"https://avatars.githubusercontent.com/in/44061?s=80&v=4"},"commit":{"message":"Removed the double re-exporting of Pallas GPU/TPU APIs\n\njax.experimental.pallas.{gpu,tpu} now import directly from the relevant\njax._src.pallas.{triton,mosaic} submodules.\n\nPiperOrigin-RevId: 641875127","shortMessageHtmlLink":"Removed the double re-exporting of Pallas GPU/TPU APIs"}},{"before":"3b4039c850c23c154bcfcf1d9da88d672c577857","after":null,"ref":"refs/heads/test_641123019","pushedAt":"2024-06-10T12:55:08.000Z","pushType":"branch_deletion","commitsCount":0,"pusher":{"login":"copybara-service[bot]","name":null,"path":"/apps/copybara-service","primaryAvatarUrl":"https://avatars.githubusercontent.com/in/44061?s=80&v=4"}},{"before":"2ade7e75265a7c1da0c6d2cb373b08de631627a1","after":"3b4039c850c23c154bcfcf1d9da88d672c577857","ref":"refs/heads/main","pushedAt":"2024-06-10T12:55:06.000Z","pushType":"push","commitsCount":1,"pusher":{"login":"copybara-service[bot]","name":null,"path":"/apps/copybara-service","primaryAvatarUrl":"https://avatars.githubusercontent.com/in/44061?s=80&v=4"},"commit":{"message":"[Mosaic GPU] Load LLVM lowering interfaces for all dialects\n\nApparently we were missing interface registration code for LLVM lowering,\nwhich the gpu-to-llvm pass gracefully ignores unless compiled with debug\nassertions enabled. But, simply adding the assertions in fact makes the\npass _too powerful_ and makes it lower _all dialects to LLVM_, which is not\nwhat we want. That's why I've replaced it with a minimal version that is\nonly repsponsible for handling the GPU dialect, making the lowering similar\nto the one prior to extra registrations.\n\nPiperOrigin-RevId: 641874183","shortMessageHtmlLink":"[Mosaic GPU] Load LLVM lowering interfaces for all dialects"}},{"before":"3e853312e2a61ccfe53f1949309700badb0f52c8","after":"3b4039c850c23c154bcfcf1d9da88d672c577857","ref":"refs/heads/test_641123019","pushedAt":"2024-06-10T12:55:05.000Z","pushType":"force_push","commitsCount":0,"pusher":{"login":"copybara-service[bot]","name":null,"path":"/apps/copybara-service","primaryAvatarUrl":"https://avatars.githubusercontent.com/in/44061?s=80&v=4"},"commit":{"message":"[Mosaic GPU] Load LLVM lowering interfaces for all dialects\n\nApparently we were missing interface registration code for LLVM lowering,\nwhich the gpu-to-llvm pass gracefully ignores unless compiled with debug\nassertions enabled. But, simply adding the assertions in fact makes the\npass _too powerful_ and makes it lower _all dialects to LLVM_, which is not\nwhat we want. That's why I've replaced it with a minimal version that is\nonly repsponsible for handling the GPU dialect, making the lowering similar\nto the one prior to extra registrations.\n\nPiperOrigin-RevId: 641874183","shortMessageHtmlLink":"[Mosaic GPU] Load LLVM lowering interfaces for all dialects"}},{"before":"80244661075857db0c1a0d4e96f2501821569935","after":"bc29866ba83d0774b39ede55b208b407b1e12920","ref":"refs/heads/test_641850715","pushedAt":"2024-06-10T12:45:25.000Z","pushType":"force_push","commitsCount":0,"pusher":{"login":"copybara-service[bot]","name":null,"path":"/apps/copybara-service","primaryAvatarUrl":"https://avatars.githubusercontent.com/in/44061?s=80&v=4"},"commit":{"message":"Removed the double re-exporting of Pallas GPU/TPU APIs\n\njax.experimental.pallas.{gpu,tpu} now import directly from the relevant\njax._src.pallas.{triton,mosaic} submodules.\n\nPiperOrigin-RevId: 641850715","shortMessageHtmlLink":"Removed the double re-exporting of Pallas GPU/TPU APIs"}},{"before":"3daf8045b70f3ab9c742651d87f591f3c9e38869","after":"3e853312e2a61ccfe53f1949309700badb0f52c8","ref":"refs/heads/test_641123019","pushedAt":"2024-06-10T12:39:36.000Z","pushType":"force_push","commitsCount":0,"pusher":{"login":"copybara-service[bot]","name":null,"path":"/apps/copybara-service","primaryAvatarUrl":"https://avatars.githubusercontent.com/in/44061?s=80&v=4"},"commit":{"message":"[Mosaic GPU] Load LLVM lowering interfaces for all dialects\n\nApparently we were missing interface registration code for LLVM lowering,\nwhich the gpu-to-llvm pass gracefully ignores unless compiled with debug\nassertions enabled. But, simply adding the assertions in fact makes the\npass _too powerful_ and makes it lower _all dialects to LLVM_, which is not\nwhat we want. That's why I've replaced it with a minimal version that is\nonly repsponsible for handling the GPU dialect, making the lowering similar\nto the one prior to extra registrations.\n\nPiperOrigin-RevId: 641123019","shortMessageHtmlLink":"[Mosaic GPU] Load LLVM lowering interfaces for all dialects"}},{"before":"603001909c22eaa2739f78937a86db7e9dd2f660","after":"80244661075857db0c1a0d4e96f2501821569935","ref":"refs/heads/test_641850715","pushedAt":"2024-06-10T12:08:50.000Z","pushType":"force_push","commitsCount":0,"pusher":{"login":"copybara-service[bot]","name":null,"path":"/apps/copybara-service","primaryAvatarUrl":"https://avatars.githubusercontent.com/in/44061?s=80&v=4"},"commit":{"message":"Removed the double re-exporting of Pallas GPU/TPU APIs\n\njax.experimental.pallas.{gpu,tpu} now import directly from the relevant\njax._src.pallas.{triton,mosaic} submodules.\n\nPiperOrigin-RevId: 641850715","shortMessageHtmlLink":"Removed the double re-exporting of Pallas GPU/TPU APIs"}},{"before":"82a36d1b9961bff1b8f8a76c964cf5373cffb45d","after":"603001909c22eaa2739f78937a86db7e9dd2f660","ref":"refs/heads/test_641850715","pushedAt":"2024-06-10T11:19:57.000Z","pushType":"force_push","commitsCount":0,"pusher":{"login":"copybara-service[bot]","name":null,"path":"/apps/copybara-service","primaryAvatarUrl":"https://avatars.githubusercontent.com/in/44061?s=80&v=4"},"commit":{"message":"Removed the double re-exporting of Pallas GPU/TPU APIs\n\njax.experimental.pallas.{gpu,tpu} now import directly from the relevant\njax._src.pallas.{triton,mosaic} submodules.\n\nPiperOrigin-RevId: 641850715","shortMessageHtmlLink":"Removed the double re-exporting of Pallas GPU/TPU APIs"}},{"before":"d54c587b2b5a0e7aefa360fa0e5a5bcbe49d6c83","after":"01644a182abdde9af193e6351cae16bc0fa43ce9","ref":"refs/heads/test_641193132","pushedAt":"2024-06-10T11:17:19.000Z","pushType":"force_push","commitsCount":0,"pusher":{"login":"copybara-service[bot]","name":null,"path":"/apps/copybara-service","primaryAvatarUrl":"https://avatars.githubusercontent.com/in/44061?s=80&v=4"},"commit":{"message":"Removed kernel_regeneration_util from Mosaic\n\nIt was only used for persisting kernel metadata, and that can be done via\njax.named_scope instead.\n\nPiperOrigin-RevId: 641193132","shortMessageHtmlLink":"Removed kernel_regeneration_util from Mosaic"}},{"before":"115510ed6858eb5fb9dad37933d8c44689324cdd","after":"82a36d1b9961bff1b8f8a76c964cf5373cffb45d","ref":"refs/heads/test_641850715","pushedAt":"2024-06-10T11:03:44.000Z","pushType":"force_push","commitsCount":0,"pusher":{"login":"copybara-service[bot]","name":null,"path":"/apps/copybara-service","primaryAvatarUrl":"https://avatars.githubusercontent.com/in/44061?s=80&v=4"},"commit":{"message":"Removed the double re-exporting of Pallas GPU/TPU APIs\n\njax.experimental.pallas.{gpu,tpu} now import directly from the relevant\njax._src.pallas.{triton,mosaic} submodules.\n\nPiperOrigin-RevId: 641850715","shortMessageHtmlLink":"Removed the double re-exporting of Pallas GPU/TPU APIs"}},{"before":null,"after":"115510ed6858eb5fb9dad37933d8c44689324cdd","ref":"refs/heads/test_641850715","pushedAt":"2024-06-10T11:02:35.000Z","pushType":"branch_creation","commitsCount":0,"pusher":{"login":"copybara-service[bot]","name":null,"path":"/apps/copybara-service","primaryAvatarUrl":"https://avatars.githubusercontent.com/in/44061?s=80&v=4"},"commit":{"message":"Removed the double re-exporting of Pallas GPU/TPU APIs\n\njax.experimental.pallas.{gpu,tpu} now import directly from the relevant\njax._src.pallas.{triton,mosaic} submodules.\n\nPiperOrigin-RevId: 641850715","shortMessageHtmlLink":"Removed the double re-exporting of Pallas GPU/TPU APIs"}},{"before":"030d73f0c4c4b1d2d159c43604e14c873dc066f3","after":"d54c587b2b5a0e7aefa360fa0e5a5bcbe49d6c83","ref":"refs/heads/test_641193132","pushedAt":"2024-06-10T11:02:19.000Z","pushType":"force_push","commitsCount":0,"pusher":{"login":"copybara-service[bot]","name":null,"path":"/apps/copybara-service","primaryAvatarUrl":"https://avatars.githubusercontent.com/in/44061?s=80&v=4"},"commit":{"message":"Removed kernel_regeneration_util from Mosaic\n\nIt was only used for persisting kernel metadata, and that can be done via\njax.named_call instead.\n\nPiperOrigin-RevId: 641193132","shortMessageHtmlLink":"Removed kernel_regeneration_util from Mosaic"}},{"before":"1a828c082295354816efb2c4f27efbbe54574877","after":"030d73f0c4c4b1d2d159c43604e14c873dc066f3","ref":"refs/heads/test_641193132","pushedAt":"2024-06-10T10:45:57.000Z","pushType":"force_push","commitsCount":0,"pusher":{"login":"copybara-service[bot]","name":null,"path":"/apps/copybara-service","primaryAvatarUrl":"https://avatars.githubusercontent.com/in/44061?s=80&v=4"},"commit":{"message":"Removed kernel_regeneration_util from Mosaic\n\nIt was only used for persisting kernel metadata, and that can be done via\njax.named_call instead.\n\nPiperOrigin-RevId: 641193132","shortMessageHtmlLink":"Removed kernel_regeneration_util from Mosaic"}},{"before":"2ade7e75265a7c1da0c6d2cb373b08de631627a1","after":null,"ref":"refs/heads/test_640919642","pushedAt":"2024-06-10T10:13:42.000Z","pushType":"branch_deletion","commitsCount":0,"pusher":{"login":"copybara-service[bot]","name":null,"path":"/apps/copybara-service","primaryAvatarUrl":"https://avatars.githubusercontent.com/in/44061?s=80&v=4"}},{"before":"af95803d0009647c049551f5cac68ac8578206b6","after":"2ade7e75265a7c1da0c6d2cb373b08de631627a1","ref":"refs/heads/main","pushedAt":"2024-06-10T10:13:40.000Z","pushType":"push","commitsCount":1,"pusher":{"login":"copybara-service[bot]","name":null,"path":"/apps/copybara-service","primaryAvatarUrl":"https://avatars.githubusercontent.com/in/44061?s=80&v=4"},"commit":{"message":"[pallas] Move the hardware_generation query in the code path that needs it\n\nThis change allows us to lower and export Pallas calls even\non machines that do not have TPUs, in many cases.\n\nPiperOrigin-RevId: 641841079","shortMessageHtmlLink":"[pallas] Move the hardware_generation query in the code path that nee…"}},{"before":"a8c88dbaf730c35dbbf1db6aaeea78e2a0622b29","after":"2ade7e75265a7c1da0c6d2cb373b08de631627a1","ref":"refs/heads/test_640919642","pushedAt":"2024-06-10T10:13:39.000Z","pushType":"force_push","commitsCount":0,"pusher":{"login":"copybara-service[bot]","name":null,"path":"/apps/copybara-service","primaryAvatarUrl":"https://avatars.githubusercontent.com/in/44061?s=80&v=4"},"commit":{"message":"[pallas] Move the hardware_generation query in the code path that needs it\n\nThis change allows us to lower and export Pallas calls even\non machines that do not have TPUs, in many cases.\n\nPiperOrigin-RevId: 641841079","shortMessageHtmlLink":"[pallas] Move the hardware_generation query in the code path that nee…"}},{"before":"617f8ffc3477fc622e263e43548e9b26f04b7e31","after":"a8c88dbaf730c35dbbf1db6aaeea78e2a0622b29","ref":"refs/heads/test_640919642","pushedAt":"2024-06-10T09:46:58.000Z","pushType":"force_push","commitsCount":0,"pusher":{"login":"copybara-service[bot]","name":null,"path":"/apps/copybara-service","primaryAvatarUrl":"https://avatars.githubusercontent.com/in/44061?s=80&v=4"},"commit":{"message":"[pallas] Move the hardware_generation query in the code path that needs it\n\nThis change allows us to lower and export Pallas calls even\non machines that do not have TPUs, in many cases.\n\nPiperOrigin-RevId: 640919642","shortMessageHtmlLink":"[pallas] Move the hardware_generation query in the code path that nee…"}},{"before":"7df71635c9bd29aaf6aff0ddb66a923805e0e883","after":null,"ref":"refs/heads/test_641720791","pushedAt":"2024-06-10T09:32:19.000Z","pushType":"branch_deletion","commitsCount":0,"pusher":{"login":"copybara-service[bot]","name":null,"path":"/apps/copybara-service","primaryAvatarUrl":"https://avatars.githubusercontent.com/in/44061?s=80&v=4"}},{"before":"8fbe65b4b2e558fe497a6006554e57575a4d25a2","after":"af95803d0009647c049551f5cac68ac8578206b6","ref":"refs/heads/main","pushedAt":"2024-06-10T09:29:16.000Z","pushType":"push","commitsCount":2,"pusher":{"login":"copybara-service[bot]","name":null,"path":"/apps/copybara-service","primaryAvatarUrl":"https://avatars.githubusercontent.com/in/44061?s=80&v=4"},"commit":{"message":"Merge pull request #21759 from rajasekharporeddy:testbranch1\n\nPiperOrigin-RevId: 641831969","shortMessageHtmlLink":"Merge pull request #21759 from rajasekharporeddy:testbranch1"}},{"before":"c9b178822ae35c53b1c3ec3ebaecb602126beb1a","after":"df7a8db35d65549305bc5e87e14c48a3c0d40a44","ref":"refs/heads/test_641341474","pushedAt":"2024-06-10T07:56:09.000Z","pushType":"force_push","commitsCount":0,"pusher":{"login":"copybara-service[bot]","name":null,"path":"/apps/copybara-service","primaryAvatarUrl":"https://avatars.githubusercontent.com/in/44061?s=80&v=4"},"commit":{"message":"Batch `pxla.shard_args` calls triggered by `jax.device_put`\n\nWith this change, one `jax.device_put` call now corresponds to one `device_put_p.bind()` instead of one per array. Immediately, this improves the performance of `jax.device_put(...)` with a large pytree by amortizing the calls to `pxla.shard_args`. Also, backends that implement efficient batch transfers (https://github.com/tensorflow/tensorflow/pull/69096) will batch device-to-device transfers across arrays in a pytree.\n\nThe api_benchmark indicates that `device_put` with a single array is slower by 10%, but `device_put` with 10 to 1000 arrays is ~30% faster. This seems acceptable since a single-array `device_put` is fast anyway (~45us on my workstation). If this becomes a problem, this can be addressed by rewriting part of `device_put` in C++.\n\nPiperOrigin-RevId: 641341474","shortMessageHtmlLink":"Batch pxla.shard_args calls triggered by jax.device_put"}},{"before":"f32a654eb9ef3bd4d9340a875840e4f2ac6c77c3","after":"c9b178822ae35c53b1c3ec3ebaecb602126beb1a","ref":"refs/heads/test_641341474","pushedAt":"2024-06-10T06:52:20.000Z","pushType":"force_push","commitsCount":0,"pusher":{"login":"copybara-service[bot]","name":null,"path":"/apps/copybara-service","primaryAvatarUrl":"https://avatars.githubusercontent.com/in/44061?s=80&v=4"},"commit":{"message":"Batch `pxla.shard_args` calls triggered by `jax.device_put`\n\nWith this change, one `jax.device_put` call now corresponds to one `device_put_p.bind()` instead of one per array. Immediately, this improves the performance of `jax.device_put(...)` with a large pytree by amortizing the calls to `pxla.shard_args`. Also, backends that implement efficient batch transfers (https://github.com/tensorflow/tensorflow/pull/69096) will batch device-to-device transfers across arrays in a pytree.\n\nThe api_benchmark indicates that `device_put` with a single array is slower by 10%, but `device_put` with 10 to 1000 arrays is ~30% faster. This seems acceptable since a single-array `device_put` is fast anyway (~45us on my workstation). If this becomes a problem, this can be addressed by rewriting part of `device_put` in C++.\n\nPiperOrigin-RevId: 641341474","shortMessageHtmlLink":"Batch pxla.shard_args calls triggered by jax.device_put"}},{"before":"1d9a5ee1e086e8d441d302c3c19213a7d680f179","after":"f32a654eb9ef3bd4d9340a875840e4f2ac6c77c3","ref":"refs/heads/test_641341474","pushedAt":"2024-06-10T00:46:30.000Z","pushType":"force_push","commitsCount":0,"pusher":{"login":"copybara-service[bot]","name":null,"path":"/apps/copybara-service","primaryAvatarUrl":"https://avatars.githubusercontent.com/in/44061?s=80&v=4"},"commit":{"message":"Batch `pxla.shard_args` calls triggered by `jax.device_put`\n\nWith this change, one `jax.device_put` call now corresponds to one `device_put_p.bind()` instead of one per array. Immediately, this improves the performance of `jax.device_put(...)` with a large pytree by amortizing the calls to `pxla.shard_args`. Also, backends that implement efficient batch transfers (https://github.com/tensorflow/tensorflow/pull/69096) will batch device-to-device transfers across arrays in a pytree.\n\nThe api_benchmark indicates that `device_put` with a single array is slower by 10%, but `device_put` with 10 to 1000 arrays is ~30% faster. This seems acceptable since a single-array `device_put` is fast anyway (~45us on my workstation). If this becomes a problem, this can be addressed by rewriting part of `device_put` in C++.\n\nPiperOrigin-RevId: 641341474","shortMessageHtmlLink":"Batch pxla.shard_args calls triggered by jax.device_put"}},{"before":"3874b0bc1022c68a61cb756d12ecdc6d371f8805","after":"1d9a5ee1e086e8d441d302c3c19213a7d680f179","ref":"refs/heads/test_641341474","pushedAt":"2024-06-10T00:40:33.000Z","pushType":"force_push","commitsCount":0,"pusher":{"login":"copybara-service[bot]","name":null,"path":"/apps/copybara-service","primaryAvatarUrl":"https://avatars.githubusercontent.com/in/44061?s=80&v=4"},"commit":{"message":"Batch `pxla.shard_args` calls triggered by `jax.device_put`\n\nWith this change, one `jax.device_put` call now corresponds to one `device_put_p.bind()` instead of one per array. Immediately, this improves the performance of `jax.device_put(...)` with a large pytree by amortizing the calls to `pxla.shard_args`. Also, backends that implement efficient batch transfers (https://github.com/tensorflow/tensorflow/pull/69096) will batch device-to-device transfers across arrays in a pytree.\n\nThe api_benchmark indicates that `device_put` with a single array is slower by 10%, but `device_put` with 10 to 1000 arrays is ~30% faster. This seems acceptable since a single-array `device_put` is fast anyway (~45us on my workstation). If this becomes a problem, this can be addressed by rewriting part of `device_put` in C++.\n\nPiperOrigin-RevId: 641341474","shortMessageHtmlLink":"Batch pxla.shard_args calls triggered by jax.device_put"}}],"hasNextPage":true,"hasPreviousPage":false,"activityType":"all","actor":null,"timePeriod":"all","sort":"DESC","perPage":30,"cursor":"djE6ks8AAAAEYSUy2QA","startCursor":null,"endCursor":null}},"title":"Activity · google/jax"}