Skip to content

Commit

Permalink
Revert "run inference in generated function (#76)" (#81)
Browse files Browse the repository at this point in the history
* Revert "run inference in generated function (#76)"

This reverts commit 19e9ac5.

* Update CHANGELOG.md

* Bump version to 0.6.7

* Preserve tests as broken
  • Loading branch information
gerlero committed Jan 2, 2024
1 parent c4fc271 commit 1b0686f
Show file tree
Hide file tree
Showing 8 changed files with 56 additions and 160 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# News

## v0.6.7 - 2024-01-02

- Fix stack overflow errors by reverting the changes introduced in v0.6.6.

## v0.6.6 - 2023-10-08

- Significantly improved performance on Julia 1.10 and newer by more precise inference of slot types.
Expand Down
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ license = "MIT"
desc = "C# sharp style generators a.k.a. semi-coroutines for Julia."
authors = ["Ben Lauwens <ben.lauwens@gmail.com>"]
repo = "https://github.com/BenLauwens/ResumableFunctions.jl.git"
version = "0.6.6"
version = "0.6.7"

[deps]
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
Expand Down
55 changes: 20 additions & 35 deletions src/macro.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,18 +39,15 @@ macro resumable(expr::Expr)
func_def[:body] = postwalk(transform_arg_yieldfrom, func_def[:body])
func_def[:body] = postwalk(transform_yieldfrom, func_def[:body])
func_def[:body] = postwalk(x->transform_for(x, ui8), func_def[:body])
inferfn, slots = get_slots(copy(func_def), arg_dict, __module__)
slots = get_slots(copy(func_def), arg_dict, __module__)
type_name = gensym(Symbol(func_def[:name], :_FSMI))
constr_def = copy(func_def)
slot_T = [gensym(s) for s in keys(slots)]
slot_T_sub = [:($k <: $v) for (k, v) in zip(slot_T, values(slots))]
struct_name = :($type_name{$(func_def[:whereparams]...), $(slot_T_sub...)} <: ResumableFunctions.FiniteStateMachineIterator{$rtype})
constr_def[:whereparams] = (func_def[:whereparams]..., slot_T_sub...)
# if there are no where or slot type parameters, we need to use the bare type
if isempty(params) && isempty(slot_T)
if isempty(params)
struct_name = :($type_name <: ResumableFunctions.FiniteStateMachineIterator{$rtype})
constr_def[:name] = :($type_name)
else
constr_def[:name] = :($type_name{$(params...), $(slot_T...)})
struct_name = :($type_name{$(func_def[:whereparams]...)} <: ResumableFunctions.FiniteStateMachineIterator{$rtype})
constr_def[:name] = :($type_name{$(params...)})
end
constr_def[:args] = tuple()
constr_def[:kwargs] = tuple()
Expand All @@ -60,44 +57,32 @@ macro resumable(expr::Expr)
fsmi._state = 0x00
fsmi
end
# the bare/fallback version of the constructor supplies default slot type parameters
# we only need to define this if there there are actually slot defaults to be filled
if !isempty(slot_T)
bareconstr_def = copy(constr_def)
if isempty(params)
bareconstr_def[:name] = :($type_name)
else
bareconstr_def[:name] = :($type_name{$(params...)})
end
bareconstr_def[:whereparams] = func_def[:whereparams]
bareconstr_def[:body] = :($(bareconstr_def[:name]){$(values(slots)...)}())
bareconst_expr = combinedef(bareconstr_def) |> flatten
else
bareconst_expr = nothing
end
constr_expr = combinedef(constr_def) |> flatten
type_expr = :(
mutable struct $struct_name
_state :: UInt8
$((:($slotname :: $slottype) for (slotname, slottype) in zip(keys(slots), slot_T))...)
$((:($slotname :: $slottype) for (slotname, slottype) in slots)...)
$(constr_expr)
$(bareconst_expr)
end
)
@debug type_expr|>MacroTools.striplines
call_def = copy(func_def)
call_def[:rtype] = nothing
if isempty(params)
fsmi_name = type_name
call_def[:rtype] = nothing
call_def[:body] = quote
fsmi = $type_name()
$((arg !== Symbol("_") ? :(fsmi.$arg = $arg) : nothing for arg in args)...)
$((:(fsmi.$arg = $arg) for arg in kwargs)...)
fsmi
end
else
fsmi_name = :($type_name{$(params...)})
end
fwd_args, fwd_kwargs = forward_args(call_def)
call_def[:body] = quote
fsmi = ResumableFunctions.typed_fsmi($fsmi_name, $inferfn, $(fwd_args...), $(fwd_kwargs...))
$((arg !== Symbol("_") ? :(fsmi.$arg = $arg) : nothing for arg in args)...)
$((:(fsmi.$arg = $arg) for arg in kwargs)...)
fsmi
call_def[:rtype] = nothing
call_def[:body] = quote
fsmi = $type_name{$(params...)}()
$((arg !== Symbol("_") ? :(fsmi.$arg = $arg) : nothing for arg in args)...)
$((:(fsmi.$arg = $arg) for arg in kwargs)...)
fsmi
end
end
call_expr = combinedef(call_def) |> flatten
@debug call_expr|>MacroTools.striplines
Expand Down
10 changes: 4 additions & 6 deletions src/transforms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ end
"""
Function that replaces a variable `x` in an expression by `_fsmi.x` where `x` is a known slot.
"""
function transform_slots(expr, symbols)
function transform_slots(expr, symbols::Base.KeySet{Symbol, Dict{Symbol,Any}})
expr isa Expr || return expr
expr.head === :let && return transform_slots_let(expr, symbols)
for i in 1:length(expr.args)
Expand All @@ -114,7 +114,7 @@ end
"""
Function that handles `let` block
"""
function transform_slots_let(expr::Expr, symbols)
function transform_slots_let(expr::Expr, symbols::Base.KeySet{Symbol, Dict{Symbol,Any}})
@capture(expr, let vars_; body_ end)
locals = Set{Symbol}()
(isa(vars, Expr) && vars.head==:(=)) || error("@resumable currently supports only single variable declarations in let blocks, i.e. only let blocks exactly of the form `let i=j; ...; end`. If you need multiple variables, please submit an issue on the issue tracker and consider contributing a patch.")
Expand Down Expand Up @@ -247,16 +247,14 @@ end
"""
Function that replaces a `@yield ret` or `@yield` statement with
```julia
Base.inferencebarrier(ret)
return ret
```
This version is used for inference only.
It makes sure that `val = @yield ret` is inferred as `Any` rather than `typeof(ret)`.
"""
function transform_yield(expr)
_is_yield(expr) || return expr
ret = length(expr.args) > 2 ? expr.args[3:end] : [nothing]
quote
Base.inferencebarrier($(ret...))
$(ret...)
end
end

Expand Down
137 changes: 23 additions & 114 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,71 +28,42 @@ function get_args(func_def::Dict)
arg_list, kwarg_list, arg_dict
end

"""
Takes a function definition and returns the expressions needed to forward the arguments to an inner function.
For example `function foo(a, ::Int, c...; x, y=1, z...)` will
1. moodify the function to `gensym()` nameless arguments
2. return `(:a, gensym(), :(c...)), (:x, :y, :(z...)))`
"""
function forward_args(func_def)
args = []
map!(func_def[:args], func_def[:args]) do arg
name, type, splat, default = splitarg(arg)
name = something(name, gensym())
if splat
push!(args, :($name...))
else
push!(args, name)
end
combinearg(name, type, splat, default)
end
kwargs = []
for arg in func_def[:kwargs]
name, type, splat, default = splitarg(arg)
if splat
push!(kwargs, :($name...))
else
push!(kwargs, name)
end
end
args, kwargs
end

const unused = (Symbol("#temp#"), Symbol("_"), Symbol(""), Symbol("#unused#"), Symbol("#self#"))

"""
Function returning the slots of a function definition
"""
function get_slots(func_def::Dict, args::Dict{Symbol, Any}, mod::Module)
function get_slots(func_def::Dict, args::Dict{Symbol, Any}, mod::Module) :: Dict{Symbol, Any}
slots = Dict{Symbol, Any}()
func_def[:name] = gensym()
func_def[:args] = (func_def[:args]..., func_def[:kwargs]...)
func_def[:kwargs] = []
# replace yield with inference barrier
body = func_def[:body]
func_def[:body] = postwalk(transform_yield, func_def[:body])
# collect items to skip
nosaves = Set{Symbol}()
func_def[:body] = postwalk(x->transform_nosave(x, nosaves), func_def[:body])
# eval function
func_expr = combinedef(func_def) |> flatten
@eval(mod, @noinline $func_expr)
# get typed code
codeinfos = @eval(mod, code_typed($(func_def[:name]), Tuple; optimize=false))
# extract slot names and types
for codeinfo in codeinfos
for (name, type) in collect(zip(codeinfo.first.slotnames, codeinfo.first.slottypes))
name nosaves && name unused && (slots[name] = Union{type, get(slots, name, Union{})})
name nosaves && (slots[name] = type)
end
end
# remove `catch exc` statements
for (argname, argtype) in args
slots[argname] = argtype
end
postwalk(x->remove_catch_exc(x, slots), func_def[:body])
# set error branches to Any
postwalk(x->make_arg_any(x, slots), body)
for (key, val) in slots
if val === Union{}
slots[key] = Any
end
end
return func_def[:name], slots
delete!(slots, Symbol("#temp#"))
delete!(slots, Symbol("_"))
delete!(slots, Symbol(""))
delete!(slots, Symbol("#unused#"))
delete!(slots, Symbol("#self#"))
slots
end

"""
Expand All @@ -103,6 +74,16 @@ function remove_catch_exc(expr, slots::Dict{Symbol, Any})
expr
end

"""
Function changing the type of a slot `arg` of a `arg = @yield ret` or `arg = @yield` statement to `Any`.
"""
function make_arg_any(expr, slots::Dict{Symbol, Any})
@capture(expr, arg_ = ex_) || return expr
_is_yield(ex) || return expr
slots[arg] = Any
expr
end

struct IteratorReturn{T}
value :: T
IteratorReturn(value) = new{typeof(value)}(value)
Expand All @@ -126,75 +107,3 @@ end
isnothing(ret) && return IteratorReturn(nothing)
ret
end

# this is similar to code_typed but it considers the world age
function code_typed_by_type(@nospecialize(tt::Type);
optimize::Bool=true,
world::UInt=Base.get_world_counter(),
interp::Core.Compiler.AbstractInterpreter=Core.Compiler.NativeInterpreter(world))
tt = Base.to_tuple_type(tt)
# look up the method
match, valid_worlds = Core.Compiler.findsup(tt, Core.Compiler.InternalMethodTable(world))
# run inference, normally not allowed in generated functions
frame = Core.Compiler.typeinf_frame(interp, match.method, match.spec_types, match.sparams, optimize)
frame === nothing && error("inference failed")
valid_worlds = Core.Compiler.intersect(valid_worlds, frame.valid_worlds)
return frame.linfo, frame.src, valid_worlds
end

function fsmi_generator(world::UInt, source::LineNumberNode, passtype, fsmitype::Type{Type{T}}, fargtypes) where T
@nospecialize
# get typed code of the inference function evaluated in get_slots
# but this time with concrete argument types
tt = Base.to_tuple_type(fargtypes)
mi, ci, valid_worlds = try
code_typed_by_type(tt; world, optimize=false)
catch err # inference failed, return generic type
Core.println(err)
slots = fieldtypes(T)[2:end]
stub = Core.GeneratedFunctionStub(identity, Core.svec(:pass, :fsmi, :fargs), Core.svec())
if isempty(slots)
return stub(world, source, :(return $T()))
else
return stub(world, source, :(return $T{$(slots...)}()))
end
end
min_world = valid_worlds.min_world
max_world = valid_worlds.max_world
# extract slot types
cislots = Dict{Symbol, Any}()
for (name, type) in collect(zip(ci.slotnames, ci.slottypes))
# take care to widen types that are unstable or Const
type = Core.Compiler.widenconst(type)
cislots[name] = Union{type, get(cislots, name, Union{})}
end
slots = map(slot->get(cislots, slot, Any), fieldnames(T)[2:end])
# generate code to instantiate the concrete type
stub = Core.GeneratedFunctionStub(identity, Core.svec(:pass, :fsmi, :fargs), Core.svec())
if isempty(slots)
exprs = stub(world, source, :(return $T()))
else
exprs = stub(world, source, :(return $T{$(slots...)}()))
end
# lower codeinfo to pass world age and invalidation edges
ci = ccall(:jl_expand_and_resolve, Any, (Any, Any, Any), exprs, passtype.name.module, Core.svec())
ci.min_world = min_world
ci.max_world = max_world
ci.edges = Core.MethodInstance[mi]
return ci
end

# JuliaLang/julia#48611: world age is exposed to generated functions, and should be used
if VERSION >= v"1.10.0-DEV.873"
# This is like @generated, but it receives the world age of the caller
# which we need to do inference safely and correctly
@eval function typed_fsmi(fsmi, fargs...)
$(Expr(:meta, :generated_only))
$(Expr(:meta, :generated, fsmi_generator))
end
else
# runtime fallback function that uses the fallback constructor with generic slot types
function typed_fsmi(fsmi::Type{T}, fargs...)::T where T
return T()
end
end
4 changes: 2 additions & 2 deletions test/test_jet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@ end
rep = report_package("ResumableFunctions";
report_pass=MayThrowIsOk(),
ignored_modules=(
Core.Compiler,

)
)
@show rep
@test length(JET.get_reports(rep)) <= 8
@test length(JET.get_reports(rep)) <= 5
@test_broken length(JET.get_reports(rep)) == 0
end
2 changes: 1 addition & 1 deletion test/test_main.jl
Original file line number Diff line number Diff line change
Expand Up @@ -225,5 +225,5 @@ end
end

@testset "test_unstable" begin
@test collect(test_unstable(3)) == ["number 1", "number 2", "number 3"]
@test_broken collect(test_unstable(3)) == ["number 1", "number 2", "number 3"]
end
2 changes: 1 addition & 1 deletion test/test_performance.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,5 +32,5 @@ end
if VERSION >= v"1.10.0-DEV.873"
cs = cumsum(1:1000)
@allocated cs() # shake out the compilation overhead
@test (@allocated cs())==0
@test_broken (@allocated cs())==0
end

2 comments on commit 1b0686f

@gerlero
Copy link
Member Author

@gerlero gerlero commented on 1b0686f Jan 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/98076

Tip: Release Notes

Did you know you can add release notes too? Just add markdown formatted text underneath the comment after the text
"Release notes:" and it will be added to the registry PR, and if TagBot is installed it will also be added to the
release that TagBot creates. i.e.

@JuliaRegistrator register

Release notes:

## Breaking changes

- blah

To add them here just re-invoke and the PR will be updated.

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.6.7 -m "<description of version>" 1b0686ff10c39ba3e69c5eaaa67320cff1eae79a
git push origin v0.6.7

Please sign in to comment.