diff --git a/docs/src/distributions.md b/docs/src/distributions.md index dfc1398..e97fbb8 100644 --- a/docs/src/distributions.md +++ b/docs/src/distributions.md @@ -23,3 +23,10 @@ BoolDistribution ```@docs Deterministic ``` + +## Uniform + +```@docs +Uniform +UnsafeUniform +``` diff --git a/src/POMDPModelTools.jl b/src/POMDPModelTools.jl index 76a9af9..2c935cf 100644 --- a/src/POMDPModelTools.jl +++ b/src/POMDPModelTools.jl @@ -72,6 +72,11 @@ export Deterministic include("distributions/deterministic.jl") +export + Uniform, + UnsafeUniform +include("distributions/uniform.jl") + # convenient implementations include("convenient_implementations.jl") diff --git a/src/distributions/uniform.jl b/src/distributions/uniform.jl new file mode 100644 index 0000000..a4f4520 --- /dev/null +++ b/src/distributions/uniform.jl @@ -0,0 +1,62 @@ +struct Uniform{T<:AbstractSet} + set::T +end + +""" + Uniform(collection) + +Create a uniform categorical distribution over a collection of objects. + +The objects in the collection must be unique (this is tested on construction), and will be stored in a `Set`. To avoid this overhead, use `UnsafeUniform`. +""" +function Uniform(c) + set = Set(c) + if length(c) > length(set) + error(""" + Error constructing Uniform($c). + + Objects must be unique (that is, length(Set(c)) == length(c)). + """ + ) + end + return Uniform(set) +end + +support(d::Uniform) = d.set +sampletype(::Type{Uniform{T}}) where T = eltype(T) + +function pdf(d::Uniform, s) + if s in d.set + return 1.0/length(d.set) + else + return 0.0 + end +end + +""" + UnsafeUniform(collection) + +Create a uniform categorical distribution over a collection of objects. + +No checks are performed to ensure uniqueness or check whether an object is actually in the set when evaluating the pdf. +""" +struct UnsafeUniform{T} + collection::T +end + +pdf(d::UnsafeUniform, s) = 1.0/length(d.collection) +support(d::UnsafeUniform) = d.collection +sampletype(::Type{UnsafeUniform{T}}) where T = eltype(T) + +# Common Implementations + +const Unif = Union{Uniform,UnsafeUniform} + +rand(rng::AbstractRNG, d::Unif) = rand(rng, support(d)) +mean(d::Unif) = mean(support(d)) +mode(d::Unif) = mode(support(d)) + +function weighted_iterator(d::Unif) + p = 1.0/length(support(d)) + return (x=>p for x in support(d)) +end diff --git a/test/runtests.jl b/test/runtests.jl index c1f48d3..90bb01f 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -6,6 +6,7 @@ using Test using Pkg using POMDPSimulators using POMDPPolicies +import Distributions.Categorical @testset "ordered" begin include("test_ordered_spaces.jl") @@ -30,6 +31,9 @@ end @testset "deterministic" begin include("test_deterministic.jl") end +@testset "uniform" begin + include("test_uniform.jl") +end @testset "terminalstate" begin include("test_terminal_state.jl") end diff --git a/test/test_uniform.jl b/test/test_uniform.jl new file mode 100644 index 0000000..b27b412 --- /dev/null +++ b/test/test_uniform.jl @@ -0,0 +1,51 @@ +d = Uniform([1]) + +@test rand(d) == 1 +@test rand(MersenneTwister(4), d) == 1 +@test collect(support(d)) == [1] +@test sampletype(d) == typeof(1) +@test sampletype(typeof(d)) == typeof(1) +@test pdf(d, 0) == 0.0 +@test pdf(d, 1) == 1.0 +@test mode(d) == 1 +@test mean(d) == 1 +@test typeof(mean(d)) == typeof(mean([1])) +@test collect(weighted_iterator(d)) == [1=>1.0] + +d2 = Uniform((:symbol,)) +@test rand(d2) == :symbol +@test rand(MersenneTwister(4), d2) == :symbol +@test collect(support(d2)) == [:symbol] +@test sampletype(d2) == typeof(:symbol) +@test sampletype(typeof(d2)) == typeof(:symbol) +@test pdf(d2, :another) == 0.0 +@test pdf(d2, :symbol) == 1.0 +@test mode(d2) == :symbol +@test collect(weighted_iterator(d2)) == [:symbol=>1.0] + +# uniqueness test +@test_throws ErrorException Uniform((:symbol, :symbol)) + +d3 = UnsafeUniform([1]) + +@test rand(d3) == 1 +@test rand(MersenneTwister(4), d3) == 1 +@test collect(support(d3)) == [1] +@test sampletype(d3) == typeof(1) +@test sampletype(typeof(d3)) == typeof(1) +@test pdf(d3, 1) == 1.0 +@test mean(d3) == 1 +@test mode(d3) == 1 +@test typeof(mean(d3)) == typeof(mean([1])) +@test collect(weighted_iterator(d3)) == [1=>1.0] + +d4 = UnsafeUniform((:symbol,)) +@test rand(d4) == :symbol +@test rand(MersenneTwister(4), d4) == :symbol +@test collect(support(d4)) == [:symbol] +@test sampletype(d4) == typeof(:symbol) +@test sampletype(typeof(d4)) == typeof(:symbol) +# @test pdf(d4, :another) == 0.0 # this will not work +@test pdf(d4, :symbol) == 1.0 +@test mode(d4) == :symbol +@test collect(weighted_iterator(d4)) == [:symbol=>1.0] diff --git a/test/test_weighted_iteration.jl b/test/test_weighted_iteration.jl index a114f51..5a949f7 100644 --- a/test/test_weighted_iteration.jl +++ b/test/test_weighted_iteration.jl @@ -1,5 +1,4 @@ let - using Distributions dist = Categorical([0.4, 0.5, 0.1]) c = collect(weighted_iterator(dist)) @test c == [1=>0.4, 2=>0.5, 3=>0.1]