Skip to content

Commit

Permalink
added pretty printing of distributions (#22)
Browse files Browse the repository at this point in the history
  • Loading branch information
zsunberg committed Jun 10, 2019
1 parent fadf4c3 commit 70e2070
Show file tree
Hide file tree
Showing 12 changed files with 81 additions and 2 deletions.
3 changes: 2 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ POMDPs = "a93abf59-7444-517b-a68a-c42f96afdd7d"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
UnicodePlots = "b8865327-cd53-5732-bb35-84acbb429228"

[compat]
Distributions = ">= 0.17"
Expand All @@ -19,8 +20,8 @@ julia = "1"
POMDPModels = "355abbd5-f08e-5560-ac9e-8b5f2592a0ca"
POMDPPolicies = "182e52fb-cfd0-5e46-8c26-fd0667c990f4"
POMDPSimulators = "e0d0a172-29c6-5d4e-96d0-f262df5d01fd"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Test", "POMDPModels", "POMDPSimulators", "POMDPPolicies", "Pkg"]
7 changes: 7 additions & 0 deletions docs/src/distributions.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

POMDPModelTools contains several utility distributions to be used in the POMDPs `transition` and `observation` functions. These implement the appropriate methods of the functions in the [distributions interface](http://juliapomdp.github.io/POMDPs.jl/latest/interfaces/#distributions).

This package also supplies [`showdistribution`](@ref) for pretty printing distributions as unicode bar graphs to the terminal.

## Sparse Categorical (`SparseCat`)

`SparseCat` is a sparse categorical distribution which is specified by simply providing a list of possible values (states or observations) and the probabilities corresponding to those particular objects.
Expand Down Expand Up @@ -30,3 +32,8 @@ Deterministic
Uniform
UnsafeUniform
```

## Pretty Printing
```@docs
showdistribution
```
5 changes: 5 additions & 0 deletions src/POMDPModelTools.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ using POMDPs
using Random
using LinearAlgebra
using SparseArrays
using UnicodePlots

import POMDPs: actions, n_actions, actionindex
import POMDPs: states, n_states, stateindex
Expand Down Expand Up @@ -84,4 +85,8 @@ export
evaluate
include("policy_evaluation.jl")

export
showdistribution
include("distributions/pretty_printing.jl")

end # module
2 changes: 2 additions & 0 deletions src/distributions/bool.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,5 @@ support(d::BoolDistribution) = [true, false]
Base.hash(d::BoolDistribution) = hash(d.p)

Base.length(d::BoolDistribution) = 2

Base.show(io::IO, m::MIME"text/plain", d::BoolDistribution) = showdistribution(io, m, d, title="BoolDistribution")
36 changes: 36 additions & 0 deletions src/distributions/pretty_printing.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
"""
showdistribution([io], [mime], d)
Show a UnicodePlots.barplot representation of a distribution.
# Keyword Arguments
- `title::String=string(typeof(d))*" distribution"`: title for the barplot.
"""
function showdistribution(io::IO, mime::MIME"text/plain", d; title=string(typeof(d))*" distribution")
limited = get(io, :limit, false)
strings = String[]
probs = Float64[]

rows = first(get(io, :displaysize, displaysize(io)))
rows -= 6 # Yuck! This magic number is also in Base.print_matrix

if limited && rows > 1
for (x,p) in Iterators.take(weighted_iterator(d), rows-1)
push!(strings, sprint(show, x)) # maybe this should have conext=:compact=>true
push!(probs, p)
end

push!(strings, "<everything else>")
push!(probs, 1.0-sum(probs))
else
for (x,p) in weighted_iterator(d)
push!(strings, sprint(show, x))
push!(probs, p)
end
end
show(io, mime, barplot(strings, probs, title=title))
end

showdistribution(io::IO, d; kwargs...) = showdistribution(io, MIME("text/plain"), d; kwargs...)
showdistribution(d; kwargs...) = showdistribution(stdout, d; kwargs...)
2 changes: 2 additions & 0 deletions src/distributions/sparse_cat.jl
Original file line number Diff line number Diff line change
Expand Up @@ -110,3 +110,5 @@ function mode(d::SparseCat)
end
return bestv
end

Base.show(io::IO, m::MIME"text/plain", d::SparseCat) = showdistribution(io, m, d, title="SparseCat distribution")
4 changes: 4 additions & 0 deletions src/distributions/uniform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ function pdf(d::Uniform, s)
end
end

Base.show(io::IO, m::MIME"text/plain", d::Uniform) = showdistribution(io, m, d, title="Uniform distribution")

"""
UnsafeUniform(collection)
Expand Down Expand Up @@ -60,3 +62,5 @@ function weighted_iterator(d::Unif)
p = 1.0/length(support(d))
return (x=>p for x in support(d))
end

Base.show(io::IO, m::MIME"text/plain", d::UnsafeUniform) = showdistribution(io, m, d, title="UnsafeUniform distribution")
4 changes: 4 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,3 +62,7 @@ end
@testset "evaluation" begin
include("test_evaluation.jl")
end

@testset "pretty printing" begin
include("test_pretty_printing.jl")
end
4 changes: 3 additions & 1 deletion test/test_bool.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,6 @@ let

# testing hash
@test hash(d) == hash(d.p)
end

@test sprint((io,d)->show(io,MIME("text/plain"),d), d) == sprint((io,d)->showdistribution(io,d,title="BoolDistribution"), d)
end
8 changes: 8 additions & 0 deletions test/test_pretty_printing.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
d = SparseCat([1,2], [0.5, 0.5])
@test sprint(showdistribution, d) == " SparseCat{Array{Int64,1},Array{Float64,1}} distribution\n ┌ ┐ \n 1 ┤■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■ 0.5 \n 2 ┤■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■ 0.5 \n └ ┘ "

d = SparseCat(1:50, fill(1/50, 50))
iob = IOBuffer()
io = IOContext(iob, :limit=>true, :displaysize=>(10, 7))
showdistribution(io, d)
@test String(take!(iob)) == " SparseCat{UnitRange{Int64},Array{Float64,1}} distribution\n ┌ ┐ \n 1 ┤■ 0.02 \n 2 ┤■ 0.02 \n 3 ┤■ 0.02 \n <everything else> ┤■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■ 0.94 \n └ ┘ "
2 changes: 2 additions & 0 deletions test/test_sparse_cat.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,4 +31,6 @@ let
@test isapprox(count(samples.==:d)/N, pdf(d,:d), atol=0.005)

@test_throws ErrorException rand(Random.GLOBAL_RNG, SparseCat([1], [0.0]))

@test sprint((io,d)->show(io,MIME("text/plain"),d), d) == sprint((io,d)->showdistribution(io,d,title="SparseCat distribution"), d)
end
6 changes: 6 additions & 0 deletions test/test_uniform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ d = Uniform([1])
@test typeof(mean(d)) == typeof(mean([1]))
@test collect(weighted_iterator(d)) == [1=>1.0]

@test sprint((io,d)->show(io,MIME("text/plain"),d), d) == sprint((io,d)->showdistribution(io,d,title="Uniform distribution"), d)

d2 = Uniform((:symbol,))
@test rand(d2) == :symbol
@test rand(MersenneTwister(4), d2) == :symbol
Expand All @@ -26,6 +28,8 @@ d2 = Uniform((:symbol,))
# uniqueness test
@test_throws ErrorException Uniform((:symbol, :symbol))



d3 = UnsafeUniform([1])

@test rand(d3) == 1
Expand All @@ -39,6 +43,8 @@ d3 = UnsafeUniform([1])
@test typeof(mean(d3)) == typeof(mean([1]))
@test collect(weighted_iterator(d3)) == [1=>1.0]

@test sprint((io,d3)->show(io,MIME("text/plain"),d3), d3) == sprint((io,d3)->showdistribution(io,d3,title="UnsafeUniform distribution"), d3)

d4 = UnsafeUniform((:symbol,))
@test rand(d4) == :symbol
@test rand(MersenneTwister(4), d4) == :symbol
Expand Down

0 comments on commit 70e2070

Please sign in to comment.