Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

broadcast noop not inferred #56

Closed
Keno opened this issue Jul 17, 2018 · 4 comments · Fixed by #57
Closed

broadcast noop not inferred #56

Keno opened this issue Jul 17, 2018 · 4 comments · Fixed by #57

Comments

@Keno
Copy link
Contributor

Keno commented Jul 17, 2018

julia> Cassette.@context NoOp

julia> f(a, b) = a .+ b

julia> (@code_typed f(rand(Float32, 10), rand(Float32, 10)))[2]
Array{Float32,1}

julia> (@code_typed Cassette.overdub(NoOp(), f, rand(Float32, 10), rand(Float32, 10)))[2]
Any
@jrevels
Copy link
Collaborator

jrevels commented Jul 17, 2018

just leaving this here as a starting point for myself later:

julia> using Cassette, Base.Broadcast

julia> x = rand(1);

julia> Cassette.@context Ctx

julia> @code_typed Cassette.overdub(Ctx(), t -> Base.Broadcast.combine_axes(t...), (x, x))
CodeInfo(
25 1getfield(%%args, 1)                                                                                                                                                                              │
   │         getfield(%%args, 2)                                                                                                                                                                              │
26getfield(%%args, 1)                                                                                                                                                                              │
   │         getfield(%%args, 2)                                                                                                                                                                              │
   └──       goto 3 if not false2nothing29 3getfield(%%args, 1)                                                                                                                                                                              │
   │   %8  = getfield(%%args, 2)::Tuple{Array{Float64,1},Array{Float64,1}}                                                                                                                                    │
   │   %9  = :(Base.Broadcast)::Module                                                                                                                                                                        │╻       #6
   └──       goto 4 if not false                                                                                                                                                                              ││╻       overdub
   4%11 = π (%9, Module)                                                                                                                                                                                   │││╻       getproperty
   │   %12 = π (:combine_axes, Symbol)                                                                                                                                                                        ││││
   │   %13 = :(Base.getfield)::typeof(getfield)                                                                                                                                                               ││││
   └──       goto 5 if not false                                                                                                                                                                              ││││╻       overdub
   5 ┄       Base.getfield(Cassette, :execute)                                                                                                                                                                │││││╻╷╷     recurse
   │         %13(%11, %12)                                                                                                                                                                                    ││││││╻       macro expansion
   └──       goto 6                                                                                                                                                                                           ││││╻       overdub
   6 ─       goto 7                                                                                                                                                                                           ││││
   7 ─       goto 8                                                                                                                                                                                           ││╻       overdub
   8%20 = :(Core._apply)::typeof(Core._apply)                                                                                                                                                              │││╻       execute
   │   %21 = %20(getfield(Main, Symbol("##4#5")){Cassette.Context{nametype(Ctx),Nothing,Cassette.NoPass,Nothing,Nothing},typeof(Base.Broadcast.combine_axes)}(Cassette.Context{nametype(Ctx),Nothing,Cassette.NoPass,Nothing,Nothing}(nametype(Ctx)(), nothing, Cassette.NoPass(), nothing, nothing), Base.Broadcast.combine_axes), %8)::Any
   └──       goto 9                                                                                                                                                                                           ││
31 9getfield(%%args, 1)                                                                                                                                                                              │
   │         getfield(%%args, 2)                                                                                                                                                                              │
32 └──       return %21                                                                                                                                                                                       │
) => Any

@jrevels
Copy link
Collaborator

jrevels commented Jul 17, 2018

Okay, so the two problems here are _apply and getproperty.

problematic _apply example:

