-
Notifications
You must be signed in to change notification settings - Fork 410
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
12 changed files
with
314 additions
and
12 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
# Uniform distribution on the n-ball in R^n | ||
# | ||
# The implementation here follows: | ||
# | ||
# - Wikipedia: | ||
# https://en.wikipedia.org/wiki/N-sphere | ||
|
||
struct UniformBall{T<:Real} <: ContinuousMultivariateDistribution | ||
n::Int | ||
logV::T # log normalization constant | ||
|
||
function UniformBall{T}(n::Int) where {T<:Real} | ||
n >= 0 || error("n must be non-negative") | ||
|
||
if n > 0 | ||
logV = convert(T, (n/2)*log(π) - loggamma((n/2) + 1)) | ||
else | ||
logV = zero(T) | ||
end | ||
new{T}(n, logV) | ||
end | ||
end | ||
UniformBall(n::Int) = UniformBall{Float64}(n) | ||
|
||
show(io::IO, d::UniformBall) = show(io, d, (:n,)) | ||
|
||
### Conversions | ||
convert(::Type{UniformBall{T}}, d::UniformBall) where {T<:Real} = UniformBall{T}(d.n) | ||
|
||
|
||
### Basic properties | ||
|
||
length(d::UniformBall) = d.n | ||
|
||
mean(d::UniformBall{T}) where T = zeros(T, length(d)) | ||
meandir(d::UniformBall) = mean(d) | ||
concentration(d::UniformBall{T}) where T = zero(T) | ||
|
||
cov(d::UniformBall{T}) where T = Diagonal{T}(var(d)) | ||
var(d::UniformBall{T}) where T = ones(T, length(d)) / (2 + length(d)) | ||
|
||
insupport(d::UniformBall, x::AbstractVector{T}) where {T<:Real} = (norm(x) <= one(T)) || isunitvec(x) | ||
params(d::UniformBall) = (d.n,) | ||
# @inline partype(d::UniformBall{T}) where {T<:Real} = T | ||
|
||
### Evaluation | ||
|
||
_logpdf(d::UniformBall, x::AbstractVector{T}) where {T<:Real} = insupport(d, x) ? -d.logV : -T(Inf) | ||
|
||
entropy(d::UniformBall) = d.logV | ||
|
||
### Sampling | ||
|
||
sampler(d::UniformBall{T}) where T = UniformBallSampler(d.n) | ||
|
||
for A in [:AbstractVector, :AbstractMatrix] | ||
@eval function _rand!(rng::AbstractRNG, d::UniformBall, x::$A) | ||
if length(d) == 1 | ||
# in 1D, reduces to U[-1, 1] | ||
for i in eachindex(x) | ||
@inbounds x[i] = 2*rand(rng) - 1 | ||
end | ||
else | ||
_rand!(rng, sampler(d), x) | ||
end | ||
return x | ||
end | ||
end | ||
|
||
|
||
### Estimation | ||
fit_mle(::Type{<:UniformBall}, X::Matrix{T}) where {T <: Real} = UniformBall{T}(size(X, 1)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
# Uniform distribution on the n-sphere in R^{n+1} | ||
# | ||
# The implementation here follows: | ||
# | ||
# - Wikipedia: | ||
# https://en.wikipedia.org/wiki/N-sphere | ||
|
||
struct UniformSpherical{T<:Real} <: ContinuousMultivariateDistribution | ||
n::Int | ||
logS::T # log normalization constant | ||
|
||
function UniformSpherical{T}(n::Int) where {T<:Real} | ||
n >= 0 || error("n must be non-negative") | ||
|
||
ln2 = log(2) | ||
if n > 0 | ||
p = (n+1) / 2.0 | ||
logS = ln2 + p*(log2π - ln2) - loggamma(p) | ||
else | ||
logS = ln2 | ||
end | ||
new{T}(n, convert(T, logS)) | ||
end | ||
end | ||
UniformSpherical(n::Int) = UniformSpherical{Float64}(n) | ||
|
||
show(io::IO, d::UniformSpherical) = show(io, d, (:n,)) | ||
|
||
### Conversions | ||
convert(::Type{UniformSpherical{T}}, d::UniformSpherical) where {T<:Real} = UniformSpherical{T}(d.n) | ||
|
||
|
||
### Basic properties | ||
|
||
length(d::UniformSpherical) = d.n + 1 | ||
|
||
mean(d::UniformSpherical{T}) where T = zeros(T, length(d)) | ||
meandir(d::UniformSpherical) = mean(d) # should this error out instead? | ||
|
||
cov(d::UniformSpherical{T}) where T = Diagonal{T}(var(d)) | ||
var(d::UniformSpherical{T}) where T = ones(T, length(d)) / length(d) | ||
|
||
concentration(d::UniformSpherical{T}) where T = zero(T) | ||
|
||
insupport(d::UniformSpherical, x::AbstractVector{T}) where {T<:Real} = isunitvec(x) | ||
params(d::UniformSpherical) = (d.n,) | ||
# @inline partype(d::UniformSpherical{T}) where {T<:Real} = T | ||
|
||
### Evaluation | ||
|
||
_logpdf(d::UniformSpherical, x::AbstractVector{T}) where {T<:Real} = insupport(d, x) ? -d.logS : -T(Inf) | ||
|
||
entropy(d::UniformSpherical) = d.logS | ||
|
||
|
||
### Sampling | ||
|
||
sampler(d::UniformSpherical{T}) where T = UniformSphericalSampler(d.n) | ||
|
||
|
||
for A in [:AbstractVector, :AbstractMatrix] | ||
@eval function _rand!(rng::AbstractRNG, d::UniformSpherical, x::$A) | ||
if length(d) == 1 | ||
# in 1D, reduces to a U{-1, 1} | ||
for i in eachindex(x) | ||
@inbounds x[i] = rand(rng, (-1, 1)) | ||
end | ||
else | ||
_rand!(rng, sampler(d), x) | ||
end | ||
return x | ||
end | ||
end | ||
|
||
|
||
### Estimation | ||
fit_mle(::Type{<:UniformSpherical}, X::Matrix{T}) where {T <: Real} = UniformSpherical{T}(size(X, 1) - 1) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
# Sampler for Uniform Ball | ||
|
||
struct UniformBallSampler | ||
n::Int | ||
end | ||
|
||
|
||
function _rand!(rng::AbstractRNG, spl::UniformBallSampler, x::AbstractVector) | ||
n = spl.n | ||
# defer to UniformSphericalSampler for calculation of unit-vector | ||
_rand!(rng, UniformSphericalSampler(n-1), x) | ||
|
||
# re-scale x | ||
u = rand(rng) | ||
r = (u^inv(n)) | ||
x .*= r | ||
return x | ||
end | ||
|
||
|
||
function _rand!(rng::AbstractRNG, spl::UniformBallSampler, x::AbstractMatrix) | ||
for j in axes(x, 2) | ||
_rand!(rng, spl, view(x,:,j)) | ||
end | ||
return x | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
# Sampler for Uniform Spherical | ||
|
||
struct UniformSphericalSampler | ||
n::Int | ||
end | ||
|
||
|
||
function _rand!(rng::AbstractRNG, spl::UniformSphericalSampler, x::AbstractVector) | ||
n = spl.n | ||
s = 0.0 | ||
@inbounds for i = 1:(n+1) | ||
x[i] = xi = randn(rng) | ||
s += abs2(xi) | ||
end | ||
|
||
# normalize x | ||
r = inv(sqrt(s)) | ||
x .*= r | ||
return x | ||
end | ||
|
||
|
||
function _rand!(rng::AbstractRNG, spl::UniformSphericalSampler, x::AbstractMatrix) | ||
for j in axes(x, 2) | ||
_rand!(rng, spl, view(x,:,j)) | ||
end | ||
return x | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
# Tests for Uniform ball distribution | ||
|
||
using Distributions | ||
using Test | ||
|
||
function test_uniformball(n::Int) | ||
d = UniformBall(n) | ||
@test length(d) == n | ||
@test mean(d) == zeros(length(d)) | ||
@test diag(cov(d)) == var(d) | ||
@test d == typeof(d)(params(d)...) | ||
@test d == deepcopy(d) | ||
@test partype(d) == Float64 | ||
|
||
# conversions | ||
@test typeof(convert(UniformBall{Float32}, d)) == UniformBall{Float32} | ||
|
||
# Support | ||
x = normalize(rand(length(d))) | ||
|
||
if length(d) > 0 | ||
@test !insupport(d, 100*x) | ||
@test pdf(d, 100*x) == 0.0 | ||
else | ||
@test insupport(d, 100*x) | ||
@test pdf(d, 100*x) == 1.0 | ||
end | ||
|
||
@test insupport(d, x) | ||
@test pdf(d, x) != 0.0 | ||
|
||
@test insupport(d, 0.1*x) | ||
@test pdf(d, 0.1*x) != 0.0 | ||
|
||
# Sampling | ||
X = [rand(d) for _ in 1:100_000] | ||
@test isapprox(mean(d), mean(X), atol=0.01) | ||
@test isapprox(var(d), var(X), atol=0.01) | ||
@test isapprox(cov(d), cov(X), atol=0.01) | ||
end | ||
|
||
|
||
## General testing | ||
|
||
@testset "Testing UniformBall at $n" for n in 0:10 | ||
test_uniformball(n) | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
# Tests for Uniform Spherical distribution | ||
|
||
using Distributions | ||
using Test | ||
|
||
function test_uniformspherical(n::Int) | ||
d = UniformSpherical(n) | ||
@test length(d) == n+1 | ||
@test mean(d) == zeros(length(d)) | ||
@test diag(cov(d)) == var(d) | ||
@test d == typeof(d)(params(d)...) | ||
@test d == deepcopy(d) | ||
@test partype(d) == Float64 | ||
|
||
# conversions | ||
@test typeof(convert(UniformSpherical{Float32}, d)) == UniformSpherical{Float32} | ||
|
||
# Support | ||
x = normalize(rand(length(d))) | ||
|
||
@test !insupport(d, 100*x) | ||
@test pdf(d, 100*x) == 0.0 | ||
|
||
@test insupport(d, x) | ||
@test pdf(d, x) != 0.0 | ||
|
||
@test !insupport(d, 0.1*x) | ||
@test pdf(d, 0.1*x) == 0.0 | ||
|
||
# Sampling | ||
X = [rand(d) for _ in 1:100_000] | ||
@test isapprox(mean(d), mean(X), atol=0.01) | ||
@test isapprox(var(d), var(X), atol=0.01) | ||
@test isapprox(cov(d), cov(X), atol=0.01) | ||
end | ||
|
||
|
||
## General testing | ||
|
||
@testset "Testing UniformSpherical at $n" for n in 0:10 | ||
test_uniformspherical(n) | ||
end |