Skip to content

Commit

Permalink
Merge e80b014 into 681a3bc
Browse files Browse the repository at this point in the history
  • Loading branch information
GWeindel committed Mar 19, 2019
2 parents 681a3bc + e80b014 commit 0f7ede7
Showing 1 changed file with 12 additions and 14 deletions.
26 changes: 12 additions & 14 deletions arviz/plots/forestplot.py
@@ -1,5 +1,5 @@
"""Forest plot."""
from collections import defaultdict
from collections import defaultdict, OrderedDict
from itertools import tee

import numpy as np
Expand Down Expand Up @@ -39,10 +39,7 @@ def plot_forest(
ridgeplot_overlap=2,
figsize=None,
):
"""Forest plot to compare credible intervals from a number of distributions.
Generates a forest plot of 100*(credible_interval)% credible intervals from
a trace or list of traces.
"""Forest plot to compare credible intervals from a number of distributions, generates a forest plot of 100*(credible_interval)% credible intervals from a trace or list of traces.
Parameters
----------
Expand Down Expand Up @@ -96,14 +93,11 @@ def plot_forest(
Returns
-------
gridspec : matplotlib GridSpec
Examples
--------
Forestpĺot
.. plot::
:context: close-figs
>>> import arviz as az
>>> non_centered_data = az.load_arviz_data('non_centered_eight')
>>> fig, axes = az.plot_forest(non_centered_data,
Expand All @@ -113,12 +107,9 @@ def plot_forest(
>>> ridgeplot_overlap=3,
>>> figsize=(9, 7))
>>> axes[0].set_title('Estimated theta for 8 schools model')
Ridgeplot
.. plot::
:context: close-figs
>>> fig, axes = az.plot_forest(non_centered_data,
>>> kind='ridgeplot',
>>> var_names=['theta'],
Expand Down Expand Up @@ -245,7 +236,14 @@ def __init__(self, datasets, var_names, model_names, combined, colors):
self.model_names = list(reversed(model_names)) # y-values are upside down

if var_names is None:
self.var_names = list(set.union(*[set(datum.data_vars) for datum in self.data]))
if len(self.data) > 1:
self.var_names = list(
set().union(*[OrderedDict(datum.data_vars) for datum in self.data])
)
else:
self.var_names = list(
reversed(*[OrderedDict(datum.data_vars) for datum in self.data])
)
else:
self.var_names = list(reversed(var_names)) # y-values are upside down

Expand Down Expand Up @@ -484,7 +482,7 @@ def iterator(self):
grouped_data = [datum.groupby("chain") for datum in self.data]
skip_dims = set()

label_dict = {}
label_dict = OrderedDict()
for name, grouped_datum in zip(self.model_names, grouped_data):
for _, sub_data in grouped_datum:
datum_iter = xarray_var_iter(
Expand All @@ -496,7 +494,7 @@ def iterator(self):
for _, selection, values in datum_iter:
label = make_label(self.var_name, selection, position="beside")
if label not in label_dict:
label_dict[label] = {}
label_dict[label] = OrderedDict()
if name not in label_dict[label]:
label_dict[label][name] = []
label_dict[label][name].append(values)
Expand Down

0 comments on commit 0f7ede7

Please sign in to comment.