Skip to content

Commit

Permalink
update hist_analysis
Browse files Browse the repository at this point in the history
  • Loading branch information
Wu-Chenyang committed May 30, 2021
1 parent a9d274f commit 9b0faab
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 18 deletions.
2 changes: 1 addition & 1 deletion Project.toml
@@ -1,7 +1,7 @@
name = "AdaOPS"
uuid = "eadfb9d8-44f1-454c-a5eb-0663ee7d74a1"
repo = "git@github.com:LAMDA-POMDP/AdaOPS.jl.git"
version = "0.5.2"
version = "0.5.3"

[deps]
BasicPOMCP = "d721219e-3fc6-5570-a8ef-e5402f47c49e"
Expand Down
1 change: 1 addition & 0 deletions src/AdaOPS.jl
Expand Up @@ -17,6 +17,7 @@ using Statistics
using StaticArrays
using Distributions
using Plots
using Plots.PlotMeasures

using MCTS
import MCTS: convert_to_policy
Expand Down
42 changes: 25 additions & 17 deletions src/analysis.jl
Expand Up @@ -6,19 +6,19 @@ function info_analysis(info::Dict)
println("Number of belief node expanded:", D.b)
m = length.(view(D.ba_particles, 1:D.ba))
println(@sprintf("m: mean±std = %5.2f±%4.2f", mean(m), std(m)))
println("Confidence interval (0.1, 0.9) = ", quantile(m, (0.1, 0.9)))
println("90% Confidence interval = ", quantile(m, (0.05, 0.95)))
branch = length.(view(D.ba_children, 1:D.ba))
println(@sprintf("Number of observation branchs: mean±std = %5.2f±%4.2f", mean(branch), std(branch)))
println("Confidence interval (0.1, 0.9) = ", quantile(branch, (0.1, 0.9)))
println("90% Confidence interval = ", quantile(branch, (0.05, 0.95)))
end
depth = info[:depth]
println("Times of exploration: ", length(depth))
println(@sprintf("Depth of exploration: mean±std = %5.2f±%4.2f", mean(depth), std(depth)))
println("Confidence interval (0.1, 0.9) = ", quantile(depth, (0.1, 0.9)))
println("90% Confidence interval = ", quantile(depth, (0.05, 0.95)))
return nothing
end

function hist_analysis(hist::H, display_mean_and_std::Bool = false) where H<:AbstractSimHistory
function hist_analysis(hist::H; display_mean_and_std::Bool = false, layout=(1,4), font_size=12, margin=40px, figure_size=(1700,400)) where H<:AbstractSimHistory
infos = ainfo_hist(hist)

median_d = Float64[]
Expand All @@ -29,24 +29,25 @@ function hist_analysis(hist::H, display_mean_and_std::Bool = false) where H<:Abs

for info in infos
depth = info[:depth]
l_d, m_d, u_d = quantile(depth, (0.1, 0.5, 0.9))
l_d, m_d, u_d = quantile(depth, (0.05, 0.5, 0.95))
push!(median_d, m_d)
push!(lower_d, m_d-l_d)
push!(upper_d, u_d-m_d)
push!(mean_d, mean(depth))
push!(std_d, std(depth))
end
p1 = plot(median_d, ribbon=(lower_d, upper_d), xaxis="Steps", yaxis="Depth of exploration", label="quantile")
if display_mean_and_std
plot!(p1, mean_d, ribbon=std_d, label="mean", legend=:best)
p1 = plot(median_d, ribbon=(lower_d, upper_d), xaxis="Steps", yaxis="Depth of exploration", label="quantile", xtickfontsize=font_size, ytickfontsize=font_size, xguidefontsize=font_size, yguidefontsize=font_size, legendfontsize=font_size)
plot!(p1, mean_d, ribbon=std_d, label="mean", xtickfontsize=font_size, ytickfontsize=font_size, xguidefontsize=font_size, yguidefontsize=font_size, legendfontsize=font_size)
else
p1 = plot(median_d, ribbon=(lower_d, upper_d), xaxis="Steps", yaxis="Depth of exploration", legend=false, xtickfontsize=font_size, ytickfontsize=font_size, xguidefontsize=font_size, yguidefontsize=font_size, legendfontsize=font_size)
end

D = get(first(infos), :tree, nothing)
if D === nothing
display(p1)
else
num_anode = Int[]
num_bnode = Int[]

median_m = Float64[]
lower_m = Float64[] # lower quantile
Expand All @@ -62,31 +63,38 @@ function hist_analysis(hist::H, display_mean_and_std::Bool = false) where H<:Abs

