Skip to content

Commit

Permalink
added L and U to the visualization
Browse files Browse the repository at this point in the history
  • Loading branch information
zsunberg committed Nov 30, 2017
1 parent 0b35b4e commit f783c3b
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 18 deletions.
5 changes: 3 additions & 2 deletions src/ARDESPOT.jl
Expand Up @@ -33,7 +33,9 @@ export
FullyObservableValueUB,
DefaultPolicyLB,
bounds,
init_bounds
init_bounds,

ReportWhenUsed


include("random.jl")
Expand Down Expand Up @@ -79,7 +81,6 @@ include("tree.jl")
include("planner.jl")
include("pomdps_glue.jl")

# include("tree_printing.jl")
include("visualization.jl")
include("exceptions.jl")

Expand Down
5 changes: 4 additions & 1 deletion src/tree.jl
Expand Up @@ -15,6 +15,8 @@ struct DESPOT{S,A,O}
ba_rho::Vector{Float64} # needed for backup
ba_Rsum::Vector{Float64} # needed for backup
ba_action::Vector{A}

_discount::Float64 # for inferring L in visualization
end

function DESPOT(p::DESPOTPlanner, b_0)
Expand Down Expand Up @@ -45,7 +47,8 @@ function DESPOT(p::DESPOTPlanner, b_0)
Float64[],
Float64[],
Float64[],
A[]
A[],
discount(p.pomdp)
)
end

Expand Down
97 changes: 82 additions & 15 deletions src/visualization.jl
Expand Up @@ -7,14 +7,16 @@ function D3Trees.D3Tree(D::DESPOT; title="DESPOT Tree", kwargs...)
text = Vector{String}(len)
tt = fill("", len)
link_style = fill("", len)
L = calc_L(D)
for b in 1:lenb
children[b] = D.children[b] .+ lenb
text[b] = @sprintf("""
o:%s (|Φ|:%3d)
U:%6.2f
L:%6.2f, U:%6.2f
l:%6.2f, μ:%6.2f, l₀:%6.2f""",
b==1 ? "<root>" : string(D.obs[b]),
length(D.scenarios[b]),
L[b],
D.U[b],
D.l[b],
D.mu[b],
Expand All @@ -23,6 +25,7 @@ function D3Trees.D3Tree(D::DESPOT; title="DESPOT Tree", kwargs...)
tt[b] = """
o: $(b==1 ? "<root>" : string(D.obs[b]))
|Φ|: $(length(D.scenarios[b]))
L: $(L[b])
U: $(D.U[b])
l: $(D.l[b])
μ: $(D.mu[b])
Expand All @@ -34,20 +37,31 @@ function D3Trees.D3Tree(D::DESPOT; title="DESPOT Tree", kwargs...)
for ba in D.children[b]
link_style[ba+lenb] = "stroke-width:$link_width"
end
end
for ba in 1:lenba
children[ba+lenb] = D.ba_children[ba]
text[ba+lenb] = @sprintf("""
a:%s (ρ:%6.2f)
l:%6.2f μ:%6.2f""",
D.ba_action[ba], D.ba_rho[ba], ba_l(D, ba), D.ba_mu[ba])
tt[ba+lenb] = """
a: $(D.ba_action[ba])
ρ: $(D.ba_rho[ba])
l: $(ba_l(D, ba))
μ: $(D.ba_mu[ba])
$(length(D.ba_children[ba])) children
"""

for ba in D.children[b]
weighted_sum_U = 0.0
for bp in D.ba_children[ba]
weighted_sum_U += length(D.scenarios[bp]) * D.U[bp]
end
U = (D.ba_Rsum[ba] + infer_discount(D) * weighted_sum_U) / length(D.scenarios[b])
children[ba+lenb] = D.ba_children[ba]
text[ba+lenb] = @sprintf("""
a:%s (ρ:%6.2f)
L:%6.2f, U:%6.2f,
l:%6.2f μ:%6.2f""",
D.ba_action[ba], D.ba_rho[ba],
L[ba+lenb], U,
ba_l(D, ba), D.ba_mu[ba])
tt[ba+lenb] = """
a: $(D.ba_action[ba])
ρ: $(D.ba_rho[ba])
L: $(L[ba+lenb])
U: $U
l: $(ba_l(D, ba))
μ: $(D.ba_mu[ba])
$(length(D.ba_children[ba])) children
"""
end
end
return D3Tree(children;
text=text,
Expand All @@ -60,3 +74,56 @@ end

Base.show(io::IO, mime::MIME"text/html", D::DESPOT) = show(io, mime, D3Tree(D))
Base.show(io::IO, mime::MIME"text/plain", D::DESPOT) = show(io, mime, D3Tree(D))

"""
Return a vector of lower bounds L of length lenb+lenba, with b nodes first followed by ba nodes.
"""
function calc_L(D::DESPOT)
lenb = length(D.children)
lenba = length(D.ba_children)
if lenb == 1
@assert lenba == 0
return [D.l_0[1]]
end
len = lenb + lenba
cache = fill(NaN, len)
disc = infer_discount(D)
fill_L!(cache, D, 1, disc)
return cache
end

function infer_discount(D::DESPOT)
# @assert !isempty(D.children[1])
# K = length(D.scenarios[0])
# firstba = first(D.children[1])
# lambda = D.ba_rsum[firstba]/K - D.ba_rho[firstba]
disc = D._discount
return disc
end


"""
Fill all the elements of the cache for b and children of b and return L[b]
"""
function fill_L!(cache::Vector{Float64}, D::DESPOT, b::Int, disc::Float64)
K = length(D.scenarios[1])
lenb = length(D.children)
if isempty(D.children[b])
L = D.l_0[b]*K/(length(D.scenarios[b])*disc^D.Delta[b])
cache[b] = L
return L
else
max_L = -Inf
for ba in D.children[b]
weighted_sum_L = 0.0
for bp in D.ba_children[ba]
weighted_sum_L += length(D.scenarios[bp]) * fill_L!(cache, D, bp, disc)
end
new_L = (D.ba_Rsum[ba] + disc * weighted_sum_L) / length(D.scenarios[b])
cache[lenb+ba] = new_L
max_L = max(max_L, new_L)
end
cache[b] = max_L
return max_L
end
end

0 comments on commit f783c3b

Please sign in to comment.