Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Method overrides using a method overlay table. #122

Closed
wants to merge 2 commits into from

Conversation

maleadt
Copy link
Member

@maleadt maleadt commented Nov 30, 2020

Just prototyping. This overrides MethodInstance->CodeInstance lookups by hacking the cache lookup:

julia> foo() = 0
julia> kernel() = foo()
julia> @show kernel()
kernel() = 0

julia> bar() = 42
julia> GPUCompiler.CI_CACHE.overrides[foo] = [bar]
julia> native_code_llvm(kernel, Tuple{}; debuginfo=:none)
define dso_local i64 @julia_kernel_3647() local_unnamed_addr {
top:
  ret i64 42
}

Not ideal because the cache isn't really meant for this, I think. But it works. I'd rather specialize some functionality from abstract interpretation, but that doesn't seem to work with the IR we emit:

function Core.Compiler.abstract_call_gf_by_type(interp::GPUInterpreter,
                                                @nospecialize(f),
                                                argtypes::Vector{Any},
                                                @nospecialize(atype),
                                                sv::InferenceState,
                                                max_methods::Int = InferenceParams(interp).MAX_METHODS)
    tt = argtypes[2:end]
    if haskey(CI_CACHE.overrides, f)
        for new_f in CI_CACHE.overrides[f]
            hasmethod(new_f, tt) || continue
            @safe_info "Override call to $f with $new_f"
            f = new_f
            argtypes[1] = Core.Compiler.Const(f)
            atype = Tuple{typeof(f), tt...}
        end
    end
    return invoke(Core.Compiler.abstract_call_gf_by_type,
                    Tuple{AbstractInterpreter, typeof(f), typeof(argtypes),
                          typeof(atype), typeof(sv), typeof(max_methods)},
                    interp, f, argtypes, atype, sv, max_methods)
end

function Core.Compiler.transform_result_for_cache(interp::GPUInterpreter, linfo::MethodInstance,
                                    @nospecialize(inferred_result))
    @safe_info "prepare for cache" linfo inferred_result
    invoke(Core.Compiler.transform_result_for_cache, Tuple{AbstractInterpreter, typeof(linfo), typeof(inferred_result)}, interp, linfo, inferred_result)
end

We successfully convince inference to use bar, but the IR still contains references to foo, resulting in bad code:

[ Info: Override call to foo with bar
┌ Info: prepare for cache
│   linfo = MethodInstance for bar()
└   inferred_result = Core.Const(42)
┌ Info: prepare for cache
│   linfo = MethodInstance for kernel()
│   inferred_result =
│    CodeInfo(
│        @ /home/tim/Julia/pkg/GPUCompiler/wip.jl:7 within `kernel'
│    1 ─ %1 = Main.foo::typeof(foo)
│    │   %2 = (isa)(%1, typeof(bar))::Bool
│    └──      goto #3 if not %2
│    2 ─      goto #4
│    3 ─ %5 = Main.foo()::Int64
│    └──      goto #4
│    4 ┄ %7 = φ (#2 => 42, #3 => %5)::Int64
│    └──      return %7
└    )

define dso_local i64 @julia_kernel_2753() local_unnamed_addr {
top:
  %0 = call nonnull {}* @jl_apply_generic({}* inttoptr (i64 140004316834848 to {}*), {}** null, i32 0)
  %1 = bitcast {}* %0 to i64*
  %2 = load i64, i64* %1, align 8
  ret i64 %2
}

I had hoped that we would not need to rewrite the code when using the interpreter.

@maleadt
Copy link
Member Author

maleadt commented Nov 30, 2020

Even cleaner using a MethodTableView:

struct OverlayMethodTable <: Core.Compiler.MethodTableView
    inner::Core.Compiler.MethodTableView
end

Core.Compiler.method_table(interp::GPUInterpreter, sv::InferenceState) = OverlayMethodTable(sv.method_table)

function Core.Compiler.findall(@nospecialize(sig::Type{<:Tuple}), table::OverlayMethodTable; limit::Int=typemax(Int))
    ft = first(sig.parameters)
    tt = Tuple{sig.parameters[2:end]...}
    if haskey(CI_CACHE.overrides, ft.instance)
        for f in CI_CACHE.overrides[ft.instance]
            hasmethod(f, tt) || continue
            sig = Tuple{typeof(f), tt.parameters...}
        end
    end
    Core.Compiler.findall(sig, table.inner; limit)
end

But same issue with the resulting IR.

@jpsamaroo
Copy link
Member

Any chance that this mechanism can be easily extended to dispatch on function signature? Or would that end up with a lot of the same issues that Cassette has?

@maleadt
Copy link
Member Author

maleadt commented Nov 30, 2020

Yeah, sure. The method table query has the full call signature, so we can do with that what we want. What kind of overrides did you have in mind, and how should it work? In the case of, e.g., sin -> CUDAnative.sin we want to override if possible, but retain other methods (that, for example, work with Duals or complex numbers).

@jpsamaroo
Copy link
Member

Pretty much what you show with sin, often we just want to override the "lowest-level" method that makes sense, but still allow higher-level methods to continue working. Another example would be Base.throw: we'd want to be able to only override specific signatures that we know we can handle cleanly, and let the rest fall through (and possibly hit lower_throw!).

@jpsamaroo
Copy link
Member

Usually we should want to always do method overriding, so maybe we can provide a hook like GPUCompiler.override(::typeof(foo), Tuple{#=arg types=#}) = bar which we'll call to get the new MI to replace with?

@jpsamaroo
Copy link
Member

And of course this should dispatch on the current job to allow this to be target-specific...

@maleadt
Copy link
Member Author

maleadt commented Nov 30, 2020

Yeah of course, this is very early development.

@codecov
Copy link

codecov bot commented Feb 11, 2021

Codecov Report

Merging #122 (06daec6) into master (cd8e62e) will decrease coverage by 12.72%.
The diff coverage is 52.17%.

Impacted file tree graph

@@             Coverage Diff             @@
##           master     #122       +/-   ##
===========================================
- Coverage   82.94%   70.21%   -12.73%     
===========================================
  Files          22       22               
  Lines        1589     1558       -31     
===========================================
- Hits         1318     1094      -224     
- Misses        271      464      +193     
Impacted Files Coverage Δ
src/jlgen.jl 64.84% <52.17%> (-13.73%) ⬇️
src/execution.jl 0.00% <0.00%> (-100.00%) ⬇️
src/cache.jl 0.00% <0.00%> (-86.80%) ⬇️
src/reflection.jl 11.34% <0.00%> (-59.57%) ⬇️
src/driver.jl 69.91% <0.00%> (-24.03%) ⬇️
src/ptx.jl 89.20% <0.00%> (-8.64%) ⬇️
src/mcgen.jl 89.18% <0.00%> (-8.11%) ⬇️
src/spirv.jl 85.43% <0.00%> (-7.77%) ⬇️
src/irgen.jl 86.60% <0.00%> (-5.15%) ⬇️
... and 8 more

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update cd8e62e...0ff7e14. Read the comment docs.

@maleadt maleadt changed the title Try out method overrides. Method overrides using a method overlay table. Feb 12, 2021
@maleadt
Copy link
Member Author

maleadt commented Mar 3, 2021

Now part of #151.

@maleadt maleadt closed this Mar 3, 2021
@maleadt maleadt deleted the tb/method_overlay branch March 9, 2022 10:10
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants