diff --git a/benchmark/aggregate.jl b/benchmark/aggregate.jl index aaf4e25fa5..7d71352255 100644 --- a/benchmark/aggregate.jl +++ b/benchmark/aggregate.jl @@ -15,5 +15,5 @@ for backend in BACKENDS end open(joinpath(dirname(@__FILE__), "results", "combinedbenchmarks.json"), "w") do io - JSON3.pretty(io, JSON3.write(all_results)) + return JSON3.pretty(io, JSON3.write(all_results)) end diff --git a/benchmark/runbenchmarks.jl b/benchmark/runbenchmarks.jl index 1a80c3a0d7..3c389e0ad2 100644 --- a/benchmark/runbenchmarks.jl +++ b/benchmark/runbenchmarks.jl @@ -44,7 +44,7 @@ for (i, (k, v)) in enumerate(results) end open(joinpath(filepath, filename), "w") do io - JSON3.pretty(io, JSON3.write(standardized_results)) + return JSON3.pretty(io, JSON3.write(standardized_results)) end @info "Saved results to $(joinpath(filepath, filename))" diff --git a/deps/build_local.jl b/deps/build_local.jl index 1fa6955c58..8d1a43821e 100644 --- a/deps/build_local.jl +++ b/deps/build_local.jl @@ -252,7 +252,7 @@ run(Cmd(Cmd(build_cmd_list); dir=source_dir)) # Discover built libraries built_libs = filter(readdir(joinpath(source_dir, "bazel-bin"))) do file - endswith(file, "Extra.so") && startswith(file, "lib") + return endswith(file, "Extra.so") && startswith(file, "lib") end lib_path = joinpath(source_dir, "bazel-bin", only(built_libs)) diff --git a/ext/ReactantKernelAbstractionsExt.jl b/ext/ReactantKernelAbstractionsExt.jl index 1e38b81b9a..617075b086 100644 --- a/ext/ReactantKernelAbstractionsExt.jl +++ b/ext/ReactantKernelAbstractionsExt.jl @@ -109,14 +109,26 @@ function (obj::KA.Kernel{ReactantBackend})(args...; ndrange=nothing, workgroupsi return nothing end -Reactant.@reactant_overlay Base.@nospecializeinfer @noinline function ( - obj::KA.Kernel{ReactantBackend} -)( - @nospecialize args...; ndrange=nothing, workgroupsize=nothing -) - return Reactant.call_with_reactant( - Reactant.ka_with_reactant, ndrange, workgroupsize, obj, args... +@static if VERSION < v"1.12-" + Reactant.@reactant_overlay Base.@nospecializeinfer @noinline function ( + obj::KA.Kernel{ReactantBackend} + )( + @nospecialize args...; ndrange=nothing, workgroupsize=nothing ) + return Reactant.call_with_reactant( + Reactant.ka_with_reactant, ndrange, workgroupsize, obj, args... + ) + end +else + Reactant.@reactant_overlay function (obj::KA.Kernel{ReactantBackend})( + args...; ndrange=nothing, workgroupsize=nothing + ) + Base.@_noinline_meta + Base.@_nospecializeinfer_meta + return Reactant.call_with_reactant( + Reactant.ka_with_reactant, ndrange, workgroupsize, obj, args... + ) + end end end diff --git a/src/Interpreter.jl b/src/Interpreter.jl index 6ffcec0ab3..0e73cff527 100644 --- a/src/Interpreter.jl +++ b/src/Interpreter.jl @@ -41,18 +41,31 @@ function set_reactant_abi( if length(argtypes) != 1 @static if VERSION < v"1.11.0-" return CallMeta(Union{}, Effects(), NoCallInfo()) - else + elseif VERSION < v"1.12.0-" return CallMeta(Union{}, Union{}, Effects(), NoCallInfo()) + else + return Core.Compiler.Future{Core.Compiler.CallMeta}( + CallMeta(Union{}, Union{}, Effects(), NoCallInfo()) + ) end end @static if VERSION < v"1.11.0-" return CallMeta( Core.Const(true), Core.Compiler.EFFECTS_TOTAL, MethodResultPure() ) - else + elseif VERSION < v"1.12.0-" return CallMeta( Core.Const(true), Union{}, Core.Compiler.EFFECTS_TOTAL, MethodResultPure() ) + else + return Core.Compiler.Future{Core.Compiler.CallMeta}( + CallMeta( + Core.Const(true), + Union{}, + Core.Compiler.EFFECTS_TOTAL, + MethodResultPure(), + ), + ) end end diff --git a/src/utils.jl b/src/utils.jl index eb8e2b3701..8688603ff1 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -369,7 +369,11 @@ function rewrite_inst(inst, ir, interp, RT, guaranteed_error) end end if Meta.isexpr(inst, :invoke) - omi = inst.args[1]::Core.MethodInstance + omi = if inst.args[1] isa Core.MethodInstance + inst.args[1] + else + (inst.args[1]::Core.CodeInstance).def + end sig = omi.specTypes ft = sig.parameters[1] argsig = sig.parameters[2:end] @@ -518,22 +522,42 @@ function make_oc_ref( if Base.isassigned(oc_captures) return oc_captures[] else - ores = ccall( - :jl_new_opaque_closure_from_code_info, - Any, - (Any, Any, Any, Any, Any, Cint, Any, Cint, Cint, Any, Cint), - sig, - rt, - rt, - @__MODULE__, - src, - 0, - nothing, - nargs, - isva, - f, - true, - )::Core.OpaqueClosure + ores = @static if VERSION < v"1.11" + ccall( + :jl_new_opaque_closure_from_code_info, + Any, + (Any, Any, Any, Any, Any, Cint, Any, Cint, Cint, Any, Cint), + sig, + rt, + rt, + @__MODULE__, + src, + 0, + nothing, + nargs, + isva, + f, + true, + )::Core.OpaqueClosure + else + ccall( + :jl_new_opaque_closure_from_code_info, + Any, + (Any, Any, Any, Any, Any, Cint, Any, Cint, Cint, Any, Cint, Cint), + sig, # jl_tupletype_t *argt + rt, # jl_value_t *rt_lb + rt, # jl_value_t *rt_ub + @__MODULE__, # jl_module_t *mod + src, # jl_code_info_t *ci + 0, # int lineno + nothing, # jl_value_t *file + nargs, # int nargs + isva, # int isva + f, # jl_value_t *env + true, # int do_compile + true, # int isinferred + )::Core.OpaqueClosure + end oc_captures[] = ores return ores end @@ -725,7 +749,9 @@ function call_with_reactant_generator( src.slotnames = fill(:none, length(ir.argtypes) + 1) src.slotflags = fill(zero(UInt8), length(ir.argtypes)) src.slottypes = copy(ir.argtypes) - src.rettype = rt + @static if VERSION < v"1.12.0-" + src.rettype = rt + end src = CC.ir_to_codeinf!(src, ir) if DEBUG_INTERP[] @@ -747,6 +773,12 @@ function call_with_reactant_generator( # and the REDUB_ARGUMENTS_NAME tuple of input arguments code_info.slotnames = Any[:call_with_reactant, REDUB_ARGUMENTS_NAME] code_info.slotflags = UInt8[0x00, 0x00] + + if VERSION >= v"1.12-" + code_info.nargs = length(code_info.slotnames) + code_info.isva = true + end + n_prepended_slots = 2 overdub_args_slot = Core.SlotNumber(n_prepended_slots) @@ -754,10 +786,18 @@ function call_with_reactant_generator( # into these overdubbed equivalents instead of updating `code_info` in-place. Then, at # the end of the pass, we'll reset `code_info` fields accordingly. overdubbed_code = Any[] - overdubbed_codelocs = Int32[] + + overdubbed_codelocs = @static if isdefined(Core, :DebugInfo) + nothing + else + Int32[] + end + function push_inst!(inst) push!(overdubbed_code, inst) - push!(overdubbed_codelocs, code_info.codelocs[1]) + @static if !isdefined(Core, :DebugInfo) + push!(overdubbed_codelocs, code_info.codelocs[1]) + end return Core.SSAValue(length(overdubbed_code)) end # Rewire the arguments from our tuple input of fn and args, to the corresponding calling convention @@ -781,6 +821,11 @@ function call_with_reactant_generator( iter_args = min(n_actual_args, n_method_args - 1) end + if VERSION >= v"1.12-" + src.nargs = length(src.slottypes) + src.isva = false + end + for i in 1:iter_args actual_argument = Expr( :call, Core.GlobalRef(Core, :getfield), overdub_args_slot, offset @@ -862,12 +907,9 @@ function call_with_reactant_generator( farg = nothing rep = Expr(:call, make_oc, dict, octup, rt, src, ocnargs, ocva, farg) push_inst!(rep) - Core.SSAValue(length(overdubbed_code)) end - push_inst!(Expr(:call, oc, fn_args[1:end]...)) - - ocres = Core.SSAValue(length(overdubbed_code)) + ocres = push_inst!(Expr(:call, oc, fn_args[1:end]...)) if DEBUG_INTERP[] push_inst!(Expr(:call, safe_print, "ocres", ocres)) @@ -882,7 +924,13 @@ function call_with_reactant_generator( end code_info.code = overdubbed_code - code_info.codelocs = overdubbed_codelocs + + @static if isdefined(Core, :DebugInfo) + code_info.debuginfo = Core.DebugInfo(:none) # Core.DebugInfoStream(overdubbed_codelocs), length(overdubbed_codelocs)) + else + code_info.codelocs = overdubbed_codelocs + end + code_info.ssavaluetypes = length(overdubbed_code) code_info.ssaflags = [0x00 for _ in 1:length(overdubbed_code)] # XXX we need to copy flags that are set for the original code diff --git a/src/xla/PJRT/LoadedExecutable.jl b/src/xla/PJRT/LoadedExecutable.jl index 65aedbf6d9..02e884f6ae 100644 --- a/src/xla/PJRT/LoadedExecutable.jl +++ b/src/xla/PJRT/LoadedExecutable.jl @@ -105,7 +105,11 @@ function XLA.compile( end function execute_ir(N, M, n_outs, with_device::Bool, nmesh_ids::Int64) - ptr = sizeof(Int) == sizeof(Int64) ? "i64" : "i32" + ptr = @static if VERSION < v"1.12" + sizeof(Int) == sizeof(Int64) ? "i64" : "i32" + else + "ptr" + end cint = sizeof(Cint) == sizeof(Int64) ? "i64" : "i32" args = N > 0 ? ", [$N x $ptr] %inps, [$M x i8] %donated" : "" if with_device