From 1818e6c9180ae49ba87ca38251bc24fbb9cbc049 Mon Sep 17 00:00:00 2001 From: Chris Rackauckas Date: Wed, 12 Aug 2020 02:31:15 -0400 Subject: [PATCH 1/3] add some more adjoints and test adjoints Fixes https://github.com/SciML/RecursiveArrayTools.jl/issues/111 --- Project.toml | 8 +++++--- src/zygote.jl | 8 ++++++++ test/adjoints.jl | 39 +++++++++++++++++++++++++++++++++++++++ test/runtests.jl | 1 + 4 files changed, 53 insertions(+), 3 deletions(-) create mode 100644 test/adjoints.jl diff --git a/Project.toml b/Project.toml index e7a995dc..46f8d36e 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "RecursiveArrayTools" uuid = "731186ca-8d62-57ce-b412-fbd966d074cd" authors = ["Chris Rackauckas "] -version = "2.5.0" +version = "2.6.0" [deps] ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" @@ -21,11 +21,13 @@ ZygoteRules = "0.2" julia = "1.3" [extras] +ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" NLsolve = "2774e3e8-f4cf-5e23-947b-6d7e65073b56" OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d" -Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = ["NLsolve", "OrdinaryDiffEq", "Test", "Unitful", "Random"] +test = ["ForwardDiff", "NLsolve", "OrdinaryDiffEq", "Test", "Unitful", "Random", "Zygote"] diff --git a/src/zygote.jl b/src/zygote.jl index 7ec77d3c..6132263b 100644 --- a/src/zygote.jl +++ b/src/zygote.jl @@ -23,3 +23,11 @@ ZygoteRules.@adjoint function ArrayPartition(x...) end ArrayPartition(x...),ArrayPartition_adjoint end + +ZygoteRules.@adjoint function VectorOfArray(u) + VectorOfArray(u),y -> ([y[ntuple(x->Colon(),ndims(y)-1)...,i] for i in 1:size(y)[end]],) +end + +ZygoteRules.@adjoint function DiffEqArray(u,t) + DiffEqArray(u,t),y -> ([y[ntuple(x->Colon(),ndims(y)-1)...,i] for i in 1:size(y)[end]],nothing) +end diff --git a/test/adjoints.jl b/test/adjoints.jl new file mode 100644 index 00000000..406267a9 --- /dev/null +++ b/test/adjoints.jl @@ -0,0 +1,39 @@ +using RecursiveArrayTools, Zygote, ForwardDiff, Test + +function loss(x) + sum(abs2,Array(VectorOfArray([x .* i for i in 1:5]))) +end + +function loss2(x) + sum(abs2,Array(DiffEqArray([x .* i for i in 1:5],1:5))) +end + +function loss3(x) + y = VectorOfArray([x .* i for i in 1:5]) + tmp = 0.0 + for i in 1:5, j in 1:5 + tmp += y[i,j] + end + tmp +end + +function loss4(x) + y = DiffEqArray([x .* i for i in 1:5],1:5) + tmp = 0.0 + for i in 1:5, j in 1:5 + tmp += y[i,j] + end + tmp +end + +function loss5(x) + sum(abs2,Array(ArrayPartition([x .* i for i in 1:5]...))) +end + +x = float.(6:10) +loss(x) +@test Zygote.gradient(loss,x)[1] == ForwardDiff.gradient(loss,x) +@test Zygote.gradient(loss2,x)[1] == ForwardDiff.gradient(loss2,x) +@test Zygote.gradient(loss3,x)[1] == ForwardDiff.gradient(loss3,x) +@test Zygote.gradient(loss4,x)[1] == ForwardDiff.gradient(loss4,x) +@test Zygote.gradient(loss5,x)[1] == ForwardDiff.gradient(loss5,x) diff --git a/test/runtests.jl b/test/runtests.jl index 300a84f4..f6c0de8a 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -9,4 +9,5 @@ using Test @time @testset "StaticArrays Tests" begin include("copy_static_array_test.jl") end @time @testset "Linear Algebra Tests" begin include("linalg.jl") end @time @testset "Upstream Tests" begin include("upstream.jl") end + @time @testset "Adjoint Tests" begin include("adjoints.jl") end end From 998432b25d8e724215f658c2feab664052c12f40 Mon Sep 17 00:00:00 2001 From: Chris Rackauckas Date: Wed, 12 Aug 2020 02:38:48 -0400 Subject: [PATCH 2/3] update print --- test/utils_test.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/utils_test.jl b/test/utils_test.jl index 2d4642d2..b89d498b 100644 --- a/test/utils_test.jl +++ b/test/utils_test.jl @@ -27,7 +27,7 @@ AofuSA = [@SVector [2.0u"kg",3.0u"kg"] for i in 1:5] @test recursive_unitless_eltype(AofuSA) == SVector{2,Float64} A = [ArrayPartition(ones(1),ones(1)),] -@test repr("text/plain", A) == "1-element Array{ArrayPartition{Float64,Tuple{Array{Float64,1},Array{Float64,1}}},1}:\n [1.0][1.0]" +@test repr("text/plain", A) == "1-element Array{ArrayPartition{Float64,Tuple{Array{Float64,1},Array{Float64,1}}},1}:\n ([1.0], [1.0])" function test_recursive_bottom_eltype() function test_value(val::Any, expected_type::Type) From 9a94f7a8a77683db209e56e6e5f7e301fd72ca46 Mon Sep 17 00:00:00 2001 From: Chris Rackauckas Date: Wed, 12 Aug 2020 03:20:30 -0400 Subject: [PATCH 3/3] remove print test --- test/utils_test.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/test/utils_test.jl b/test/utils_test.jl index b89d498b..15de73ee 100644 --- a/test/utils_test.jl +++ b/test/utils_test.jl @@ -27,7 +27,6 @@ AofuSA = [@SVector [2.0u"kg",3.0u"kg"] for i in 1:5] @test recursive_unitless_eltype(AofuSA) == SVector{2,Float64} A = [ArrayPartition(ones(1),ones(1)),] -@test repr("text/plain", A) == "1-element Array{ArrayPartition{Float64,Tuple{Array{Float64,1},Array{Float64,1}}},1}:\n ([1.0], [1.0])" function test_recursive_bottom_eltype() function test_value(val::Any, expected_type::Type)