diff --git a/src/utils.jl b/src/utils.jl index b8eb028494..46b17f40ed 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -259,6 +259,29 @@ end const DEBUG_INTERP = Ref(false) +# Rewrite type unstable calls to recurse into call_with_reactant to ensure +# they continue to use our interpreter. Reset the derived return type +# to Any if our interpreter would change the return type of any result. +# Also rewrite invoke (type stable call) to be :call, since otherwise apparently +# screws up type inference after this (TODO this should be fixed). +function rewrite_insts!(ir, interp) + any_changed = false + for (i, inst) in enumerate(ir.stmts) + @static if VERSION < v"1.11" + changed, next = rewrite_inst(inst[:inst], ir, interp) + Core.Compiler.setindex!(ir.stmts[i], next, :inst) + else + changed, next = rewrite_inst(inst[:stmt], ir, interp) + Core.Compiler.setindex!(ir.stmts[i], next, :stmt) + end + if changed + any_changed = true + Core.Compiler.setindex!(ir.stmts[i], Any, :type) + end + end + return ir, any_changed +end + # Generator function which ensures that all calls to the function are executed within the ReactantInterpreter # In particular this entails two pieces: # 1) We enforce the use of the ReactantInterpreter method table when generating the original methodinstance @@ -320,72 +343,28 @@ function call_with_reactant_generator( match.spec_types, match.sparams, ) - - result = Core.Compiler.InferenceResult(mi, Core.Compiler.typeinf_lattice(interp)) - frame = Core.Compiler.InferenceState(result, VERSION < v"1.11-" ? :local : :no, interp) #=cache_mode=# - @assert frame !== nothing - Core.Compiler.typeinf(interp, frame) - @static if VERSION >= v"1.11" - # `typeinf` doesn't update the cfg. We need to do it manually. - # frame.cfg = Core.Compiler.compute_basic_blocks(frame.src.code) - end - @assert Core.Compiler.is_inferred(frame) - - method = match.method - - # The original julia code (on 1.11+) has the potential constprop, for now - # we assume this outermost function does not constprop, for ease. - #if Core.Compiler.result_is_constabi(interp, frame.result) - # rt = frame.result.result::Core.Compiler.Const - # src = Core.Compiler.codeinfo_for_const(interp, frame.linfo, rt.val) - #else - # - opt = Core.Compiler.OptimizationState(frame, interp) - - if DEBUG_INTERP[] - safe_print("opt.src", opt.src) - end - - caller = frame.result - @static if VERSION < v"1.11-" - ir = Core.Compiler.run_passes(opt.src, opt, caller) + method = mi.def + + @static if VERSION < v"1.11" + # For older Julia versions, we vendor in some of the code to prevent + # having to build the MethodInstance twice. + result = CC.InferenceResult(mi, CC.typeinf_lattice(interp)) + frame = CC.InferenceState(result, :no, interp) + @assert !isnothing(frame) + CC.typeinf(interp, frame) + ir = CC.run_passes(frame.src, CC.OptimizationState(frame, interp), result, nothing) + rt = CC.widenconst(CC.ignorelimited(result.result)) else - ir = Core.Compiler.run_passes_ipo_safe(opt.src, opt, caller) - @static if VERSION < v"1.12-" - else - Core.Compiler.ipo_dataflow_analysis!(interp, ir, caller) - end - end - - if DEBUG_INTERP[] - safe_print("ir1", ir) - end - - # Rewrite type unstable calls to recurse into call_with_reactant to ensure - # they continue to use our interpreter. Reset the derived return type - # to Any if our interpreter would change the return type of any result. - # Also rewrite invoke (type stable call) to be :call, since otherwise apparently - # screws up type inference after this (TODO this should be fixed). - any_changed = false - if should_rewrite_ft(args[1]) && !is_reactant_method(mi) - for (i, inst) in enumerate(ir.stmts) - @static if VERSION < v"1.11" - changed, next = rewrite_inst(inst[:inst], ir, interp) - Core.Compiler.setindex!(ir.stmts[i], next, :inst) - else - changed, next = rewrite_inst(inst[:stmt], ir, interp) - Core.Compiler.setindex!(ir.stmts[i], next, :stmt) - end - if changed - any_changed = true - Core.Compiler.setindex!(ir.stmts[i], Any, :type) - end - end + ir, rt = CC.typeinf_ircode(interp, mi, nothing) end - Core.Compiler.finish(interp, opt, ir, caller) - - src = Core.Compiler.ir_to_codeinf!(opt) + ir, any_changed = rewrite_insts!(ir, interp) + src = ccall(:jl_new_code_info_uninit, Ref{CC.CodeInfo}, ()) + src.slotnames = fill(:none, length(ir.argtypes) + 1) + src.slotflags = fill(zero(UInt8), length(ir.argtypes)) + src.slottypes = copy(ir.argtypes) + src.rettype = rt + src = CC.ir_to_codeinf!(src, ir) if DEBUG_INTERP[] safe_print("src", src) @@ -488,8 +467,6 @@ function call_with_reactant_generator( end end - rt = Base.Experimental.compute_ir_rettype(ir) - # ocva = method.isva ocva = false # method.isva