Skip to content

Commit

Permalink
fixup, rm many comments
Browse files Browse the repository at this point in the history
  • Loading branch information
mcabbott committed Oct 18, 2022
1 parent b45bc0b commit 4f144fd
Show file tree
Hide file tree
Showing 5 changed files with 11 additions and 74 deletions.
10 changes: 7 additions & 3 deletions src/rulesets/Base/base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -249,13 +249,17 @@ end
#####

function rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(map), f::F, x::AbstractArray) where {F}
# y, back = rrule_via_ad(cfg, Broadcast.broadcasted, f, x) # Yota likes this one
# return Broadcast.materialize(y), back
y, back = rrule_via_ad(cfg, broadcast, f, x) # but testing like this one
# Here map agrees with broadcast, and we have a meta-rule with 4 different paths, should be fast:
y, back = rrule_via_ad(cfg, broadcast, f, x)
return y, back
end

function rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(map), f::F, x::AbstractArray, ys::AbstractArray...) where {F}
if all(==(size(x)), map(size, ys))
# Here too map agrees with broadcast, maybe the test could be more elegant?
y, back = rrule_via_ad(cfg, broadcast, f, x, ys...)
return y, back
end
@debug "rrule(map, f, arrays...)" f
z, backs = unzip_map((xy...) -> rrule_via_ad(cfg, f, xy...), x, ys...)
function map_pullback_2(dz)
Expand Down
20 changes: 0 additions & 20 deletions src/rulesets/Base/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -121,26 +121,6 @@ end
# Path 4: The most generic, save all the pullbacks. Can be 1000x slower.
# While broadcast makes no guarantee about order of calls, it's cheap to reverse the iteration.

#=
julia> Yota.grad(xs -> sum(abs2, (x -> abs(x)).(xs)), [1,2,3.0])
┌ Debug: split broadcasting generic
│ f = #69 (generic function with 1 method)
│ N = 1
└ @ ChainRules ~/.julia/dev/ChainRules/src/rulesets/Base/broadcast.jl:126
(14.0, (ZeroTangent(), [2.0, 4.0, 6.0]))
julia> ENV["JULIA_DEBUG"] = nothing
julia> @btime Yota.grad(xs -> sum(abs2, (x -> abs(x)).(xs)), $(rand(1000)));
min 1.321 ms, mean 1.434 ms (23010 allocations, 594.66 KiB) # with unzip_map, as before
min 1.279 ms, mean 1.393 ms (23029 allocations, 595.73 KiB) # with unzip_map_reversed
julia> @btime Yota.grad(xs -> sum(abs2, abs.(xs)), $(randn(1000))); # Debug: split broadcasting derivative
min 2.144 μs, mean 6.620 μs (6 allocations, 23.88 KiB)
=#

function split_bc_pullbacks(cfg::RCR, f::F, args::Vararg{Any,N}) where {F,N}
@debug("split broadcasting generic", f, N)
ys3, backs = unzip_broadcast(args...) do a...
Expand Down
49 changes: 0 additions & 49 deletions src/rulesets/Base/iterators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,57 +17,10 @@ function rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(collect), gen::G) whe
ys, generator_pullback
end

# Needed for Yota, but shouldn't these be automatic?
ChainRulesCore.rrule(::Type{<:Base.Generator}, f, iter) = Base.Generator(f, iter), dy -> (NoTangent(), dy.f, dy.iter)
ChainRulesCore.rrule(::Type{<:Iterators.ProductIterator}, iters) = Iterators.ProductIterator(iters), dy -> (NoTangent(), dy.iterators)

