From 62ce0e26f38366d318f4c53a242e7cb78c2366dc Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Thu, 6 Jun 2019 11:43:06 +0100 Subject: [PATCH] Fix issue with tuples --- src/grad.jl | 2 +- test/grad.jl | 17 +++++++++++++++++ 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/src/grad.jl b/src/grad.jl index a18f2084..5843dd5d 100644 --- a/src/grad.jl +++ b/src/grad.jl @@ -104,6 +104,6 @@ function to_vec(x::Tuple) x_vecs, x_backs = zip(map(to_vec, x)...) sz = cumsum([map(length, x_vecs)...]) return vcat(x_vecs...), function(v) - return ntuple(n->x_backs[n](v[sz[n]-length(x[n])+1:sz[n]]), length(x)) + return ntuple(n->x_backs[n](v[sz[n]-length(x_vecs[n])+1:sz[n]]), length(x)) end end diff --git a/test/grad.jl b/test/grad.jl index 7379e7a4..a798bc84 100644 --- a/test/grad.jl +++ b/test/grad.jl @@ -1,5 +1,19 @@ using FDM: grad, jacobian, _jvp, _j′vp, jvp, j′vp, to_vec +# Dummy type where length(x::DummyType) ≠ length(first(to_vec(x))) +struct DummyType{TX<:Matrix} + X::TX +end + +function FDM.to_vec(x::DummyType) + x_vec, back = to_vec(x.X) + return x_vec, x_vec -> DummyType(back(x_vec)) +end + +Base.:(==)(x::DummyType, y::DummyType) = x.X == y.X +Base.length(x::DummyType) = size(x.X, 1) + + @testset "grad" begin @testset "grad" begin @@ -47,6 +61,7 @@ using FDM: grad, jacobian, _jvp, _j′vp, jvp, j′vp, to_vec test_to_vec(UpperTriangular(randn(13, 13))) test_to_vec(Symmetric(randn(11, 11))) test_to_vec(Diagonal(randn(7))) + test_to_vec(DummyType(randn(2, 9))) @testset "$T" for T in (Adjoint, Transpose) test_to_vec(T(randn(4, 4))) @@ -60,6 +75,8 @@ using FDM: grad, jacobian, _jvp, _j′vp, jvp, j′vp, to_vec test_to_vec((randn(4), randn(4, 3, 2), 1)) test_to_vec((5, randn(4, 3, 2), UpperTriangular(randn(4, 4)), 2.5)) test_to_vec(((6, 5), 3, randn(3, 2, 0, 1))) + test_to_vec((DummyType(randn(2, 7)), DummyType(randn(3, 9)))) + test_to_vec((DummyType(randn(3, 2)), randn(11, 8))) end end