Skip to content

Mixed <: ValueSupport and Product with tuple storage #906

@baggepinnen

Description

@baggepinnen

I needed a Product distribution with mixed continuous/discrete support for a project and thus implemented the types

struct Mixed <: ValueSupport
TupleProduct <: MultivariateDistribution

Would you be interested in a PR with (parts of) this implementation, possibly with a different name?

Some code and benchmarks below

julia> dt = TupleProduct((Normal(0,2), Normal(0,2), Binomial())) # Mixed value support
TupleProduct{3,LowLevelParticleFilters.Mixed,Tuple{Normal{Float64},Normal{Float64},Binomial{Float64}}}(
v: (Normal{Float64}=0.0, σ=2.0), Normal{Float64}=0.0, σ=2.0), Binomial{Float64}(n=1, p=0.5))
)

A small benchmark

The package where I've implemented this is called LowLevelPartricleFilters, which also defines some methods for distributions and static arrays. If this package is loaded we have the following timings

using BenchmarkTools, Distributions, LowLevelParticleFilters
sv = @SVector randn(2)
d = Product([Normal(0,2), Normal(0,2)])
dt = TupleProduct((Normal(0,2), Normal(0,2)))
dm = MvNormal(2, 2)
@btime logpdf($d,$(Vector(sv))) # 32.449 ns (1 allocation: 32 bytes)
@btime logpdf($dt,$(Vector(sv))) # 21.141 ns (0 allocations: 0 bytes)
@btime logpdf($dm,$(Vector(sv))) # 48.745 ns (1 allocation: 96 bytes)

@btime logpdf($d,$sv) # 22.651 ns (0 allocations: 0 bytes)
@btime logpdf($dt,$sv) # 0.021 ns (0 allocations: 0 bytes)
@btime logpdf($dm,$sv) # 0.021 ns (0 allocations: 0 bytes)

If LowLevelPartricleFilters and the special static methods are not loaded, we have identical timings for SVector and Vector

@btime logpdf($d,$sv) # 32.621 ns (1 allocation: 32 bytes)
@btime logpdf($dm,$sv) # 46.415 ns (1 allocation: 96 bytes)

Implementation

Key points, the rest of the implementation is here

struct Mixed <: ValueSupport end

struct TupleProduct{N,S,V<:NTuple{N,UnivariateDistribution}} <: MultivariateDistribution{S}
    v::V
    function TupleProduct(v::V) where {N,V<:NTuple{N,UnivariateDistribution}}
        all(Distributions.value_support(typeof(d)) == Discrete for d in v) &&
            return new{N,Discrete,V}(v)
        all(Distributions.value_support(typeof(d)) == Continuous for d in v) &&
            return new{N,Continuous,V}(v)
        return new{N,Mixed,V}(v)
    end
end

@generated function Distributions._logpdf(d::TupleProduct{N}, x::AbstractVector{<:Real}) where N
    :(Base.Cartesian.@ncall $N Base.:+ i->logpdf(d.v[i], x[i]))
end

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions