From f783c3be47c6e948426899a6e84241f0052f3ad9 Mon Sep 17 00:00:00 2001 From: Zachary Sunberg Date: Thu, 30 Nov 2017 10:57:41 -0800 Subject: [PATCH] added L and U to the visualization --- src/ARDESPOT.jl | 5 ++- src/tree.jl | 5 ++- src/visualization.jl | 97 +++++++++++++++++++++++++++++++++++++------- 3 files changed, 89 insertions(+), 18 deletions(-) diff --git a/src/ARDESPOT.jl b/src/ARDESPOT.jl index 5542e3a..2ae7954 100644 --- a/src/ARDESPOT.jl +++ b/src/ARDESPOT.jl @@ -33,7 +33,9 @@ export FullyObservableValueUB, DefaultPolicyLB, bounds, - init_bounds + init_bounds, + + ReportWhenUsed include("random.jl") @@ -79,7 +81,6 @@ include("tree.jl") include("planner.jl") include("pomdps_glue.jl") -# include("tree_printing.jl") include("visualization.jl") include("exceptions.jl") diff --git a/src/tree.jl b/src/tree.jl index e1758ba..64e6f78 100644 --- a/src/tree.jl +++ b/src/tree.jl @@ -15,6 +15,8 @@ struct DESPOT{S,A,O} ba_rho::Vector{Float64} # needed for backup ba_Rsum::Vector{Float64} # needed for backup ba_action::Vector{A} + + _discount::Float64 # for inferring L in visualization end function DESPOT(p::DESPOTPlanner, b_0) @@ -45,7 +47,8 @@ function DESPOT(p::DESPOTPlanner, b_0) Float64[], Float64[], Float64[], - A[] + A[], + discount(p.pomdp) ) end diff --git a/src/visualization.jl b/src/visualization.jl index 2ef2ad2..e73ffa6 100644 --- a/src/visualization.jl +++ b/src/visualization.jl @@ -7,14 +7,16 @@ function D3Trees.D3Tree(D::DESPOT; title="DESPOT Tree", kwargs...) text = Vector{String}(len) tt = fill("", len) link_style = fill("", len) + L = calc_L(D) for b in 1:lenb children[b] = D.children[b] .+ lenb text[b] = @sprintf(""" o:%s (|Φ|:%3d) - U:%6.2f + L:%6.2f, U:%6.2f l:%6.2f, μ:%6.2f, l₀:%6.2f""", b==1 ? "" : string(D.obs[b]), length(D.scenarios[b]), + L[b], D.U[b], D.l[b], D.mu[b], @@ -23,6 +25,7 @@ function D3Trees.D3Tree(D::DESPOT; title="DESPOT Tree", kwargs...) tt[b] = """ o: $(b==1 ? "" : string(D.obs[b])) |Φ|: $(length(D.scenarios[b])) + L: $(L[b]) U: $(D.U[b]) l: $(D.l[b]) μ: $(D.mu[b]) @@ -34,20 +37,31 @@ function D3Trees.D3Tree(D::DESPOT; title="DESPOT Tree", kwargs...) for ba in D.children[b] link_style[ba+lenb] = "stroke-width:$link_width" end - end - for ba in 1:lenba - children[ba+lenb] = D.ba_children[ba] - text[ba+lenb] = @sprintf(""" - a:%s (ρ:%6.2f) - l:%6.2f μ:%6.2f""", - D.ba_action[ba], D.ba_rho[ba], ba_l(D, ba), D.ba_mu[ba]) - tt[ba+lenb] = """ - a: $(D.ba_action[ba]) - ρ: $(D.ba_rho[ba]) - l: $(ba_l(D, ba)) - μ: $(D.ba_mu[ba]) - $(length(D.ba_children[ba])) children - """ + + for ba in D.children[b] + weighted_sum_U = 0.0 + for bp in D.ba_children[ba] + weighted_sum_U += length(D.scenarios[bp]) * D.U[bp] + end + U = (D.ba_Rsum[ba] + infer_discount(D) * weighted_sum_U) / length(D.scenarios[b]) + children[ba+lenb] = D.ba_children[ba] + text[ba+lenb] = @sprintf(""" + a:%s (ρ:%6.2f) + L:%6.2f, U:%6.2f, + l:%6.2f μ:%6.2f""", + D.ba_action[ba], D.ba_rho[ba], + L[ba+lenb], U, + ba_l(D, ba), D.ba_mu[ba]) + tt[ba+lenb] = """ + a: $(D.ba_action[ba]) + ρ: $(D.ba_rho[ba]) + L: $(L[ba+lenb]) + U: $U + l: $(ba_l(D, ba)) + μ: $(D.ba_mu[ba]) + $(length(D.ba_children[ba])) children + """ + end end return D3Tree(children; text=text, @@ -60,3 +74,56 @@ end Base.show(io::IO, mime::MIME"text/html", D::DESPOT) = show(io, mime, D3Tree(D)) Base.show(io::IO, mime::MIME"text/plain", D::DESPOT) = show(io, mime, D3Tree(D)) + +""" +Return a vector of lower bounds L of length lenb+lenba, with b nodes first followed by ba nodes. +""" +function calc_L(D::DESPOT) + lenb = length(D.children) + lenba = length(D.ba_children) + if lenb == 1 + @assert lenba == 0 + return [D.l_0[1]] + end + len = lenb + lenba + cache = fill(NaN, len) + disc = infer_discount(D) + fill_L!(cache, D, 1, disc) + return cache +end + +function infer_discount(D::DESPOT) + # @assert !isempty(D.children[1]) + # K = length(D.scenarios[0]) + # firstba = first(D.children[1]) + # lambda = D.ba_rsum[firstba]/K - D.ba_rho[firstba] + disc = D._discount + return disc +end + + +""" +Fill all the elements of the cache for b and children of b and return L[b] +""" +function fill_L!(cache::Vector{Float64}, D::DESPOT, b::Int, disc::Float64) + K = length(D.scenarios[1]) + lenb = length(D.children) + if isempty(D.children[b]) + L = D.l_0[b]*K/(length(D.scenarios[b])*disc^D.Delta[b]) + cache[b] = L + return L + else + max_L = -Inf + for ba in D.children[b] + weighted_sum_L = 0.0 + for bp in D.ba_children[ba] + weighted_sum_L += length(D.scenarios[bp]) * fill_L!(cache, D, bp, disc) + end + new_L = (D.ba_Rsum[ba] + disc * weighted_sum_L) / length(D.scenarios[b]) + cache[lenb+ba] = new_L + max_L = max(max_L, new_L) + end + cache[b] = max_L + return max_L + end +end