diff --git a/src/Compiler.jl b/src/Compiler.jl index 9d98ec2fe0..34cc887c04 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -1049,6 +1049,13 @@ function compile_xla(f, args; client=nothing, optimize=true, no_nan=false, devic end end + num_devices = XLA.ClientNumAddressableDevices(client) + if num_devices != 1 + error( + "Unsupported client with multiple addressible devices (we need to pass right shard data)", + ) + end + # compile MLIR module to XLA executable exec = XLA.Compile(client, mod) (