Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding a NamedTupleVariate #1762

Open
sethaxen opened this issue Sep 2, 2023 · 3 comments
Open

Adding a NamedTupleVariate #1762

sethaxen opened this issue Sep 2, 2023 · 3 comments

Comments

@sethaxen
Copy link
Contributor

sethaxen commented Sep 2, 2023

It could be useful to add a NamedTupleVariate with necessary defaults to this package. A concrete use case is that one may want a product distribution with easy access to the individual components.

Here's a barebones implementation:

using Distributions, Random

abstract type NamedTupleVariate <: VariateForm end

struct NamedTupleProductDistribution{Tnames,Tdists,eltypes,S<:ValueSupport} <:
       Distribution{NamedTupleVariate,S}
    dists::NamedTuple{Tnames,Tdists}
end
function NamedTupleProductDistribution(
    dists::NamedTuple{K,V}
) where {K,V<:Tuple{Vararg{Distribution}}}
    eltypes = Tuple{map(eltype, values(dists))...}
    # would be better to allow mixed ValueSupports here
    vs = Distributions._product_valuesupport(dists)
    return NamedTupleProductDistribution{K,V,eltypes,vs}(dists)
end

function Distributions.product_distribution(
    dists::NamedTuple{K,V}
) where {K,V<:Tuple{Vararg{Distribution}}}
    return NamedTupleProductDistribution(dists)
end

function Distributions.eltype(::NamedTupleProductDistribution{K,<:Any,V}) where {K,V}
    return NamedTuple{K,V}
end

function Distributions.insupport(
    dist::NamedTupleProductDistribution{K}, x::NamedTuple{K}
) where {K}
    return all(Base.splat(insupport), zip(dist.dists, x))
end

function Distributions.pdf(
    dist::NamedTupleProductDistribution{K}, x::NamedTuple{K}
) where {K}
    return exp(logpdf(dist, x))
end
function Distributions.logpdf(
    dist::NamedTupleProductDistribution{K}, x::NamedTuple{K}
) where {K}
    return mapreduce(logpdf, +, dist.dists, x)
end

function Distributions.rand(
    rng::AbstractRNG, dist::NamedTupleProductDistribution{K}
) where {K}
    return NamedTuple{K}(map(Base.Fix1(rand, rng), dist.dists))
end

and example usage:

julia> dist = product_distribution((; x=Normal(), y=Dirichlet(3, 1)));

julia> x = rand(dist)
(x = 0.35877545055901744, y = [0.5772411031132897, 0.2856859832919011, 0.13707291359480941])

julia> insupport(dist, x)
true

julia> pdf(dist, x)
0.7481503903444227
@aplavin
Copy link
Contributor

aplavin commented Sep 2, 2023

See https://github.com/invenia/KeyedDistributions.jl for an array-based take on this: it builds on keyed arrays.
But yeah Base NamedTuples would also be useful.

@Red-Portal
Copy link

Hi I also think this would be a really useful idea. Any plans for this to make it forward?

@sethaxen
Copy link
Contributor Author

Hi I also think this would be a really useful idea. Any plans for this to make it forward?

I had a local draft I just pushed to #1803. Still needs tests and some clarification about eltypes, but it works.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants