Skip to content

Commit

Permalink
Merge pull request #860 from dsweber2/activations
Browse files Browse the repository at this point in the history
Activations
  • Loading branch information
MikeInnes committed Nov 19, 2019
2 parents 2fa3e56 + dea2953 commit 5839e16
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 11 deletions.
16 changes: 10 additions & 6 deletions src/layers/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,19 +44,23 @@ end
# it might be replaced in the future for better performance
# see issue https://github.com/FluxML/Flux.jl/issues/702
# Johnny Chen -- @johnnychen94
# only slightly changed to better handle interaction with Zygote @dsweber2
"""
activations(c::Chain, input)
Calculate the forward results of each layers in Chain `c` with `input` as model input.
"""
function activations(c::Chain, input)
rst = []
for l in c
x = get(rst, length(rst), input)
push!(rst, l(x))
end
return rst
extraChain(c.layers, input)
end

function extraChain(fs::Tuple, x)
res = first(fs)(x)
return (res, extraChain(Base.tail(fs), res)...)
end

extraChain(::Tuple{}, x) = ()



"""
Dense(in::Integer, out::Integer, σ = identity)
Expand Down
18 changes: 13 additions & 5 deletions test/layers/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@ import Flux: activations
@testset "basic" begin
@testset "helpers" begin
@testset "activations" begin
dummy_model = Chain(Dense(10,5,σ),Dense(5,2),softmax)
x = rand(10)
@test activations(Chain(), x) == []
@test activations(dummy_model, x)[1] == dummy_model[1](x)
@test activations(dummy_model, x)[2] == x |> dummy_model[1] |> dummy_model[2]
dummy_model = Chain(x->x.^2, x->x .- 3, x -> tan.(x))
x = randn(10)
@test activations(dummy_model, x)[1] == x.^2
@test activations(dummy_model, x)[2] == (x.^2 .- 3)
@test activations(dummy_model, x)[3] == tan.(x.^2 .- 3)

@test activations(Chain(), x) == ()
@test activations(Chain(identity, x->:foo), x)[2] == :foo # results include `Any` type
end
end
Expand All @@ -19,6 +21,12 @@ import Flux: activations
# numeric test should be put into testset of corresponding layer
end

@testset "Activations" begin
c = Chain(Dense(3,5,relu), Dense(5,1,relu))
X = Float32.([1.0; 1.0; 1.0])
@test_nowarn gradient(()->Flux.activations(c, X)[2][1], params(c))
end

@testset "Dense" begin
@test length(Dense(10, 5)(randn(10))) == 5
@test_throws DimensionMismatch Dense(10, 5)(randn(1))
Expand Down

0 comments on commit 5839e16

Please sign in to comment.