#=
Yota.grad(xs -> sum(abs, [sin(x) for x in xs]), [1,2,3]pi/3)
Diffractor.gradient(xs -> sum(abs, [sin(x) for x in xs]), [1,2,3]pi/3)
Yota.grad((xs, ys) -> sum(abs, [atan(x/y) for x in xs, y in ys]), [1,2,3]pi/3, [4,5]) # ERROR: all field arrays must have same shape
Diffractor.gradient((xs, ys) -> sum(abs, [atan(x/y) for x in xs, y in ys]), [1,2,3]pi/3, [4,5]) # ERROR: type Array has no field iterators
Yota.grad(xs -> sum(abs, map(sin, xs)), [1,2,3]pi/3)
Diffractor.gradient(xs -> sum(abs, map(sin, xs)), [1,2,3]pi/3) # fails internally
Yota.grad(xs -> sum(abs, [sin(x/y) for (x,y) in zip(xs, 1:2)]), [1,2,3]pi/3)
Diffractor.gradient(xs -> sum(abs, [sin(x/y) for (x,y) in zip(xs, 1:2)]), [1,2,3]pi/3)
Yota.grad(xs -> sum(abs, map((x,y) -> sin(x/y), xs, 1:2)), [1,2,3]pi/3)
Diffractor.gradient(xs -> sum(abs, map((x,y) -> sin(x/y), xs, 1:2)), [1,2,3]pi/3)
@btime Yota.grad($(rand(1000))) do xs
sum(abs2, [sqrt(x) for x in xs])
end
# Yota min 759.000 μs, mean 800.754 μs (22041 allocations, 549.62 KiB)
# Diffractor min 559.000 μs, mean 622.464 μs (18051 allocations, 612.34 KiB)
# Zygote min 3.198 μs, mean 6.849 μs (20 allocations, 40.11 KiB)
@btime Yota.grad($(rand(1000)), $(rand(1000))) do xs, ys
zs = map(xs, ys) do x, y
atan(x/y)
end
sum(abs2, zs)
end
# Yota + CR: min 1.598 ms, mean 1.691 ms (38030 allocations, 978.75 KiB)
# Diffractor + CR: min 767.250 μs, mean 847.640 μs (26045 allocations, 838.66 KiB)
# Zygote: min 13.417 μs, mean 22.896 μs (26 allocations, 79.59 KiB) -- 100x faster
=#


#####
##### `zip`
#####


function rrule(::typeof(zip), xs::AbstractArray...)
function zip_pullback(dy)
@debug "zip array pullback" summary(dy)
Expand All @@ -94,8 +47,6 @@ function _unmap_pad(x::AbstractArray, dx::AbstractArray)
@debug "_unmap_pad is extending gradient" length(x) == length(dx)
i1 = firstindex(x)
∇getindex(x, vec(dx), i1:i1+length(dx)-1)
# dx2 = vcat(vec(dx), similar(x, ZeroTangent, length(x) - length(dx)))
# ProjectTo(x)(reshape(dx2, axes(x)))
end
end

Expand Down
2 changes: 2 additions & 0 deletions test/rulesets/Base/base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,8 @@

@testset "map(f, ::Array, ::Array)" begin
test_rrule(map, +, [1.0, 2.0], [3.0, 4.0], check_inferred=false) # NoTangent does not match Union{NoTangent, ZeroTangent}
test_rrule(map, atan, [1 2; 3.0 4.0], [4 5; 6 7.0], check_inferred=false) # same shape => just broadcast

test_rrule(map, /, [1.0, 2.0], [3.0, 4.0, 5.0], check_inferred=false)
test_rrule(map, atan, [1, 2, 3.0], [4 5; 6 7.0], check_inferred=false)

Expand Down
4 changes: 2 additions & 2 deletions test/rulesets/Base/iterators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

y2, bk2 = rrule(CFG, collect, Iterators.map(Counter(), [11, 12, 13.0]))
@test y2 == map(Counter(), 11:13)
@test bk2(ones(3))[2].iter == [93, 83, 73]
@test bk2(ones(3))[2].iter == [33, 23, 13]
end
end

Expand All @@ -23,4 +23,4 @@ end
test_rrule(collectzip, rand(3), rand(5))
test_rrule(collectzip, rand(3,2), rand(5))
end
end
end

0 comments on commit 4f144fd

Please sign in to comment.