julia> Cassette.recurse_typed(Ctx(), t -> Base.Broadcast.combine_axes(t...), (rand(1), rand(1)))
1-element Array{Any,1}:
 CodeInfo(
1 1 ─      :(#self# = (Core.getfield)(Core.Compiler.Argument(3), 1))        │
  │        :(t = (Core.getfield)(Core.Compiler.Argument(3), 2))             │
  │   %3 = :(Base.Broadcast)::Core.Compiler.Const(Base.Broadcast, false)    │
  │   %4 = Cassette.overdub(%%##recurse_context#376, Base.getproperty, %3, :combine_axes)::Core.Compiler.Const(Base.Broadcast.combine_axes, false)%5 = Cassette.overdub(%%##recurse_context#376, Core._apply, %4, %%t)::Any
  └──      return %5                                                        │
) => Any

problematic getproperty:

julia> Cassette.recurse_typed(Ctx(), Broadcast.copy, Broadcast.instantiate(Broadcast.broadcasted(+, rand(1), rand(1))))
1-element Array{Any,1}:
 CodeInfo(
    1 ─       :(#self# = (Core.getfield)(Core.Compiler.Argument(3), 1))     │
    │         :(bc = (Core.getfield)(Core.Compiler.Argument(3), 2))         │
%14 = Cassette.overdub(%%##recurse_context#376, Base.getproperty, %%bc, :f)::Union{typeof(+), Tuple}%15 = Cassette.overdub(%%##recurse_context#376, Base.getproperty, %%bc, :args)::Union{typeof(+), Tuple}

) => Any

Making both _apply and getproperty primitives causes the whole example to infer totally.

I'm okay with making getproperty a default primitive, at least when tagging is disabled (context authors can just unmark it as a primitive if they want to trace into it), though that feels like a workaround for a problem we should really try to fix in the compiler eventually.

The _apply is more complicated, I'll try to mess around with alternatives...

@jrevels
Copy link
Collaborator

jrevels commented Jul 17, 2018

Argh. Fixing the _apply case above is easy on it's own - it seems that switching the default _apply overdub primitive to use nested _apply calls (instead of creating a closure) fixes it.

However, having _apply as a primitive at all seems to cause the getproperty example above not to infer. I've tracked it down to here:

julia> using Cassette, Base.Broadcast

julia> Cassette.@context Ctx

julia> b = Broadcast.instantiate(Broadcast.broadcasted(+, rand(1), rand(1)))
Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}}(+, ([0.0228267], [0.173769]))

julia> Cassette.recurse_typed(Ctx(), Broadcast._broadcast_getindex, b, 1)
1-element Array{Any,1}:
 CodeInfo(
    1 ─      :(#self# = (Core.getfield)(Core.Compiler.Argument(3), 1))      │
    │        :(bc = (Core.getfield)(Core.Compiler.Argument(3), 2))          │
    │        :(I = (Core.getfield)(Core.Compiler.Argument(3), 3))           │
    │        nothing551%5 = Cassette.overdub(%%##recurse_context#376, Base.getproperty, %%bc, :args)::Tuple{Array{Float64,1},Array{Float64,1}}
    │        :(args = (Cassette.overdub)(Core.Compiler.Argument(2), Base.Broadcast._getindex, Core.SSAValue(5), Core.Compiler.Argument(6)))
552%7 = Cassette.overdub(%%##recurse_context#376, Base.getproperty, %%bc, :f)::Core.Compiler.Const(+, false)%8 = Cassette.overdub(%%##recurse_context#376, Core.tuple, %7)::Core.Compiler.Const((+,), false)%9 = Cassette.overdub(%%##recurse_context#376, Core._apply, Base.Broadcast._broadcast_getindex_evalf, %8, %%args)::Any
    └──      return %9                                                      │
) => Any

That's with _apply as a primitive. Here it is without _apply as a primitive:

julia> Cassette.recurse_typed(Ctx(), Broadcast._broadcast_getindex, b, 1)
1-element Array{Any,1}:
 CodeInfo(
    1 ─      :(#self# = (Core.getfield)(Core.Compiler.Argument(3), 1))      │
    │        :(bc = (Core.getfield)(Core.Compiler.Argument(3), 2))          │
    │        :(I = (Core.getfield)(Core.Compiler.Argument(3), 3))           │
    │        nothing551%5 = Cassette.overdub(%%##recurse_context#376, Base.getproperty, %%bc, :args)::Tuple{Array{Float64,1},Array{Float64,1}}
    │        :(args = (Cassette.overdub)(Core.Compiler.Argument(2), Base.Broadcast._getindex, Core.SSAValue(5), Core.Compiler.Argument(6)))
552%7 = Cassette.overdub(%%##recurse_context#376, Base.getproperty, %%bc, :f)::Core.Compiler.Const(+, false)%8 = Cassette.overdub(%%##recurse_context#376, Core.tuple, %7)::Core.Compiler.Const((+,), false)%9 = Cassette.overdub(%%##recurse_context#376, Core._apply, Base.Broadcast._broadcast_getindex_evalf, %8, %%args)::Float64
    └──      return %9                                                      │
) => Float64

If we can come up with an _apply primitive definition that works for both this and the previous example, then I believe that should solve this issue entirely.

@jrevels
Copy link
Collaborator

jrevels commented Jul 17, 2018

Note that if I set optimize=true on the previous example, it does unroll the _apply call all the way to a final invoke of recurse where the arguments are well-typed, but the output is not:

julia> Cassette.recurse_typed(Ctx(), Broadcast._broadcast_getindex, b, 1; optimize=true)
1-element Array{Any,1}:
 CodeInfo(
 # everything until here seems well-inferred
    431%746 = invoke Cassette.recurse(:($(QuoteNode(Cassette.Context{nametype(Ctx),Nothing,Cassette.NoPass,Nothing,Nothing}(nametype(Ctx)(), nothing, Cassette.NoPass(), nothing, nothing))))::Cassette.Context{nametype(Ctx),Nothing,Cassette.NoPass,Nothing,Nothing}, Base.Broadcast._broadcast_getindex_evalf::typeof(Base.Broadcast._broadcast_getindex_evalf), %742::typeof(+), %743::Float64, %744::Float64)::Any
    └────        goto 432
    432 ─        goto 433
    433 ─        goto 434
    434 ─        goto 435                 overdub
    435return %746
) => Any

However, this call independently seems to infer fine:

julia> Cassette.recurse_typed(Ctx(), Broadcast._broadcast_getindex_evalf, +, 1.0, 1.0)
1-element Array{Any,1}:
 CodeInfo(
    1 ─      :(#self# = (Core.getfield)(Core.Compiler.Argument(3), 1))      │
    │        :(f = (Core.getfield)(Core.Compiler.Argument(3), 2))           │
    │   %3 = Core.getfield(%%##recurse_arguments#377, 3)::Float64           │%4 = Core.getfield(%%##recurse_arguments#377, 4)::Float64           │
    │        :(args = (Core.tuple)(Core.SSAValue(3), Core.SSAValue(4)))     │
    │        nothing579%7 = Cassette.overdub(%%##recurse_context#376, Core._apply, %%f, %%args)::Float64
    └──      return %7                                                      │
) => Float64

Hmm....

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants