From 540b7366ec0edd711953223ef44bf342d691127f Mon Sep 17 00:00:00 2001 From: dsweber2 Date: Tue, 10 Sep 2019 00:54:49 -0700 Subject: [PATCH 01/15] make activations zygote friendly --- src/layers/basic.jl | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/src/layers/basic.jl b/src/layers/basic.jl index 83eeee21af..b4b869c52e 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -45,17 +45,15 @@ end # it might be replaced in the future for better performance # see issue https://github.com/FluxML/Flux.jl/issues/702 # Johnny Chen -- @johnnychen94 +# only slightly changed to better handle interaction with Zygote @dsweber2 """ activations(c::Chain, input) Calculate the forward results of each layers in Chain `c` with `input` as model input. """ function activations(c::Chain, input) - rst = [] - for l in c - x = get(rst, length(rst), input) - push!(rst, l(x)) - end - return rst + buffed = accumulate!((x,y)->y(x), Zygote.Buffer([], length(c)), + [l for l in c], dims=1, init=input) + return copy(buffed) end From 38790dd4db5520e6e587783804d1144a3b75ac9d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mos=C3=A8=20Giordano?= Date: Sun, 8 Sep 2019 16:15:35 +0100 Subject: [PATCH 02/15] Restore purity --- .gitattributes | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitattributes b/.gitattributes index e02ed0b720..4992eb2c25 100644 --- a/.gitattributes +++ b/.gitattributes @@ -1 +1,2 @@ paper/* linguist-documentation +CITATION.bib linguist-detectable=false From 82261b5bb7e6783d6a273c8e7803c4fbb28a3dd8 Mon Sep 17 00:00:00 2001 From: dsweber2 Date: Tue, 10 Sep 2019 00:54:49 -0700 Subject: [PATCH 03/15] make activations zygote friendly --- src/layers/basic.jl | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/src/layers/basic.jl b/src/layers/basic.jl index 13d5647267..fd187d8c29 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -45,17 +45,15 @@ end # it might be replaced in the future for better performance # see issue https://github.com/FluxML/Flux.jl/issues/702 # Johnny Chen -- @johnnychen94 +# only slightly changed to better handle interaction with Zygote @dsweber2 """ activations(c::Chain, input) Calculate the forward results of each layers in Chain `c` with `input` as model input. """ function activations(c::Chain, input) - rst = [] - for l in c - x = get(rst, length(rst), input) - push!(rst, l(x)) - end - return rst + buffed = accumulate!((x,y)->y(x), Zygote.Buffer([], length(c)), + [l for l in c], dims=1, init=input) + return copy(buffed) end From 1bb25dc1f9c54666d73b516629e0c89033e1c0e2 Mon Sep 17 00:00:00 2001 From: dsweber2 Date: Tue, 10 Sep 2019 01:34:12 -0700 Subject: [PATCH 04/15] adding the extra commits broke the accumulate version --- src/layers/basic.jl | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/layers/basic.jl b/src/layers/basic.jl index fd187d8c29..e1e9ab4519 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -51,9 +51,12 @@ end Calculate the forward results of each layers in Chain `c` with `input` as model input. """ function activations(c::Chain, input) - buffed = accumulate!((x,y)->y(x), Zygote.Buffer([], length(c)), - [l for l in c], dims=1, init=input) - return copy(buffed) + res = Zygote.Buffer([], length(c)) + res[1] = c[1](input) + for (i,l) in enumerate(c[2:end]) + res[i+1] = l(res[i]) + end + return copy(res) end From f41219133e8a233c8e0056972641378c4e83c427 Mon Sep 17 00:00:00 2001 From: dsweber2 Date: Tue, 10 Sep 2019 10:46:56 -0700 Subject: [PATCH 05/15] deal with empty Chain --- src/layers/basic.jl | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/layers/basic.jl b/src/layers/basic.jl index e1e9ab4519..9ef6f19549 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -52,9 +52,11 @@ Calculate the forward results of each layers in Chain `c` with `input` as model """ function activations(c::Chain, input) res = Zygote.Buffer([], length(c)) - res[1] = c[1](input) - for (i,l) in enumerate(c[2:end]) - res[i+1] = l(res[i]) + if length(c) > 0 + res[1] = c[1](input) + for (i,l) in enumerate(c[2:end]) + res[i+1] = l(res[i]) + end end return copy(res) end From 46abfbbd5cd4579e66912996c5ff4b568a01d1ea Mon Sep 17 00:00:00 2001 From: dsweber2 Date: Wed, 11 Sep 2019 17:36:37 -0700 Subject: [PATCH 06/15] recursive way of doing activations --- src/layers/basic.jl | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/src/layers/basic.jl b/src/layers/basic.jl index 9ef6f19549..e2e3e56ae3 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -51,16 +51,17 @@ end Calculate the forward results of each layers in Chain `c` with `input` as model input. """ function activations(c::Chain, input) - res = Zygote.Buffer([], length(c)) - if length(c) > 0 - res[1] = c[1](input) - for (i,l) in enumerate(c[2:end]) - res[i+1] = l(res[i]) - end - end - return copy(res) + extraChain(c.layers, input) end +function extraChain(fs::Tuple, x) + res = first(fs)(x) + return (res, extraChain(Base.tail(fs), res)...) +end + +extraChain(::Tuple{}, x) = [] + + """ Dense(in::Integer, out::Integer, σ = identity) From 3b7b780d398bef91f2e793e2293f140d8c3b9241 Mon Sep 17 00:00:00 2001 From: dsweber2 Date: Tue, 8 Oct 2019 23:04:31 -0700 Subject: [PATCH 07/15] super simple test --- test/layers/basic.jl | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/test/layers/basic.jl b/test/layers/basic.jl index cbe250fcca..4edfecc777 100644 --- a/test/layers/basic.jl +++ b/test/layers/basic.jl @@ -19,6 +19,12 @@ import Flux: activations # numeric test should be put into testset of corresponding layer end + @testset "Activations" begin + c = Chain(Dense(3,5,relu), Dense(5,1,relu)) + X = Float32.([1.0; 1.0; 1.0]) + @test_nowarn gradient(()->Flux.activations(c, X)[2][1], params(c)) + end + @testset "Dense" begin @test length(Dense(10, 5)(randn(10))) == 5 @test_throws DimensionMismatch Dense(10, 5)(randn(1)) From cdaaca8cfa880b2f45f30379639f347b3ebfd175 Mon Sep 17 00:00:00 2001 From: dsweber2 Date: Tue, 10 Sep 2019 00:54:49 -0700 Subject: [PATCH 08/15] make activations zygote friendly --- src/layers/basic.jl | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/src/layers/basic.jl b/src/layers/basic.jl index f42a9619f9..e8dde1a339 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -44,17 +44,15 @@ end # it might be replaced in the future for better performance # see issue https://github.com/FluxML/Flux.jl/issues/702 # Johnny Chen -- @johnnychen94 +# only slightly changed to better handle interaction with Zygote @dsweber2 """ activations(c::Chain, input) Calculate the forward results of each layers in Chain `c` with `input` as model input. """ function activations(c::Chain, input) - rst = [] - for l in c - x = get(rst, length(rst), input) - push!(rst, l(x)) - end - return rst + buffed = accumulate!((x,y)->y(x), Zygote.Buffer([], length(c)), + [l for l in c], dims=1, init=input) + return copy(buffed) end From d0202a2945bf86a7827075c77642405b25c752fe Mon Sep 17 00:00:00 2001 From: dsweber2 Date: Tue, 10 Sep 2019 01:34:12 -0700 Subject: [PATCH 09/15] adding the extra commits broke the accumulate version --- src/layers/basic.jl | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/layers/basic.jl b/src/layers/basic.jl index e8dde1a339..2d86da858c 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -50,9 +50,12 @@ end Calculate the forward results of each layers in Chain `c` with `input` as model input. """ function activations(c::Chain, input) - buffed = accumulate!((x,y)->y(x), Zygote.Buffer([], length(c)), - [l for l in c], dims=1, init=input) - return copy(buffed) + res = Zygote.Buffer([], length(c)) + res[1] = c[1](input) + for (i,l) in enumerate(c[2:end]) + res[i+1] = l(res[i]) + end + return copy(res) end From 99679f7e16b2244ace129e9c6288b4ab2159a452 Mon Sep 17 00:00:00 2001 From: dsweber2 Date: Tue, 10 Sep 2019 10:46:56 -0700 Subject: [PATCH 10/15] deal with empty Chain --- src/layers/basic.jl | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/layers/basic.jl b/src/layers/basic.jl index 2d86da858c..c3783567fa 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -51,9 +51,11 @@ Calculate the forward results of each layers in Chain `c` with `input` as model """ function activations(c::Chain, input) res = Zygote.Buffer([], length(c)) - res[1] = c[1](input) - for (i,l) in enumerate(c[2:end]) - res[i+1] = l(res[i]) + if length(c) > 0 + res[1] = c[1](input) + for (i,l) in enumerate(c[2:end]) + res[i+1] = l(res[i]) + end end return copy(res) end From 6475f6a43eba8feab5f34a7dc2cf0f86d1d7c0fc Mon Sep 17 00:00:00 2001 From: dsweber2 Date: Wed, 11 Sep 2019 17:36:37 -0700 Subject: [PATCH 11/15] recursive way of doing activations --- src/layers/basic.jl | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/src/layers/basic.jl b/src/layers/basic.jl index c3783567fa..b92bc919ea 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -50,16 +50,17 @@ end Calculate the forward results of each layers in Chain `c` with `input` as model input. """ function activations(c::Chain, input) - res = Zygote.Buffer([], length(c)) - if length(c) > 0 - res[1] = c[1](input) - for (i,l) in enumerate(c[2:end]) - res[i+1] = l(res[i]) - end - end - return copy(res) + extraChain(c.layers, input) end +function extraChain(fs::Tuple, x) + res = first(fs)(x) + return (res, extraChain(Base.tail(fs), res)...) +end + +extraChain(::Tuple{}, x) = [] + + """ Dense(in::Integer, out::Integer, σ = identity) From db92b0e3ce3d5cb06a11b6cf77e74e1e0d56b2f1 Mon Sep 17 00:00:00 2001 From: dsweber2 Date: Tue, 8 Oct 2019 23:04:31 -0700 Subject: [PATCH 12/15] super simple test --- test/layers/basic.jl | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/test/layers/basic.jl b/test/layers/basic.jl index cbe250fcca..4edfecc777 100644 --- a/test/layers/basic.jl +++ b/test/layers/basic.jl @@ -19,6 +19,12 @@ import Flux: activations # numeric test should be put into testset of corresponding layer end + @testset "Activations" begin + c = Chain(Dense(3,5,relu), Dense(5,1,relu)) + X = Float32.([1.0; 1.0; 1.0]) + @test_nowarn gradient(()->Flux.activations(c, X)[2][1], params(c)) + end + @testset "Dense" begin @test length(Dense(10, 5)(randn(10))) == 5 @test_throws DimensionMismatch Dense(10, 5)(randn(1)) From 0fe3ac4e770de17a46d37809238a6deae06f98a3 Mon Sep 17 00:00:00 2001 From: dsweber2 Date: Tue, 8 Oct 2019 23:05:22 -0700 Subject: [PATCH 13/15] bring activations into function call --- src/layers/basic.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/layers/basic.jl b/src/layers/basic.jl index b92bc919ea..db4914246f 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -31,6 +31,8 @@ applychain(fs::Tuple, x) = applychain(tail(fs), first(fs)(x)) (c::Chain)(x) = applychain(c.layers, x) +(c::Chain)(x, i) = extraChain(c.layers, x)[i] + Base.getindex(c::Chain, i::AbstractArray) = Chain(c.layers[i]...) function Base.show(io::IO, c::Chain) From 58c794702d030b61a3744f1a180e9ab65113682b Mon Sep 17 00:00:00 2001 From: dsweber2 Date: Thu, 14 Nov 2019 14:05:53 -0800 Subject: [PATCH 14/15] simpler test --- src/layers/basic.jl | 4 ++-- test/layers/basic.jl | 12 +++++++----- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/src/layers/basic.jl b/src/layers/basic.jl index db4914246f..75f18e3c6a 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -31,7 +31,7 @@ applychain(fs::Tuple, x) = applychain(tail(fs), first(fs)(x)) (c::Chain)(x) = applychain(c.layers, x) -(c::Chain)(x, i) = extraChain(c.layers, x)[i] +(c::Chain)(x) = extraChain(c.layers, x) Base.getindex(c::Chain, i::AbstractArray) = Chain(c.layers[i]...) @@ -60,7 +60,7 @@ function extraChain(fs::Tuple, x) return (res, extraChain(Base.tail(fs), res)...) end -extraChain(::Tuple{}, x) = [] +extraChain(::Tuple{}, x) = () diff --git a/test/layers/basic.jl b/test/layers/basic.jl index 4edfecc777..0ff1776db8 100644 --- a/test/layers/basic.jl +++ b/test/layers/basic.jl @@ -4,11 +4,13 @@ import Flux: activations @testset "basic" begin @testset "helpers" begin @testset "activations" begin - dummy_model = Chain(Dense(10,5,σ),Dense(5,2),softmax) - x = rand(10) - @test activations(Chain(), x) == [] - @test activations(dummy_model, x)[1] == dummy_model[1](x) - @test activations(dummy_model, x)[2] == x |> dummy_model[1] |> dummy_model[2] + dummy_model = Chain(x->x.^2, x->x .- 3, x -> tan.(x)) + x = randn(10) + @test activations(dummy_model, x)[1] == x.^2 + @test activations(dummy_model, x)[2] == (x.^2 .- 3) + @test activations(dummy_model, x)[3] == tan.(x.^2 .- 3) + + @test activations(Chain(), x) == () @test activations(Chain(identity, x->:foo), x)[2] == :foo # results include `Any` type end end From 20eb840882752228a49130aed0712da389f6db1a Mon Sep 17 00:00:00 2001 From: dsweber2 Date: Fri, 15 Nov 2019 12:03:08 -0800 Subject: [PATCH 15/15] keeping activations separate --- src/layers/basic.jl | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/layers/basic.jl b/src/layers/basic.jl index 75f18e3c6a..2a46520818 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -31,8 +31,6 @@ applychain(fs::Tuple, x) = applychain(tail(fs), first(fs)(x)) (c::Chain)(x) = applychain(c.layers, x) -(c::Chain)(x) = extraChain(c.layers, x) - Base.getindex(c::Chain, i::AbstractArray) = Chain(c.layers[i]...) function Base.show(io::IO, c::Chain)