diff --git a/docs/Project.toml b/docs/Project.toml index 3d6a9ee4e..fbb4973b6 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -1,4 +1,5 @@ [deps] +Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" GR = "28b8d3ca-fb5f-59d9-8090-bfdbd6d07a71" diff --git a/docs/src/fit.md b/docs/src/fit.md index 24ddbad5c..b04341667 100644 --- a/docs/src/fit.md +++ b/docs/src/fit.md @@ -74,6 +74,23 @@ fit_mle(Categorical, x) # equivalent to fit_mle(Categorical, max(x), x) fit_mle(Categorical, x, w) ``` +It is also possible to directly input a distribution `fit_mle(d::Distribution, x[, w])`. This form avoids the extra arguments: + +```julia +fit_mle(Binomial(n, 0.1), x) +# equivalent to fit_mle(Binomial, ntrials(Binomial(n, 0.1)), x), here the parameter 0.1 is not used + +fit_mle(Categorical(p), x) +# equivalent to fit_mle(Categorical, ncategories(Categorical(p)), x), here the only the length of p is used not its values + +d = product_distribution([Exponential(0.5), Normal(11.3, 3.2)]) +fit_mle(d, x) +# equivalent to product_distribution([fit_mle(Exponential, x[1,:]), fit_mle(Normal, x[2, :])]). Again parameters of d are not used. +``` + +Note that for standard distributions, the values of the distribution parameters `d` are not used in `fit_mle` only the “structure” of `d` is passed into `fit_mle`. +However, for complex Maximum Likelihood estimation requiring optimization, e.g., EM algorithm, one could use `D` as an initial guess. + ## Sufficient Statistics For many distributions, the estimation can be based on (sum of) sufficient statistics computed from a dataset. To simplify implementation, for such distributions, we implement `suffstats` method instead of `fit_mle` directly: diff --git a/docs/src/mixture.md b/docs/src/mixture.md index e6b24b103..2ae7ab904 100644 --- a/docs/src/mixture.md +++ b/docs/src/mixture.md @@ -106,4 +106,4 @@ rand!(::AbstractMixtureModel, ::AbstractArray) ## Estimation There are several methods for the estimation of mixture models from data, and this problem remains an open research topic. -This package does not provide facilities for estimating mixture models. One can resort to other packages, *e.g.* [*GaussianMixtures.jl*](https://github.com/davidavdav/GaussianMixtures.jl), for this purpose. +This package does not provide facilities for estimating mixture models. One can resort to other packages, *e.g.* [*GaussianMixtures.jl*](https://github.com/davidavdav/GaussianMixtures.jl) or [*ExpectationMaximization.jl*](https://github.com/dmetivie/ExpectationMaximization.jl), for this purpose. diff --git a/src/genericfit.jl b/src/genericfit.jl index 2b4b5e67e..f5cb7c72b 100644 --- a/src/genericfit.jl +++ b/src/genericfit.jl @@ -30,5 +30,7 @@ fit_mle(dt::Type{D}, x::AbstractArray, w::AbstractArray) where {D<:UnivariateDis fit_mle(dt::Type{D}, x::AbstractMatrix) where {D<:MultivariateDistribution} = fit_mle(D, suffstats(D, x)) fit_mle(dt::Type{D}, x::AbstractMatrix, w::AbstractArray) where {D<:MultivariateDistribution} = fit_mle(D, suffstats(D, x, w)) +fit_mle(dist::Distribution, args...) = fit_mle(typeof(dist), args...) + fit(dt::Type{D}, x) where {D<:Distribution} = fit_mle(D, x) fit(dt::Type{D}, args...) where {D<:Distribution} = fit_mle(D, args...) diff --git a/src/multivariate/product.jl b/src/multivariate/product.jl index 3bb0f0d9e..32366fed3 100644 --- a/src/multivariate/product.jl +++ b/src/multivariate/product.jl @@ -36,6 +36,8 @@ function Base.eltype(::Type{<:Product{S,T}}) where {S<:ValueSupport, return eltype(T) end +params(g::Product) = params.(g.v) + _rand!(rng::AbstractRNG, d::Product, x::AbstractVector{<:Real}) = map!(Base.Fix1(rand, rng), x, d.v) function _logpdf(d::Product, x::AbstractVector{<:Real}) @@ -60,3 +62,17 @@ maximum(d::Product) = map(maximum, d.v) function product_distribution(dists::V) where {S<:ValueSupport,T<:UnivariateDistribution{S},V<:AbstractVector{T}} return Product{S,T,V}(dists) end + +#### Fitting + +""" + fit_mle(g::Product, x::AbstractMatrix) + fit_mle(g::Product, x::AbstractMatrix, γ::AbstractVector) + +The `fit_mle` for a multivariate Product distributions `g` is the `product_distribution` of `fit_mle` of each components of `g`. +""" +function fit_mle(g::Product, x::AbstractMatrix, args...) + d = size(x, 1) + length(g) == d || throw(DimensionMismatch("The dimensions of g and x are inconsistent.")) + return product_distribution([fit_mle(g.v[s], y, args...) for (s, y) in enumerate(eachrow(x))]) +end diff --git a/src/product.jl b/src/product.jl index 7a4904ae7..e862abc32 100644 --- a/src/product.jl +++ b/src/product.jl @@ -74,6 +74,8 @@ end size(d::ProductDistribution) = d.size +params(d::ArrayOfUnivariateDistribution) = params.(d.dists) + mean(d::ProductDistribution) = reshape(mapreduce(vec ∘ mean, vcat, d.dists), size(d)) var(d::ProductDistribution) = reshape(mapreduce(vec ∘ var, vcat, d.dists), size(d)) cov(d::ProductDistribution) = Diagonal(vec(var(d))) @@ -244,3 +246,22 @@ function product_distribution(dists::AbstractVector{<:Normal}) σ2 = map(var, dists) return MvNormal(µ, Diagonal(σ2)) end + +#### Fitting +promote_sample(::Type{dT}, x::AbstractArray{T}) where {T<:Real, dT<:Real} = T <: dT ? x : convert.(dT, x) + +""" + fit_mle(dists::ArrayOfUnivariateDistribution, x::AbstractArray) + fit_mle(dists::ArrayOfUnivariateDistribution, x::AbstractArray, γ::AbstractVector) + +The `fit_mle` for a `ArrayOfUnivariateDistribution` distributions `dists` is the `product_distribution` of `fit_mle` of each components of `dists`. +""" +function fit_mle(dists::VectorOfUnivariateDistribution, x::AbstractMatrix{<:Real}, args...) + length(dists) == size(x, 1) || throw(DimensionMismatch("The dimensions of dists and x are inconsistent.")) + return product_distribution([fit_mle(d, promote_sample(eltype(d), x[s, :]), args...) for (s, d) in enumerate(dists.dists)]) +end + +function fit_mle(dists::ArrayOfUnivariateDistribution, x::AbstractArray, args...) + size(dists) == size(first(x)) || throw(DimensionMismatch("The dimensions of dists and x are inconsistent.")) + return product_distribution([fit_mle(d, promote_sample(eltype(d), [x[i][s] for i in eachindex(x)]), args...) for (s, d) in enumerate(dists.dists)]) +end \ No newline at end of file diff --git a/src/univariate/discrete/binomial.jl b/src/univariate/discrete/binomial.jl index f4102cbb8..3aa24545c 100644 --- a/src/univariate/discrete/binomial.jl +++ b/src/univariate/discrete/binomial.jl @@ -196,6 +196,7 @@ suffstats(::Type{T}, data::BinomData, w::AbstractArray{<:Real}) where {T<:Binomi fit_mle(::Type{T}, ss::BinomialStats) where {T<:Binomial} = T(ss.n, ss.ns / (ss.ne * ss.n)) +fit_mle(d::T, x::AbstractArray{<:Integer}) where {T<:Binomial} = fit_mle(T, suffstats(T, ntrials(d), x)) fit_mle(::Type{T}, n::Integer, x::AbstractArray{<:Integer}) where {T<:Binomial}= fit_mle(T, suffstats(T, n, x)) fit_mle(::Type{T}, n::Integer, x::AbstractArray{<:Integer}, w::AbstractArray{<:Real}) where {T<:Binomial} = fit_mle(T, suffstats(T, n, x, w)) fit_mle(::Type{T}, data::BinomData) where {T<:Binomial} = fit_mle(T, suffstats(T, data)) diff --git a/src/univariate/discrete/categorical.jl b/src/univariate/discrete/categorical.jl index 1dc08960c..a49810d45 100644 --- a/src/univariate/discrete/categorical.jl +++ b/src/univariate/discrete/categorical.jl @@ -156,6 +156,7 @@ function fit_mle(::Type{<:Categorical}, ss::CategoricalStats) Categorical(normalize!(ss.h, 1)) end +fit_mle(d::T, x::AbstractArray{<:Integer}) where {T<:Categorical} = fit_mle(T, ncategories(d), x) function fit_mle(::Type{<:Categorical}, k::Integer, x::AbstractArray{T}) where T<:Integer Categorical(normalize!(add_categorical_counts!(zeros(k), x), 1), check_args=false) end diff --git a/test/fit.jl b/test/fit.jl index 4483dd1ec..d587d0d87 100644 --- a/test/fit.jl +++ b/test/fit.jl @@ -115,6 +115,11 @@ end @test d isa D @test ntrials(d) == 100 @test succprob(d) ≈ 0.3 atol=0.01 + + d2 = @inferred fit_mle(D(100, 0.5), x) + @test d2 isa D + @test ntrials(d2) == 100 + @test succprob(d2) ≈ 0.3 atol = 0.01 end end end @@ -141,6 +146,10 @@ end @test isa(d2, Categorical) @test probs(d2) == probs(d) + d3 = fit_mle(Categorical(p), x) + @test isa(d3, Categorical) + @test probs(d3) == probs(d) + ss = suffstats(Categorical, (3, x), w) h = Float64[sum(w[x .== i]) for i = 1 : 3] @test isa(ss, Distributions.CategoricalStats) @@ -414,6 +423,23 @@ end end end +@testset "Testing fit_mle for ProductDistribution" begin + dists = [product_distribution([Exponential(0.5), Normal(11.3, 3.2)]), product_distribution([Exponential(0.5) Normal(11.3, 3.2) + Bernoulli(0.2) Normal{Float32}(0f0,0f0)])] + for func in funcs, dist in dists + x = rand(dist, N) + w = func[1](N) + + d = fit_mle(dist, x) + @test isa(d, typeof(dist)) + @test isapprox(collect.(params(d)), collect.(params(dist)), atol=0.1) + + d = fit_mle(dist, x, w) + @test isa(d, typeof(dist)) + @test isapprox(collect.(params(d)), collect.(params(dist)), atol=0.1) + end +end + @testset "Testing fit for InverseGaussian" begin for func in funcs, dist in (InverseGaussian, InverseGaussian{Float64}) x = rand(dist(3.9, 2.1), n0)