-
Notifications
You must be signed in to change notification settings - Fork 432
Open
Description
I needed a Product distribution with mixed continuous/discrete support for a project and thus implemented the types
struct Mixed <: ValueSupport
TupleProduct <: MultivariateDistributionWould you be interested in a PR with (parts of) this implementation, possibly with a different name?
Some code and benchmarks below
julia> dt = TupleProduct((Normal(0,2), Normal(0,2), Binomial())) # Mixed value support
TupleProduct{3,LowLevelParticleFilters.Mixed,Tuple{Normal{Float64},Normal{Float64},Binomial{Float64}}}(
v: (Normal{Float64}(μ=0.0, σ=2.0), Normal{Float64}(μ=0.0, σ=2.0), Binomial{Float64}(n=1, p=0.5))
)A small benchmark
The package where I've implemented this is called LowLevelPartricleFilters, which also defines some methods for distributions and static arrays. If this package is loaded we have the following timings
using BenchmarkTools, Distributions, LowLevelParticleFilters
sv = @SVector randn(2)
d = Product([Normal(0,2), Normal(0,2)])
dt = TupleProduct((Normal(0,2), Normal(0,2)))
dm = MvNormal(2, 2)
@btime logpdf($d,$(Vector(sv))) # 32.449 ns (1 allocation: 32 bytes)
@btime logpdf($dt,$(Vector(sv))) # 21.141 ns (0 allocations: 0 bytes)
@btime logpdf($dm,$(Vector(sv))) # 48.745 ns (1 allocation: 96 bytes)
@btime logpdf($d,$sv) # 22.651 ns (0 allocations: 0 bytes)
@btime logpdf($dt,$sv) # 0.021 ns (0 allocations: 0 bytes)
@btime logpdf($dm,$sv) # 0.021 ns (0 allocations: 0 bytes)If LowLevelPartricleFilters and the special static methods are not loaded, we have identical timings for SVector and Vector
@btime logpdf($d,$sv) # 32.621 ns (1 allocation: 32 bytes)
@btime logpdf($dm,$sv) # 46.415 ns (1 allocation: 96 bytes)Implementation
Key points, the rest of the implementation is here
struct Mixed <: ValueSupport end
struct TupleProduct{N,S,V<:NTuple{N,UnivariateDistribution}} <: MultivariateDistribution{S}
v::V
function TupleProduct(v::V) where {N,V<:NTuple{N,UnivariateDistribution}}
all(Distributions.value_support(typeof(d)) == Discrete for d in v) &&
return new{N,Discrete,V}(v)
all(Distributions.value_support(typeof(d)) == Continuous for d in v) &&
return new{N,Continuous,V}(v)
return new{N,Mixed,V}(v)
end
end
@generated function Distributions._logpdf(d::TupleProduct{N}, x::AbstractVector{<:Real}) where N
:(Base.Cartesian.@ncall $N Base.:+ i->logpdf(d.v[i], x[i]))
endbhgomes and lassepe
Metadata
Metadata
Assignees
Labels
No labels