Replies: 5 comments 24 replies
-
The following is basically a copy-paste from #45 (comment), as I think this covers some of the discussion.
For this one in particular we have an implementation in DynamicPPL that can potentially moved to its own package if we really want to: https://github.com/TuringLang/DynamicPPL.jl/blob/b23acff013a9111c8ce2c89dbf5339e76234d120/src/utils.jl#L434-L473 But this has a couple of issues:
(1) can be addressed by instead taking a closure-approach a la Functors.jl: function flatten(d::MvNormal{<:AbstractVector,<:Diagonal})
dim = length(d)
function MvNormal_unflatten(x)
return MvNormal(d[1:dim], Diagonal(d[dim+1:end]))
end
return vcat(d.μ, diag(d.Σ)), MvNormal_unflatten
end For (2), we have a couple of immediate options: For (a) we'd have something like: abstract type WrapperDistribution{D<:Distribution{V,F}} <: Distribution{V,F} end
# HACK: Probably shouldn't do this.
inner_dist(x::WrapperDistribution) = x.inner
# TODO: Specialize further on `x` to avoid hitting default implementations?
Distributions.logpdf(d::WrapperDistribution, x) = logpdf(d.dist, x)
# Etc.
struct MeanParameterized{D} <: WrapperDistribution{D}
inner::D
end
function flatten(d::MeanParameterized{<:MvNormal})
μ = mean(d.inner)
function MeanParameterized_MvNormal_unflatten(x)
return MeanParameterized(MvNormal(x, d.inner.Σ))
end
return μ, MeanParameterized_MvNormal_unflatten
end Pros:
For (b) we'd have something like struct MeanOnly end
function flatten(::MeanOnly, d::MvNormal)
μ = mean(d.inner)
function MvNormal_meanonly_unflatten(x)
return MeanParameterized(MvNormal(x, d.inner.Σ))
end
return μ, MvNormal_meanonly_unflatten
end Pros:
|
Beta Was this translation helpful? Give feedback.
-
I have a few additional considerations.
Questions:
|
Beta Was this translation helpful? Give feedback.
-
Stumbled upon this while reading through some interesting GSoC proposals. My main advice would be to keep the scope as narrow as possible. Trying for a very general system quite quickly leads to JuliaGaussianProcesses/ParameterHandling.jl#43, and down that road is madness 😅. |
Beta Was this translation helpful? Give feedback.
-
I think I'm leaning toward Provided that we make our simple-enough-to-maintain pre-packaged functors (again, another reason to use the location-scale abstraction IMO), I think it should be good. I also guess the flows in |
Beta Was this translation helpful? Give feedback.
-
Maybe worth adding a comment as to why this was closed @Red-Portal ? Will be useful for future reference:) |
Beta Was this translation helpful? Give feedback.
-
Overview
Flattening of parameters is an issue that has been discussed many times in the Julia community, and there are numerous attempts at addressing this:
This is also an issue we're facing here in AdvancedVI.jl, and it becomes particularly annoying when combined with Distributions.jl as we might want to work with different parameterizations of a given distribution, etc.
So the issue is:
How do we "flatten" nested structs, etc. into a something we can give AD-frameworks, i.e.
AbstractVector{<:Real}
, and allow specification of exactly which parameters are considered "learnable"?Beta Was this translation helpful? Give feedback.
All reactions