Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revert "run inference in generated function (#76)" #81

Merged
merged 4 commits into from
Jan 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# News

## v0.6.7 - 2024-01-02

- Fix stack overflow errors by reverting the changes introduced in v0.6.6.

## v0.6.6 - 2023-10-08

- Significantly improved performance on Julia 1.10 and newer by more precise inference of slot types.
Expand Down
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ license = "MIT"
desc = "C# sharp style generators a.k.a. semi-coroutines for Julia."
authors = ["Ben Lauwens <ben.lauwens@gmail.com>"]
repo = "https://github.com/BenLauwens/ResumableFunctions.jl.git"
version = "0.6.6"
version = "0.6.7"

[deps]
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
Expand Down
55 changes: 20 additions & 35 deletions src/macro.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,18 +39,15 @@ macro resumable(expr::Expr)
func_def[:body] = postwalk(transform_arg_yieldfrom, func_def[:body])
func_def[:body] = postwalk(transform_yieldfrom, func_def[:body])
func_def[:body] = postwalk(x->transform_for(x, ui8), func_def[:body])
inferfn, slots = get_slots(copy(func_def), arg_dict, __module__)
slots = get_slots(copy(func_def), arg_dict, __module__)
type_name = gensym(Symbol(func_def[:name], :_FSMI))
constr_def = copy(func_def)
slot_T = [gensym(s) for s in keys(slots)]
slot_T_sub = [:($k <: $v) for (k, v) in zip(slot_T, values(slots))]
struct_name = :($type_name{$(func_def[:whereparams]...), $(slot_T_sub...)} <: ResumableFunctions.FiniteStateMachineIterator{$rtype})
constr_def[:whereparams] = (func_def[:whereparams]..., slot_T_sub...)
# if there are no where or slot type parameters, we need to use the bare type
if isempty(params) && isempty(slot_T)
if isempty(params)
struct_name = :($type_name <: ResumableFunctions.FiniteStateMachineIterator{$rtype})
constr_def[:name] = :($type_name)
else
constr_def[:name] = :($type_name{$(params...), $(slot_T...)})
struct_name = :($type_name{$(func_def[:whereparams]...)} <: ResumableFunctions.FiniteStateMachineIterator{$rtype})
constr_def[:name] = :($type_name{$(params...)})
end
constr_def[:args] = tuple()
constr_def[:kwargs] = tuple()
Expand All @@ -60,44 +57,32 @@ macro resumable(expr::Expr)
fsmi._state = 0x00
fsmi
end
# the bare/fallback version of the constructor supplies default slot type parameters
# we only need to define this if there there are actually slot defaults to be filled
if !isempty(slot_T)
bareconstr_def = copy(constr_def)
if isempty(params)
bareconstr_def[:name] = :($type_name)
else
bareconstr_def[:name] = :($type_name{$(params...)})
end
bareconstr_def[:whereparams] = func_def[:whereparams]
bareconstr_def[:body] = :($(bareconstr_def[:name]){$(values(slots)...)}())
bareconst_expr = combinedef(bareconstr_def) |> flatten
else
bareconst_expr = nothing
end
constr_expr = combinedef(constr_def) |> flatten
type_expr = :(
mutable struct $struct_name
_state :: UInt8
$((:($slotname :: $slottype) for (slotname, slottype) in zip(keys(slots), slot_T))...)
$((:($slotname :: $slottype) for (slotname, slottype) in slots)...)
$(constr_expr)
$(bareconst_expr)
end
)
@debug type_expr|>MacroTools.striplines
call_def = copy(func_def)
call_def[:rtype] = nothing
if isempty(params)
fsmi_name = type_name
call_def[:rtype] = nothing
call_def[:body] = quote
fsmi = $type_name()
$((arg !== Symbol("_") ? :(fsmi.$arg = $arg) : nothing for arg in args)...)
$((:(fsmi.$arg = $arg) for arg in kwargs)...)
fsmi
end
else
fsmi_name = :($type_name{$(params...)})
end
fwd_args, fwd_kwargs = forward_args(call_def)
call_def[:body] = quote
fsmi = ResumableFunctions.typed_fsmi($fsmi_name, $inferfn, $(fwd_args...), $(fwd_kwargs...))
$((arg !== Symbol("_") ? :(fsmi.$arg = $arg) : nothing for arg in args)...)
$((:(fsmi.$arg = $arg) for arg in kwargs)...)
fsmi
call_def[:rtype] = nothing
call_def[:body] = quote
fsmi = $type_name{$(params...)}()
$((arg !== Symbol("_") ? :(fsmi.$arg = $arg) : nothing for arg in args)...)
$((:(fsmi.$arg = $arg) for arg in kwargs)...)
fsmi
end
end
call_expr = combinedef(call_def) |> flatten
@debug call_expr|>MacroTools.striplines
Expand Down
10 changes: 4 additions & 6 deletions src/transforms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ end
"""
Function that replaces a variable `x` in an expression by `_fsmi.x` where `x` is a known slot.
"""
function transform_slots(expr, symbols)
function transform_slots(expr, symbols::Base.KeySet{Symbol, Dict{Symbol,Any}})
expr isa Expr || return expr
expr.head === :let && return transform_slots_let(expr, symbols)
for i in 1:length(expr.args)
Expand All @@ -114,7 +114,7 @@ end
"""
Function that handles `let` block
"""
function transform_slots_let(expr::Expr, symbols)
function transform_slots_let(expr::Expr, symbols::Base.KeySet{Symbol, Dict{Symbol,Any}})
@capture(expr, let vars_; body_ end)
locals = Set{Symbol}()
(isa(vars, Expr) && vars.head==:(=)) || error("@resumable currently supports only single variable declarations in let blocks, i.e. only let blocks exactly of the form `let i=j; ...; end`. If you need multiple variables, please submit an issue on the issue tracker and consider contributing a patch.")
Expand Down Expand Up @@ -247,16 +247,14 @@ end
"""
Function that replaces a `@yield ret` or `@yield` statement with
```julia
Base.inferencebarrier(ret)
return ret
```
This version is used for inference only.
It makes sure that `val = @yield ret` is inferred as `Any` rather than `typeof(ret)`.
"""
function transform_yield(expr)
_is_yield(expr) || return expr
ret = length(expr.args) > 2 ? expr.args[3:end] : [nothing]
quote
Base.inferencebarrier($(ret...))
$(ret...)
end
end

Expand Down
137 changes: 23 additions & 114 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,71 +28,42 @@ function get_args(func_def::Dict)
arg_list, kwarg_list, arg_dict
end

"""
Takes a function definition and returns the expressions needed to forward the arguments to an inner function.
For example `function foo(a, ::Int, c...; x, y=1, z...)` will
1. moodify the function to `gensym()` nameless arguments
2. return `(:a, gensym(), :(c...)), (:x, :y, :(z...)))`
"""
function forward_args(func_def)
args = []
map!(func_def[:args], func_def[:args]) do arg
name, type, splat, default = splitarg(arg)
name = something(name, gensym())
if splat
push!(args, :($name...))
else
push!(args, name)
end
combinearg(name, type, splat, default)
end
kwargs = []
for arg in func_def[:kwargs]
name, type, splat, default = splitarg(arg)
if splat
push!(kwargs, :($name...))
else
push!(kwargs, name)
end
end
args, kwargs
end

const unused = (Symbol("#temp#"), Symbol("_"), Symbol(""), Symbol("#unused#"), Symbol("#self#"))

"""
Function returning the slots of a function definition
"""
function get_slots(func_def::Dict, args::Dict{Symbol, Any}, mod::Module)
function get_slots(func_def::Dict, args::Dict{Symbol, Any}, mod::Module) :: Dict{Symbol, Any}
slots = Dict{Symbol, Any}()
func_def[:name] = gensym()
func_def[:args] = (func_def[:args]..., func_def[:kwargs]...)
func_def[:kwargs] = []
# replace yield with inference barrier
body = func_def[:body]
func_def[:body] = postwalk(transform_yield, func_def[:body])
# collect items to skip
nosaves = Set{Symbol}()
func_def[:body] = postwalk(x->transform_nosave(x, nosaves), func_def[:body])
# eval function
func_expr = combinedef(func_def) |> flatten
@eval(mod, @noinline $func_expr)
# get typed code
codeinfos = @eval(mod, code_typed($(func_def[:name]), Tuple; optimize=false))
# extract slot names and types
for codeinfo in codeinfos
for (name, type) in collect(zip(codeinfo.first.slotnames, codeinfo.first.slottypes))
name ∉ nosaves && name ∉ unused && (slots[name] = Union{type, get(slots, name, Union{})})
name ∉ nosaves && (slots[name] = type)
end
end
# remove `catch exc` statements
for (argname, argtype) in args
slots[argname] = argtype
end
postwalk(x->remove_catch_exc(x, slots), func_def[:body])
# set error branches to Any
postwalk(x->make_arg_any(x, slots), body)
for (key, val) in slots
if val === Union{}
slots[key] = Any
end
end
return func_def[:name], slots
delete!(slots, Symbol("#temp#"))
delete!(slots, Symbol("_"))
delete!(slots, Symbol(""))
delete!(slots, Symbol("#unused#"))
delete!(slots, Symbol("#self#"))
slots
end

