From bc5389b6d4470653366c79a36b184c5bab96ddce Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=5BA=1B=5BA=1B=5BA=1B=5BB=1B=5BB=1B=5BBwilltebbutt?= Date: Sun, 21 Apr 2019 14:41:14 +0100 Subject: [PATCH 1/2] Use x_vec length --- src/grad.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/grad.jl b/src/grad.jl index 6216cba4..0aca916a 100644 --- a/src/grad.jl +++ b/src/grad.jl @@ -99,6 +99,6 @@ function to_vec(x::Tuple) x_vecs, x_backs = zip(map(to_vec, x)...) sz = cumsum([map(length, x_vecs)...]) return vcat(x_vecs...), function(v) - return ntuple(n->x_backs[n](v[sz[n]-length(x[n])+1:sz[n]]), length(x)) + return ntuple(n->x_backs[n](v[sz[n]-length(x_vecs[n])+1:sz[n]]), length(x)) end end From 430c742acc68ba499ff9405f4da8fc93ee126c74 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=5BA=1B=5BA=1B=5BA=1B=5BB=1B=5BB=1B=5BBwilltebbutt?= Date: Sun, 21 Apr 2019 14:55:23 +0100 Subject: [PATCH 2/2] Improve Diagonal implementation and test Tuple implementation --- src/grad.jl | 5 ++++- test/grad.jl | 1 + 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/src/grad.jl b/src/grad.jl index 0aca916a..c3ddc6a4 100644 --- a/src/grad.jl +++ b/src/grad.jl @@ -92,7 +92,10 @@ function to_vec(x::T) where {T<:LinearAlgebra.AbstractTriangular} return x_vec, x_vec->T(reshape(back(x_vec), size(x))) end to_vec(x::Symmetric) = vec(Matrix(x)), x_vec->Symmetric(reshape(x_vec, size(x))) -to_vec(X::Diagonal) = vec(Matrix(X)), x_vec->Diagonal(reshape(x_vec, size(X)...)) +function to_vec(X::Diagonal) + diag_vec, from_diag = to_vec(X.diag) + return diag_vec, x_vec->Diagonal(from_diag(x_vec)) +end # Non-array data structures. function to_vec(x::Tuple) diff --git a/test/grad.jl b/test/grad.jl index ae7c8d73..f15fefe1 100644 --- a/test/grad.jl +++ b/test/grad.jl @@ -54,6 +54,7 @@ using FDM: grad, jacobian, _jvp, _j′vp, jvp, j′vp, to_vec test_to_vec((randn(4), randn(4, 3, 2), 1)) test_to_vec((5, randn(4, 3, 2), UpperTriangular(randn(4, 4)), 2.5)) test_to_vec(((6, 5), 3, randn(3, 2, 0, 1))) + test_to_vec((Diagonal(randn(7)),)) end end