Skip to content

Commit

Permalink
Fix warning message
Browse files Browse the repository at this point in the history
  • Loading branch information
PaulinaMartin96 committed Jul 13, 2021
1 parent 01d56d5 commit e4c9765
Showing 1 changed file with 39 additions and 46 deletions.
85 changes: 39 additions & 46 deletions src/plot.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ struct _MeanPlot; c; val; end
struct _DensityPlot; c; val; end
struct _HistogramPlot; c; val; end
struct _AutocorPlot; lags; val; end
struct _ViolinPlot; parameters; val; total_chains; end
struct _ViolinPlot; par; val; end

# define alias functions for old syntax
const translationdict = Dict(
Expand All @@ -33,7 +33,9 @@ const supportedplots = push!(collect(keys(translationdict)), :mixeddensity, :cor
colordim = :chain,
barbounds = (-Inf, Inf),
maxlag = nothing,
append_chains = false
append_chains = false,
sections = chains.name_map[:parameters],
combined = true
)
st = get(plotattributes, :seriestype, :traceplot)
c = append_chains || st == :pooleddensity ? pool_chain(chains) : chains
Expand Down Expand Up @@ -72,6 +74,39 @@ const supportedplots = push!(collect(keys(translationdict)), :mixeddensity, :cor
else
range(c), val
end

total_chains = i
if st == :violinplot
n_iter, n_par, n_chains = size(chains)
if combined
colordim := :chain
par = string.(reshape(repeat(sections, inner = n_iter), n_iter, n_par))[:,i]
val = Array(chains)[:,i]
_ViolinPlot(par, val)
elseif combined == false
if colordim == :chain
par_names = ["$(sections[i]).Chain $j" for i in 1:n_par, j in 1:n_chains]
pars = string.(reshape(repeat(vec(par_names), inner = n_iter), (n_iter, n_par, n_chains)))
val = chains.value[:,i,:]
par = pars[:,i,:]
elseif colordim == :parameter
par_vec = repeat(sections, inner = n_iter)
pars = string.(reshape(repeat(par_vec, n_chains, 1), (n_iter, n_par, n_chains)))
val = chains.value[:,:,i]
par = pars[:,:,i]
label --> string.(names(c))
else
throw(ArgumentError("`colordim` must be one of `:chain` or `:parameter`"))
end
_ViolinPlot(par, val)
else
throw(ArgumentError("In `ViolinPlots` `Chains` can be combined or separated "))
end
elseif st supportedplots
translationdict[st](c, val)
else
range(c), val
end
end

@recipe function f(p::_DensityPlot)
Expand Down Expand Up @@ -188,59 +223,17 @@ end
RecipesBase.recipetype(:cornerplot, vcat(ar...))
end

@recipe function f(
chains::Chains;
sections::Vector{Symbol} = chains.name_map[:parameters],
combined = true
)

st = get(plotattributes, :seriestype, :traceplot)
total_chains = 0
if st == :violinplot
if combined
n_iter, n_parameters = size(Array(chains))
parameters = string.(repeat(sections, inner = n_iter))
val = vec(Array(chains))
total_chains = Integer(size(chains.value.data)[3])
_ViolinPlot(parameters, val, total_chains)
elseif combined == false
n_parameters = length(sections)
chain_arr = Array(chains, append_chains = false)
val_vec = [chain_arr[j][:,i]
for i in 1:n_parameters
for j in 1:length(chain_arr)]
n_iter = length(val_vec[1])
total_chains = length(val_vec)
val = zeros(Float64, n_iter, total_chains)
for i in 1:total_chains
val[:,i] = val_vec[:][i]
end
val = vec(val)
parameters_names = ["param $(sections[i]).Chain $j"
for i in 1:n_parameters
for j in 1:length(chain_arr)]
parameters = string.(repeat(parameters_names, inner = n_iter))
_ViolinPlot(parameters, val, total_chains)
else
error("Symbol names are interpreted as parameter names, only compatible with ",
"`colordim = :chain`")
end
end
end

@recipe function f(p::_ViolinPlot)
@series begin
seriestype := :violin
xaxis --> "Parameter"
size --> (200*p.total_chains, 500)
p.parameters, p.val
p.par, p.val
end

@series begin
seriestype := :boxplot
bar_width --> 0.1
linewidth --> 2
fillalpha --> 0.8
p.parameters, p.val
p.par, p.val
end
end

0 comments on commit e4c9765

Please sign in to comment.