"""
Expand All @@ -103,6 +74,16 @@ function remove_catch_exc(expr, slots::Dict{Symbol, Any})
expr
end

"""
Function changing the type of a slot `arg` of a `arg = @yield ret` or `arg = @yield` statement to `Any`.
"""
function make_arg_any(expr, slots::Dict{Symbol, Any})
@capture(expr, arg_ = ex_) || return expr
_is_yield(ex) || return expr
slots[arg] = Any
expr
end

struct IteratorReturn{T}
value :: T
IteratorReturn(value) = new{typeof(value)}(value)
Expand All @@ -126,75 +107,3 @@ end
isnothing(ret) && return IteratorReturn(nothing)
ret
end

# this is similar to code_typed but it considers the world age
function code_typed_by_type(@nospecialize(tt::Type);
optimize::Bool=true,
world::UInt=Base.get_world_counter(),
interp::Core.Compiler.AbstractInterpreter=Core.Compiler.NativeInterpreter(world))
tt = Base.to_tuple_type(tt)
# look up the method
match, valid_worlds = Core.Compiler.findsup(tt, Core.Compiler.InternalMethodTable(world))
# run inference, normally not allowed in generated functions
frame = Core.Compiler.typeinf_frame(interp, match.method, match.spec_types, match.sparams, optimize)
frame === nothing && error("inference failed")
valid_worlds = Core.Compiler.intersect(valid_worlds, frame.valid_worlds)
return frame.linfo, frame.src, valid_worlds
end

function fsmi_generator(world::UInt, source::LineNumberNode, passtype, fsmitype::Type{Type{T}}, fargtypes) where T
@nospecialize
# get typed code of the inference function evaluated in get_slots
# but this time with concrete argument types
tt = Base.to_tuple_type(fargtypes)
mi, ci, valid_worlds = try
code_typed_by_type(tt; world, optimize=false)
catch err # inference failed, return generic type
Core.println(err)
slots = fieldtypes(T)[2:end]
stub = Core.GeneratedFunctionStub(identity, Core.svec(:pass, :fsmi, :fargs), Core.svec())
if isempty(slots)
return stub(world, source, :(return $T()))
else
return stub(world, source, :(return $T{$(slots...)}()))
end
end
min_world = valid_worlds.min_world
max_world = valid_worlds.max_world
# extract slot types
cislots = Dict{Symbol, Any}()
for (name, type) in collect(zip(ci.slotnames, ci.slottypes))
# take care to widen types that are unstable or Const
type = Core.Compiler.widenconst(type)
cislots[name] = Union{type, get(cislots, name, Union{})}
end
slots = map(slot->get(cislots, slot, Any), fieldnames(T)[2:end])
# generate code to instantiate the concrete type
stub = Core.GeneratedFunctionStub(identity, Core.svec(:pass, :fsmi, :fargs), Core.svec())
if isempty(slots)
exprs = stub(world, source, :(return $T()))
else
exprs = stub(world, source, :(return $T{$(slots...)}()))
end
# lower codeinfo to pass world age and invalidation edges
ci = ccall(:jl_expand_and_resolve, Any, (Any, Any, Any), exprs, passtype.name.module, Core.svec())
ci.min_world = min_world
ci.max_world = max_world
ci.edges = Core.MethodInstance[mi]
return ci
end

# JuliaLang/julia#48611: world age is exposed to generated functions, and should be used
if VERSION >= v"1.10.0-DEV.873"
# This is like @generated, but it receives the world age of the caller
# which we need to do inference safely and correctly
@eval function typed_fsmi(fsmi, fargs...)
$(Expr(:meta, :generated_only))
$(Expr(:meta, :generated, fsmi_generator))
end
else
# runtime fallback function that uses the fallback constructor with generic slot types
function typed_fsmi(fsmi::Type{T}, fargs...)::T where T
return T()
end
end
4 changes: 2 additions & 2 deletions test/test_jet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@ end
rep = report_package("ResumableFunctions";
report_pass=MayThrowIsOk(),
ignored_modules=(
Core.Compiler,

)
)
@show rep
@test length(JET.get_reports(rep)) <= 8
@test length(JET.get_reports(rep)) <= 5
@test_broken length(JET.get_reports(rep)) == 0
end
2 changes: 1 addition & 1 deletion test/test_main.jl
Original file line number Diff line number Diff line change
Expand Up @@ -225,5 +225,5 @@ end
end

@testset "test_unstable" begin
@test collect(test_unstable(3)) == ["number 1", "number 2", "number 3"]
@test_broken collect(test_unstable(3)) == ["number 1", "number 2", "number 3"]
end
2 changes: 1 addition & 1 deletion test/test_performance.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,5 +32,5 @@ end
if VERSION >= v"1.10.0-DEV.873"
cs = cumsum(1:1000)
@allocated cs() # shake out the compilation overhead
@test (@allocated cs())==0
@test_broken (@allocated cs())==0
end
Loading