-
Notifications
You must be signed in to change notification settings - Fork 231
Closed
Description
Drawing a variable from a Dirichlet distribution introduces a type instability, slowing down sampling dramatically (around 25x on my laptop):
using Turing
@model MarginalizedGMM(x, K, ::Type{T}=Vector{Float64}) where {T} = begin
N = length(x)
μ = T(undef, K)
σ = T(undef, K)
for i in 1:K
μ[i] ~ Normal(0, 5)
σ[i] ~ Gamma()
end
w ~ Dirichlet(K, 1.0)
# w = T([0.75, 0.25]) Way faster with this line instead of ↑
for i in 1:N
x[i] ~ Distributions.UnivariateGMM(μ,σ, Categorical(w))
end
return (μ::T, σ::T, w::T)
end
x = [randn(150) .- 2; randn(50) .+ 2]
gmm = MarginalizedGMM(x, 2)
varinfo = Turing.VarInfo(gmm)
spl = Turing.SampleFromPrior()
@code_warntype gmm.f(varinfo, spl, Turing.DefaultContext(), gmm)Metadata
Metadata
Assignees
Labels
No labels