From 0b4419033852d198dc7d3d8685d13622cdd1563a Mon Sep 17 00:00:00 2001 From: Zachary Sunberg Date: Mon, 3 Jun 2019 08:54:39 -0700 Subject: [PATCH 1/4] added a uniform distribution over a collection of objects --- docs/src/distributions.md | 6 ++++++ src/POMDPModelTools.jl | 4 ++++ src/distributions/uniform.jl | 35 +++++++++++++++++++++++++++++++++ test/runtests.jl | 4 ++++ test/test_uniform.jl | 22 +++++++++++++++++++++ test/test_weighted_iteration.jl | 1 - 6 files changed, 71 insertions(+), 1 deletion(-) create mode 100644 src/distributions/uniform.jl create mode 100644 test/test_uniform.jl diff --git a/docs/src/distributions.md b/docs/src/distributions.md index dfc1398..a7fb73f 100644 --- a/docs/src/distributions.md +++ b/docs/src/distributions.md @@ -23,3 +23,9 @@ BoolDistribution ```@docs Deterministic ``` + +## Uniform + +```@docs +Uniform +``` diff --git a/src/POMDPModelTools.jl b/src/POMDPModelTools.jl index 16ab65c..a66a045 100644 --- a/src/POMDPModelTools.jl +++ b/src/POMDPModelTools.jl @@ -71,6 +71,10 @@ export Deterministic include("distributions/deterministic.jl") +export + Uniform +include("distributions/uniform.jl") + # convenient implementations include("convenient_implementations.jl") diff --git a/src/distributions/uniform.jl b/src/distributions/uniform.jl new file mode 100644 index 0000000..eb14148 --- /dev/null +++ b/src/distributions/uniform.jl @@ -0,0 +1,35 @@ +""" + Uniform(collection) + +Create a categorical distribution over a collection of objects. +""" +mutable struct Uniform{C, T} + collection::C + _set::Union{Set{T}, Nothing} # keep track of what's in the collection to make pdf more efficient +end + +Uniform(c) = Uniform{typeof(c), eltype(c)}(c, nothing) +Uniform(c::Set) = Uniform{typeof(c), eltype(c)}(c, c) + +rand(rng::AbstractRNG, d::Uniform) = rand(rng, d.collection) +rand(rng::AbstractRNG, d::Uniform{<:NamedTuple}) = d.collection[rand(rng, 1:length(d.collection))] + +support(d::Uniform) = d.collection +sampletype(::Type{Uniform{C, T}}) where {C,T} = T + +function pdf(d::Uniform, s) + d._set = something(d._set, Set(d.collection)) + if s in d._set + return 1/length(d.collection) + else + return 0.0 + end +end + +mode(d::Uniform) = mode(d.collection) +mean(d::Uniform) = mean(d.collection) + +function weighted_iterator(d::Uniform) + p = 1/length(d.collection) + return (x=>p for x in d.collection) +end diff --git a/test/runtests.jl b/test/runtests.jl index 9341857..0887699 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -5,6 +5,7 @@ using Random using Test using Pkg using POMDPSimulators +import Distributions.Categorical @testset "ordered" begin include("test_ordered_spaces.jl") @@ -29,6 +30,9 @@ end @testset "deterministic" begin include("test_deterministic.jl") end +@testset "uniform" begin + include("test_uniform.jl") +end @testset "terminalstate" begin include("test_terminal_state.jl") end diff --git a/test/test_uniform.jl b/test/test_uniform.jl new file mode 100644 index 0000000..cd13e82 --- /dev/null +++ b/test/test_uniform.jl @@ -0,0 +1,22 @@ +d = Uniform([1]) + +@test rand(d) == 1 +@test rand(MersenneTwister(4), d) == 1 +@test collect(support(d)) == [1] +@test sampletype(d) == typeof(1) +@test sampletype(typeof(d)) == typeof(1) +@test pdf(d, 0) == 0.0 +@test pdf(d, 1) == 1.0 +@test mode(d) == 1 +@test mean(d) == 1 +@test typeof(mean(d)) == typeof(mean([1])) + +d2 = Uniform((:symbol,)) +@test rand(d2) == :symbol +@test rand(MersenneTwister(4), d2) == :symbol +@test collect(support(d2)) == [:symbol] +@test sampletype(d2) == typeof(:symbol) +@test sampletype(typeof(d2)) == typeof(:symbol) +@test pdf(d2, :another) == 0.0 +@test pdf(d2, :symbol) == 1.0 +@test mode(d2) == :symbol diff --git a/test/test_weighted_iteration.jl b/test/test_weighted_iteration.jl index a114f51..5a949f7 100644 --- a/test/test_weighted_iteration.jl +++ b/test/test_weighted_iteration.jl @@ -1,5 +1,4 @@ let - using Distributions dist = Categorical([0.4, 0.5, 0.1]) c = collect(weighted_iterator(dist)) @test c == [1=>0.4, 2=>0.5, 3=>0.1] From 56859d97c2bd3d921c6af10a47ebe8728887377b Mon Sep 17 00:00:00 2001 From: Zachary Sunberg Date: Tue, 4 Jun 2019 10:14:35 -0700 Subject: [PATCH 2/4] Uniform checks uniqueness, added UnsafeUniform --- docs/src/distributions.md | 1 + src/POMDPModelTools.jl | 3 +- src/distributions/uniform.jl | 69 +++++++++++++++++++++++++++--------- test/test_uniform.jl | 29 +++++++++++++++ 4 files changed, 84 insertions(+), 18 deletions(-) diff --git a/docs/src/distributions.md b/docs/src/distributions.md index a7fb73f..e97fbb8 100644 --- a/docs/src/distributions.md +++ b/docs/src/distributions.md @@ -28,4 +28,5 @@ Deterministic ```@docs Uniform +UnsafeUniform ``` diff --git a/src/POMDPModelTools.jl b/src/POMDPModelTools.jl index 7cac943..2c935cf 100644 --- a/src/POMDPModelTools.jl +++ b/src/POMDPModelTools.jl @@ -73,7 +73,8 @@ export include("distributions/deterministic.jl") export - Uniform + Uniform, + UnsafeUniform include("distributions/uniform.jl") # convenient implementations diff --git a/src/distributions/uniform.jl b/src/distributions/uniform.jl index eb14148..c618944 100644 --- a/src/distributions/uniform.jl +++ b/src/distributions/uniform.jl @@ -1,35 +1,70 @@ +struct Uniform{T<:AbstractSet} + set::T +end + """ Uniform(collection) -Create a categorical distribution over a collection of objects. +Create a uniform categorical distribution over a collection of objects. + +The objects in the collection must be unique (this is tested on construction), and will be stored in a `Set`. To avoid this overhead, use `UnsafeUniform`. """ -mutable struct Uniform{C, T} - collection::C - _set::Union{Set{T}, Nothing} # keep track of what's in the collection to make pdf more efficient -end +function Uniform(c) + set = Set(c) + if length(c) > length(set) + error(""" + Error constructing Uniform($c). -Uniform(c) = Uniform{typeof(c), eltype(c)}(c, nothing) -Uniform(c::Set) = Uniform{typeof(c), eltype(c)}(c, c) + Objects must be unique (that is, length(Set(c)) == length(c)). + """ + ) + end + return Uniform(set) +end -rand(rng::AbstractRNG, d::Uniform) = rand(rng, d.collection) -rand(rng::AbstractRNG, d::Uniform{<:NamedTuple}) = d.collection[rand(rng, 1:length(d.collection))] +rand(rng::AbstractRNG, d::Uniform) = rand(rng, d.set) -support(d::Uniform) = d.collection -sampletype(::Type{Uniform{C, T}}) where {C,T} = T +support(d::Uniform) = d.set +sampletype(::Type{Uniform{T}}) where T = eltype(T) function pdf(d::Uniform, s) - d._set = something(d._set, Set(d.collection)) - if s in d._set - return 1/length(d.collection) + if s in d.set + return 1.0/length(d.set) else return 0.0 end end -mode(d::Uniform) = mode(d.collection) -mean(d::Uniform) = mean(d.collection) +mean(d::Uniform) = mean(d.set) +mode(d::Uniform) = mode(d.set) function weighted_iterator(d::Uniform) - p = 1/length(d.collection) + p = 1.0/length(d.set) + return (x=>p for x in d.set) +end + +""" + UnsafeUniform(collection) + +Create a uniform categorical distribution over a collection of objects. + +No checks are performed to ensure uniqueness or check whether an object is actually in the set when evaluating the pdf. +""" +struct UnsafeUniform{T} + collection::T +end + +rand(rng::AbstractRNG, d::UnsafeUniform) = rand(rng, d.collection) + +support(d::UnsafeUniform) = d.collection +sampletype(::Type{UnsafeUniform{T}}) where T = eltype(T) + +pdf(d::UnsafeUniform, s) = 1.0/length(d.collection) + +mean(d::UnsafeUniform) = mean(d.collection) +mode(d::UnsafeUniform) = mode(d.collection) + +function weighted_iterator(d::UnsafeUniform) + p = 1.0/length(d.collection) return (x=>p for x in d.collection) end diff --git a/test/test_uniform.jl b/test/test_uniform.jl index cd13e82..b27b412 100644 --- a/test/test_uniform.jl +++ b/test/test_uniform.jl @@ -10,6 +10,7 @@ d = Uniform([1]) @test mode(d) == 1 @test mean(d) == 1 @test typeof(mean(d)) == typeof(mean([1])) +@test collect(weighted_iterator(d)) == [1=>1.0] d2 = Uniform((:symbol,)) @test rand(d2) == :symbol @@ -20,3 +21,31 @@ d2 = Uniform((:symbol,)) @test pdf(d2, :another) == 0.0 @test pdf(d2, :symbol) == 1.0 @test mode(d2) == :symbol +@test collect(weighted_iterator(d2)) == [:symbol=>1.0] + +# uniqueness test +@test_throws ErrorException Uniform((:symbol, :symbol)) + +d3 = UnsafeUniform([1]) + +@test rand(d3) == 1 +@test rand(MersenneTwister(4), d3) == 1 +@test collect(support(d3)) == [1] +@test sampletype(d3) == typeof(1) +@test sampletype(typeof(d3)) == typeof(1) +@test pdf(d3, 1) == 1.0 +@test mean(d3) == 1 +@test mode(d3) == 1 +@test typeof(mean(d3)) == typeof(mean([1])) +@test collect(weighted_iterator(d3)) == [1=>1.0] + +d4 = UnsafeUniform((:symbol,)) +@test rand(d4) == :symbol +@test rand(MersenneTwister(4), d4) == :symbol +@test collect(support(d4)) == [:symbol] +@test sampletype(d4) == typeof(:symbol) +@test sampletype(typeof(d4)) == typeof(:symbol) +# @test pdf(d4, :another) == 0.0 # this will not work +@test pdf(d4, :symbol) == 1.0 +@test mode(d4) == :symbol +@test collect(weighted_iterator(d4)) == [:symbol=>1.0] From 8db5e7ba98d58578709e7fe4ab9b7c9e2218796e Mon Sep 17 00:00:00 2001 From: Zachary Sunberg Date: Tue, 4 Jun 2019 10:26:18 -0700 Subject: [PATCH 3/4] put code into common section --- src/distributions/uniform.jl | 34 +++++++++++++++++----------------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/src/distributions/uniform.jl b/src/distributions/uniform.jl index c618944..80a7f01 100644 --- a/src/distributions/uniform.jl +++ b/src/distributions/uniform.jl @@ -22,8 +22,7 @@ function Uniform(c) return Uniform(set) end -rand(rng::AbstractRNG, d::Uniform) = rand(rng, d.set) - +# rand(rng::AbstractRNG, d::Uniform) = rand(rng, d.set) support(d::Uniform) = d.set sampletype(::Type{Uniform{T}}) where T = eltype(T) @@ -35,13 +34,13 @@ function pdf(d::Uniform, s) end end -mean(d::Uniform) = mean(d.set) -mode(d::Uniform) = mode(d.set) - -function weighted_iterator(d::Uniform) - p = 1.0/length(d.set) - return (x=>p for x in d.set) -end +# mean(d::Uniform) = mean(d.set) +# mode(d::Uniform) = mode(d.set) +# +# function weighted_iterator(d::Uniform) +# p = 1.0/length(d.set) +# return (x=>p for x in d.set) +# end """ UnsafeUniform(collection) @@ -54,17 +53,18 @@ struct UnsafeUniform{T} collection::T end -rand(rng::AbstractRNG, d::UnsafeUniform) = rand(rng, d.collection) - +pdf(d::UnsafeUniform, s) = 1.0/length(d.collection) support(d::UnsafeUniform) = d.collection sampletype(::Type{UnsafeUniform{T}}) where T = eltype(T) -pdf(d::UnsafeUniform, s) = 1.0/length(d.collection) +# common +const Unif = Union{Uniform,UnsafeUniform} -mean(d::UnsafeUniform) = mean(d.collection) -mode(d::UnsafeUniform) = mode(d.collection) +rand(rng::AbstractRNG, d::Unif) = rand(rng, support(d)) +mean(d::Unif) = mean(support(d)) +mode(d::Unif) = mode(support(d)) -function weighted_iterator(d::UnsafeUniform) - p = 1.0/length(d.collection) - return (x=>p for x in d.collection) +function weighted_iterator(d::Unif) + p = 1.0/length(support(d)) + return (x=>p for x in support(d)) end From 47936b394a705655acbde52da127cf6d695669bc Mon Sep 17 00:00:00 2001 From: Zachary Sunberg Date: Thu, 6 Jun 2019 09:33:59 -0700 Subject: [PATCH 4/4] deleted commented code --- src/distributions/uniform.jl | 12 ++---------- 1 file changed, 2 insertions(+), 10 deletions(-) diff --git a/src/distributions/uniform.jl b/src/distributions/uniform.jl index 80a7f01..a4f4520 100644 --- a/src/distributions/uniform.jl +++ b/src/distributions/uniform.jl @@ -22,7 +22,6 @@ function Uniform(c) return Uniform(set) end -# rand(rng::AbstractRNG, d::Uniform) = rand(rng, d.set) support(d::Uniform) = d.set sampletype(::Type{Uniform{T}}) where T = eltype(T) @@ -34,14 +33,6 @@ function pdf(d::Uniform, s) end end -# mean(d::Uniform) = mean(d.set) -# mode(d::Uniform) = mode(d.set) -# -# function weighted_iterator(d::Uniform) -# p = 1.0/length(d.set) -# return (x=>p for x in d.set) -# end - """ UnsafeUniform(collection) @@ -57,7 +48,8 @@ pdf(d::UnsafeUniform, s) = 1.0/length(d.collection) support(d::UnsafeUniform) = d.collection sampletype(::Type{UnsafeUniform{T}}) where T = eltype(T) -# common +# Common Implementations + const Unif = Union{Uniform,UnsafeUniform} rand(rng::AbstractRNG, d::Unif) = rand(rng, support(d))