Skip to content

Commit

Permalink
add rand for ProbabilitySimplex (#604)
Browse files Browse the repository at this point in the history
* squash

* type params simpler

* Fix change_metric, change_representer and project for ProbabilitySimplex (#627)

* Add Exports, Improve GHA Caches. (#626)

* Add a few exports, mainly reexports from ManifoldsBase.

* Improve caching.

i.e. doing it correctly.

* 📖🩹

at least a bit,

* 📖🩹 2

* export FisherRaoMetric.

* new change metric

* add exp and test

* Update test/manifolds/probability_simplex.jl

* Update src/manifolds/ProbabilitySimplexEuclideanMetric.jl

* Update src/manifolds/ProbabilitySimplex.jl

Co-authored-by: Seth Axen <seth@sethaxen.com>

---------

Co-authored-by: Mateusz Baran <mateuszbaran89@gmail.com>
Co-authored-by: Ronny Bergmann <git@ronnybergmann.net>
Co-authored-by: Seth Axen <seth@sethaxen.com>
  • Loading branch information
4 people committed Jun 24, 2023
1 parent dae9e52 commit d24f6d2
Show file tree
Hide file tree
Showing 4 changed files with 107 additions and 1 deletion.
1 change: 1 addition & 0 deletions src/Manifolds.jl
Original file line number Diff line number Diff line change
Expand Up @@ -410,6 +410,7 @@ include("manifolds/SymplecticStiefel.jl")
include("manifolds/Tucker.jl")
#
include("manifolds/ProbabilitySimplex.jl")
include("manifolds/ProbabilitySimplexEuclideanMetric.jl")
include("manifolds/GeneralUnitaryMatrices.jl")
include("manifolds/Unitary.jl")
include("manifolds/Rotations.jl")
Expand Down
41 changes: 40 additions & 1 deletion src/manifolds/ProbabilitySimplex.jl
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ function check_vector(M::ProbabilitySimplex, p, X; kwargs...)
if !isapprox(sum(X), 0.0; kwargs...)
return DomainError(
sum(X),
"The vector $(X) is not a tangent vector to $(p) on $(M), since its elements to not sum up to 0.",
"The vector $(X) is not a tangent vector to $(p) on $(M), since its elements do not sum up to 0.",
)
end
return nothing
Expand Down Expand Up @@ -354,6 +354,45 @@ function parallel_transport_to!(M::ProbabilitySimplex{N}, Y, p, X, q) where {N}
return amplitude_to_simplex_diff!(M, Y, q_s, Ys)
end

@doc raw"""
rand(::ProbabilitySimplex; vector_at=nothing, σ::Real=1.0)
When `vector_at` is `nothing`, return a random (uniform over the Fisher-Rao metric; that is, uniform with respect to the `n`-sphere whose positive orthant is mapped to the simplex).
point `x` on the [`ProbabilitySimplex`](@ref) manifold `M` according to the isometric embedding into
the `n`-sphere by normalizing the vector length of a sample from a multivariate Gaussian. See [^Marsaglia1972].
When `vector_at` is not `nothing`, return a (Gaussian) random vector from the tangent space
``T_{p}\mathrm{\Delta}^n``by shifting a multivariate Gaussian with standard deviation `σ`
to have a zero component sum.
[^Marsaglia1972]:
> Marsaglia, G.:
> _Choosing a Point from the Surface of a Sphere_.
> Annals of Mathematical Statistics, 43 (2): 645–646, 1972.
> doi: [10.1214/aoms/1177692644](https://doi.org/10.1214/aoms/1177692644)
"""
rand(::ProbabilitySimplex; σ::Real=1.0)

function Random.rand!(
rng::AbstractRNG,
M::ProbabilitySimplex,
pX;
vector_at=nothing,
σ=one(eltype(pX)),
)
if isnothing(vector_at)
Random.randn!(rng, pX)
LinearAlgebra.normalize!(pX, 2)
pX .= abs2.(pX)
else
Random.randn!(rng, pX)
pX .= (pX .- mean(pX)) .* σ
change_metric!(M, pX, EuclideanMetric(), vector_at, pX)
end
return pX
end

@doc raw"""
project(M::ProbabilitySimplex, p, Y)
Expand Down
46 changes: 46 additions & 0 deletions src/manifolds/ProbabilitySimplexEuclideanMetric.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
exp!(::MetricManifold{ℝ,<:ProbabilitySimplex,<:EuclideanMetric}, q, p, X) = (q .= p .+ X)
function exp!(
::MetricManifold{ℝ,<:ProbabilitySimplex,<:EuclideanMetric},
q,
p,
X,
t::Number,
)
return (q .= p .+ t .* X)
end

@doc raw"""
rand(::MetricManifold{ℝ,<:ProbabilitySimplex,<:EuclideanMetric}; vector_at=nothing, σ::Real=1.0)
When `vector_at` is `nothing`, return a random (uniform) point `x` on the [`ProbabilitySimplex`](@ref) with the Euclidean metric
manifold `M` by normalizing independent exponential draws to unit sum, see [^Devroye1986], Theorems 2.1 and 2.2 on p. 207 and 208, respectively.
When `vector_at` is not `nothing`, return a (Gaussian) random vector from the tangent space
``T_{p}\mathrm{\Delta}^n``by shifting a multivariate Gaussian with standard deviation `σ`
to have a zero component sum.
[^Devroye1986]:
> Devroye, L.:
> _Non-Uniform Random Variate Generation_.
> Springer New York, NY, 1986.
> doi: [10.1007/978-1-4613-8643-8](https://doi.org/10.1007/978-1-4613-8643-8)
"""
rand(::MetricManifold{ℝ,<:ProbabilitySimplex,<:EuclideanMetric}; σ::Real=1.0)

function Random.rand!(
rng::AbstractRNG,
M::MetricManifold{ℝ,<:ProbabilitySimplex,<:EuclideanMetric},
pX;
vector_at=nothing,
σ=one(eltype(pX)),
)
if isnothing(vector_at)
Random.randexp!(rng, pX)
LinearAlgebra.normalize!(pX, 1)
else
Random.randn!(rng, pX)
pX .= (pX .- mean(pX)) .* σ
end
return pX
end
20 changes: 20 additions & 0 deletions test/manifolds/probability_simplex.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ include("../utils.jl")

@testset "Probability simplex" begin
M = ProbabilitySimplex(2)
M_euc = MetricManifold(M, EuclideanMetric())

@test M^4 == MultinomialMatrices(3, 4)
p = [0.1, 0.7, 0.2]
q = [0.3, 0.6, 0.1]
Expand Down Expand Up @@ -60,6 +62,24 @@ include("../utils.jl")
retraction_methods=[SoftmaxRetraction()],
test_inplace=true,
vector_transport_methods=[ParallelTransport()],
test_rand_point=true,
test_rand_tvector=true,
rand_tvector_atol_multiplier=20.0,
)
test_manifold(
M_euc,
pts,
test_exp_log=false,
test_injectivity_radius=false,
test_project_tangent=true,
test_musical_isomorphisms=true,
test_vee_hat=false,
is_tangent_atol_multiplier=40.0,
default_inverse_retraction_method=nothing,
test_inplace=false,
test_rand_point=true,
test_rand_tvector=true,
rand_tvector_atol_multiplier=40.0,
)
end
end
Expand Down

2 comments on commit d24f6d2

@mateuszbaran
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/86223

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.8.71 -m "<description of version>" d24f6d249f01b6074bf1c76b8e3e35ae6a1f4913
git push origin v0.8.71

Please sign in to comment.