Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion benchmark/aggregate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,5 @@ for backend in BACKENDS
end

open(joinpath(dirname(@__FILE__), "results", "combinedbenchmarks.json"), "w") do io
JSON3.pretty(io, JSON3.write(all_results))
return JSON3.pretty(io, JSON3.write(all_results))
end
2 changes: 1 addition & 1 deletion benchmark/runbenchmarks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ for (i, (k, v)) in enumerate(results)
end

open(joinpath(filepath, filename), "w") do io
JSON3.pretty(io, JSON3.write(standardized_results))
return JSON3.pretty(io, JSON3.write(standardized_results))
end

@info "Saved results to $(joinpath(filepath, filename))"
2 changes: 1 addition & 1 deletion deps/build_local.jl
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ run(Cmd(Cmd(build_cmd_list); dir=source_dir))

# Discover built libraries
built_libs = filter(readdir(joinpath(source_dir, "bazel-bin"))) do file
endswith(file, "Extra.so") && startswith(file, "lib")
return endswith(file, "Extra.so") && startswith(file, "lib")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

return outside of functions doesn't make much sense

end

lib_path = joinpath(source_dir, "bazel-bin", only(built_libs))
Expand Down
26 changes: 19 additions & 7 deletions ext/ReactantKernelAbstractionsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -109,14 +109,26 @@ function (obj::KA.Kernel{ReactantBackend})(args...; ndrange=nothing, workgroupsi
return nothing
end

Reactant.@reactant_overlay Base.@nospecializeinfer @noinline function (
obj::KA.Kernel{ReactantBackend}
)(
@nospecialize args...; ndrange=nothing, workgroupsize=nothing
)
return Reactant.call_with_reactant(
Reactant.ka_with_reactant, ndrange, workgroupsize, obj, args...
@static if VERSION < v"1.12-"
Reactant.@reactant_overlay Base.@nospecializeinfer @noinline function (
obj::KA.Kernel{ReactantBackend}
)(
@nospecialize args...; ndrange=nothing, workgroupsize=nothing
)
return Reactant.call_with_reactant(
Reactant.ka_with_reactant, ndrange, workgroupsize, obj, args...
)
end
else
Reactant.@reactant_overlay function (obj::KA.Kernel{ReactantBackend})(
args...; ndrange=nothing, workgroupsize=nothing
)
Base.@_noinline_meta
Base.@_nospecializeinfer_meta
return Reactant.call_with_reactant(
Reactant.ka_with_reactant, ndrange, workgroupsize, obj, args...
)
end
end

end
17 changes: 15 additions & 2 deletions src/Interpreter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,18 +41,31 @@ function set_reactant_abi(
if length(argtypes) != 1
@static if VERSION < v"1.11.0-"
return CallMeta(Union{}, Effects(), NoCallInfo())
else
elseif VERSION < v"1.12.0-"
return CallMeta(Union{}, Union{}, Effects(), NoCallInfo())
else
return Core.Compiler.Future{Core.Compiler.CallMeta}(
CallMeta(Union{}, Union{}, Effects(), NoCallInfo())
)
end
end
@static if VERSION < v"1.11.0-"
return CallMeta(
Core.Const(true), Core.Compiler.EFFECTS_TOTAL, MethodResultPure()
)
else
elseif VERSION < v"1.12.0-"
return CallMeta(
Core.Const(true), Union{}, Core.Compiler.EFFECTS_TOTAL, MethodResultPure()
)
else
return Core.Compiler.Future{Core.Compiler.CallMeta}(
CallMeta(
Core.Const(true),
Union{},
Core.Compiler.EFFECTS_TOTAL,
MethodResultPure(),
),
)
end
end

Expand Down
98 changes: 73 additions & 25 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -369,7 +369,11 @@ function rewrite_inst(inst, ir, interp, RT, guaranteed_error)
end
end
if Meta.isexpr(inst, :invoke)
omi = inst.args[1]::Core.MethodInstance
omi = if inst.args[1] isa Core.MethodInstance
inst.args[1]
else
(inst.args[1]::Core.CodeInstance).def
end
sig = omi.specTypes
ft = sig.parameters[1]
argsig = sig.parameters[2:end]
Expand Down Expand Up @@ -518,22 +522,42 @@ function make_oc_ref(
if Base.isassigned(oc_captures)
return oc_captures[]
else
ores = ccall(
:jl_new_opaque_closure_from_code_info,
Any,
(Any, Any, Any, Any, Any, Cint, Any, Cint, Cint, Any, Cint),
sig,
rt,
rt,
@__MODULE__,
src,
0,
nothing,
nargs,
isva,
f,
true,
)::Core.OpaqueClosure
ores = @static if VERSION < v"1.11"
ccall(
:jl_new_opaque_closure_from_code_info,
Any,
(Any, Any, Any, Any, Any, Cint, Any, Cint, Cint, Any, Cint),
sig,
rt,
rt,
@__MODULE__,
src,
0,
nothing,
nargs,
isva,
f,
true,
)::Core.OpaqueClosure
else
ccall(
:jl_new_opaque_closure_from_code_info,
Any,
(Any, Any, Any, Any, Any, Cint, Any, Cint, Cint, Any, Cint, Cint),
sig, # jl_tupletype_t *argt
rt, # jl_value_t *rt_lb
rt, # jl_value_t *rt_ub
@__MODULE__, # jl_module_t *mod
src, # jl_code_info_t *ci
0, # int lineno
nothing, # jl_value_t *file
nargs, # int nargs
isva, # int isva
f, # jl_value_t *env
true, # int do_compile
true, # int isinferred
)::Core.OpaqueClosure
end
oc_captures[] = ores
return ores
end
Expand Down Expand Up @@ -725,7 +749,9 @@ function call_with_reactant_generator(
src.slotnames = fill(:none, length(ir.argtypes) + 1)
src.slotflags = fill(zero(UInt8), length(ir.argtypes))
src.slottypes = copy(ir.argtypes)
src.rettype = rt
@static if VERSION < v"1.12.0-"
src.rettype = rt
end
src = CC.ir_to_codeinf!(src, ir)

if DEBUG_INTERP[]
Expand All @@ -747,17 +773,31 @@ function call_with_reactant_generator(
# and the REDUB_ARGUMENTS_NAME tuple of input arguments
code_info.slotnames = Any[:call_with_reactant, REDUB_ARGUMENTS_NAME]
code_info.slotflags = UInt8[0x00, 0x00]

if VERSION >= v"1.12-"
code_info.nargs = length(code_info.slotnames)
code_info.isva = true
end

n_prepended_slots = 2
overdub_args_slot = Core.SlotNumber(n_prepended_slots)

# For the sake of convenience, the rest of this pass will translate `code_info`'s fields
# into these overdubbed equivalents instead of updating `code_info` in-place. Then, at
# the end of the pass, we'll reset `code_info` fields accordingly.
overdubbed_code = Any[]
overdubbed_codelocs = Int32[]

overdubbed_codelocs = @static if isdefined(Core, :DebugInfo)
nothing
else
Int32[]
end

function push_inst!(inst)
push!(overdubbed_code, inst)
push!(overdubbed_codelocs, code_info.codelocs[1])
@static if !isdefined(Core, :DebugInfo)
push!(overdubbed_codelocs, code_info.codelocs[1])
end
return Core.SSAValue(length(overdubbed_code))
end
# Rewire the arguments from our tuple input of fn and args, to the corresponding calling convention
Expand All @@ -781,6 +821,11 @@ function call_with_reactant_generator(
iter_args = min(n_actual_args, n_method_args - 1)
end

if VERSION >= v"1.12-"
src.nargs = length(src.slottypes)
src.isva = false
end

for i in 1:iter_args
actual_argument = Expr(
:call, Core.GlobalRef(Core, :getfield), overdub_args_slot, offset
Expand Down Expand Up @@ -862,12 +907,9 @@ function call_with_reactant_generator(
farg = nothing
rep = Expr(:call, make_oc, dict, octup, rt, src, ocnargs, ocva, farg)
push_inst!(rep)
Core.SSAValue(length(overdubbed_code))
end

push_inst!(Expr(:call, oc, fn_args[1:end]...))

ocres = Core.SSAValue(length(overdubbed_code))
ocres = push_inst!(Expr(:call, oc, fn_args[1:end]...))

if DEBUG_INTERP[]
push_inst!(Expr(:call, safe_print, "ocres", ocres))
Expand All @@ -882,7 +924,13 @@ function call_with_reactant_generator(
end

code_info.code = overdubbed_code
code_info.codelocs = overdubbed_codelocs

@static if isdefined(Core, :DebugInfo)
code_info.debuginfo = Core.DebugInfo(:none) # Core.DebugInfoStream(overdubbed_codelocs), length(overdubbed_codelocs))
else
code_info.codelocs = overdubbed_codelocs
end

code_info.ssavaluetypes = length(overdubbed_code)
code_info.ssaflags = [0x00 for _ in 1:length(overdubbed_code)] # XXX we need to copy flags that are set for the original code

Expand Down
6 changes: 5 additions & 1 deletion src/xla/PJRT/LoadedExecutable.jl
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,11 @@ function XLA.compile(
end

function execute_ir(N, M, n_outs, with_device::Bool, nmesh_ids::Int64)
ptr = sizeof(Int) == sizeof(Int64) ? "i64" : "i32"
ptr = @static if VERSION < v"1.12"
sizeof(Int) == sizeof(Int64) ? "i64" : "i32"
else
"ptr"
end
cint = sizeof(Cint) == sizeof(Int64) ? "i64" : "i32"
args = N > 0 ? ", [$N x $ptr] %inps, [$M x i8] %donated" : ""
if with_device
Expand Down
Loading