Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion src/Compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -712,6 +712,9 @@ function compile_mlir!(
# Save in the TLS whether we are raising. We identify that condition by
# checking whether the user set an explicit list of passes, or chose
# `raise=true` to use the default passes.
if backend == "tpu" && raise isa Bool
raise = true
end
is_raising = raise isa String || raise
activate_raising!(is_raising)

Expand Down Expand Up @@ -743,7 +746,7 @@ function compile_mlir!(
toolkit = Reactant_jll.ptxas_path[1:(end - length("/bin/ptxas"))]
end

if backend == "cpu"
if backend == "cpu" || backend == "tpu"
kern = "lower-kernel{backend=cpu},canonicalize"
jit = "lower-jit{openmp=true backend=cpu},symbol-dce"
elseif DEBUG_KERNEL[]
Expand Down
4 changes: 2 additions & 2 deletions src/xla/XLA.jl
Original file line number Diff line number Diff line change
Expand Up @@ -184,8 +184,9 @@ for runtime in (:PJRT, :IFRT)
state.default_client = cpu

# Try TPU if possible, then try GPU (CUDA)
if !Reactant.precompiling()
@static if !Sys.isapple()
if Reactant.has_tpu()
if Reactant.has_tpu()
dataset_dir = @get_scratch!("libtpu")
download_tpu(dataset_dir)
try
Expand All @@ -202,7 +203,6 @@ for runtime in (:PJRT, :IFRT)
println(stdout, e)
end
else
if !Reactant.precompiling()
try
if was_initialized && haskey(state.clients, "gpu")
XLA.free_client(state.clients["gpu"])
Expand Down
Loading