Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "FiniteDifferences"
uuid = "26cc04aa-876d-5657-8c51-4c34ba976000"
version = "0.12.0"
version = "0.12.1"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down
20 changes: 19 additions & 1 deletion src/to_vec.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,25 @@ end
# Base case -- if x is already a Vector{<:Real} there's no conversion necessary.
to_vec(x::Vector{<:Real}) = (x, identity)

# Fallback method for `to_vec`. Won't always do what you wanted, but should be fine a decent
# chunk of the time.
function to_vec(x::T) where {T}
Base.isstructtype(T) || throw(error("Expected a struct type"))
isempty(fieldnames(T)) && return (Bool[], _ -> x) # Singleton types

val_vecs_and_backs = map(name -> to_vec(getfield(x, name)), fieldnames(T))
vals = first.(val_vecs_and_backs)
backs = last.(val_vecs_and_backs)

v, vals_from_vec = to_vec(vals)
function structtype_from_vec(v::Vector{<:Real})
val_vecs = vals_from_vec(v)
vals = map((b, v) -> b(v), backs, val_vecs)
return T(vals...)
end
return v, structtype_from_vec
end

function to_vec(x::AbstractVector)
x_vecs_and_backs = map(to_vec, x)
x_vecs, backs = first.(x_vecs_and_backs), last.(x_vecs_and_backs)
Expand Down Expand Up @@ -169,7 +188,6 @@ function FiniteDifferences.to_vec(x::Composite{P}) where{P}
return x_vec, Composite_from_vec
end


function FiniteDifferences.to_vec(x::AbstractZero)
function AbstractZero_from_vec(x_vec::Vector)
return x
Expand Down
11 changes: 11 additions & 0 deletions test/to_vec.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,12 @@ struct ThreeFields
c
end

# For testing nested fallback for structs
struct Singleton end
struct Nested
x::ThreeFields
y::Singleton
end

Base.size(x::FillVector) = (x.len,)
Base.getindex(x::FillVector, n::Int) = x.x
Expand Down Expand Up @@ -168,4 +174,9 @@ end
x_vec, from_vec = to_vec(x)
@test_throws MethodError from_vec(randn(10))
end

@testset "fallback" begin
nested = Nested(ThreeFields(1.0, 2.0, "Three"), Singleton())
test_to_vec(nested; check_inferred=false) # map
end
end