diff --git a/src/Compiler.jl b/src/Compiler.jl index 76e6ab7462..d1d37f3455 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -712,6 +712,9 @@ function compile_mlir!( # Save in the TLS whether we are raising. We identify that condition by # checking whether the user set an explicit list of passes, or chose # `raise=true` to use the default passes. + if backend == "tpu" && raise isa Bool + raise = true + end is_raising = raise isa String || raise activate_raising!(is_raising) @@ -743,7 +746,7 @@ function compile_mlir!( toolkit = Reactant_jll.ptxas_path[1:(end - length("/bin/ptxas"))] end - if backend == "cpu" + if backend == "cpu" || backend == "tpu" kern = "lower-kernel{backend=cpu},canonicalize" jit = "lower-jit{openmp=true backend=cpu},symbol-dce" elseif DEBUG_KERNEL[] diff --git a/src/xla/XLA.jl b/src/xla/XLA.jl index 9e9d305a42..318267603c 100644 --- a/src/xla/XLA.jl +++ b/src/xla/XLA.jl @@ -184,8 +184,9 @@ for runtime in (:PJRT, :IFRT) state.default_client = cpu # Try TPU if possible, then try GPU (CUDA) + if !Reactant.precompiling() @static if !Sys.isapple() - if Reactant.has_tpu() + if Reactant.has_tpu() dataset_dir = @get_scratch!("libtpu") download_tpu(dataset_dir) try @@ -202,7 +203,6 @@ for runtime in (:PJRT, :IFRT) println(stdout, e) end else - if !Reactant.precompiling() try if was_initialized && haskey(state.clients, "gpu") XLA.free_client(state.clients["gpu"])