for info in infos
D = info[:tree]

push!(num_anode, D.ba)
push!(num_bnode, D.b)

m = length.(view(D.ba_particles, 1:D.ba))
l_m, m_m, u_m = quantile(m, (0.1, 0.5, 0.9))
l_m, m_m, u_m = quantile(m, (0.05, 0.5, 0.95))
push!(median_m, m_m)
push!(lower_m, m_m-l_m)
push!(upper_m, u_m-m_m)
push!(mean_m, mean(m))
push!(std_m, std(m))

branch = length.(view(D.ba_children, 1:D.ba))
l_b, m_b, u_b = quantile(branch, (0.1, 0.5, 0.9))
l_b, m_b, u_b = quantile(branch, (0.05, 0.5, 0.95))
push!(median_branch, m_b)
push!(lower_branch, m_b-l_b)
push!(upper_branch, u_b-m_b)
push!(mean_branch, mean(branch))
push!(std_branch, std(branch))
end
p2 = plot(hcat(num_anode,num_bnode), label=["Action" "Belief"], xaxis="Steps", yaxis="Nodes expanded", legend=:best)
p3 = plot(median_m, ribbon=(lower_m, upper_m), xaxis="Steps", yaxis="Particles used", label="quantile")
p4 = plot(median_branch, ribbon=(lower_branch, upper_branch), xaxis="Steps", yaxis="Obs. Num.", label="quantile")
base = 10^floor(log(10, mean(num_anode)))
p2 = plot(num_anode, legend=false, xaxis="Steps", yaxis="Action Nodes Expanded", yformatter=y->y/base, xtickfontsize=font_size, ytickfontsize=font_size, xguidefontsize=font_size, yguidefontsize=font_size, legendfontsize=font_size)
annotate!(p2, [(0.9, maximum(num_anode) + (maximum(num_anode)-minimum(num_anode)) * 0.07, Plots.text(@sprintf("\$\\times10^{%d}\$", round(Int, log(10,base))), font_size, :black, :center, "courier"))])
if display_mean_and_std
plot!(p3, mean_m, ribbon=std_m, label="mean", legend=:best)
plot!(p4, mean_branch, ribbon=std_branch, label="mean", legend=:best)
p3 = plot(median_m, ribbon=(lower_m, upper_m), xaxis="Steps", yaxis="Number of Particles", label="quantile", xtickfontsize=font_size, ytickfontsize=font_size, xguidefontsize=font_size, yguidefontsize=font_size, legendfontsize=font_size)
p4 = plot(median_branch, ribbon=(lower_branch, upper_branch), xaxis="Steps", yaxis="Number of Observations", label="quantile", xtickfontsize=font_size, ytickfontsize=font_size, xguidefontsize=font_size, yguidefontsize=font_size, legendfontsize=font_size)
plot!(p3, mean_m, ribbon=std_m, label="mean", xtickfontsize=font_size, ytickfontsize=font_size, xguidefontsize=font_size, yguidefontsize=font_size, legendfontsize=font_size)
plot!(p3, mean_branch, ribbon=std_branch, label="mean", xtickfontsize=font_size, ytickfontsize=font_size, xguidefontsize=font_size, yguidefontsize=font_size, legendfontsize=font_size)
else
p3 = plot(median_m, ribbon=(lower_m, upper_m), xaxis="Steps", yaxis="Number of Particles", legend=false, xtickfontsize=font_size, ytickfontsize=font_size, xguidefontsize=font_size, yguidefontsize=font_size, legendfontsize=font_size)
p4 = plot(median_branch, ribbon=(lower_branch, upper_branch), xaxis="Steps", yaxis="Number of Observations", legend=false, xtickfontsize=font_size, ytickfontsize=font_size, xguidefontsize=font_size, yguidefontsize=font_size, legendfontsize=font_size)
end
display(plot(p1, p2, p3, p4, layout = (2, 2)))
display(plot(p1, p2, p3, p4, layout=layout, size=figure_size, margin=margin))
end
return nothing
end
2 changes: 2 additions & 0 deletions test/runtests.jl
Expand Up @@ -10,6 +10,8 @@ using ParticleFilters
using BeliefUpdaters
using StaticArrays
using POMDPPolicies
using Plots
theme(:mute)

# include("baby_sanity_check.jl")
include("independent_bounds.jl")
Expand Down

2 comments on commit 9b0faab

@Wu-Chenyang
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/37901

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.5.3 -m "<description of version>" 9b0faab55468215988d6695ffca7d78e0700d2b6
git push origin v0.5.3

Please sign in to comment.