Skip to content

Commit

Permalink
Merge pull request #19 from JuliaPOMDP/uniform
Browse files Browse the repository at this point in the history
added a uniform distribution over a collection of objects
  • Loading branch information
zsunberg committed Jun 6, 2019
2 parents 715451c + 47936b3 commit 0527629
Show file tree
Hide file tree
Showing 6 changed files with 129 additions and 1 deletion.
7 changes: 7 additions & 0 deletions docs/src/distributions.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,10 @@ BoolDistribution
```@docs
Deterministic
```

## Uniform

```@docs
Uniform
UnsafeUniform
```
5 changes: 5 additions & 0 deletions src/POMDPModelTools.jl
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,11 @@ export
Deterministic
include("distributions/deterministic.jl")

export
Uniform,
UnsafeUniform
include("distributions/uniform.jl")

# convenient implementations
include("convenient_implementations.jl")

Expand Down
62 changes: 62 additions & 0 deletions src/distributions/uniform.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
struct Uniform{T<:AbstractSet}
set::T
end

"""
Uniform(collection)
Create a uniform categorical distribution over a collection of objects.
The objects in the collection must be unique (this is tested on construction), and will be stored in a `Set`. To avoid this overhead, use `UnsafeUniform`.
"""
function Uniform(c)
set = Set(c)
if length(c) > length(set)
error("""
Error constructing Uniform($c).
Objects must be unique (that is, length(Set(c)) == length(c)).
"""
)
end
return Uniform(set)
end

support(d::Uniform) = d.set
sampletype(::Type{Uniform{T}}) where T = eltype(T)

function pdf(d::Uniform, s)
if s in d.set
return 1.0/length(d.set)
else
return 0.0
end
end

"""
UnsafeUniform(collection)
Create a uniform categorical distribution over a collection of objects.
No checks are performed to ensure uniqueness or check whether an object is actually in the set when evaluating the pdf.
"""
struct UnsafeUniform{T}
collection::T
end

pdf(d::UnsafeUniform, s) = 1.0/length(d.collection)
support(d::UnsafeUniform) = d.collection
sampletype(::Type{UnsafeUniform{T}}) where T = eltype(T)

# Common Implementations

const Unif = Union{Uniform,UnsafeUniform}

rand(rng::AbstractRNG, d::Unif) = rand(rng, support(d))
mean(d::Unif) = mean(support(d))
mode(d::Unif) = mode(support(d))

function weighted_iterator(d::Unif)
p = 1.0/length(support(d))
return (x=>p for x in support(d))
end
4 changes: 4 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ using Test
using Pkg
using POMDPSimulators
using POMDPPolicies
import Distributions.Categorical

@testset "ordered" begin
include("test_ordered_spaces.jl")
Expand All @@ -30,6 +31,9 @@ end
@testset "deterministic" begin
include("test_deterministic.jl")
end
@testset "uniform" begin
include("test_uniform.jl")
end
@testset "terminalstate" begin
include("test_terminal_state.jl")
end
Expand Down
51 changes: 51 additions & 0 deletions test/test_uniform.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
d = Uniform([1])

@test rand(d) == 1
@test rand(MersenneTwister(4), d) == 1
@test collect(support(d)) == [1]
@test sampletype(d) == typeof(1)
@test sampletype(typeof(d)) == typeof(1)
@test pdf(d, 0) == 0.0
@test pdf(d, 1) == 1.0
@test mode(d) == 1
@test mean(d) == 1
@test typeof(mean(d)) == typeof(mean([1]))
@test collect(weighted_iterator(d)) == [1=>1.0]

d2 = Uniform((:symbol,))
@test rand(d2) == :symbol
@test rand(MersenneTwister(4), d2) == :symbol
@test collect(support(d2)) == [:symbol]
@test sampletype(d2) == typeof(:symbol)
@test sampletype(typeof(d2)) == typeof(:symbol)
@test pdf(d2, :another) == 0.0
@test pdf(d2, :symbol) == 1.0
@test mode(d2) == :symbol
@test collect(weighted_iterator(d2)) == [:symbol=>1.0]

# uniqueness test
@test_throws ErrorException Uniform((:symbol, :symbol))

d3 = UnsafeUniform([1])

@test rand(d3) == 1
@test rand(MersenneTwister(4), d3) == 1
@test collect(support(d3)) == [1]
@test sampletype(d3) == typeof(1)
@test sampletype(typeof(d3)) == typeof(1)
@test pdf(d3, 1) == 1.0
@test mean(d3) == 1
@test mode(d3) == 1
@test typeof(mean(d3)) == typeof(mean([1]))
@test collect(weighted_iterator(d3)) == [1=>1.0]

d4 = UnsafeUniform((:symbol,))
@test rand(d4) == :symbol
@test rand(MersenneTwister(4), d4) == :symbol
@test collect(support(d4)) == [:symbol]
@test sampletype(d4) == typeof(:symbol)
@test sampletype(typeof(d4)) == typeof(:symbol)
# @test pdf(d4, :another) == 0.0 # this will not work
@test pdf(d4, :symbol) == 1.0
@test mode(d4) == :symbol
@test collect(weighted_iterator(d4)) == [:symbol=>1.0]
1 change: 0 additions & 1 deletion test/test_weighted_iteration.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
let
using Distributions
dist = Categorical([0.4, 0.5, 0.1])
c = collect(weighted_iterator(dist))
@test c == [1=>0.4, 2=>0.5, 3=>0.1]
Expand Down

0 comments on commit 0527629

Please sign in to comment.