diff --git a/src/grad.jl b/src/grad.jl index 6216cba4..c3ddc6a4 100644 --- a/src/grad.jl +++ b/src/grad.jl @@ -92,13 +92,16 @@ function to_vec(x::T) where {T<:LinearAlgebra.AbstractTriangular} return x_vec, x_vec->T(reshape(back(x_vec), size(x))) end to_vec(x::Symmetric) = vec(Matrix(x)), x_vec->Symmetric(reshape(x_vec, size(x))) -to_vec(X::Diagonal) = vec(Matrix(X)), x_vec->Diagonal(reshape(x_vec, size(X)...)) +function to_vec(X::Diagonal) + diag_vec, from_diag = to_vec(X.diag) + return diag_vec, x_vec->Diagonal(from_diag(x_vec)) +end # Non-array data structures. function to_vec(x::Tuple) x_vecs, x_backs = zip(map(to_vec, x)...) sz = cumsum([map(length, x_vecs)...]) return vcat(x_vecs...), function(v) - return ntuple(n->x_backs[n](v[sz[n]-length(x[n])+1:sz[n]]), length(x)) + return ntuple(n->x_backs[n](v[sz[n]-length(x_vecs[n])+1:sz[n]]), length(x)) end end diff --git a/test/grad.jl b/test/grad.jl index ae7c8d73..f15fefe1 100644 --- a/test/grad.jl +++ b/test/grad.jl @@ -54,6 +54,7 @@ using FDM: grad, jacobian, _jvp, _j′vp, jvp, j′vp, to_vec test_to_vec((randn(4), randn(4, 3, 2), 1)) test_to_vec((5, randn(4, 3, 2), UpperTriangular(randn(4, 4)), 2.5)) test_to_vec(((6, 5), 3, randn(3, 2, 0, 1))) + test_to_vec((Diagonal(randn(7)),)) end end