You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Compilation cache does not trigger for jitted functions with custom_partitioning ops. After running the JAX program multiple times, there is a separate entry with different hash for each run.
I believe the cache misses because the backend_config parameter to the CustomSPMDPartitioning custom call is the address of some descriptor data structure and that is different from run to run. As a result, the MLIR is different in different runs and the cache never triggers. Below is the HLO for the function f, the backend_config here is the address that differs across runs.
Yes. That's correct and how it works at the moment. The custom partitioning ultimately refers to a Python object, which is why it's not stable run to run.
If this only cover the python callback, could we just remove it from the key?
Not everything is 100% versioned right now (like XLA isn't versioned).
So it would just end up to the end user responsibility to make sure handle the cache correctly?
Or could we get the python ast of the callbacks and hash it? I see that as more work and not sure it is useful.
Description
Compilation cache does not trigger for jitted functions with custom_partitioning ops. After running the JAX program multiple times, there is a separate entry with different hash for each run.
Here is the invocation:
Here are the contents of the cache afterwards:
The program with the custom-partitioned op:
I believe the cache misses because the
backend_config
parameter to theCustomSPMDPartitioning
custom call is the address of some descriptor data structure and that is different from run to run. As a result, the MLIR is different in different runs and the cache never triggers. Below is the HLO for the functionf
, thebackend_config
here is the address that differs across runs.System info (python version, jaxlib version, accelerator, etc.)
The text was updated successfully, but these errors were encountered: