-
Notifications
You must be signed in to change notification settings - Fork 416
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Weighted MLE fit for Laplace distribution with unit test #1310
Conversation
function fit_mle(::Type{<:Laplace}, x::AbstractArray{T}, w::AbstractArray{T}) where {T <: Real} | ||
sp = sortperm(x) | ||
n = length(x) | ||
sw = sum(w) | ||
highsum = sw | ||
lowsum = zero(T) | ||
idx = 0 | ||
for i = 1:n | ||
lowsum += w[sp[i]] | ||
highsum -= w[sp[i]] | ||
if lowsum >= highsum | ||
idx = sp[i] | ||
break | ||
end | ||
end | ||
μ = x[idx] | ||
θ = zero(T) | ||
for i = 1:length(x) | ||
θ += w[i] * abs(x[i] - μ) | ||
end | ||
θ /= sw | ||
return Laplace(μ, θ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would suggest to use AbstractWeights
(from StatsBase) since this allows to use different types of weights and existing optimized algorithms. E.g., one should then probably just use median(x, weights)
defined in StatsBase and mean(abs.(x .- m), weights)
(also defined in StatsBase). This creates more memory allocations than in the unweighted case but I think it is fine since the code complexity (and hence the risk of introducing bugs) is much lower.
Thanks for the feedback, David. I tested your suggested implementation vs mine in the attached code. My understanding is the MLE estimate for the location parameter should minimize the weighted absolute deviation. Running on my machine, I got
d_fit = Laplace{Float64}(μ=5.409932332163653, θ=3.0460352087395486)
sb_wmed = 5.889103597724628 <-- This is the StatsBase weighted median
Looking at a plot of the weigted absolute deviation, I think my estimate is more reasonable.
***@***.***D7374A.B3A79190]
Regards,
Sam
Sent from Mail<https://go.microsoft.com/fwlink/?LinkId=550986> for Windows 10
From: David ***@***.***>
Sent: Thursday, April 22, 2021 1:52 AM
To: ***@***.***>
Cc: ***@***.***>; ***@***.***>
Subject: Re: [JuliaStats/Distributions.jl] Weighted MLE fit for Laplace distribution with unit test (#1310)
@devmotion commented on this pull request.
In src/univariate/continuous/laplace.jl<https://na01.safelinks.protection.outlook.com/?url=https%3A%2F%2Fgithub.com%2FJuliaStats%2FDistributions.jl%2Fpull%2F1310%23discussion_r618098662&data=04%7C01%7C%7Ce836a144cf11455714da08d90552d6e1%7C84df9e7fe9f640afb435aaaaaaaaaaaa%7C1%7C0%7C637546675611647390%7CUnknown%7CTWFpbGZsb3d8eyJWIjoiMC4wLjAwMDAiLCJQIjoiV2luMzIiLCJBTiI6Ik1haWwiLCJXVCI6Mn0%3D%7C1000&sdata=Zd12u7OYVhuMigjh3zq7DxlEyTyEE6cJIGmwONa4X2Y%3D&reserved=0>:
+function fit_mle(::Type{<:Laplace}, x::AbstractArray{T}, w::AbstractArray{T}) where {T <: Real}
+ sp = sortperm(x)
+ n = length(x)
+ sw = sum(w)
+ highsum = sw
+ lowsum = zero(T)
+ idx = 0
+ for i = 1:n
+ lowsum += w[sp[i]]
+ highsum -= w[sp[i]]
+ if lowsum >= highsum
+ idx = sp[i]
+ break
+ end
+ end
+ μ = x[idx]
+ θ = zero(T)
+ for i = 1:length(x)
+ θ += w[i] * abs(x[i] - μ)
+ end
+ θ /= sw
+ return Laplace(μ, θ)
I would suggest to use AbstractWeights (from StatsBase) since this allows to use different types of weights and existing optimized algorithms. E.g., one should then probably just use median(x, weights) defined in StatsBase and mean(abs.(x .- m), weights) (also defined in StatsBase). This creates more memory allocations than in the unweighted case but I think it is fine since the code complexity (and hence the risk of introducing bugs) is much lower.
—
You are receiving this because you authored the thread.
Reply to this email directly, view it on GitHub<https://na01.safelinks.protection.outlook.com/?url=https%3A%2F%2Fgithub.com%2FJuliaStats%2FDistributions.jl%2Fpull%2F1310%23pullrequestreview-641798960&data=04%7C01%7C%7Ce836a144cf11455714da08d90552d6e1%7C84df9e7fe9f640afb435aaaaaaaaaaaa%7C1%7C0%7C637546675611647390%7CUnknown%7CTWFpbGZsb3d8eyJWIjoiMC4wLjAwMDAiLCJQIjoiV2luMzIiLCJBTiI6Ik1haWwiLCJXVCI6Mn0%3D%7C1000&sdata=tlsaIudKfcLjngO4ZVDNIzwaeL4CWtOard1mkXbyPYY%3D&reserved=0>, or unsubscribe<https://na01.safelinks.protection.outlook.com/?url=https%3A%2F%2Fgithub.com%2Fnotifications%2Funsubscribe-auth%2FAFPIHY7TSFCZEWPORM6CPV3TJ62SPANCNFSM43LQT6KQ&data=04%7C01%7C%7Ce836a144cf11455714da08d90552d6e1%7C84df9e7fe9f640afb435aaaaaaaaaaaa%7C1%7C0%7C637546675611657383%7CUnknown%7CTWFpbGZsb3d8eyJWIjoiMC4wLjAwMDAiLCJQIjoiV2luMzIiLCJBTiI6Ik1haWwiLCJXVCI6Mn0%3D%7C1000&sdata=yqvKfyMzAt5Se1DZin7aZU7I3wFUJn8vlySZ65uK%2BbU%3D&reserved=0>.
|
Unfortunately, I can only see parts of your message, so it is difficult to reply to it. |
I ran the code in the attached file. It uses the weighted fit_mle from my pull request and the weighted median calculation from StatsBase. Let me know if you can run it. Thanks. |
I don't think there's a problem with the StatsBase version. The only difference is that the implementation of the weighted median in StatsBase interpolates whereas you don't interpolate (see https://juliastats.org/StatsBase.jl/latest/scalarstats/#Statistics.quantile). So for small sample sizes the estimate of the location can be slightly larger but there's nothing fundamentally wrong with it (and also the scale, i.e., the average weighted absolute deviation, is usually very similar). You can also see that for larger sample sizes the estimates converge to the same value (and, of course, also fluctuate a lot less than for only 11 samples): julia> function fit_mle2(x::AbstractArray{<:Real}, w::AbstractWeights{<:Real})
m = median(x, w)
scale = mean(abs.(x .- m), w)
return Laplace(m, scale)
end
fit_mle2 (generic function with 1 method)
julia> function compare_mle(x, w)
pr = fit_mle(Laplace, x, w)
alt = fit_mle2(x, weights(w))
return (; pr, alt)
end
compare_mle (generic function with 1 method)
julia> compare_mle(rand(Laplace(5, 3), 11), rand(11))
(pr = Laplace{Float64}(μ=6.211073416980449, θ=3.9581446545423478), alt = Laplace{Float64}(μ=6.767319600558441, θ=3.9712873816902974))
julia> compare_mle(rand(Laplace(5, 3), 11), rand(11))
(pr = Laplace{Float64}(μ=5.253611359044937, θ=2.005751951517064), alt = Laplace{Float64}(μ=4.713454315205316, θ=2.0800523494694185))
julia> compare_mle(rand(Laplace(5, 3), 100_000), rand(100_000))
(pr = Laplace{Float64}(μ=4.97717614535248, θ=2.9890249614443305), alt = Laplace{Float64}(μ=4.9768921974534, θ=2.989024961660979))
julia> compare_mle(rand(Laplace(5, 3), 100_000), rand(100_000))
(pr = Laplace{Float64}(μ=5.00979868748012, θ=2.9953653545191705), alt = Laplace{Float64}(μ=5.009795936865125, θ=2.9953653545471393)) An additional argument for using StatsBase here is that also for the unweighted estimation interpolation is performed if there is an even number of samples. |
Thanks for the careful review. I have more questions that are better for a separate PR. |
I implemented a weighted MLE for the Laplace distribution.