Skip to content

Commit

Permalink
Reapply "run inference in generated function (#76)" (#81)
Browse files Browse the repository at this point in the history
This reverts commit 1b0686f.
  • Loading branch information
Krastanov committed Feb 12, 2024
1 parent dc9f2fc commit 935cb10
Show file tree
Hide file tree
Showing 8 changed files with 160 additions and 56 deletions.
4 changes: 0 additions & 4 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,9 +1,5 @@
# News

## v0.6.7 - 2024-01-02

- Fix stack overflow errors by reverting the changes introduced in v0.6.6.

## v0.6.6 - 2023-10-08

- Significantly improved performance on Julia 1.10 and newer by more precise inference of slot types.
Expand Down
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ license = "MIT"
desc = "C# sharp style generators a.k.a. semi-coroutines for Julia."
authors = ["Ben Lauwens <ben.lauwens@gmail.com>"]
repo = "https://github.com/BenLauwens/ResumableFunctions.jl.git"
version = "0.6.7"
version = "0.6.6"

[deps]
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
Expand Down
55 changes: 35 additions & 20 deletions src/macro.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,15 +39,18 @@ macro resumable(expr::Expr)
func_def[:body] = postwalk(transform_arg_yieldfrom, func_def[:body])
func_def[:body] = postwalk(transform_yieldfrom, func_def[:body])
func_def[:body] = postwalk(x->transform_for(x, ui8), func_def[:body])
slots = get_slots(copy(func_def), arg_dict, __module__)
inferfn, slots = get_slots(copy(func_def), arg_dict, __module__)
type_name = gensym(Symbol(func_def[:name], :_FSMI))
constr_def = copy(func_def)
if isempty(params)
struct_name = :($type_name <: ResumableFunctions.FiniteStateMachineIterator{$rtype})
slot_T = [gensym(s) for s in keys(slots)]
slot_T_sub = [:($k <: $v) for (k, v) in zip(slot_T, values(slots))]
struct_name = :($type_name{$(func_def[:whereparams]...), $(slot_T_sub...)} <: ResumableFunctions.FiniteStateMachineIterator{$rtype})
constr_def[:whereparams] = (func_def[:whereparams]..., slot_T_sub...)
# if there are no where or slot type parameters, we need to use the bare type
if isempty(params) && isempty(slot_T)
constr_def[:name] = :($type_name)
else
struct_name = :($type_name{$(func_def[:whereparams]...)} <: ResumableFunctions.FiniteStateMachineIterator{$rtype})
constr_def[:name] = :($type_name{$(params...)})
constr_def[:name] = :($type_name{$(params...), $(slot_T...)})
end
constr_def[:args] = tuple()
constr_def[:kwargs] = tuple()
Expand All @@ -57,32 +60,44 @@ macro resumable(expr::Expr)
fsmi._state = 0x00
fsmi
end
# the bare/fallback version of the constructor supplies default slot type parameters
# we only need to define this if there there are actually slot defaults to be filled
if !isempty(slot_T)
bareconstr_def = copy(constr_def)
if isempty(params)
bareconstr_def[:name] = :($type_name)
else
bareconstr_def[:name] = :($type_name{$(params...)})
end
bareconstr_def[:whereparams] = func_def[:whereparams]
bareconstr_def[:body] = :($(bareconstr_def[:name]){$(values(slots)...)}())
bareconst_expr = combinedef(bareconstr_def) |> flatten
else
bareconst_expr = nothing
end
constr_expr = combinedef(constr_def) |> flatten
type_expr = :(
mutable struct $struct_name
_state :: UInt8
$((:($slotname :: $slottype) for (slotname, slottype) in slots)...)
$((:($slotname :: $slottype) for (slotname, slottype) in zip(keys(slots), slot_T))...)
$(constr_expr)
$(bareconst_expr)
end
)
@debug type_expr|>MacroTools.striplines
call_def = copy(func_def)
call_def[:rtype] = nothing
if isempty(params)
call_def[:rtype] = nothing
call_def[:body] = quote
fsmi = $type_name()
$((arg !== Symbol("_") ? :(fsmi.$arg = $arg) : nothing for arg in args)...)
$((:(fsmi.$arg = $arg) for arg in kwargs)...)
fsmi
end
fsmi_name = type_name
else
call_def[:rtype] = nothing
call_def[:body] = quote
fsmi = $type_name{$(params...)}()
$((arg !== Symbol("_") ? :(fsmi.$arg = $arg) : nothing for arg in args)...)
$((:(fsmi.$arg = $arg) for arg in kwargs)...)
fsmi
end
fsmi_name = :($type_name{$(params...)})
end
fwd_args, fwd_kwargs = forward_args(call_def)
call_def[:body] = quote
fsmi = ResumableFunctions.typed_fsmi($fsmi_name, $inferfn, $(fwd_args...), $(fwd_kwargs...))
$((arg !== Symbol("_") ? :(fsmi.$arg = $arg) : nothing for arg in args)...)
$((:(fsmi.$arg = $arg) for arg in kwargs)...)
fsmi
end
call_expr = combinedef(call_def) |> flatten
@debug call_expr|>MacroTools.striplines
Expand Down
10 changes: 6 additions & 4 deletions src/transforms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ end
"""
Function that replaces a variable `x` in an expression by `_fsmi.x` where `x` is a known slot.
"""
function transform_slots(expr, symbols::Base.KeySet{Symbol, Dict{Symbol,Any}})
function transform_slots(expr, symbols)
expr isa Expr || return expr
expr.head === :let && return transform_slots_let(expr, symbols)
for i in 1:length(expr.args)
Expand All @@ -114,7 +114,7 @@ end
"""
Function that handles `let` block
"""
function transform_slots_let(expr::Expr, symbols::Base.KeySet{Symbol, Dict{Symbol,Any}})
function transform_slots_let(expr::Expr, symbols)
@capture(expr, let vars_; body_ end)
locals = Set{Symbol}()
(isa(vars, Expr) && vars.head==:(=)) || error("@resumable currently supports only single variable declarations in let blocks, i.e. only let blocks exactly of the form `let i=j; ...; end`. If you need multiple variables, please submit an issue on the issue tracker and consider contributing a patch.")
Expand Down Expand Up @@ -247,14 +247,16 @@ end
"""
Function that replaces a `@yield ret` or `@yield` statement with
```julia
return ret
Base.inferencebarrier(ret)
```
This version is used for inference only.
It makes sure that `val = @yield ret` is inferred as `Any` rather than `typeof(ret)`.
"""
function transform_yield(expr)
_is_yield(expr) || return expr
ret = length(expr.args) > 2 ? expr.args[3:end] : [nothing]
quote
$(ret...)
Base.inferencebarrier($(ret...))
end
end

Expand Down
137 changes: 114 additions & 23 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,42 +28,71 @@ function get_args(func_def::Dict)
arg_list, kwarg_list, arg_dict
end

"""
Takes a function definition and returns the expressions needed to forward the arguments to an inner function.
For example `function foo(a, ::Int, c...; x, y=1, z...)` will
1. moodify the function to `gensym()` nameless arguments
2. return `(:a, gensym(), :(c...)), (:x, :y, :(z...)))`
"""
function forward_args(func_def)
args = []
map!(func_def[:args], func_def[:args]) do arg
name, type, splat, default = splitarg(arg)
name = something(name, gensym())
if splat
push!(args, :($name...))
else
push!(args, name)
end
combinearg(name, type, splat, default)
end
kwargs = []
for arg in func_def[:kwargs]
name, type, splat, default = splitarg(arg)
if splat
push!(kwargs, :($name...))
else
push!(kwargs, name)
end
end
args, kwargs
end

const unused = (Symbol("#temp#"), Symbol("_"), Symbol(""), Symbol("#unused#"), Symbol("#self#"))

"""
Function returning the slots of a function definition
"""
function get_slots(func_def::Dict, args::Dict{Symbol, Any}, mod::Module) :: Dict{Symbol, Any}
function get_slots(func_def::Dict, args::Dict{Symbol, Any}, mod::Module)
slots = Dict{Symbol, Any}()
func_def[:name] = gensym()
func_def[:args] = (func_def[:args]..., func_def[:kwargs]...)
func_def[:kwargs] = []
body = func_def[:body]
# replace yield with inference barrier
func_def[:body] = postwalk(transform_yield, func_def[:body])
# collect items to skip
nosaves = Set{Symbol}()
func_def[:body] = postwalk(x->transform_nosave(x, nosaves), func_def[:body])
# eval function
func_expr = combinedef(func_def) |> flatten
@eval(mod, @noinline $func_expr)
# get typed code
codeinfos = @eval(mod, code_typed($(func_def[:name]), Tuple; optimize=false))
# extract slot names and types
for codeinfo in codeinfos
for (name, type) in collect(zip(codeinfo.first.slotnames, codeinfo.first.slottypes))
name nosaves && (slots[name] = type)
name nosaves && name unused && (slots[name] = Union{type, get(slots, name, Union{})})
end
end
for (argname, argtype) in args
slots[argname] = argtype
end
# remove `catch exc` statements
postwalk(x->remove_catch_exc(x, slots), func_def[:body])
postwalk(x->make_arg_any(x, slots), body)
# set error branches to Any
for (key, val) in slots
if val === Union{}
slots[key] = Any
end
end
delete!(slots, Symbol("#temp#"))
delete!(slots, Symbol("_"))
delete!(slots, Symbol(""))
delete!(slots, Symbol("#unused#"))
delete!(slots, Symbol("#self#"))
slots
return func_def[:name], slots
end

"""
Expand All @@ -74,16 +103,6 @@ function remove_catch_exc(expr, slots::Dict{Symbol, Any})
expr
end

"""
Function changing the type of a slot `arg` of a `arg = @yield ret` or `arg = @yield` statement to `Any`.
"""
function make_arg_any(expr, slots::Dict{Symbol, Any})
@capture(expr, arg_ = ex_) || return expr
_is_yield(ex) || return expr
slots[arg] = Any
expr
end

struct IteratorReturn{T}
value :: T
IteratorReturn(value) = new{typeof(value)}(value)
Expand All @@ -107,3 +126,75 @@ end
isnothing(ret) && return IteratorReturn(nothing)
ret
end

# this is similar to code_typed but it considers the world age
function code_typed_by_type(@nospecialize(tt::Type);
optimize::Bool=true,
world::UInt=Base.get_world_counter(),
interp::Core.Compiler.AbstractInterpreter=Core.Compiler.NativeInterpreter(world))
tt = Base.to_tuple_type(tt)
# look up the method
match, valid_worlds = Core.Compiler.findsup(tt, Core.Compiler.InternalMethodTable(world))
# run inference, normally not allowed in generated functions
frame = Core.Compiler.typeinf_frame(interp, match.method, match.spec_types, match.sparams, optimize)
frame === nothing && error("inference failed")
valid_worlds = Core.Compiler.intersect(valid_worlds, frame.valid_worlds)
return frame.linfo, frame.src, valid_worlds
end

function fsmi_generator(world::UInt, source::LineNumberNode, passtype, fsmitype::Type{Type{T}}, fargtypes) where T
@nospecialize
# get typed code of the inference function evaluated in get_slots
# but this time with concrete argument types
tt = Base.to_tuple_type(fargtypes)
mi, ci, valid_worlds = try
code_typed_by_type(tt; world, optimize=false)
catch err # inference failed, return generic type
Core.println(err)
slots = fieldtypes(T)[2:end]
stub = Core.GeneratedFunctionStub(identity, Core.svec(:pass, :fsmi, :fargs), Core.svec())
if isempty(slots)
return stub(world, source, :(return $T()))
else
return stub(world, source, :(return $T{$(slots...)}()))
end
end
min_world = valid_worlds.min_world
max_world = valid_worlds.max_world
# extract slot types
cislots = Dict{Symbol, Any}()
for (name, type) in collect(zip(ci.slotnames, ci.slottypes))
# take care to widen types that are unstable or Const
type = Core.Compiler.widenconst(type)
cislots[name] = Union{type, get(cislots, name, Union{})}
end
slots = map(slot->get(cislots, slot, Any), fieldnames(T)[2:end])
# generate code to instantiate the concrete type
stub = Core.GeneratedFunctionStub(identity, Core.svec(:pass, :fsmi, :fargs), Core.svec())
if isempty(slots)
exprs = stub(world, source, :(return $T()))
else
exprs = stub(world, source, :(return $T{$(slots...)}()))
end
# lower codeinfo to pass world age and invalidation edges
ci = ccall(:jl_expand_and_resolve, Any, (Any, Any, Any), exprs, passtype.name.module, Core.svec())
ci.min_world = min_world
ci.max_world = max_world
ci.edges = Core.MethodInstance[mi]
return ci
end

# JuliaLang/julia#48611: world age is exposed to generated functions, and should be used
if VERSION >= v"1.10.0-DEV.873"
# This is like @generated, but it receives the world age of the caller
# which we need to do inference safely and correctly
@eval function typed_fsmi(fsmi, fargs...)
$(Expr(:meta, :generated_only))
$(Expr(:meta, :generated, fsmi_generator))
end
else
# runtime fallback function that uses the fallback constructor with generic slot types
function typed_fsmi(fsmi::Type{T}, fargs...)::T where T
return T()
end
end
4 changes: 2 additions & 2 deletions test/test_jet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@ end
rep = report_package("ResumableFunctions";
report_pass=MayThrowIsOk(),
ignored_modules=(

Core.Compiler,
)
)
@show rep
@test length(JET.get_reports(rep)) <= 5
@test length(JET.get_reports(rep)) <= 8
@test_broken length(JET.get_reports(rep)) == 0
end
2 changes: 1 addition & 1 deletion test/test_main.jl
Original file line number Diff line number Diff line change
Expand Up @@ -225,5 +225,5 @@ end
end

@testset "test_unstable" begin
@test_broken collect(test_unstable(3)) == ["number 1", "number 2", "number 3"]
@test collect(test_unstable(3)) == ["number 1", "number 2", "number 3"]
end
2 changes: 1 addition & 1 deletion test/test_performance.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,5 +32,5 @@ end
if VERSION >= v"1.10.0-DEV.873"
cs = cumsum(1:1000)
@allocated cs() # shake out the compilation overhead
@test_broken (@allocated cs())==0
@test (@allocated cs())==0
end

0 comments on commit 935cb10

Please sign in to comment.