diff --git a/arviz/plots/plot_utils.py b/arviz/plots/plot_utils.py index 825c8dbb3e..5d3c59b262 100644 --- a/arviz/plots/plot_utils.py +++ b/arviz/plots/plot_utils.py @@ -1,6 +1,5 @@ """Utilities for plotting.""" from itertools import product -from collections import OrderedDict as odict import numpy as np import matplotlib.pyplot as plt @@ -204,7 +203,7 @@ def xarray_var_iter(data, var_names=None, combined=False, skip_dims=None, revers if var_name in data: new_dims = [dim for dim in data[var_name].dims if dim not in skip_dims] vals = [data[var_name][dim].values for dim in new_dims] - dims = [odict((k, v) for k, v in zip(new_dims, prod)) for prod in product(*vals)] + dims = [{k : v for k, v in zip(new_dims, prod)} for prod in product(*vals)] if reverse_selections: dims = reversed(dims) diff --git a/arviz/tests/test_plot_utils.py b/arviz/tests/test_plot_utils.py index 7fdef23bb0..9ba379fe9e 100644 --- a/arviz/tests/test_plot_utils.py +++ b/arviz/tests/test_plot_utils.py @@ -26,7 +26,11 @@ def test_dataset_to_numpy_not_combined(sample_dataset): # pylint: disable=inval # 2 vars x 2 chains assert len(var_names) == 4 - assert (data == np.concatenate((mu, tau), axis=0)).all() + mu_tau = np.concatenate((mu, tau), axis=0) + tau_mu = np.concatenate((tau, mu), axis=0) + deqmt = data == mu_tau + deqtm = data == tau_mu + assert deqmt.all() or deqtm.all() def test_dataset_to_numpy_combined(sample_dataset): @@ -34,6 +38,8 @@ def test_dataset_to_numpy_combined(sample_dataset): var_names, data = xarray_to_ndarray(data, combined=True) assert len(var_names) == 2 + if var_names[0] == 'tau': + data = data[::-1] assert (data[0] == mu.reshape(1, 6)).all() assert (data[1] == tau.reshape(1, 6)).all()