Skip to content

Commit

Permalink
add MultilineFusion example
Browse files Browse the repository at this point in the history
  • Loading branch information
vchuravy committed Feb 13, 2024
1 parent 0a89c1b commit a7b68cf
Show file tree
Hide file tree
Showing 2 changed files with 207 additions and 0 deletions.
3 changes: 3 additions & 0 deletions test/compiler/plugins/MultilineFusion/Project.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
name = "MultilineFusion"
uuid = "bb4966f2-fd13-4cc8-856b-cab8c274a504"
version = "0.1.0"
204 changes: 204 additions & 0 deletions test/compiler/plugins/MultilineFusion/src/MultilineFusion.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
module MultilineFusion

export multiline_fusion

function multiline_fusion(f, args...)
Base.invoke_within(MultilineFusion(), f, args...)
end

const CC = Core.Compiler
import .CC: SSAValue, GlobalRef

const COMPILER_WORLD = Ref{UInt}(0)
function __init__()
COMPILER_WORLD[] = Base.get_world_counter()
end

struct MLFCompiler <: CC.AbstractCompiler end
CC.compiler_world(::MLFCompiler) = COMPILER_WORLD[]
CC.abstract_interpreter(compiler::MLFCompiler, world::UInt) =
MLFInterp(compiler; world)

struct MLFInterp <: CC.AbstractInterpreter
compiler::MLFCompiler
world::UInt
inf_params::CC.InferenceParams
opt_params::CC.OptimizationParams
inf_cache::Vector{CC.InferenceResult}
function MLFInterp(compiler::MLFCompiler;
world::UInt = Base.get_world_counter(),
inf_params::CC.InferenceParams = CC.InferenceParams(),
opt_params::CC.OptimizationParams = CC.OptimizationParams(),
inf_cache::Vector{CC.InferenceResult} = CC.InferenceResult[])
return new(compiler, world, inf_params, opt_params, inf_cache)
end
end

CC.InferenceParams(interp::MLFInterp) = interp.inf_params
CC.OptimizationParams(interp::MLFInterp) = interp.opt_params
CC.get_inference_world(interp::MLFInterp) = interp.world
CC.get_inference_cache(interp::MLFInterp) = interp.inf_cache
CC.cache_owner(interp::MLFInterp) = interp.compiler

import Core.Compiler: retrieve_code_info, maybe_validate_code
# Replace usage sited of `retrieve_code_info`, OptimizationState is one such, but in all interesting use-cases
# it is derived from an InferenceState. There is a third one in `typeinf_ext` in case the module forbids inference.
function CC.InferenceState(result::CC.InferenceResult, cache_mode::UInt8, interp::MLFInterp)
world = CC.get_inference_world(interp)
src = retrieve_code_info(result.linfo, world)
src === nothing && return nothing
maybe_validate_code(result.linfo, src, "lowered")
src = transform(interp, result.linfo, src)
maybe_validate_code(result.linfo, src, "transformed")
return CC.InferenceState(result, src, cache_mode, interp)
end

function transform(interp, mi, src)
ci = copy(src)
transform!(mi, ci)
return ci
end

import .CC: userefs, UseRefIterator, UseRef

Base.iterate(useref::UseRefIterator, state...) = CC.iterate(useref, state...)
Base.getindex(useref::UseRef) = CC.getindex(useref)
Base.setindex!(useref::UseRef, x) = CC.setindex!(useref, x)

"""
is_ir_element(x, y, code::Vector)
Return `true` if `x === y` or if `x` is an `SSAValue` such that
`is_ir_element(code[x.id], y, code)` is `true`.
See also: [`replace_match!`](@ref), [`insert_statements!`](@ref)
"""
function is_ir_element(x, y, code::Vector)
result = false
while true # break by default
if x === y #
result = true
break
elseif isa(x, Core.SSAValue)
x = code[x.id]
else
break
end
end
return result
end

mutable struct MaterializeAnalysis
def::SSAValue
forward::Any
removable::Bool
MaterializeAnalysis(def, forward) = new(def, forward, true)
end

"""
transform!
The goal of this transform is to perform multi-line fusion of broadcast statements.
This needs to happen before inference since the broadcast machinery depends on inference
to perform the actual fusion.
```
C = A .* B
D = C .+ A
```
Is lowered to
```
1 ─ %1 = Base.broadcasted(Main.:*, A, B)
│ C = Base.materialize(%1)
│ %3 = C
│ %4 = Base.broadcasted(Main.:+, %3, A)
│ %5 = Base.materialize(%4)
│ D = %5
└── return %5
```
We find all calls to `materialize` and find all their uses.
If the materialize statment is used inside a `broadcasted`

Check warning on line 122 in test/compiler/plugins/MultilineFusion/src/MultilineFusion.jl

View workflow job for this annotation

GitHub Actions / Check for new typos

perhaps "statment" should be "statement".
statment we forward the argument to the `broadcasted`.

Check warning on line 123 in test/compiler/plugins/MultilineFusion/src/MultilineFusion.jl

View workflow job for this annotation

GitHub Actions / Check for new typos

perhaps "statment" should be "statement".
If it is only used for broadcasted statements we delete
the call to `Base.materialize`.
```
1 ─ %1 = Base.broadcasted(Main.:*, A, B)
│ C = nothing
│ %3 = C
│ %4 = Base.broadcasted(Main.:+, %1, A)
│ %5 = Base.materialize(%4)
│ D = %5
└── return %5
```
"""
function transform!(mi, src)
materialize = Base.IdDict{Any, MaterializeAnalysis}()
for (i, x) in enumerate(src.code)
isassign = Base.Meta.isexpr(x, :(=))
stmt = isassign ? x.args[2] : x
if Base.Meta.isexpr(stmt, :call)
if is_ir_element(stmt.args[1], GlobalRef(Base, :materialize), src.code)
manalysis = MaterializeAnalysis(SSAValue(i), stmt.args[2])
if isassign
materialize[x.args[1]] = manalysis
end
materialize[SSAValue(i)] = manalysis
continue
end
if is_ir_element(stmt.args[1], GlobalRef(Base, :broadcasted), src.code)
for op in userefs(stmt)
use = op[]
manalysis = get(materialize, use, nothing)
if manalysis !== nothing
op[] = manalysis.forward
end
end
end
elseif isassign
if haskey(materialize, x.args[2])
materialize[x.args[1]] = materialize[x.args[2]]
end
elseif x isa CC.SlotNumber || x isa CC.SSAValue
if haskey(materialize, x)
materialize[SSAValue(i)] = materialize[x]
end
else
for op in userefs(stmt)
use = op[]
manalysis = get(materialize, use, nothing)
if manalysis !== nothing
manalysis.removable = false
end
end
end
end

for manalysis in unique(values(materialize))
if manalysis.removable
x = src.code[manalysis.def.id]
if Base.Meta.isexpr(x, :(=))
x.args[2] = nothing
else
src.code[manalysis.def.id] = nothing
end
end
end
return nothing
end


# precompilation
precompile(CC.abstract_interpreter, (MLFCompiler, UInt))
precompile(CC.typeinf_ext_toplevel, (MLFInterp, CC.MethodInstance))

COMPILER_WORLD[] = Base.get_world_counter()
# Insert code execution statements here

# end precompile
COMPILER_WORLD[] = 0

end # module MultilineFusion

0 comments on commit a7b68cf

Please sign in to comment.