Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Slowdown caused by broadcasting macro? #502

Open
mcabbott opened this issue Feb 5, 2020 · 3 comments
Open

Slowdown caused by broadcasting macro? #502

mcabbott opened this issue Feb 5, 2020 · 3 comments

Comments

@mcabbott
Copy link
Member

mcabbott commented Feb 5, 2020

I expected g1 and g2 here to be identical:

using Zygote, BenchmarkTools
function g1(x)
    y = reshape(x,1,1,:)
    dropdims(sum(x .* x' .* y; dims=(2,3)), dims=(2,3))
end
function g2(x)
    y = reshape(x,1,1,:)
    dropdims(sum(@.(x * x' * y); dims=(2,3)), dims=(2,3))
end
@btime gradient(sumg1, $(rand(20))); #  33.147 μs (74 allocations: 266.53 KiB)
@btime gradient(sumg2, $(rand(20))); # 962.748 μs (56172 allocations: 2.39 MiB)

Kristoffer Carlsson points me to JuliaLang/julia#29120 as a possible source of this. It looks like I can also trigger it like this:

function g4(x)
    y = reshape(x,1,1,:)
    dropdims(sum(.*(x, x', y); dims=(2,3)), dims=(2,3))
end
@btime gradient(sumg4, $(rand(20))); # 960.827 μs (56172 allocations: 2.39 MiB)
@mbauman
Copy link

mbauman commented May 20, 2020

Yes, this is a dup of JuliaLang/julia#29120 and/or a potential Zygote improvement. Explicitly adding parens here fixes it:

sum(@.((x * x') * y)

I think a potential Zygote-level fix would be to add a 3-arg gradient override for broadcasted *.

@adjoint broadcasted(::typeof(*), x::Numeric, y::Numeric) = x.*y,
-> (nothing, unbroadcast(x, z̄ .* conj.(y)), unbroadcast(y, z̄ .* conj.(x)))

@cossio
Copy link
Contributor

cossio commented May 20, 2020

How come adjoint is not being broadcasted in @.(x * x' * y)?

@cossio
Copy link
Contributor

cossio commented May 22, 2020

I think a potential Zygote-level fix would be to add a 3-arg gradient override for broadcasted *.

But would this cover products with more terms, @. A * B * C * D * E ....?
Perhaps we need an n-arg override,

@adjoint broadcasted(::typeof(*), x::Numeric...) = ...

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants