Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

use LoopVectorization to vectorize activation functions and softmax #199

Merged
merged 18 commits into from
Jul 7, 2020

Conversation

AStupidBear
Copy link
Contributor

@AStupidBear AStupidBear commented May 10, 2020

using BenchmarkTools
using NNlib
using Zygote

Old NNlib

julia> for f in (tanh, σ)
           println(f)
           @btime $f.(x)
           @btime Zygote.gradient($x) do z
               sum($f.(z))
           end
       end
tanh
  88.858 μs (3 allocations: 16.16 KiB)
  114.096 μs (20 allocations: 32.61 KiB)
σ
  48.257 μs (3 allocations: 16.16 KiB)
  56.673 μs (5 allocations: 32.30 KiB)

julia> for d in 1:ndims(x)
           println(softmax, " dim=", d)
           @btime softmax($x, dims = $d)
           @btime Zygote.gradient($x) do z
               sum(softmax(z, dims = $d))
           end
       end
softmax dim=1
  58.959 μs (13 allocations: 32.86 KiB)
  133.949 μs (30 allocations: 98.19 KiB)
softmax dim=2
  51.954 μs (31 allocations: 34.23 KiB)
  110.524 μs (72 allocations: 101.53 KiB)

This PR:

julia> for f in (tanh, σ)
           println(f)
           @btime $f.(x)
           @btime Zygote.gradient($x) do z
               sum($f.(z))
           end
       end
tanh
  9.181 μs (1 allocation: 16.13 KiB)
  11.725 μs (20 allocations: 32.61 KiB)
σ
  2.344 μs (1 allocation: 16.13 KiB)
  4.755 μs (5 allocations: 32.30 KiB)

julia> for d in 1:ndims(x)
           println(softmax, " dim=", d)
           @btime softmax($x, dims = $d)
           @btime Zygote.gradient($x) do z
               sum(softmax(z, dims = $d))
           end
       end
softmax dim=1
  25.194 μs (13 allocations: 32.86 KiB)
  56.679 μs (30 allocations: 98.19 KiB)
softmax dim=2
  9.377 μs (31 allocations: 34.23 KiB)
  25.198 μs (72 allocations: 101.53 KiB)

Other activation functions can be sped up by overloading Base.broadcasted after the adjoint is defined in LoopVectorization (JuliaSIMD/LoopVectorization.jl#108).

@CarloLucibello
Copy link
Member

vreduce doesn't support a dims args yet

@AStupidBear
Copy link
Contributor Author

@CarloLucibello The next release of LoopVectorization will has that.

@AStupidBear
Copy link
Contributor Author

Tests are passed now. Any idea?

src/activation.jl Outdated Show resolved Hide resolved
src/activation.jl Outdated Show resolved Hide resolved
@CarloLucibello
Copy link
Member

Is all of this Zygote friendly?

src/softmax.jl Outdated Show resolved Hide resolved
src/softmax.jl Outdated Show resolved Hide resolved
@CarloLucibello
Copy link
Member

@AStupidBear bump on this, it would be really nice to have this performance improvement

@chriselrod
Copy link

chriselrod commented Jul 2, 2020

FWIW, I added a much faster AVX512-Float32 tanh to SLEEFPirates.
I haven't released it yet, but if it will be used here, I can try to add AVX2 and Float64 versions as well.

These will still probably be slower than:

function tanh_fast(x)
    exp2x = exp(x + x)
    (exp2x - 1)/(exp2x + 1)
end

But they are a little more accurate.

@AStupidBear
Copy link
Contributor Author

Is all of this Zygote friendly?

Since all those modifications are in kernel functions used in the forward and backward pass of Zygote, they are Zygote friendly.

@CarloLucibello
Copy link
Member

what's the status of this?

@AStupidBear
Copy link
Contributor Author

It's ready to get merged.

julia> using NNlib, StaticArrays; x = SMatrix{2, 2}(rand(2, 2))
2×2 SArray{Tuple{2,2},Float64,2,4} with indices SOneTo(2)×SOneTo(2):
 0.604445  0.330955
 0.975996  0.909042

julia> softmax(x)
2×2 SArray{Tuple{2,2},Float64,2,4} with indices SOneTo(2)×SOneTo(2):
 0.408166  0.359373
 0.591834  0.640627

julia> σ.(x)
2×2 SArray{Tuple{2,2},Float64,2,4} with indices SOneTo(2)×SOneTo(2):
 0.646673  0.581992
 0.726313  0.712804

julia> logsoftmax(x)
2×2 SArray{Tuple{2,2},Float64,2,4} with indices SOneTo(2)×SOneTo(2):
 -0.89608  -1.02339 
 -0.52453  -0.445308

@CarloLucibello
Copy link
Member

Alright, let's merge this and tag a new release. If no problems come up, we can then extend apply vmap to all activations by providing custom adjoints, right @AStupidBear ?

@CarloLucibello CarloLucibello merged commit 1f0388d into FluxML:master Jul 7, 2020
@AStupidBear
Copy link
Contributor Author

@CarloLucibello Yes! But where should we put those definitions? Zygote?

@CarloLucibello
Copy link
Member

@AStupidBear
Copy link
Contributor Author

Maybe it's better to dispatch other activation functions to vmap and then define the adjoint for vmap? @chriselrod Any idea?

@CarloLucibello
Copy link
Member

I think we can simply copy the adjoint map, i.e. doing something similar to FluxML/Zygote.jl#728

@ChrisRackauckas
Copy link
Member

yeah, we can just add vmap to the loop I have in that PR. That would need a LoopVectorization dependency then?

@CarloLucibello
Copy link
Member

yes. also needs some tests

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants