diff --git a/pathwaysutils/experimental/reshard.py b/pathwaysutils/experimental/reshard.py index 09bfe08..4f2b19a 100644 --- a/pathwaysutils/experimental/reshard.py +++ b/pathwaysutils/experimental/reshard.py @@ -150,24 +150,27 @@ def reshard( # put them back together in the right order. array_info_lambda = lambda: {"arrays": [], "indices": [], "dst_shardings": []} jax_arrays = collections.defaultdict(array_info_lambda) - non_jax_arrays = array_info_lambda() + non_reshardable_arrays = array_info_lambda() for index, (arr, dst_sharding) in enumerate(zip(flat_x, flat_sharding)): if not isinstance(dst_sharding, jax.sharding.Sharding): raise ValueError("`sharding` must contain only `jax.sharding.Sharding`") - if isinstance(arr, jax.Array): + if not isinstance(arr, jax.Array) or ( + hasattr(arr, "dtype") + and jax.dtypes.issubdtype(arr.dtype, jax.dtypes.prng_key) + ): + non_reshardable_arrays["arrays"].append(arr) + non_reshardable_arrays["indices"].append(index) + non_reshardable_arrays["dst_shardings"].append(dst_sharding) + else: device_set = frozenset(arr.sharding.device_set) jax_arrays[device_set]["arrays"].append(arr) jax_arrays[device_set]["indices"].append(index) jax_arrays[device_set]["dst_shardings"].append(dst_sharding) - else: - non_jax_arrays["arrays"].append(arr) - non_jax_arrays["indices"].append(index) - non_jax_arrays["dst_shardings"].append(dst_sharding) - - if non_jax_arrays["arrays"]: - non_jax_arrays["arrays"] = jax.device_put( - non_jax_arrays["arrays"], - non_jax_arrays["dst_shardings"], + + if non_reshardable_arrays["arrays"]: + non_reshardable_arrays["arrays"] = jax.device_put( + non_reshardable_arrays["arrays"], + non_reshardable_arrays["dst_shardings"], donate=donate, may_alias=may_alias, ) @@ -186,7 +189,9 @@ def reshard( ).execute(tuple(array_info["arrays"])) result = [None] * len(flat_x) - for arr, idx in zip(non_jax_arrays["arrays"], non_jax_arrays["indices"]): + for arr, idx in zip( + non_reshardable_arrays["arrays"], non_reshardable_arrays["indices"] + ): result[idx] = arr for array_info in jax_arrays.values(): for arr, idx in zip(array_info["arrays"], array_info["indices"]):