Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Energy plots #329

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
139 changes: 136 additions & 3 deletions src/plot.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,15 @@
@shorthands pooleddensity
@shorthands traceplot
@shorthands corner
@userplot EnergyPlot
#@shorthands energyplot

struct _TracePlot; c; val; end
struct _MeanPlot; c; val; end
struct _DensityPlot; c; val; end
struct _HistogramPlot; c; val; end
struct _AutocorPlot; lags; val; end
#struct _EnergyPlot; marginal_energy; energy_transition; p_type; n_chains; end

# define alias functions for old syntax
const translationdict = Dict(
Expand All @@ -18,10 +21,10 @@ const translationdict = Dict(
:density => _DensityPlot,
:histogram => _HistogramPlot,
:autocorplot => _AutocorPlot,
:pooleddensity => _DensityPlot
:pooleddensity => _DensityPlot,
)

const supportedplots = push!(collect(keys(translationdict)), :mixeddensity, :corner)
const supportedplots = push!(collect(keys(translationdict)), :mixeddensity, :corner, :energyplot)

@recipe f(c::Chains, s::Symbol) = c, [s]

Expand All @@ -30,7 +33,8 @@ const supportedplots = push!(collect(keys(translationdict)), :mixeddensity, :cor
colordim = :chain,
barbounds = (-Inf, Inf),
maxlag = nothing,
append_chains = false
append_chains = false,
plot_type = :density
)
st = get(plotattributes, :seriestype, :traceplot)
c = append_chains || st == :pooleddensity ? pool_chain(chains) : chains
Expand Down Expand Up @@ -64,6 +68,17 @@ const supportedplots = push!(collect(keys(translationdict)), :mixeddensity, :cor
ac_mat = convert(Array, ac)
val = colordim == :parameter ? ac_mat[:, :, i]' : ac_mat[i, :, :]
_AutocorPlot(lags, val)
#elseif st == :energyplot
# p_type = plot_type
# energy_section = get(c, :hamiltonian_energy)
# #@show energy_section
# #@show params.hamiltonian_energy
# n_chains = (append_chains ? 1 : size(c, 3))
# energy_data = (append_chains ? vec(energy_section.hamiltonian_energy.data) : energy_section.hamiltonian_energy.data)
# mean_energy = vec(mean(energy_data, dims = 1))
# marginal_energy = [energy_data[:,i] .- mean_energy[i] for i in 1:n_chains]
# energy_transition = [energy_data[2:end,i] .- energy_data[1:end-1,i] for i in 1:n_chains]
# _EnergyPlot(marginal_energy, energy_transition, p_type, n_chains)
elseif st ∈ supportedplots
translationdict[st](c, val)
else
Expand Down Expand Up @@ -184,3 +199,121 @@ end
ar = collect(Array(corner.c.value[:, corner.parameters,i]) for i in chains(corner.c))
RecipesBase.recipetype(:cornerplot, vcat(ar...))
end

#function compute_energy(
# chains::Chains,
# combined = false,
# plot_type = :density
#)
# st = get(plotattributes, :seriestype, :traceplot)
#
# if st == :energyplot
# p_type = plot_type
# params = get(chains, :hamiltonian_energy)
# n_chains = (combined ? 1 : size(chains, 3))
# energy_data = (combined ? vec(params.hamiltonian_energy.data) : params.hamiltonian_energy.data)
# mean_energy = vec(mean(energy_data, dims = 1))
# marginal_energy = energy_data[:,i] .- mean_energy[i]
# energy_transition = energy_data[2:end,i] .- energy_data[1:end-1,i]
# _EnergyPlot(marginal_energy, energy_transition, p_type, n_chains)
# else
#
# end
#end

#@recipe function f(
# chains::Chains;
# plot_type = :density,
# append_chains = false
#)
#
# st = get(plotattributes, :seriestype, :traceplot)
# if st == :energyplot
# p_type = plot_type
# energy_section = get(chains, :hamiltonian_energy)
# #@show energy_section
# #@show params.hamiltonian_energy
# n_chains = (append_chains ? 1 : size(chains, 3))
# energy_data = (append_chains ? vec(energy_section.hamiltonian_energy.data) : energy_section.hamiltonian_energy.data)
# mean_energy = vec(mean(energy_data, dims = 1))
# marginal_energy = [energy_data[:,i] .- mean_energy[i] for i in 1:n_chains]
# energy_transition = [energy_data[2:end,i] .- energy_data[1:end-1,i] for i in 1:n_chains]
# _EnergyPlot(marginal_energy, energy_transition, p_type, n_chains)
# elseif st ∈ supportedplots
# translationdict[st](c, val)
# end
#end

function compute_energy(
chains::Chains,
combined = false,
plot_type = :density
)
p_type = plot_type
params = get(chains, :hamiltonian_energy)
isempty(params) && error("EnergyPlot receives a Chains object containing only the
:internals section. Please use Chains(chain, [:internals]) to create it")
n_chains = (combined ? 1 : size(chains, 3))
energy_data = (combined ? vec(params.hamiltonian_energy.data) : params.hamiltonian_energy.data)
mean_energy = vec(mean(energy_data, dims = 1))
marginal_energy = [energy_data[:,i] .- mean_energy[i] for i in 1:n_chains]
energy_transition = [energy_data[2:end,i] .- energy_data[1:end-1,i] for i in 1:n_chains]
return marginal_energy, energy_transition, p_type, n_chains
end

@recipe function f(
p::EnergyPlot;
combined = false,
plot_type = :density
)

c = p.args[1]
#p_type = plot_type
#params = get(c, :hamiltonian_energy)
#isempty(params) && error("EnergyPlot receives a Chains object containing only the
# :internals section. Please use Chains(chain, [:internals]) to create it")
#n_chains = (combined ? 1 : size(c, 3))
#energy_data = (combined ? vec(params.hamiltonian_energy.data) : params.hamiltonian_energy.data)
#mean_energy = vec(mean(energy_data, dims = 1))
#marginal_energy = [energy_data[:,i] .- mean_energy[i] for i in 1:n_chains]
#energy_transition = [energy_data[2:end,i] .- energy_data[1:end-1,i] for i in 1:n_chains]
marginal_energy, energy_transition, p_type, n_chains = compute_energy(c, combined, plot_type)
k = 0
for i in 1:n_chains
k += 1
title --> "Chain $(MCMCChains.chains(c)[i])"
subplot := i
@series begin
seriestype := p_type
label --> "Marginal energy"
marginal_energy[i]
end

@series begin
seriestype := p_type
label --> "Energy transition"
energy_transition[i]
end
end
end

#@recipe function f(p::_EnergyPlot)
#
# k = 0
# for i in 1:p.n_chains
# k = 1
# @series begin
# subplot := i
# seriestype := p.p_type
# label --> "Marginal energy"
# p.marginal_energy[i]
# end
#
# @series begin
# subplot := i
# seriestype := p.p_type
# label --> "Energy transition"
# p.energy_transition[i]
# end
# end
#end