From 70e20706cba4eaf749f39493d3370145eb6ff6a0 Mon Sep 17 00:00:00 2001 From: Zachary Sunberg Date: Mon, 10 Jun 2019 09:46:18 -0700 Subject: [PATCH] added pretty printing of distributions (#22) --- Project.toml | 3 ++- docs/src/distributions.md | 7 ++++++ src/POMDPModelTools.jl | 5 ++++ src/distributions/bool.jl | 2 ++ src/distributions/pretty_printing.jl | 36 ++++++++++++++++++++++++++++ src/distributions/sparse_cat.jl | 2 ++ src/distributions/uniform.jl | 4 ++++ test/runtests.jl | 4 ++++ test/test_bool.jl | 4 +++- test/test_pretty_printing.jl | 8 +++++++ test/test_sparse_cat.jl | 2 ++ test/test_uniform.jl | 6 +++++ 12 files changed, 81 insertions(+), 2 deletions(-) create mode 100644 src/distributions/pretty_printing.jl create mode 100644 test/test_pretty_printing.jl diff --git a/Project.toml b/Project.toml index 2cea045..0d9163b 100644 --- a/Project.toml +++ b/Project.toml @@ -10,6 +10,7 @@ POMDPs = "a93abf59-7444-517b-a68a-c42f96afdd7d" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" +UnicodePlots = "b8865327-cd53-5732-bb35-84acbb429228" [compat] Distributions = ">= 0.17" @@ -19,8 +20,8 @@ julia = "1" POMDPModels = "355abbd5-f08e-5560-ac9e-8b5f2592a0ca" POMDPPolicies = "182e52fb-cfd0-5e46-8c26-fd0667c990f4" POMDPSimulators = "e0d0a172-29c6-5d4e-96d0-f262df5d01fd" -Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" +Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] test = ["Test", "POMDPModels", "POMDPSimulators", "POMDPPolicies", "Pkg"] diff --git a/docs/src/distributions.md b/docs/src/distributions.md index e97fbb8..8169899 100644 --- a/docs/src/distributions.md +++ b/docs/src/distributions.md @@ -2,6 +2,8 @@ POMDPModelTools contains several utility distributions to be used in the POMDPs `transition` and `observation` functions. These implement the appropriate methods of the functions in the [distributions interface](http://juliapomdp.github.io/POMDPs.jl/latest/interfaces/#distributions). +This package also supplies [`showdistribution`](@ref) for pretty printing distributions as unicode bar graphs to the terminal. + ## Sparse Categorical (`SparseCat`) `SparseCat` is a sparse categorical distribution which is specified by simply providing a list of possible values (states or observations) and the probabilities corresponding to those particular objects. @@ -30,3 +32,8 @@ Deterministic Uniform UnsafeUniform ``` + +## Pretty Printing +```@docs +showdistribution +``` diff --git a/src/POMDPModelTools.jl b/src/POMDPModelTools.jl index 2c935cf..cec0ab2 100644 --- a/src/POMDPModelTools.jl +++ b/src/POMDPModelTools.jl @@ -4,6 +4,7 @@ using POMDPs using Random using LinearAlgebra using SparseArrays +using UnicodePlots import POMDPs: actions, n_actions, actionindex import POMDPs: states, n_states, stateindex @@ -84,4 +85,8 @@ export evaluate include("policy_evaluation.jl") +export + showdistribution +include("distributions/pretty_printing.jl") + end # module diff --git a/src/distributions/bool.jl b/src/distributions/bool.jl index 4b0b26d..50af37b 100644 --- a/src/distributions/bool.jl +++ b/src/distributions/bool.jl @@ -29,3 +29,5 @@ support(d::BoolDistribution) = [true, false] Base.hash(d::BoolDistribution) = hash(d.p) Base.length(d::BoolDistribution) = 2 + +Base.show(io::IO, m::MIME"text/plain", d::BoolDistribution) = showdistribution(io, m, d, title="BoolDistribution") diff --git a/src/distributions/pretty_printing.jl b/src/distributions/pretty_printing.jl new file mode 100644 index 0000000..da617dd --- /dev/null +++ b/src/distributions/pretty_printing.jl @@ -0,0 +1,36 @@ +""" + showdistribution([io], [mime], d) + +Show a UnicodePlots.barplot representation of a distribution. + +# Keyword Arguments + +- `title::String=string(typeof(d))*" distribution"`: title for the barplot. +""" +function showdistribution(io::IO, mime::MIME"text/plain", d; title=string(typeof(d))*" distribution") + limited = get(io, :limit, false) + strings = String[] + probs = Float64[] + + rows = first(get(io, :displaysize, displaysize(io))) + rows -= 6 # Yuck! This magic number is also in Base.print_matrix + + if limited && rows > 1 + for (x,p) in Iterators.take(weighted_iterator(d), rows-1) + push!(strings, sprint(show, x)) # maybe this should have conext=:compact=>true + push!(probs, p) + end + + push!(strings, "") + push!(probs, 1.0-sum(probs)) + else + for (x,p) in weighted_iterator(d) + push!(strings, sprint(show, x)) + push!(probs, p) + end + end + show(io, mime, barplot(strings, probs, title=title)) +end + +showdistribution(io::IO, d; kwargs...) = showdistribution(io, MIME("text/plain"), d; kwargs...) +showdistribution(d; kwargs...) = showdistribution(stdout, d; kwargs...) diff --git a/src/distributions/sparse_cat.jl b/src/distributions/sparse_cat.jl index d39318a..37b3f46 100644 --- a/src/distributions/sparse_cat.jl +++ b/src/distributions/sparse_cat.jl @@ -110,3 +110,5 @@ function mode(d::SparseCat) end return bestv end + +Base.show(io::IO, m::MIME"text/plain", d::SparseCat) = showdistribution(io, m, d, title="SparseCat distribution") diff --git a/src/distributions/uniform.jl b/src/distributions/uniform.jl index a4f4520..65bd8c8 100644 --- a/src/distributions/uniform.jl +++ b/src/distributions/uniform.jl @@ -33,6 +33,8 @@ function pdf(d::Uniform, s) end end +Base.show(io::IO, m::MIME"text/plain", d::Uniform) = showdistribution(io, m, d, title="Uniform distribution") + """ UnsafeUniform(collection) @@ -60,3 +62,5 @@ function weighted_iterator(d::Unif) p = 1.0/length(support(d)) return (x=>p for x in support(d)) end + +Base.show(io::IO, m::MIME"text/plain", d::UnsafeUniform) = showdistribution(io, m, d, title="UnsafeUniform distribution") diff --git a/test/runtests.jl b/test/runtests.jl index 90bb01f..9d3fd13 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -62,3 +62,7 @@ end @testset "evaluation" begin include("test_evaluation.jl") end + +@testset "pretty printing" begin + include("test_pretty_printing.jl") +end diff --git a/test/test_bool.jl b/test/test_bool.jl index a08cbff..50907d2 100644 --- a/test/test_bool.jl +++ b/test/test_bool.jl @@ -13,4 +13,6 @@ let # testing hash @test hash(d) == hash(d.p) -end \ No newline at end of file + + @test sprint((io,d)->show(io,MIME("text/plain"),d), d) == sprint((io,d)->showdistribution(io,d,title="BoolDistribution"), d) +end diff --git a/test/test_pretty_printing.jl b/test/test_pretty_printing.jl new file mode 100644 index 0000000..6ac8fe5 --- /dev/null +++ b/test/test_pretty_printing.jl @@ -0,0 +1,8 @@ +d = SparseCat([1,2], [0.5, 0.5]) +@test sprint(showdistribution, d) == " SparseCat{Array{Int64,1},Array{Float64,1}} distribution\n ┌ ┐ \n 1 ┤■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■ 0.5 \n 2 ┤■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■ 0.5 \n └ ┘ " + +d = SparseCat(1:50, fill(1/50, 50)) +iob = IOBuffer() +io = IOContext(iob, :limit=>true, :displaysize=>(10, 7)) +showdistribution(io, d) +@test String(take!(iob)) == " SparseCat{UnitRange{Int64},Array{Float64,1}} distribution\n ┌ ┐ \n 1 ┤■ 0.02 \n 2 ┤■ 0.02 \n 3 ┤■ 0.02 \n ┤■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■ 0.94 \n └ ┘ " diff --git a/test/test_sparse_cat.jl b/test/test_sparse_cat.jl index bca338c..3d5d93d 100644 --- a/test/test_sparse_cat.jl +++ b/test/test_sparse_cat.jl @@ -31,4 +31,6 @@ let @test isapprox(count(samples.==:d)/N, pdf(d,:d), atol=0.005) @test_throws ErrorException rand(Random.GLOBAL_RNG, SparseCat([1], [0.0])) + + @test sprint((io,d)->show(io,MIME("text/plain"),d), d) == sprint((io,d)->showdistribution(io,d,title="SparseCat distribution"), d) end diff --git a/test/test_uniform.jl b/test/test_uniform.jl index b27b412..8fc8da5 100644 --- a/test/test_uniform.jl +++ b/test/test_uniform.jl @@ -12,6 +12,8 @@ d = Uniform([1]) @test typeof(mean(d)) == typeof(mean([1])) @test collect(weighted_iterator(d)) == [1=>1.0] +@test sprint((io,d)->show(io,MIME("text/plain"),d), d) == sprint((io,d)->showdistribution(io,d,title="Uniform distribution"), d) + d2 = Uniform((:symbol,)) @test rand(d2) == :symbol @test rand(MersenneTwister(4), d2) == :symbol @@ -26,6 +28,8 @@ d2 = Uniform((:symbol,)) # uniqueness test @test_throws ErrorException Uniform((:symbol, :symbol)) + + d3 = UnsafeUniform([1]) @test rand(d3) == 1 @@ -39,6 +43,8 @@ d3 = UnsafeUniform([1]) @test typeof(mean(d3)) == typeof(mean([1])) @test collect(weighted_iterator(d3)) == [1=>1.0] +@test sprint((io,d3)->show(io,MIME("text/plain"),d3), d3) == sprint((io,d3)->showdistribution(io,d3,title="UnsafeUniform distribution"), d3) + d4 = UnsafeUniform((:symbol,)) @test rand(d4) == :symbol @test rand(MersenneTwister(4), d4) == :symbol