From c7d9385afa2dae7f44abd788608359dc561f6260 Mon Sep 17 00:00:00 2001 From: Ravin Kumar Date: Sat, 7 Dec 2019 17:01:55 -0800 Subject: [PATCH 01/41] Add find plotting method --- arviz/plots/backends/__init__.py | 1 - arviz/plots/backends/matplotlib/__init__.py | 1 + arviz/plots/distplot.py | 21 +++++----------- arviz/plots/plot_utils.py | 28 +++++++++++++++++++++ arviz/tests/test_plot_utils.py | 9 +++++++ 5 files changed, 44 insertions(+), 16 deletions(-) diff --git a/arviz/plots/backends/__init__.py b/arviz/plots/backends/__init__.py index 8225ecf543..2857dbde0c 100644 --- a/arviz/plots/backends/__init__.py +++ b/arviz/plots/backends/__init__.py @@ -1,2 +1 @@ """ArviZ plotting backends.""" -from .bokeh import * diff --git a/arviz/plots/backends/matplotlib/__init__.py b/arviz/plots/backends/matplotlib/__init__.py index fa13544a7b..7bc076a0ff 100644 --- a/arviz/plots/backends/matplotlib/__init__.py +++ b/arviz/plots/backends/matplotlib/__init__.py @@ -1 +1,2 @@ """Matplotlib Plotting Backend.""" +from .mpl_distplot import _plot_dist_mpl \ No newline at end of file diff --git a/arviz/plots/distplot.py b/arviz/plots/distplot.py index eabdb2faef..ab5bbfda44 100644 --- a/arviz/plots/distplot.py +++ b/arviz/plots/distplot.py @@ -1,7 +1,11 @@ # pylint: disable=unexpected-keyword-arg """Plot distribution as histogram or kernel density estimates.""" +<<<<<<< HEAD from .backends import check_bokeh_version from .plot_utils import get_bins +======= +from .plot_utils import get_bins, get_plotting_method +>>>>>>> Add find plotting method def plot_dist( @@ -186,19 +190,6 @@ def plot_dist( **kwargs, ) - if backend is None or backend.lower() in ("mpl", "matplotlib"): - from .backends.matplotlib.mpl_distplot import _plot_dist_mpl - - ax = _plot_dist_mpl(**dist_plot_args) - elif backend == "bokeh": - check_bokeh_version() - from .backends.bokeh.bokeh_distplot import _plot_dist_bokeh - - dist_plot_args.pop("textsize") - dist_plot_args["show"] = show - ax = _plot_dist_bokeh(**dist_plot_args) - else: - raise NotImplementedError( - 'Backend {} not implemented. Use {{"matplotlib", "bokeh"}}'.format(backend) - ) + method = get_plotting_method("plot_dist", "distplot", "bokeh") + ax = method(dist_plot_args) return ax diff --git a/arviz/plots/plot_utils.py b/arviz/plots/plot_utils.py index c462547560..9a9bdb763f 100644 --- a/arviz/plots/plot_utils.py +++ b/arviz/plots/plot_utils.py @@ -1,6 +1,7 @@ """Utilities for plotting.""" import warnings from itertools import product, tee +import importlib import numpy as np import matplotlib.pyplot as plt @@ -11,6 +12,8 @@ from ..utils import conditional_jit from ..rcparams import rcParams +from . import backends + def make_2d(ary): """Convert any array into a 2d numpy array. @@ -622,3 +625,28 @@ def filter_plotters_list(plotters, plot_kind): ) return plotters[:max_plots] return plotters + + +def get_plotting_method(plot_name, plot_module, backend): + """Returns plotting function for correct backend""" + _backend = {"mpl":"matplotlib", "bokeh":"bokeh", "matplotlib":"matplotlib"} + + try: + backend = _backend[backend] + except KeyError: + raise KeyError("Backend {} is not implemented. Try backend in {}".format(backend, set(_backend.values))) + + if backend == "bokeh": + try: + import bokeh + assert bokeh.__version__ >= "1.4.0" + except (ImportError, AssertionError): + raise ImportError("'bokeh' backend needs Bokeh (1.4.0+) installed.") + + # module = importlib.import_module("arviz.plots.backends.bokeh.bokeh_distplot") + module = importlib.import_module("arviz.plots.backends.{backend}.{backend}_{plot_module}".format(backend=backend, plot_module=plot_module)) + # module = importlib.import_module("arviz.plots.backends.{backend}.{backend}_distplot".format(backend=backend)) + plotting_method = getattr(module, "_{plot_name}_{backend}".format(plot_name=plot_name, backend=backend)) + # plotting_method = getattr(module, "_{plot_name}_{backend") + + return plotting_method diff --git a/arviz/tests/test_plot_utils.py b/arviz/tests/test_plot_utils.py index 3c426a7331..1c5874605b 100644 --- a/arviz/tests/test_plot_utils.py +++ b/arviz/tests/test_plot_utils.py @@ -12,6 +12,7 @@ get_coords, filter_plotters_list, format_sig_figs, + get_plotting_method ) from ..rcparams import rc_context @@ -190,3 +191,11 @@ def test_filter_plotter_list_warning(): with pytest.warns(SyntaxWarning, match="test warning"): plotters_filtered = filter_plotters_list(plotters, "test warning") assert len(plotters_filtered) == 5 + + +def test_bokeh_import(): + """Tests that correct method is returned on bokeh import""" + method = get_plotting_method("plot_dist", "distplot", "bokeh") + + from arviz.plots.backends.bokeh.bokeh_distplot import _plot_dist_bokeh + assert method is _plot_dist_bokeh From cbfe9c2a1bb145f75a0a20181a8759ad9644fc06 Mon Sep 17 00:00:00 2001 From: Ravin Kumar Date: Sat, 7 Dec 2019 17:10:16 -0800 Subject: [PATCH 02/41] Add a couple of fixes --- arviz/plots/backends/bokeh/bokeh_distplot.py | 1 + arviz/plots/distplot.py | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/arviz/plots/backends/bokeh/bokeh_distplot.py b/arviz/plots/backends/bokeh/bokeh_distplot.py index 45aa8ccef4..9e5d272e3f 100644 --- a/arviz/plots/backends/bokeh/bokeh_distplot.py +++ b/arviz/plots/backends/bokeh/bokeh_distplot.py @@ -29,6 +29,7 @@ def _plot_dist_bokeh( hist_kwargs=None, ax=None, show=True, + **kwargs ): if ax is None: diff --git a/arviz/plots/distplot.py b/arviz/plots/distplot.py index ab5bbfda44..ebd9ddeff7 100644 --- a/arviz/plots/distplot.py +++ b/arviz/plots/distplot.py @@ -190,6 +190,6 @@ def plot_dist( **kwargs, ) - method = get_plotting_method("plot_dist", "distplot", "bokeh") - ax = method(dist_plot_args) + method = get_plotting_method("plot_dist", "distplot", backend) + ax = method(**dist_plot_args) return ax From 2c6017c83acca16c9cf18da2ce22910c740e8b23 Mon Sep 17 00:00:00 2001 From: Ravin Kumar Date: Sat, 7 Dec 2019 17:10:54 -0800 Subject: [PATCH 03/41] Remove unneeded lines --- arviz/plots/plot_utils.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/arviz/plots/plot_utils.py b/arviz/plots/plot_utils.py index 9a9bdb763f..4982aa0892 100644 --- a/arviz/plots/plot_utils.py +++ b/arviz/plots/plot_utils.py @@ -643,10 +643,7 @@ def get_plotting_method(plot_name, plot_module, backend): except (ImportError, AssertionError): raise ImportError("'bokeh' backend needs Bokeh (1.4.0+) installed.") - # module = importlib.import_module("arviz.plots.backends.bokeh.bokeh_distplot") module = importlib.import_module("arviz.plots.backends.{backend}.{backend}_{plot_module}".format(backend=backend, plot_module=plot_module)) - # module = importlib.import_module("arviz.plots.backends.{backend}.{backend}_distplot".format(backend=backend)) plotting_method = getattr(module, "_{plot_name}_{backend}".format(plot_name=plot_name, backend=backend)) - # plotting_method = getattr(module, "_{plot_name}_{backend") return plotting_method From 27d5f317df35b84e8d485fbc31301616f691897d Mon Sep 17 00:00:00 2001 From: Ravin Kumar Date: Sat, 7 Dec 2019 17:11:06 -0800 Subject: [PATCH 04/41] Add black --- arviz/plots/backends/matplotlib/__init__.py | 2 +- arviz/plots/plot_utils.py | 17 +++++++++++++---- arviz/tests/test_plot_utils.py | 3 ++- 3 files changed, 16 insertions(+), 6 deletions(-) diff --git a/arviz/plots/backends/matplotlib/__init__.py b/arviz/plots/backends/matplotlib/__init__.py index 7bc076a0ff..6818ed643c 100644 --- a/arviz/plots/backends/matplotlib/__init__.py +++ b/arviz/plots/backends/matplotlib/__init__.py @@ -1,2 +1,2 @@ """Matplotlib Plotting Backend.""" -from .mpl_distplot import _plot_dist_mpl \ No newline at end of file +from .mpl_distplot import _plot_dist_mpl diff --git a/arviz/plots/plot_utils.py b/arviz/plots/plot_utils.py index 4982aa0892..9248484df7 100644 --- a/arviz/plots/plot_utils.py +++ b/arviz/plots/plot_utils.py @@ -629,21 +629,30 @@ def filter_plotters_list(plotters, plot_kind): def get_plotting_method(plot_name, plot_module, backend): """Returns plotting function for correct backend""" - _backend = {"mpl":"matplotlib", "bokeh":"bokeh", "matplotlib":"matplotlib"} + _backend = {"mpl": "matplotlib", "bokeh": "bokeh", "matplotlib": "matplotlib"} try: backend = _backend[backend] except KeyError: - raise KeyError("Backend {} is not implemented. Try backend in {}".format(backend, set(_backend.values))) + raise KeyError( + "Backend {} is not implemented. Try backend in {}".format(backend, set(_backend.values)) + ) if backend == "bokeh": try: import bokeh + assert bokeh.__version__ >= "1.4.0" except (ImportError, AssertionError): raise ImportError("'bokeh' backend needs Bokeh (1.4.0+) installed.") - module = importlib.import_module("arviz.plots.backends.{backend}.{backend}_{plot_module}".format(backend=backend, plot_module=plot_module)) - plotting_method = getattr(module, "_{plot_name}_{backend}".format(plot_name=plot_name, backend=backend)) + module = importlib.import_module( + "arviz.plots.backends.{backend}.{backend}_{plot_module}".format( + backend=backend, plot_module=plot_module + ) + ) + plotting_method = getattr( + module, "_{plot_name}_{backend}".format(plot_name=plot_name, backend=backend) + ) return plotting_method diff --git a/arviz/tests/test_plot_utils.py b/arviz/tests/test_plot_utils.py index 1c5874605b..1f887b78bf 100644 --- a/arviz/tests/test_plot_utils.py +++ b/arviz/tests/test_plot_utils.py @@ -12,7 +12,7 @@ get_coords, filter_plotters_list, format_sig_figs, - get_plotting_method + get_plotting_method, ) from ..rcparams import rc_context @@ -198,4 +198,5 @@ def test_bokeh_import(): method = get_plotting_method("plot_dist", "distplot", "bokeh") from arviz.plots.backends.bokeh.bokeh_distplot import _plot_dist_bokeh + assert method is _plot_dist_bokeh From 4066acb8d6e12f038e00435f0fca89e98a64a99a Mon Sep 17 00:00:00 2001 From: Ravin Kumar Date: Sat, 7 Dec 2019 17:58:38 -0800 Subject: [PATCH 05/41] Fix bug in method --- arviz/plots/plot_utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/arviz/plots/plot_utils.py b/arviz/plots/plot_utils.py index 9248484df7..a18bb9137e 100644 --- a/arviz/plots/plot_utils.py +++ b/arviz/plots/plot_utils.py @@ -635,7 +635,9 @@ def get_plotting_method(plot_name, plot_module, backend): backend = _backend[backend] except KeyError: raise KeyError( - "Backend {} is not implemented. Try backend in {}".format(backend, set(_backend.values)) + "Backend {} is not implemented. Try backend in {}".format( + backend, set(_backend.values()) + ) ) if backend == "bokeh": From ad1140eafb4f101505bf812ab0fc0ac26f2d7527 Mon Sep 17 00:00:00 2001 From: Ravin Kumar Date: Sat, 7 Dec 2019 18:00:01 -0800 Subject: [PATCH 06/41] Add none to backend routing --- arviz/plots/plot_utils.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/arviz/plots/plot_utils.py b/arviz/plots/plot_utils.py index a18bb9137e..67d05732fc 100644 --- a/arviz/plots/plot_utils.py +++ b/arviz/plots/plot_utils.py @@ -629,7 +629,12 @@ def filter_plotters_list(plotters, plot_kind): def get_plotting_method(plot_name, plot_module, backend): """Returns plotting function for correct backend""" - _backend = {"mpl": "matplotlib", "bokeh": "bokeh", "matplotlib": "matplotlib"} + _backend = { + "mpl": "matplotlib", + "bokeh": "bokeh", + "matplotlib": "matplotlib", + None: "matplotlib", + } try: backend = _backend[backend] From f5e4ca4571874c017fbb109ab6cf0c3ba89d59a5 Mon Sep 17 00:00:00 2001 From: Ravin Kumar Date: Sun, 8 Dec 2019 19:32:33 -0800 Subject: [PATCH 07/41] Remove missed rebase issues --- arviz/plots/distplot.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/arviz/plots/distplot.py b/arviz/plots/distplot.py index ebd9ddeff7..b5a4da0c6d 100644 --- a/arviz/plots/distplot.py +++ b/arviz/plots/distplot.py @@ -1,11 +1,8 @@ # pylint: disable=unexpected-keyword-arg """Plot distribution as histogram or kernel density estimates.""" -<<<<<<< HEAD from .backends import check_bokeh_version from .plot_utils import get_bins -======= from .plot_utils import get_bins, get_plotting_method ->>>>>>> Add find plotting method def plot_dist( From cbb35a3087555b8450c5447ad531e8ec623a9eb7 Mon Sep 17 00:00:00 2001 From: Ravin Kumar Date: Sun, 8 Dec 2019 19:33:06 -0800 Subject: [PATCH 08/41] Move version check to plot_utils --- arviz/plots/plot_utils.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/arviz/plots/plot_utils.py b/arviz/plots/plot_utils.py index 67d05732fc..22da553623 100644 --- a/arviz/plots/plot_utils.py +++ b/arviz/plots/plot_utils.py @@ -3,17 +3,16 @@ from itertools import product, tee import importlib +import packaging import numpy as np import matplotlib.pyplot as plt import matplotlib as mpl import xarray as xr -from .backends import check_bokeh_version + from ..utils import conditional_jit from ..rcparams import rcParams -from . import backends - def make_2d(ary): """Convert any array into a 2d numpy array. @@ -214,7 +213,6 @@ def _create_axes_grid(length_plotters, rows, cols, backend=None, **kwargs): kwargs.setdefault("constrained_layout", True) if backend == "bokeh": - check_bokeh_version() from bokeh.plotting import figure bokeh_dpi = rcParams["plot.bokeh.figure.dpi"] @@ -648,8 +646,8 @@ def get_plotting_method(plot_name, plot_module, backend): if backend == "bokeh": try: import bokeh + assert packaging.version.parse(bokeh.__version__) >= packaging.version.parse("1.4.0") - assert bokeh.__version__ >= "1.4.0" except (ImportError, AssertionError): raise ImportError("'bokeh' backend needs Bokeh (1.4.0+) installed.") From 7ca488dee7ac88b48006fc15fd5a132abdb24454 Mon Sep 17 00:00:00 2001 From: Ravin Kumar Date: Sun, 8 Dec 2019 19:40:08 -0800 Subject: [PATCH 09/41] Add show argument forwarding to distplot --- arviz/plots/distplot.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/arviz/plots/distplot.py b/arviz/plots/distplot.py index b5a4da0c6d..f1ef5ad702 100644 --- a/arviz/plots/distplot.py +++ b/arviz/plots/distplot.py @@ -184,6 +184,9 @@ def plot_dist( pcolormesh_kwargs=pcolormesh_kwargs, hist_kwargs=hist_kwargs, ax=ax, + + # TODO: Change this to be a backend kwarg + show=show, **kwargs, ) From 55b1f946be8074f3379771542fe3c482f6fdc144 Mon Sep 17 00:00:00 2001 From: Ravin Kumar Date: Sun, 8 Dec 2019 19:41:16 -0800 Subject: [PATCH 10/41] Add temporary backend check for passing tests --- arviz/plots/backends/__init__.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/arviz/plots/backends/__init__.py b/arviz/plots/backends/__init__.py index 2857dbde0c..d3df238c0a 100644 --- a/arviz/plots/backends/__init__.py +++ b/arviz/plots/backends/__init__.py @@ -1 +1,4 @@ """ArviZ plotting backends.""" + +# TODO: Get rid of this line once check_bokeh_version is removed +from .bokeh import check_bokeh_version From 83a5f79466915a7e34e9ddfc28542b8e43ae7664 Mon Sep 17 00:00:00 2001 From: Ravin Kumar Date: Sun, 8 Dec 2019 20:18:30 -0800 Subject: [PATCH 11/41] Add backend plot arg architecture --- arviz/plots/backends/bokeh/__init__.py | 2 ++ arviz/plots/backends/bokeh/bokeh_distplot.py | 3 +++ arviz/plots/distplot.py | 14 ++++++-------- arviz/plots/plot_utils.py | 19 +++++++++++++++++-- arviz/tests/test_plots_bokeh.py | 4 ++-- 5 files changed, 30 insertions(+), 12 deletions(-) diff --git a/arviz/plots/backends/bokeh/__init__.py b/arviz/plots/backends/bokeh/__init__.py index 293cc11b24..6c621d855c 100644 --- a/arviz/plots/backends/bokeh/__init__.py +++ b/arviz/plots/backends/bokeh/__init__.py @@ -2,6 +2,8 @@ """Bokeh Plotting Backend.""" import packaging +# Set plot generic bokeh keyword arg defaults if none provided +KWARG_DEFAULTS = {"show": True} def output_notebook(*args, **kwargs): """Wrap bokeh.plotting.output_notebook.""" diff --git a/arviz/plots/backends/bokeh/bokeh_distplot.py b/arviz/plots/backends/bokeh/bokeh_distplot.py index 9e5d272e3f..2345acef6d 100644 --- a/arviz/plots/backends/bokeh/bokeh_distplot.py +++ b/arviz/plots/backends/bokeh/bokeh_distplot.py @@ -28,6 +28,8 @@ def _plot_dist_bokeh( pcolormesh_kwargs=None, hist_kwargs=None, ax=None, + + # Backend Kwargs show=True, **kwargs ): @@ -76,6 +78,7 @@ def _plot_dist_bokeh( pcolormesh_kwargs=pcolormesh_kwargs, ax=ax, backend="bokeh", + # TODO: Revisit this when I refactor backend args for kde show=False, ) else: diff --git a/arviz/plots/distplot.py b/arviz/plots/distplot.py index f1ef5ad702..31dabbd6bd 100644 --- a/arviz/plots/distplot.py +++ b/arviz/plots/distplot.py @@ -28,7 +28,7 @@ def plot_dist( hist_kwargs=None, ax=None, backend=None, - show=True, + backend_kwargs=None, **kwargs ): """Plot distribution as histogram or kernel density estimates. @@ -90,8 +90,9 @@ def plot_dist( Matplotlib axes or bokeh figures. backend: str, optional Select plotting backend {"matplotlib","bokeh"}. Default "matplotlib". - show: bool, optional - If True, call bokeh.plotting.show. + backend_kwargs: bool, optional + These are kwargs specific to the backend being used. For additional documentation + check the plotting method of the backend Returns ------- @@ -184,12 +185,9 @@ def plot_dist( pcolormesh_kwargs=pcolormesh_kwargs, hist_kwargs=hist_kwargs, ax=ax, - - # TODO: Change this to be a backend kwarg - show=show, **kwargs, ) - method = get_plotting_method("plot_dist", "distplot", backend) - ax = method(**dist_plot_args) + method, backend_kwargs = get_plotting_method("plot_dist", "distplot", backend, backend_kwargs) + ax = method(**dist_plot_args, **backend_kwargs) return ax diff --git a/arviz/plots/plot_utils.py b/arviz/plots/plot_utils.py index 22da553623..ac28a2c66c 100644 --- a/arviz/plots/plot_utils.py +++ b/arviz/plots/plot_utils.py @@ -625,7 +625,7 @@ def filter_plotters_list(plotters, plot_kind): return plotters -def get_plotting_method(plot_name, plot_module, backend): +def get_plotting_method(plot_name, plot_module, backend, user_backend_kwargs): """Returns plotting function for correct backend""" _backend = { "mpl": "matplotlib", @@ -651,13 +651,28 @@ def get_plotting_method(plot_name, plot_module, backend): except (ImportError, AssertionError): raise ImportError("'bokeh' backend needs Bokeh (1.4.0+) installed.") + # Perform import of plotting method + # TODO: Convert module import to top level for all plots module = importlib.import_module( "arviz.plots.backends.{backend}.{backend}_{plot_module}".format( backend=backend, plot_module=plot_module ) ) + plotting_method = getattr( module, "_{plot_name}_{backend}".format(plot_name=plot_name, backend=backend) ) - return plotting_method + # Get default backend args and combine with user provided values + default_backend_temp_module = importlib.import_module( + "arviz.plots.backends.{}".format(backend), + ) + + default_backend_kwargs = getattr(default_backend_temp_module, "KWARG_DEFAULTS") + + if user_backend_kwargs is None: + user_backend_kwargs = {} + + combined_backend_kwargs = {**default_backend_kwargs, **user_backend_kwargs} + + return plotting_method, combined_backend_kwargs diff --git a/arviz/tests/test_plots_bokeh.py b/arviz/tests/test_plots_bokeh.py index f6ad50a847..486fc24b9f 100644 --- a/arviz/tests/test_plots_bokeh.py +++ b/arviz/tests/test_plots_bokeh.py @@ -88,7 +88,7 @@ def get_ax(): ) def test_plot_density_float(models, kwargs): obj = [getattr(models, model_fit) for model_fit in ["model_1", "model_2"]] - axes = plot_density(obj, backend="bokeh", show=False, **kwargs) + axes = plot_density(obj, backend="bokeh", **kwargs) assert axes.shape[0] >= 6 assert axes.shape[0] >= 3 @@ -184,7 +184,7 @@ def test_plot_kde_cumulative(continuous_model, kwargs): @pytest.mark.parametrize("kwargs", [{"kind": "hist"}, {"kind": "kde"}]) def test_plot_dist(continuous_model, kwargs): - axes = plot_dist(continuous_model["x"], backend="bokeh", show=False, **kwargs) + axes = plot_dist(continuous_model["x"], backend="bokeh", backend_kwargs={"show": False}, **kwargs) assert axes From e3eb3639f7b0d0d34b74588672dba78865a1d080 Mon Sep 17 00:00:00 2001 From: Ravin Kumar Date: Sun, 8 Dec 2019 20:35:14 -0800 Subject: [PATCH 12/41] Update distplot tests --- arviz/plots/backends/bokeh/__init__.py | 2 +- arviz/plots/backends/bokeh/bokeh_distplot.py | 1 + arviz/plots/plot_utils.py | 2 +- arviz/tests/test_plots_bokeh.py | 7 ++++++- 4 files changed, 9 insertions(+), 3 deletions(-) diff --git a/arviz/plots/backends/bokeh/__init__.py b/arviz/plots/backends/bokeh/__init__.py index 6c621d855c..bccab3c6dc 100644 --- a/arviz/plots/backends/bokeh/__init__.py +++ b/arviz/plots/backends/bokeh/__init__.py @@ -3,7 +3,7 @@ import packaging # Set plot generic bokeh keyword arg defaults if none provided -KWARG_DEFAULTS = {"show": True} +BACKEND_KWARG_DEFAULTS = {"show": True} def output_notebook(*args, **kwargs): """Wrap bokeh.plotting.output_notebook.""" diff --git a/arviz/plots/backends/bokeh/bokeh_distplot.py b/arviz/plots/backends/bokeh/bokeh_distplot.py index 2345acef6d..c61918f9c0 100644 --- a/arviz/plots/backends/bokeh/bokeh_distplot.py +++ b/arviz/plots/backends/bokeh/bokeh_distplot.py @@ -84,6 +84,7 @@ def _plot_dist_bokeh( else: raise TypeError('Invalid "kind":{}. Select from {{"auto","kde","hist"}}'.format(kind)) + # TODO: Temporary setting just to make sure tests work. This needs to be removed if show: bkp.show(ax, toolbar_location="above") return ax diff --git a/arviz/plots/plot_utils.py b/arviz/plots/plot_utils.py index ac28a2c66c..c87a4cd433 100644 --- a/arviz/plots/plot_utils.py +++ b/arviz/plots/plot_utils.py @@ -668,7 +668,7 @@ def get_plotting_method(plot_name, plot_module, backend, user_backend_kwargs): "arviz.plots.backends.{}".format(backend), ) - default_backend_kwargs = getattr(default_backend_temp_module, "KWARG_DEFAULTS") + default_backend_kwargs = getattr(default_backend_temp_module, "BACKEND_KWARG_DEFAULTS") if user_backend_kwargs is None: user_backend_kwargs = {} diff --git a/arviz/tests/test_plots_bokeh.py b/arviz/tests/test_plots_bokeh.py index 486fc24b9f..83de778d5c 100644 --- a/arviz/tests/test_plots_bokeh.py +++ b/arviz/tests/test_plots_bokeh.py @@ -74,6 +74,11 @@ def get_ax(): return ax +@pytest.fixture(scope="session") +def backend_kwargs(): + return {"show":False} + + @pytest.mark.parametrize( "kwargs", [ @@ -88,7 +93,7 @@ def get_ax(): ) def test_plot_density_float(models, kwargs): obj = [getattr(models, model_fit) for model_fit in ["model_1", "model_2"]] - axes = plot_density(obj, backend="bokeh", **kwargs) + axes = plot_density(obj, backend="bokeh", show=False, **kwargs) assert axes.shape[0] >= 6 assert axes.shape[0] >= 3 From 32a3671f46ec863d27e520aebb25c808ddb3bfcc Mon Sep 17 00:00:00 2001 From: Ravin Kumar Date: Sun, 8 Dec 2019 20:38:09 -0800 Subject: [PATCH 13/41] Change backend kwargs for distplot to fixture --- arviz/tests/test_plots_bokeh.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/arviz/tests/test_plots_bokeh.py b/arviz/tests/test_plots_bokeh.py index 83de778d5c..951f43276f 100644 --- a/arviz/tests/test_plots_bokeh.py +++ b/arviz/tests/test_plots_bokeh.py @@ -188,8 +188,8 @@ def test_plot_kde_cumulative(continuous_model, kwargs): @pytest.mark.parametrize("kwargs", [{"kind": "hist"}, {"kind": "kde"}]) -def test_plot_dist(continuous_model, kwargs): - axes = plot_dist(continuous_model["x"], backend="bokeh", backend_kwargs={"show": False}, **kwargs) +def test_plot_dist(continuous_model, kwargs, backend_kwargs): + axes = plot_dist(continuous_model["x"], backend="bokeh", backend_kwargs=backend_kwargs, **kwargs) assert axes From 31ce94bf62c3af86670073dd2df274345052cb54 Mon Sep 17 00:00:00 2001 From: Ravin Kumar Date: Sun, 8 Dec 2019 23:31:54 -0800 Subject: [PATCH 14/41] Change backend parameter order --- arviz/plots/backends/bokeh/bokeh_distplot.py | 12 ++++++++---- arviz/plots/distplot.py | 5 +++-- arviz/plots/plot_utils.py | 14 +------------- 3 files changed, 12 insertions(+), 19 deletions(-) diff --git a/arviz/plots/backends/bokeh/bokeh_distplot.py b/arviz/plots/backends/bokeh/bokeh_distplot.py index c61918f9c0..4a7befaea7 100644 --- a/arviz/plots/backends/bokeh/bokeh_distplot.py +++ b/arviz/plots/backends/bokeh/bokeh_distplot.py @@ -5,6 +5,7 @@ from ...kdeplot import plot_kde from ...plot_utils import get_bins from ....rcparams import rcParams +from . import BACKEND_KWARG_DEFAULTS def _plot_dist_bokeh( @@ -28,12 +29,15 @@ def _plot_dist_bokeh( pcolormesh_kwargs=None, hist_kwargs=None, ax=None, - - # Backend Kwargs - show=True, + backend_kwargs=None, **kwargs ): + if backend_kwargs is None: + backend_kwargs = {} + + backend_kwargs = {**BACKEND_KWARG_DEFAULTS, **backend_kwargs} + if ax is None: tools = rcParams["plot.bokeh.tools"] output_backend = rcParams["plot.bokeh.output_backend"] @@ -85,7 +89,7 @@ def _plot_dist_bokeh( raise TypeError('Invalid "kind":{}. Select from {{"auto","kde","hist"}}'.format(kind)) # TODO: Temporary setting just to make sure tests work. This needs to be removed - if show: + if backend_kwargs["show"] is True: bkp.show(ax, toolbar_location="above") return ax diff --git a/arviz/plots/distplot.py b/arviz/plots/distplot.py index 31dabbd6bd..a0667e909d 100644 --- a/arviz/plots/distplot.py +++ b/arviz/plots/distplot.py @@ -185,9 +185,10 @@ def plot_dist( pcolormesh_kwargs=pcolormesh_kwargs, hist_kwargs=hist_kwargs, ax=ax, + backend_kwargs=backend_kwargs, **kwargs, ) - method, backend_kwargs = get_plotting_method("plot_dist", "distplot", backend, backend_kwargs) - ax = method(**dist_plot_args, **backend_kwargs) + method = get_plotting_method("plot_dist", "distplot", backend, backend_kwargs) + ax = method(**dist_plot_args) return ax diff --git a/arviz/plots/plot_utils.py b/arviz/plots/plot_utils.py index c87a4cd433..13945c4d22 100644 --- a/arviz/plots/plot_utils.py +++ b/arviz/plots/plot_utils.py @@ -663,16 +663,4 @@ def get_plotting_method(plot_name, plot_module, backend, user_backend_kwargs): module, "_{plot_name}_{backend}".format(plot_name=plot_name, backend=backend) ) - # Get default backend args and combine with user provided values - default_backend_temp_module = importlib.import_module( - "arviz.plots.backends.{}".format(backend), - ) - - default_backend_kwargs = getattr(default_backend_temp_module, "BACKEND_KWARG_DEFAULTS") - - if user_backend_kwargs is None: - user_backend_kwargs = {} - - combined_backend_kwargs = {**default_backend_kwargs, **user_backend_kwargs} - - return plotting_method, combined_backend_kwargs + return plotting_method From 9a2ddba24c6939782bf512cb068ebe5254491541 Mon Sep 17 00:00:00 2001 From: Ravin Kumar Date: Mon, 9 Dec 2019 07:36:39 -0800 Subject: [PATCH 15/41] WIP Commit --- arviz/plots/traceplot.py | 77 ++++++++++++--------------------- arviz/tests/test_plots_bokeh.py | 4 +- 2 files changed, 29 insertions(+), 52 deletions(-) diff --git a/arviz/plots/traceplot.py b/arviz/plots/traceplot.py index 7eb038c1e0..b25a1ba5a9 100644 --- a/arviz/plots/traceplot.py +++ b/arviz/plots/traceplot.py @@ -1,5 +1,5 @@ """Plot kde or histograms and values from MCMC samples.""" -from .backends import check_bokeh_version +from .plot_utils import get_plotting_method def plot_trace( @@ -20,7 +20,7 @@ def plot_trace( hist_kwargs=None, trace_kwargs=None, backend=None, - show=True, + backend_kwargs=None, **kwargs ): """Plot distribution (histogram or kernel density estimates) and sampled values. @@ -113,52 +113,29 @@ def plot_trace( >>> az.plot_trace(data, var_names=('theta_t', 'theta'), coords=coords, lines=lines) """ - if backend is None or backend.lower() in ("mpl", "matplotlib"): - from .backends.matplotlib.mpl_traceplot import _plot_trace_mpl - - axes = _plot_trace_mpl( - data, - var_names=var_names, - coords=coords, - divergences=divergences, - figsize=figsize, - textsize=textsize, - rug=rug, - lines=lines, - compact=compact, - combined=combined, - legend=legend, - plot_kwargs=plot_kwargs, - fill_kwargs=fill_kwargs, - rug_kwargs=rug_kwargs, - hist_kwargs=hist_kwargs, - trace_kwargs=trace_kwargs, - ) - elif backend.lower() == "bokeh": - check_bokeh_version() - from .backends.bokeh.bokeh_traceplot import _plot_trace_bokeh - - axes = _plot_trace_bokeh( - data, - var_names=var_names, - coords=coords, - divergences=divergences, - figsize=figsize, - rug=rug, - lines=lines, - compact=compact, - combined=combined, - legend=legend, - plot_kwargs=plot_kwargs, - fill_kwargs=fill_kwargs, - rug_kwargs=rug_kwargs, - hist_kwargs=hist_kwargs, - trace_kwargs=trace_kwargs, - show=show, - **kwargs, - ) - else: - raise NotImplementedError( - 'Backend {} not implemented. Use {{"matplotlib", "bokeh"}}'.format(backend) - ) + + # TODO: Check if this can be further simplified + trace_plot_args = dict( + data=data, + var_names = var_names, + coords = coords, + divergences = divergences, + figsize = figsize, + # textsize = textsize, + rug = rug, + lines = lines, + compact = compact, + combined = combined, + legend = legend, + plot_kwargs = plot_kwargs, + fill_kwargs = fill_kwargs, + rug_kwargs = rug_kwargs, + hist_kwargs = hist_kwargs, + trace_kwargs = trace_kwargs, + ) + + method, backend_kwargs = get_plotting_method("plot_trace", "traceplot", backend, backend_kwargs) + axes = method(**trace_plot_args, **backend_kwargs) + return axes + diff --git a/arviz/tests/test_plots_bokeh.py b/arviz/tests/test_plots_bokeh.py index 951f43276f..4743431de8 100644 --- a/arviz/tests/test_plots_bokeh.py +++ b/arviz/tests/test_plots_bokeh.py @@ -135,8 +135,8 @@ def test_plot_density_bad_kwargs(models): {"lines": [("mu", {}, 8)]}, ], ) -def test_plot_trace(models, kwargs): - axes = plot_trace(models.model_1, backend="bokeh", show=False, **kwargs) +def test_plot_trace(models, kwargs, backend_kwargs): + axes = plot_trace(models.model_1, backend="bokeh", backend_kwargs=backend_kwargs, **kwargs) assert axes.shape From 5ec7a75713560194c05b4b21eae42a79311c18b1 Mon Sep 17 00:00:00 2001 From: Ravin Kumar Date: Mon, 9 Dec 2019 09:36:49 -0800 Subject: [PATCH 16/41] WIP Commit for debugging --- arviz/plots/backends/bokeh/bokeh_traceplot.py | 127 +++++------------- .../backends/matplotlib/mpl_traceplot.py | 25 ---- arviz/plots/traceplot.py | 109 +++++++++++++-- 3 files changed, 130 insertions(+), 131 deletions(-) diff --git a/arviz/plots/backends/bokeh/bokeh_traceplot.py b/arviz/plots/backends/bokeh/bokeh_traceplot.py index 3758005cb9..228e4b55fe 100644 --- a/arviz/plots/backends/bokeh/bokeh_traceplot.py +++ b/arviz/plots/backends/bokeh/bokeh_traceplot.py @@ -1,122 +1,60 @@ # pylint: disable=all """Bokeh Traceplot.""" +from typing import Dict from collections.abc import Iterable -from itertools import cycle -import warnings import bokeh.plotting as bkp from bokeh.models import ColumnDataSource, Dash, Span from bokeh.models.annotations import Title from bokeh.layouts import gridplot -import matplotlib.pyplot as plt import numpy as np - -from ....data import convert_to_dataset +from . import BACKEND_KWARG_DEFAULTS from ...distplot import plot_dist -from ...plot_utils import _scale_fig_size, xarray_var_iter, make_label, get_coords +from ...plot_utils import xarray_var_iter, make_label from ....rcparams import rcParams -from ....utils import _var_names + +BACKEND_KWARG_DEFAULTS["tools"] = rcParams["plot.bokeh.tools"] +BACKEND_KWARG_DEFAULTS["output_backend"] = rcParams["plot.bokeh.output_backend"] +BACKEND_KWARG_DEFAULTS["output_backend"] = rcParams["plot.bokeh.output_backend"] def _plot_trace_bokeh( data, - var_names=None, - coords=None, - divergences="bottom", - figsize=None, - rug=False, - lines=None, - compact=False, - combined=False, - legend=False, - plot_kwargs=None, - fill_kwargs=None, - rug_kwargs=None, - hist_kwargs=None, - trace_kwargs=None, - backend_kwargs=None, - show=True, + var_names, + divergences, + figsize, + rug, + lines, + combined, + legend, + plot_kwargs: [Dict], + fill_kwargs: [Dict], + rug_kwargs: [Dict], + hist_kwargs: [Dict], + trace_kwargs: [Dict], + plotters, + divergence_data, + colors, + backend_kwargs: [Dict] ): - if divergences: - try: - divergence_data = convert_to_dataset(data, group="sample_stats").diverging - except (ValueError, AttributeError): # No sample_stats, or no `.diverging` - divergences = False - if coords is None: - coords = {} - - data = get_coords(convert_to_dataset(data, group="posterior"), coords) - var_names = _var_names(var_names, data) - - if divergences: - divergence_data = get_coords( - divergence_data, {k: v for k, v in coords.items() if k in ("chain", "draw")} - ) - - if lines is None: - lines = () - - num_colors = len(data.chain) + 1 if combined else len(data.chain) - colors = [ - prop - for _, prop in zip( - range(num_colors), cycle(plt.rcParams["axes.prop_cycle"].by_key()["color"]) - ) - ] - - if compact: - skip_dims = set(data.dims) - {"chain", "draw"} - else: - skip_dims = set() - - plotters = list(xarray_var_iter(data, var_names=var_names, combined=True, skip_dims=skip_dims)) - max_plots = rcParams["plot.max_subplots"] - max_plots = len(plotters) if max_plots is None else max_plots - if len(plotters) > max_plots: - warnings.warn( - "rcParams['plot.max_subplots'] ({max_plots}) is smaller than the number " - "of variables to plot ({len_plotters}), generating only {max_plots} " - "plots".format(max_plots=max_plots, len_plotters=len(plotters)), - SyntaxWarning, - ) - plotters = plotters[:max_plots] - - if figsize is None: - figsize = (12, len(plotters) * 2) - - if trace_kwargs is None: - trace_kwargs = {} - - trace_kwargs.setdefault("alpha", 0.35) - - if hist_kwargs is None: - hist_kwargs = {} - if plot_kwargs is None: - plot_kwargs = {} - if fill_kwargs is None: - fill_kwargs = {} - if rug_kwargs is None: - rug_kwargs = {} - - hist_kwargs.setdefault("alpha", 0.35) - - figsize, _, _, _, linewidth, _ = _scale_fig_size(figsize, 10, rows=len(plotters), cols=2) - - trace_kwargs.setdefault("line_width", linewidth) - plot_kwargs.setdefault("line_width", linewidth) + # If divergences are plotted they must be provided + assert divergences is not False and divergence_data is not None + # Set plot default backend kwargs if backend_kwargs is None: - backend_kwargs = dict() + backend_kwargs = {} + + backend_kwargs = {**BACKEND_KWARG_DEFAULTS, **backend_kwargs} - backend_kwargs.setdefault("tools", rcParams["plot.bokeh.tools"]) - backend_kwargs.setdefault("output_backend", rcParams["plot.bokeh.output_backend"]) backend_kwargs.setdefault( "height", int(figsize[1] * rcParams["plot.bokeh.figure.dpi"] // len(plotters)) ) backend_kwargs.setdefault("width", int(figsize[0] * rcParams["plot.bokeh.figure.dpi"] // 2)) + # Temporary + backend_kwargs.pop("show") axes = [] for i in range(len(plotters)): if i != 0: @@ -303,7 +241,8 @@ def _plot_trace_bokeh( axes[idx, 0].add_glyph(tmp_cds, glyph_density) axes[idx, 1].add_glyph(tmp_cds, glyph_trace) - if show: + # if backend_kwargs["show"]: + if True: grid = gridplot([list(item) for item in axes], toolbar_location="above") bkp.show(grid) diff --git a/arviz/plots/backends/matplotlib/mpl_traceplot.py b/arviz/plots/backends/matplotlib/mpl_traceplot.py index 9652143021..d5b03ab8e6 100644 --- a/arviz/plots/backends/matplotlib/mpl_traceplot.py +++ b/arviz/plots/backends/matplotlib/mpl_traceplot.py @@ -117,33 +117,8 @@ def _plot_trace_mpl( >>> az.plot_trace(data, var_names=('theta_t', 'theta'), coords=coords, lines=lines) """ - if divergences: - try: - divergence_data = convert_to_dataset(data, group="sample_stats").diverging - except (ValueError, AttributeError): # No sample_stats, or no `.diverging` - divergences = False - if coords is None: - coords = {} - data = get_coords(convert_to_dataset(data, group="posterior"), coords) - var_names = _var_names(var_names, data) - - if divergences: - divergence_data = get_coords( - divergence_data, {k: v for k, v in coords.items() if k in ("chain", "draw")} - ) - - if lines is None: - lines = () - - num_colors = len(data.chain) + 1 if combined else len(data.chain) - colors = [ - prop - for _, prop in zip( - range(num_colors), cycle(plt.rcParams["axes.prop_cycle"].by_key()["color"]) - ) - ] if compact: skip_dims = set(data.dims) - {"chain", "draw"} diff --git a/arviz/plots/traceplot.py b/arviz/plots/traceplot.py index b25a1ba5a9..9bd740e297 100644 --- a/arviz/plots/traceplot.py +++ b/arviz/plots/traceplot.py @@ -1,5 +1,13 @@ """Plot kde or histograms and values from MCMC samples.""" -from .plot_utils import get_plotting_method +from itertools import cycle +import warnings + +import matplotlib.pyplot as plt + +from .plot_utils import get_plotting_method, get_coords, xarray_var_iter, _scale_fig_size +from ..data import convert_to_dataset +from ..utils import _var_names +from ..rcparams import rcParams def plot_trace( @@ -69,8 +77,6 @@ def plot_trace( Extra keyword arguments passed to `plt.plot` backend: str, optional Select plotting backend {"matplotlib","bokeh"}. Default "matplotlib". - show: bool, optional - If True, call bokeh.plotting.show. Returns ------- @@ -114,28 +120,107 @@ def plot_trace( """ + # TODO: This can be simplified somehow I feel like + if divergences: + try: + divergence_data = convert_to_dataset(data, group="sample_stats").diverging + except (ValueError, AttributeError): # No sample_stats, or no `.diverging` + divergences=False + + if coords is None: + coords = {} + + if divergences: + divergence_data = get_coords( + divergence_data, {k: v for k, v in coords.items() if k in ("chain", "draw")} + ) + else: + divergence_data = False + + data = get_coords(convert_to_dataset(data, group="posterior"), coords) + var_names = _var_names(var_names, data) + + if lines is None: + lines = () + + num_colors = len(data.chain) + 1 if combined else len(data.chain) + + # TODO: matplotlib is always required by arviz. Can we get rid of it? + colors = [ + prop + for _, prop in zip( + range(num_colors), cycle(plt.rcParams["axes.prop_cycle"].by_key()["color"]) + ) + ] + + if compact: + skip_dims = set(data.dims) - {"chain", "draw"} + else: + skip_dims = set() + + plotters = list(xarray_var_iter(data, var_names=var_names, combined=True, skip_dims=skip_dims)) + max_plots = rcParams["plot.max_subplots"] + max_plots = len(plotters) if max_plots is None else max_plots + if len(plotters) > max_plots: + warnings.warn( + "rcParams['plot.max_subplots'] ({max_plots}) is smaller than the number " + "of variables to plot ({len_plotters}), generating only {max_plots} " + "plots".format(max_plots=max_plots, len_plotters=len(plotters)), + SyntaxWarning, + ) + plotters = plotters[:max_plots] + + if figsize is None: + figsize = (12, len(plotters) * 2) + + if trace_kwargs is None: + trace_kwargs = {} + trace_kwargs.setdefault("alpha", 0.35) + + if hist_kwargs is None: + hist_kwargs = {} + hist_kwargs.setdefault("alpha", 0.35) + + if plot_kwargs is None: + plot_kwargs = {} + if fill_kwargs is None: + fill_kwargs = {} + if rug_kwargs is None: + rug_kwargs = {} + + figsize, _, _, _, linewidth, _ = _scale_fig_size(figsize, 10, rows=len(plotters), cols=2) + trace_kwargs.setdefault("line_width", linewidth) + plot_kwargs.setdefault("line_width", linewidth) + # TODO: Check if this can be further simplified trace_plot_args = dict( + # User Kwargs data=data, - var_names = var_names, - coords = coords, + var_names=var_names, + # coords = coords, divergences = divergences, figsize = figsize, - # textsize = textsize, - rug = rug, + rug=rug, lines = lines, - compact = compact, - combined = combined, - legend = legend, plot_kwargs = plot_kwargs, fill_kwargs = fill_kwargs, rug_kwargs = rug_kwargs, hist_kwargs = hist_kwargs, trace_kwargs = trace_kwargs, + # compact = compact, + combined = combined, + legend = legend, + + # Generated kwargs + divergence_data = divergence_data, + # skip_dims=skip_dims, + plotters=plotters, + colors=colors, + backend_kwargs=backend_kwargs ) - method, backend_kwargs = get_plotting_method("plot_trace", "traceplot", backend, backend_kwargs) - axes = method(**trace_plot_args, **backend_kwargs) + method = get_plotting_method("plot_trace", "traceplot", backend, backend_kwargs) + axes = method(**trace_plot_args) return axes From 7c83af43435f128d53a149e4122dbb317f857554 Mon Sep 17 00:00:00 2001 From: Ravin Kumar Date: Mon, 9 Dec 2019 22:40:27 -0800 Subject: [PATCH 17/41] Fix traceplot and jointplot --- arviz/plots/backends/bokeh/bokeh_jointplot.py | 3 ++- arviz/plots/backends/bokeh/bokeh_traceplot.py | 16 ++++++++-------- 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/arviz/plots/backends/bokeh/bokeh_jointplot.py b/arviz/plots/backends/bokeh/bokeh_jointplot.py index a223ee3525..074ce0ca6f 100644 --- a/arviz/plots/backends/bokeh/bokeh_jointplot.py +++ b/arviz/plots/backends/bokeh/bokeh_jointplot.py @@ -92,8 +92,9 @@ def _plot_joint( rotated=rotate, ax=ax_, backend="bokeh", - show=False, + backend_kwargs={"show":False}, **marginal_kwargs + ) if show: diff --git a/arviz/plots/backends/bokeh/bokeh_traceplot.py b/arviz/plots/backends/bokeh/bokeh_traceplot.py index 228e4b55fe..187260bd17 100644 --- a/arviz/plots/backends/bokeh/bokeh_traceplot.py +++ b/arviz/plots/backends/bokeh/bokeh_traceplot.py @@ -16,7 +16,6 @@ BACKEND_KWARG_DEFAULTS["tools"] = rcParams["plot.bokeh.tools"] BACKEND_KWARG_DEFAULTS["output_backend"] = rcParams["plot.bokeh.output_backend"] -BACKEND_KWARG_DEFAULTS["output_backend"] = rcParams["plot.bokeh.output_backend"] def _plot_trace_bokeh( @@ -40,7 +39,8 @@ def _plot_trace_bokeh( ): # If divergences are plotted they must be provided - assert divergences is not False and divergence_data is not None + if divergences is not False: + assert divergence_data is not None # Set plot default backend kwargs if backend_kwargs is None: @@ -53,8 +53,9 @@ def _plot_trace_bokeh( ) backend_kwargs.setdefault("width", int(figsize[0] * rcParams["plot.bokeh.figure.dpi"] // 2)) - # Temporary - backend_kwargs.pop("show") + # Used near end for whether to show plot or not, can't be passed to bkp.figure + show = backend_kwargs.pop("show") + axes = [] for i in range(len(plotters)): if i != 0: @@ -241,8 +242,7 @@ def _plot_trace_bokeh( axes[idx, 0].add_glyph(tmp_cds, glyph_density) axes[idx, 1].add_glyph(tmp_cds, glyph_trace) - # if backend_kwargs["show"]: - if True: + if show is True: grid = gridplot([list(item) for item in axes], toolbar_location="above") bkp.show(grid) @@ -297,7 +297,7 @@ def _plot_chains_bokeh( fill_kwargs=fill_kwargs, rug_kwargs=rug_kwargs, backend="bokeh", - show=False, + backend_kwargs={"show":False} ) if combined: @@ -314,5 +314,5 @@ def _plot_chains_bokeh( fill_kwargs=fill_kwargs, rug_kwargs=rug_kwargs, backend="bokeh", - show=False, + backend_kwargs={"show":False} ) From 1b9c9e872ca9055e9f8a051e1c5c8f1bf953f2cd Mon Sep 17 00:00:00 2001 From: Ravin Kumar Date: Mon, 9 Dec 2019 22:45:29 -0800 Subject: [PATCH 18/41] Make error message more verbose --- arviz/plots/plot_utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/arviz/plots/plot_utils.py b/arviz/plots/plot_utils.py index 13945c4d22..baff4316d3 100644 --- a/arviz/plots/plot_utils.py +++ b/arviz/plots/plot_utils.py @@ -649,7 +649,8 @@ def get_plotting_method(plot_name, plot_module, backend, user_backend_kwargs): assert packaging.version.parse(bokeh.__version__) >= packaging.version.parse("1.4.0") except (ImportError, AssertionError): - raise ImportError("'bokeh' backend needs Bokeh (1.4.0+) installed.") + raise ImportError("'bokeh' backend needs Bokeh (1.4.0+) installed." + " Please upgrade or install") # Perform import of plotting method # TODO: Convert module import to top level for all plots From dea511d2b3cd6ec48ab509e81e0a362c03b558a5 Mon Sep 17 00:00:00 2001 From: Ravin Kumar Date: Wed, 11 Dec 2019 07:29:25 -0800 Subject: [PATCH 19/41] Rearrange bokeh imports --- arviz/plots/backends/bokeh/__init__.py | 23 +++++++++++++++++++ ...{bokeh_autocorrplot.py => autocorrplot.py} | 8 +++---- .../{bokeh_compareplot.py => compareplot.py} | 2 +- .../{bokeh_densityplot.py => densityplot.py} | 8 +++---- .../bokeh/{bokeh_distplot.py => distplot.py} | 4 ++-- .../bokeh/{bokeh_elpdplot.py => elpdplot.py} | 7 +++--- .../{bokeh_energyplot.py => energyplot.py} | 4 ++-- .../bokeh/{bokeh_essplot.py => essplot.py} | 10 +++----- .../{bokeh_forestplot.py => forestplot.py} | 10 ++++---- .../bokeh/{bokeh_hpdplot.py => hpdplot.py} | 4 ++-- .../{bokeh_jointplot.py => jointplot.py} | 4 ++-- .../bokeh/{bokeh_kdeplot.py => kdeplot.py} | 8 +++---- .../bokeh/{bokeh_khatplot.py => khatplot.py} | 5 ++-- .../{bokeh_loopitplot.py => loopitplot.py} | 4 ++-- .../bokeh/{bokeh_mcseplot.py => mcseplot.py} | 6 ++--- .../bokeh/{bokeh_pairplot.py => pairplot.py} | 5 ++-- ...{bokeh_parallelplot.py => parallelplot.py} | 4 ++-- ...okeh_posteriorplot.py => posteriorplot.py} | 9 ++++---- .../bokeh/{bokeh_ppcplot.py => ppcplot.py} | 2 +- .../bokeh/{bokeh_rankplot.py => rankplot.py} | 8 +++---- .../{bokeh_traceplot.py => traceplot.py} | 8 +++---- .../{bokeh_violinplot.py => violinplot.py} | 4 ++-- 22 files changed, 83 insertions(+), 64 deletions(-) rename arviz/plots/backends/bokeh/{bokeh_autocorrplot.py => autocorrplot.py} (98%) rename arviz/plots/backends/bokeh/{bokeh_compareplot.py => compareplot.py} (99%) rename arviz/plots/backends/bokeh/{bokeh_densityplot.py => densityplot.py} (99%) rename arviz/plots/backends/bokeh/{bokeh_distplot.py => distplot.py} (99%) rename arviz/plots/backends/bokeh/{bokeh_elpdplot.py => elpdplot.py} (99%) rename arviz/plots/backends/bokeh/{bokeh_energyplot.py => energyplot.py} (97%) rename arviz/plots/backends/bokeh/{bokeh_essplot.py => essplot.py} (97%) rename arviz/plots/backends/bokeh/{bokeh_forestplot.py => forestplot.py} (99%) rename arviz/plots/backends/bokeh/{bokeh_hpdplot.py => hpdplot.py} (95%) rename arviz/plots/backends/bokeh/{bokeh_jointplot.py => jointplot.py} (99%) rename arviz/plots/backends/bokeh/{bokeh_kdeplot.py => kdeplot.py} (99%) rename arviz/plots/backends/bokeh/{bokeh_khatplot.py => khatplot.py} (99%) rename arviz/plots/backends/bokeh/{bokeh_loopitplot.py => loopitplot.py} (99%) rename arviz/plots/backends/bokeh/{bokeh_mcseplot.py => mcseplot.py} (99%) rename arviz/plots/backends/bokeh/{bokeh_pairplot.py => pairplot.py} (99%) rename arviz/plots/backends/bokeh/{bokeh_parallelplot.py => parallelplot.py} (94%) rename arviz/plots/backends/bokeh/{bokeh_posteriorplot.py => posteriorplot.py} (99%) rename arviz/plots/backends/bokeh/{bokeh_ppcplot.py => ppcplot.py} (100%) rename arviz/plots/backends/bokeh/{bokeh_rankplot.py => rankplot.py} (99%) rename arviz/plots/backends/bokeh/{bokeh_traceplot.py => traceplot.py} (99%) rename arviz/plots/backends/bokeh/{bokeh_violinplot.py => violinplot.py} (99%) diff --git a/arviz/plots/backends/bokeh/__init__.py b/arviz/plots/backends/bokeh/__init__.py index bccab3c6dc..b886d1a0bc 100644 --- a/arviz/plots/backends/bokeh/__init__.py +++ b/arviz/plots/backends/bokeh/__init__.py @@ -2,9 +2,32 @@ """Bokeh Plotting Backend.""" import packaging +from .autocorrplot import plot_autocorr +from .compareplot import plot_compare +from .densityplot import plot_density +from .distplot import plot_dist +from .elpdplot import plot_elpd +from .energyplot import plot_energy +from .essplot import plot_ess +from .forestplot import plot_forest +from .hpdplot import plot_hpdplot +from .jointplot import plot_joint +from .kdeplot import plot_kde +from .khatplot import plot_khat +from .loopitplot import plot_loo_pit +from .mcseplot import plot_mcse +from .pairplot import plot_pair +from .parallelplot import plot_parallel +from .posteriorplot import plot_posterior +from .rankplot import plot_rank +from .traceplot import plot_trace +from .violinplot import plot_violin + # Set plot generic bokeh keyword arg defaults if none provided BACKEND_KWARG_DEFAULTS = {"show": True} + + def output_notebook(*args, **kwargs): """Wrap bokeh.plotting.output_notebook.""" import bokeh.plotting as bkp diff --git a/arviz/plots/backends/bokeh/bokeh_autocorrplot.py b/arviz/plots/backends/bokeh/autocorrplot.py similarity index 98% rename from arviz/plots/backends/bokeh/bokeh_autocorrplot.py rename to arviz/plots/backends/bokeh/autocorrplot.py index 69616777c2..4c2acbb45b 100644 --- a/arviz/plots/backends/bokeh/bokeh_autocorrplot.py +++ b/arviz/plots/backends/bokeh/autocorrplot.py @@ -1,14 +1,14 @@ """Bokeh Autocorrplot.""" -import numpy as np import bokeh.plotting as bkp -from bokeh.models.annotations import Title +import numpy as np from bokeh.layouts import gridplot +from bokeh.models.annotations import Title -from ....stats import autocorr from ...plot_utils import make_label +from ....stats import autocorr -def _plot_autocorr( +def plot_autocorr( axes, plotters, max_lag, line_width, combined=False, show=True, ): for (var_name, selection, x), ax_ in zip(plotters, axes.flatten()): diff --git a/arviz/plots/backends/bokeh/bokeh_compareplot.py b/arviz/plots/backends/bokeh/compareplot.py similarity index 99% rename from arviz/plots/backends/bokeh/bokeh_compareplot.py rename to arviz/plots/backends/bokeh/compareplot.py index 6114e34b1f..9d0ce2ff90 100644 --- a/arviz/plots/backends/bokeh/bokeh_compareplot.py +++ b/arviz/plots/backends/bokeh/compareplot.py @@ -5,7 +5,7 @@ from ....rcparams import rcParams -def _compareplot( +def plot_compare( ax, comp_df, figsize, diff --git a/arviz/plots/backends/bokeh/bokeh_densityplot.py b/arviz/plots/backends/bokeh/densityplot.py similarity index 99% rename from arviz/plots/backends/bokeh/bokeh_densityplot.py rename to arviz/plots/backends/bokeh/densityplot.py index 413a0e27aa..e1f579eabf 100644 --- a/arviz/plots/backends/bokeh/bokeh_densityplot.py +++ b/arviz/plots/backends/bokeh/densityplot.py @@ -1,16 +1,16 @@ """Bokeh Densityplot.""" import bokeh.plotting as bkp -from bokeh.models.annotations import Title -from bokeh.layouts import gridplot import numpy as np +from bokeh.layouts import gridplot +from bokeh.models.annotations import Title -from ....stats import hpd from ...kdeplot import _fast_kde from ...plot_utils import make_label +from ....stats import hpd from ....stats.stats_utils import histogram -def _plot_density( +def plot_density( ax, all_labels, to_plot, diff --git a/arviz/plots/backends/bokeh/bokeh_distplot.py b/arviz/plots/backends/bokeh/distplot.py similarity index 99% rename from arviz/plots/backends/bokeh/bokeh_distplot.py rename to arviz/plots/backends/bokeh/distplot.py index 4a7befaea7..32ba626764 100644 --- a/arviz/plots/backends/bokeh/bokeh_distplot.py +++ b/arviz/plots/backends/bokeh/distplot.py @@ -2,13 +2,13 @@ import bokeh.plotting as bkp import numpy as np +from . import BACKEND_KWARG_DEFAULTS from ...kdeplot import plot_kde from ...plot_utils import get_bins from ....rcparams import rcParams -from . import BACKEND_KWARG_DEFAULTS -def _plot_dist_bokeh( +def plot_dist( values, values2=None, color="C0", diff --git a/arviz/plots/backends/bokeh/bokeh_elpdplot.py b/arviz/plots/backends/bokeh/elpdplot.py similarity index 99% rename from arviz/plots/backends/bokeh/bokeh_elpdplot.py rename to arviz/plots/backends/bokeh/elpdplot.py index fcd6723cd4..e535c4235a 100644 --- a/arviz/plots/backends/bokeh/bokeh_elpdplot.py +++ b/arviz/plots/backends/bokeh/elpdplot.py @@ -2,16 +2,15 @@ import warnings import bokeh.plotting as bkp -from bokeh.models.annotations import Title -from bokeh.layouts import gridplot import numpy as np - +from bokeh.layouts import gridplot +from bokeh.models.annotations import Title from ...plot_utils import _scale_fig_size from ....rcparams import rcParams -def _plot_elpd( +def plot_elpd( ax, models, pointwise_data, diff --git a/arviz/plots/backends/bokeh/bokeh_energyplot.py b/arviz/plots/backends/bokeh/energyplot.py similarity index 97% rename from arviz/plots/backends/bokeh/bokeh_energyplot.py rename to arviz/plots/backends/bokeh/energyplot.py index 630827f02f..3c62362ee5 100644 --- a/arviz/plots/backends/bokeh/bokeh_energyplot.py +++ b/arviz/plots/backends/bokeh/energyplot.py @@ -2,13 +2,13 @@ import bokeh.plotting as bkp from bokeh.models import Label +from .distplot import _histplot_bokeh_op from ...kdeplot import plot_kde -from .bokeh_distplot import _histplot_bokeh_op from ....rcparams import rcParams from ....stats import bfmi as e_bfmi -def _plot_energy( +def plot_energy( ax, series, energy, kind, bfmi, figsize, line_width, fill_kwargs, plot_kwargs, bw, legend, show, ): if ax is None: diff --git a/arviz/plots/backends/bokeh/bokeh_essplot.py b/arviz/plots/backends/bokeh/essplot.py similarity index 97% rename from arviz/plots/backends/bokeh/bokeh_essplot.py rename to arviz/plots/backends/bokeh/essplot.py index 62635de9fa..cd4a375412 100644 --- a/arviz/plots/backends/bokeh/bokeh_essplot.py +++ b/arviz/plots/backends/bokeh/essplot.py @@ -1,23 +1,19 @@ # pylint: disable=all """Bokeh ESS plots.""" import bokeh.plotting as bkp +import numpy as np +from bokeh.layouts import gridplot from bokeh.models import Dash, Span, ColumnDataSource from bokeh.models.annotations import Title -from bokeh.layouts import gridplot -import numpy as np from scipy.stats import rankdata - from ...plot_utils import ( make_label, _create_axes_grid, - get_coords, - filter_plotters_list, ) -from ....rcparams import rcParams -def _plot_ess( +def plot_ess( ax, plotters, xdata, diff --git a/arviz/plots/backends/bokeh/bokeh_forestplot.py b/arviz/plots/backends/bokeh/forestplot.py similarity index 99% rename from arviz/plots/backends/bokeh/bokeh_forestplot.py rename to arviz/plots/backends/bokeh/forestplot.py index 999c460b82..8d38930401 100644 --- a/arviz/plots/backends/bokeh/bokeh_forestplot.py +++ b/arviz/plots/backends/bokeh/forestplot.py @@ -3,16 +3,16 @@ from collections import defaultdict, OrderedDict from itertools import cycle, tee -import numpy as np import bokeh.plotting as bkp +import matplotlib.pyplot as plt +import numpy as np +from bokeh.layouts import gridplot from bokeh.models import Band, ColumnDataSource from bokeh.models.annotations import Title from bokeh.models.tickers import FixedTicker -from bokeh.layouts import gridplot -import matplotlib.pyplot as plt -from ...plot_utils import _scale_fig_size, xarray_var_iter, make_label, get_bins from ...kdeplot import _fast_kde +from ...plot_utils import _scale_fig_size, xarray_var_iter, make_label, get_bins from ....rcparams import rcParams from ....stats import hpd from ....stats.diagnostics import _ess, _rhat @@ -27,7 +27,7 @@ def pairwise(iterable): return zip(first, second) -def _plot_forest( +def plot_forest( ax, datasets, var_names, diff --git a/arviz/plots/backends/bokeh/bokeh_hpdplot.py b/arviz/plots/backends/bokeh/hpdplot.py similarity index 95% rename from arviz/plots/backends/bokeh/bokeh_hpdplot.py rename to arviz/plots/backends/bokeh/hpdplot.py index 76c1e859b7..e9485ea3b8 100644 --- a/arviz/plots/backends/bokeh/bokeh_hpdplot.py +++ b/arviz/plots/backends/bokeh/hpdplot.py @@ -2,13 +2,13 @@ from itertools import cycle import bokeh.plotting as bkp -from matplotlib.pyplot import rcParams as mpl_rcParams import numpy as np +from matplotlib.pyplot import rcParams as mpl_rcParams from ....rcparams import rcParams -def _plot_hpdplot(ax, x_data, y_data, plot_kwargs, fill_kwargs, show): +def plot_hpdplot(ax, x_data, y_data, plot_kwargs, fill_kwargs, show): if ax is None: tools = rcParams["plot.bokeh.tools"] output_backend = rcParams["plot.bokeh.output_backend"] diff --git a/arviz/plots/backends/bokeh/bokeh_jointplot.py b/arviz/plots/backends/bokeh/jointplot.py similarity index 99% rename from arviz/plots/backends/bokeh/bokeh_jointplot.py rename to arviz/plots/backends/bokeh/jointplot.py index 074ce0ca6f..40c9051181 100644 --- a/arviz/plots/backends/bokeh/bokeh_jointplot.py +++ b/arviz/plots/backends/bokeh/jointplot.py @@ -1,7 +1,7 @@ """Bokeh jointplot.""" import bokeh.plotting as bkp -from bokeh.layouts import gridplot import numpy as np +from bokeh.layouts import gridplot from ...distplot import plot_dist from ...kdeplot import plot_kde @@ -9,7 +9,7 @@ from ....rcparams import rcParams -def _plot_joint( +def plot_joint( ax, figsize, plotters, diff --git a/arviz/plots/backends/bokeh/bokeh_kdeplot.py b/arviz/plots/backends/bokeh/kdeplot.py similarity index 99% rename from arviz/plots/backends/bokeh/bokeh_kdeplot.py rename to arviz/plots/backends/bokeh/kdeplot.py index 93051c8875..d40e1ca20d 100644 --- a/arviz/plots/backends/bokeh/bokeh_kdeplot.py +++ b/arviz/plots/backends/bokeh/kdeplot.py @@ -4,17 +4,17 @@ from numbers import Integral import bokeh.plotting as bkp -from bokeh.models import ColumnDataSource, Dash, Range1d import matplotlib._contour as _contour +import numpy as np +from bokeh.models import ColumnDataSource, Dash, Range1d +from matplotlib.cm import get_cmap from matplotlib.colors import rgb2hex from matplotlib.pyplot import rcParams as mpl_rcParams -from matplotlib.cm import get_cmap -import numpy as np from ....rcparams import rcParams -def _plot_kde_bokeh( +def plot_kde( density, lower, upper, diff --git a/arviz/plots/backends/bokeh/bokeh_khatplot.py b/arviz/plots/backends/bokeh/khatplot.py similarity index 99% rename from arviz/plots/backends/bokeh/bokeh_khatplot.py rename to arviz/plots/backends/bokeh/khatplot.py index a04efd53a9..5560caf744 100644 --- a/arviz/plots/backends/bokeh/bokeh_khatplot.py +++ b/arviz/plots/backends/bokeh/khatplot.py @@ -2,15 +2,14 @@ from collections.abc import Iterable import bokeh.plotting as bkp -from bokeh.models import ColumnDataSource, Span - import numpy as np +from bokeh.models import ColumnDataSource, Span from ....rcparams import rcParams from ....stats.stats_utils import histogram -def _plot_khat( +def plot_khat( ax, figsize, xdata, diff --git a/arviz/plots/backends/bokeh/bokeh_loopitplot.py b/arviz/plots/backends/bokeh/loopitplot.py similarity index 99% rename from arviz/plots/backends/bokeh/bokeh_loopitplot.py rename to arviz/plots/backends/bokeh/loopitplot.py index 0ec535f078..e001a1a19b 100644 --- a/arviz/plots/backends/bokeh/bokeh_loopitplot.py +++ b/arviz/plots/backends/bokeh/loopitplot.py @@ -1,13 +1,13 @@ """Bokeh loopitplot.""" -import numpy as np import bokeh.plotting as bkp +import numpy as np from ...hpdplot import plot_hpd from ...kdeplot import _fast_kde from ....rcparams import rcParams -def _plot_loo_pit( +def plot_loo_pit( ax, figsize, ecdf, diff --git a/arviz/plots/backends/bokeh/bokeh_mcseplot.py b/arviz/plots/backends/bokeh/mcseplot.py similarity index 99% rename from arviz/plots/backends/bokeh/bokeh_mcseplot.py rename to arviz/plots/backends/bokeh/mcseplot.py index 6ad005ffc6..b2d35be9b2 100644 --- a/arviz/plots/backends/bokeh/bokeh_mcseplot.py +++ b/arviz/plots/backends/bokeh/mcseplot.py @@ -1,10 +1,10 @@ """Bokeh mcseplot.""" import bokeh.plotting as bkp +import numpy as np +from bokeh.layouts import gridplot from bokeh.models import ColumnDataSource, Dash, Span from bokeh.models.annotations import Title -from bokeh.layouts import gridplot -import numpy as np from scipy.stats import rankdata from ...plot_utils import ( @@ -14,7 +14,7 @@ from ....stats.stats_utils import quantile as _quantile -def _plot_mcse( +def plot_mcse( ax, plotters, length_plotters, diff --git a/arviz/plots/backends/bokeh/bokeh_pairplot.py b/arviz/plots/backends/bokeh/pairplot.py similarity index 99% rename from arviz/plots/backends/bokeh/bokeh_pairplot.py rename to arviz/plots/backends/bokeh/pairplot.py index c233c5ce82..c52319b618 100644 --- a/arviz/plots/backends/bokeh/bokeh_pairplot.py +++ b/arviz/plots/backends/bokeh/pairplot.py @@ -2,8 +2,9 @@ import warnings from uuid import uuid4 -import numpy as np + import bokeh.plotting as bkp +import numpy as np from bokeh.layouts import gridplot from bokeh.models import ColumnDataSource @@ -12,7 +13,7 @@ from ....rcparams import rcParams -def _plot_pair( +def plot_pair( ax, _posterior, numvars, diff --git a/arviz/plots/backends/bokeh/bokeh_parallelplot.py b/arviz/plots/backends/bokeh/parallelplot.py similarity index 94% rename from arviz/plots/backends/bokeh/bokeh_parallelplot.py rename to arviz/plots/backends/bokeh/parallelplot.py index 3fa0392200..e28133ad1a 100644 --- a/arviz/plots/backends/bokeh/bokeh_parallelplot.py +++ b/arviz/plots/backends/bokeh/parallelplot.py @@ -1,12 +1,12 @@ """Bokeh Parallel coordinates plot.""" import bokeh.plotting as bkp -from bokeh.models.tickers import FixedTicker import numpy as np +from bokeh.models.tickers import FixedTicker from ....rcparams import rcParams -def _plot_parallel(ax, diverging_mask, _posterior, var_names, figsize, show): +def plot_parallel(ax, diverging_mask, _posterior, var_names, figsize, show): if ax is None: tools = rcParams["plot.bokeh.tools"] output_backend = rcParams["plot.bokeh.output_backend"] diff --git a/arviz/plots/backends/bokeh/bokeh_posteriorplot.py b/arviz/plots/backends/bokeh/posteriorplot.py similarity index 99% rename from arviz/plots/backends/bokeh/bokeh_posteriorplot.py rename to arviz/plots/backends/bokeh/posteriorplot.py index 4fa3573195..cc7963bb67 100644 --- a/arviz/plots/backends/bokeh/bokeh_posteriorplot.py +++ b/arviz/plots/backends/bokeh/posteriorplot.py @@ -1,12 +1,13 @@ """Bokeh Plot posterior densities.""" from typing import Optional from numbers import Number +from typing import Optional + import bokeh.plotting as bkp +import numpy as np +from bokeh.layouts import gridplot from bokeh.models import ColumnDataSource from bokeh.models.annotations import Title -from bokeh.layouts import gridplot - -import numpy as np from scipy.stats import mode from ...kdeplot import plot_kde, _fast_kde @@ -19,7 +20,7 @@ from ....stats import hpd -def _plot_posterior( +def plot_posterior( ax, length_plotters, rows, diff --git a/arviz/plots/backends/bokeh/bokeh_ppcplot.py b/arviz/plots/backends/bokeh/ppcplot.py similarity index 100% rename from arviz/plots/backends/bokeh/bokeh_ppcplot.py rename to arviz/plots/backends/bokeh/ppcplot.py index 476601a183..b9cb59ced4 100644 --- a/arviz/plots/backends/bokeh/bokeh_ppcplot.py +++ b/arviz/plots/backends/bokeh/ppcplot.py @@ -1,7 +1,7 @@ """Bokeh Posterior predictive plot.""" import bokeh.plotting as bkp -from bokeh.layouts import gridplot import numpy as np +from bokeh.layouts import gridplot from ...kdeplot import plot_kde, _fast_kde from ...plot_utils import ( diff --git a/arviz/plots/backends/bokeh/bokeh_rankplot.py b/arviz/plots/backends/bokeh/rankplot.py similarity index 99% rename from arviz/plots/backends/bokeh/bokeh_rankplot.py rename to arviz/plots/backends/bokeh/rankplot.py index 057ff7bcbf..9fb5a5bc9a 100644 --- a/arviz/plots/backends/bokeh/bokeh_rankplot.py +++ b/arviz/plots/backends/bokeh/rankplot.py @@ -1,11 +1,11 @@ """Bokeh rankplot.""" import bokeh.plotting as bkp +import numpy as np +import scipy.stats +from bokeh.layouts import gridplot from bokeh.models import Span from bokeh.models.annotations import Title from bokeh.models.tickers import FixedTicker -from bokeh.layouts import gridplot -import numpy as np -import scipy.stats from ...plot_utils import ( _create_axes_grid, @@ -14,7 +14,7 @@ from ....stats.stats_utils import histogram -def _plot_rank( +def plot_rank( axes, length_plotters, rows, diff --git a/arviz/plots/backends/bokeh/bokeh_traceplot.py b/arviz/plots/backends/bokeh/traceplot.py similarity index 99% rename from arviz/plots/backends/bokeh/bokeh_traceplot.py rename to arviz/plots/backends/bokeh/traceplot.py index 187260bd17..cada0eec92 100644 --- a/arviz/plots/backends/bokeh/bokeh_traceplot.py +++ b/arviz/plots/backends/bokeh/traceplot.py @@ -1,13 +1,13 @@ # pylint: disable=all """Bokeh Traceplot.""" -from typing import Dict from collections.abc import Iterable +from typing import Dict import bokeh.plotting as bkp +import numpy as np +from bokeh.layouts import gridplot from bokeh.models import ColumnDataSource, Dash, Span from bokeh.models.annotations import Title -from bokeh.layouts import gridplot -import numpy as np from . import BACKEND_KWARG_DEFAULTS from ...distplot import plot_dist @@ -18,7 +18,7 @@ BACKEND_KWARG_DEFAULTS["output_backend"] = rcParams["plot.bokeh.output_backend"] -def _plot_trace_bokeh( +def plot_trace( data, var_names, divergences, diff --git a/arviz/plots/backends/bokeh/bokeh_violinplot.py b/arviz/plots/backends/bokeh/violinplot.py similarity index 99% rename from arviz/plots/backends/bokeh/bokeh_violinplot.py rename to arviz/plots/backends/bokeh/violinplot.py index 4cba285f08..332206dada 100644 --- a/arviz/plots/backends/bokeh/bokeh_violinplot.py +++ b/arviz/plots/backends/bokeh/violinplot.py @@ -1,8 +1,8 @@ """Bokeh Violinplot.""" import bokeh.plotting as bkp +import numpy as np from bokeh.layouts import gridplot from bokeh.models.annotations import Title -import numpy as np from ...kdeplot import _fast_kde from ...plot_utils import get_bins, make_label, _create_axes_grid @@ -10,7 +10,7 @@ from ....stats.stats_utils import histogram -def _plot_violin( +def plot_violin( ax, plotters, figsize, From cffe99c7125ff0a33c83c9bd554d80877c1f905c Mon Sep 17 00:00:00 2001 From: Ravin Kumar Date: Wed, 11 Dec 2019 07:37:08 -0800 Subject: [PATCH 20/41] Renamespace matplotlib --- arviz/plots/backends/matplotlib/__init__.py | 21 ++++++++++++++++++- .../{mpl_autocorrplot.py => autocorrplot.py} | 2 +- .../{mpl_compareplot.py => compareplot.py} | 2 +- .../{mpl_densityplot.py => densityplot.py} | 2 +- .../{mpl_distplot.py => distplot.py} | 2 +- .../{mpl_elpdplot.py => elpdplot.py} | 2 +- .../{mpl_energyplot.py => energyplot.py} | 2 +- .../matplotlib/{mpl_essplot.py => essplot.py} | 2 +- .../{mpl_forestplot.py => forestplot.py} | 2 +- .../matplotlib/{mpl_hpdplot.py => hpdplot.py} | 2 +- .../{mpl_jointplot.py => jointplot.py} | 2 +- .../matplotlib/{mpl_kdeplot.py => kdeplot.py} | 2 +- .../{mpl_khatplot.py => khatplot.py} | 2 +- .../{mpl_loopitplot.py => loopitplot.py} | 2 +- .../{mpl_mcseplot.py => mcseplot.py} | 2 +- .../{mpl_pairplot.py => pairplot.py} | 2 +- .../{mpl_parallelplot.py => parallelplot.py} | 2 +- ...{mpl_posteriorplot.py => posteriorplot.py} | 2 +- .../matplotlib/{mpl_ppcplot.py => ppcplot.py} | 2 +- .../{mpl_rankplot.py => rankplot.py} | 2 +- .../{mpl_traceplot.py => traceplot.py} | 2 +- .../{mpl_violinplot.py => violinplot.py} | 2 +- 22 files changed, 41 insertions(+), 22 deletions(-) rename arviz/plots/backends/matplotlib/{mpl_autocorrplot.py => autocorrplot.py} (97%) rename arviz/plots/backends/matplotlib/{mpl_compareplot.py => compareplot.py} (99%) rename arviz/plots/backends/matplotlib/{mpl_densityplot.py => densityplot.py} (99%) rename arviz/plots/backends/matplotlib/{mpl_distplot.py => distplot.py} (99%) rename arviz/plots/backends/matplotlib/{mpl_elpdplot.py => elpdplot.py} (99%) rename arviz/plots/backends/matplotlib/{mpl_energyplot.py => energyplot.py} (98%) rename arviz/plots/backends/matplotlib/{mpl_essplot.py => essplot.py} (99%) rename arviz/plots/backends/matplotlib/{mpl_forestplot.py => forestplot.py} (99%) rename arviz/plots/backends/matplotlib/{mpl_hpdplot.py => hpdplot.py} (77%) rename arviz/plots/backends/matplotlib/{mpl_jointplot.py => jointplot.py} (99%) rename arviz/plots/backends/matplotlib/{mpl_kdeplot.py => kdeplot.py} (99%) rename arviz/plots/backends/matplotlib/{mpl_khatplot.py => khatplot.py} (99%) rename arviz/plots/backends/matplotlib/{mpl_loopitplot.py => loopitplot.py} (98%) rename arviz/plots/backends/matplotlib/{mpl_mcseplot.py => mcseplot.py} (99%) rename arviz/plots/backends/matplotlib/{mpl_pairplot.py => pairplot.py} (99%) rename arviz/plots/backends/matplotlib/{mpl_parallelplot.py => parallelplot.py} (97%) rename arviz/plots/backends/matplotlib/{mpl_posteriorplot.py => posteriorplot.py} (99%) rename arviz/plots/backends/matplotlib/{mpl_ppcplot.py => ppcplot.py} (99%) rename arviz/plots/backends/matplotlib/{mpl_rankplot.py => rankplot.py} (99%) rename arviz/plots/backends/matplotlib/{mpl_traceplot.py => traceplot.py} (99%) rename arviz/plots/backends/matplotlib/{mpl_violinplot.py => violinplot.py} (99%) diff --git a/arviz/plots/backends/matplotlib/__init__.py b/arviz/plots/backends/matplotlib/__init__.py index 6818ed643c..adfbb1f0c3 100644 --- a/arviz/plots/backends/matplotlib/__init__.py +++ b/arviz/plots/backends/matplotlib/__init__.py @@ -1,2 +1,21 @@ """Matplotlib Plotting Backend.""" -from .mpl_distplot import _plot_dist_mpl +from .autocorrplot import plot_autocorr +from .compareplot import plot_compare +from .densityplot import plot_density +from .distplot import plot_dist +from .elpdplot import plot_elpd +from .energyplot import plot_energy +from .essplot import plot_ess +from .forestplot import plot_forest +from .hpdplot import plot_hpdplot +from .jointplot import plot_joint +from .kdeplot import plot_kde +from .khatplot import plot_khat +from .loopitplot import plot_loo_pit +from .mcseplot import plot_mcse +from .pairplot import plot_pair +from .parallelplot import plot_parallel +from .posteriorplot import plot_posterior +from .rankplot import plot_rank +from .traceplot import plot_trace +from .violinplot import plot_violin \ No newline at end of file diff --git a/arviz/plots/backends/matplotlib/mpl_autocorrplot.py b/arviz/plots/backends/matplotlib/autocorrplot.py similarity index 97% rename from arviz/plots/backends/matplotlib/mpl_autocorrplot.py rename to arviz/plots/backends/matplotlib/autocorrplot.py index 8efae1e9f5..1d77c37649 100644 --- a/arviz/plots/backends/matplotlib/mpl_autocorrplot.py +++ b/arviz/plots/backends/matplotlib/autocorrplot.py @@ -5,7 +5,7 @@ from ...plot_utils import make_label -def _plot_autocorr( +def plot_autocorr( axes, plotters, max_lag, linewidth, titlesize, combined=False, xt_labelsize=None, ): for (var_name, selection, x), ax_ in zip(plotters, axes.flatten()): diff --git a/arviz/plots/backends/matplotlib/mpl_compareplot.py b/arviz/plots/backends/matplotlib/compareplot.py similarity index 99% rename from arviz/plots/backends/matplotlib/mpl_compareplot.py rename to arviz/plots/backends/matplotlib/compareplot.py index 9880260cc0..8dd0f3f68f 100644 --- a/arviz/plots/backends/matplotlib/mpl_compareplot.py +++ b/arviz/plots/backends/matplotlib/compareplot.py @@ -2,7 +2,7 @@ import matplotlib.pyplot as plt -def _compareplot( +def plot_compare( ax, comp_df, figsize, diff --git a/arviz/plots/backends/matplotlib/mpl_densityplot.py b/arviz/plots/backends/matplotlib/densityplot.py similarity index 99% rename from arviz/plots/backends/matplotlib/mpl_densityplot.py rename to arviz/plots/backends/matplotlib/densityplot.py index 3e9f6c4372..6d226ba39f 100644 --- a/arviz/plots/backends/matplotlib/mpl_densityplot.py +++ b/arviz/plots/backends/matplotlib/densityplot.py @@ -6,7 +6,7 @@ from ...plot_utils import make_label -def _plot_density( +def plot_density( ax, all_labels, to_plot, diff --git a/arviz/plots/backends/matplotlib/mpl_distplot.py b/arviz/plots/backends/matplotlib/distplot.py similarity index 99% rename from arviz/plots/backends/matplotlib/mpl_distplot.py rename to arviz/plots/backends/matplotlib/distplot.py index a038dcfe53..cf1f5aebad 100644 --- a/arviz/plots/backends/matplotlib/mpl_distplot.py +++ b/arviz/plots/backends/matplotlib/distplot.py @@ -4,7 +4,7 @@ from ...kdeplot import plot_kde -def _plot_dist_mpl( +def plot_dist( values, values2=None, color="C0", diff --git a/arviz/plots/backends/matplotlib/mpl_elpdplot.py b/arviz/plots/backends/matplotlib/elpdplot.py similarity index 99% rename from arviz/plots/backends/matplotlib/mpl_elpdplot.py rename to arviz/plots/backends/matplotlib/elpdplot.py index 473b7718f4..647cff6740 100644 --- a/arviz/plots/backends/matplotlib/mpl_elpdplot.py +++ b/arviz/plots/backends/matplotlib/elpdplot.py @@ -12,7 +12,7 @@ from ....rcparams import rcParams -def _plot_elpd( +def plot_elpd( ax, models, pointwise_data, diff --git a/arviz/plots/backends/matplotlib/mpl_energyplot.py b/arviz/plots/backends/matplotlib/energyplot.py similarity index 98% rename from arviz/plots/backends/matplotlib/mpl_energyplot.py rename to arviz/plots/backends/matplotlib/energyplot.py index 13f3f9c244..504baf884f 100644 --- a/arviz/plots/backends/matplotlib/mpl_energyplot.py +++ b/arviz/plots/backends/matplotlib/energyplot.py @@ -5,7 +5,7 @@ from ....stats import bfmi as e_bfmi -def _plot_energy( +def plot_energy( ax, series, energy, diff --git a/arviz/plots/backends/matplotlib/mpl_essplot.py b/arviz/plots/backends/matplotlib/essplot.py similarity index 99% rename from arviz/plots/backends/matplotlib/mpl_essplot.py rename to arviz/plots/backends/matplotlib/essplot.py index 202b6057ed..21fddccbec 100644 --- a/arviz/plots/backends/matplotlib/mpl_essplot.py +++ b/arviz/plots/backends/matplotlib/essplot.py @@ -9,7 +9,7 @@ ) -def _plot_ess( +def plot_ess( ax, plotters, xdata, diff --git a/arviz/plots/backends/matplotlib/mpl_forestplot.py b/arviz/plots/backends/matplotlib/forestplot.py similarity index 99% rename from arviz/plots/backends/matplotlib/mpl_forestplot.py rename to arviz/plots/backends/matplotlib/forestplot.py index 53e3427bf3..e8c780078a 100644 --- a/arviz/plots/backends/matplotlib/mpl_forestplot.py +++ b/arviz/plots/backends/matplotlib/forestplot.py @@ -21,7 +21,7 @@ def pairwise(iterable): return zip(first, second) -def _plot_forest( +def plot_forest( ax, datasets, var_names, diff --git a/arviz/plots/backends/matplotlib/mpl_hpdplot.py b/arviz/plots/backends/matplotlib/hpdplot.py similarity index 77% rename from arviz/plots/backends/matplotlib/mpl_hpdplot.py rename to arviz/plots/backends/matplotlib/hpdplot.py index cca86418f7..e6129a084c 100644 --- a/arviz/plots/backends/matplotlib/mpl_hpdplot.py +++ b/arviz/plots/backends/matplotlib/hpdplot.py @@ -2,7 +2,7 @@ from matplotlib.pyplot import gca -def _plot_hpdplot(ax, x_data, y_data, plot_kwargs, fill_kwargs): +def plot_hpdplot(ax, x_data, y_data, plot_kwargs, fill_kwargs): if ax is None: ax = gca() ax.plot(x_data, y_data, **plot_kwargs) diff --git a/arviz/plots/backends/matplotlib/mpl_jointplot.py b/arviz/plots/backends/matplotlib/jointplot.py similarity index 99% rename from arviz/plots/backends/matplotlib/mpl_jointplot.py rename to arviz/plots/backends/matplotlib/jointplot.py index 706140aa37..089c669915 100644 --- a/arviz/plots/backends/matplotlib/mpl_jointplot.py +++ b/arviz/plots/backends/matplotlib/jointplot.py @@ -7,7 +7,7 @@ from ...plot_utils import make_label -def _plot_joint( +def plot_joint( ax, figsize, plotters, diff --git a/arviz/plots/backends/matplotlib/mpl_kdeplot.py b/arviz/plots/backends/matplotlib/kdeplot.py similarity index 99% rename from arviz/plots/backends/matplotlib/mpl_kdeplot.py rename to arviz/plots/backends/matplotlib/kdeplot.py index 0d52744dd3..d94627d247 100644 --- a/arviz/plots/backends/matplotlib/mpl_kdeplot.py +++ b/arviz/plots/backends/matplotlib/kdeplot.py @@ -5,7 +5,7 @@ from ...plot_utils import _scale_fig_size -def _plot_kde_mpl( +def plot_kde( density, lower, upper, diff --git a/arviz/plots/backends/matplotlib/mpl_khatplot.py b/arviz/plots/backends/matplotlib/khatplot.py similarity index 99% rename from arviz/plots/backends/matplotlib/mpl_khatplot.py rename to arviz/plots/backends/matplotlib/khatplot.py index 88a8f934db..93d9e4314b 100644 --- a/arviz/plots/backends/matplotlib/mpl_khatplot.py +++ b/arviz/plots/backends/matplotlib/khatplot.py @@ -9,7 +9,7 @@ from ....stats.stats_utils import histogram -def _plot_khat( +def plot_khat( hover_label, hover_format, ax, diff --git a/arviz/plots/backends/matplotlib/mpl_loopitplot.py b/arviz/plots/backends/matplotlib/loopitplot.py similarity index 98% rename from arviz/plots/backends/matplotlib/mpl_loopitplot.py rename to arviz/plots/backends/matplotlib/loopitplot.py index 485e1dc3a1..7a5f21dc9c 100644 --- a/arviz/plots/backends/matplotlib/mpl_loopitplot.py +++ b/arviz/plots/backends/matplotlib/loopitplot.py @@ -6,7 +6,7 @@ from ...hpdplot import plot_hpd -def _plot_loo_pit( +def plot_loo_pit( ax, figsize, ecdf, diff --git a/arviz/plots/backends/matplotlib/mpl_mcseplot.py b/arviz/plots/backends/matplotlib/mcseplot.py similarity index 99% rename from arviz/plots/backends/matplotlib/mpl_mcseplot.py rename to arviz/plots/backends/matplotlib/mcseplot.py index c242e038a3..a04025bf98 100644 --- a/arviz/plots/backends/matplotlib/mpl_mcseplot.py +++ b/arviz/plots/backends/matplotlib/mcseplot.py @@ -10,7 +10,7 @@ ) -def _plot_mcse( +def plot_mcse( ax, plotters, length_plotters, diff --git a/arviz/plots/backends/matplotlib/mpl_pairplot.py b/arviz/plots/backends/matplotlib/pairplot.py similarity index 99% rename from arviz/plots/backends/matplotlib/mpl_pairplot.py rename to arviz/plots/backends/matplotlib/pairplot.py index 6cf91b44e4..4980a33bf6 100644 --- a/arviz/plots/backends/matplotlib/mpl_pairplot.py +++ b/arviz/plots/backends/matplotlib/pairplot.py @@ -11,7 +11,7 @@ from ....rcparams import rcParams -def _plot_pair( +def plot_pair( ax, _posterior, numvars, diff --git a/arviz/plots/backends/matplotlib/mpl_parallelplot.py b/arviz/plots/backends/matplotlib/parallelplot.py similarity index 97% rename from arviz/plots/backends/matplotlib/mpl_parallelplot.py rename to arviz/plots/backends/matplotlib/parallelplot.py index 1b7818c5bc..0b74315170 100644 --- a/arviz/plots/backends/matplotlib/mpl_parallelplot.py +++ b/arviz/plots/backends/matplotlib/parallelplot.py @@ -3,7 +3,7 @@ import numpy as np -def _plot_parallel( +def plot_parallel( ax, colornd, colord, diff --git a/arviz/plots/backends/matplotlib/mpl_posteriorplot.py b/arviz/plots/backends/matplotlib/posteriorplot.py similarity index 99% rename from arviz/plots/backends/matplotlib/mpl_posteriorplot.py rename to arviz/plots/backends/matplotlib/posteriorplot.py index f4fd32c434..2c3bd2a382 100644 --- a/arviz/plots/backends/matplotlib/mpl_posteriorplot.py +++ b/arviz/plots/backends/matplotlib/posteriorplot.py @@ -14,7 +14,7 @@ ) -def _plot_posterior( +def plot_posterior( ax, length_plotters, rows, diff --git a/arviz/plots/backends/matplotlib/mpl_ppcplot.py b/arviz/plots/backends/matplotlib/ppcplot.py similarity index 99% rename from arviz/plots/backends/matplotlib/mpl_ppcplot.py rename to arviz/plots/backends/matplotlib/ppcplot.py index d118215926..12cd6d9341 100644 --- a/arviz/plots/backends/matplotlib/mpl_ppcplot.py +++ b/arviz/plots/backends/matplotlib/ppcplot.py @@ -10,7 +10,7 @@ from ....stats.stats_utils import histogram -def _plot_ppc( +def plot_ppc( ax, length_plotters, rows, diff --git a/arviz/plots/backends/matplotlib/mpl_rankplot.py b/arviz/plots/backends/matplotlib/rankplot.py similarity index 99% rename from arviz/plots/backends/matplotlib/mpl_rankplot.py rename to arviz/plots/backends/matplotlib/rankplot.py index b2fe845c51..b132bfba36 100644 --- a/arviz/plots/backends/matplotlib/mpl_rankplot.py +++ b/arviz/plots/backends/matplotlib/rankplot.py @@ -9,7 +9,7 @@ from ....stats.stats_utils import histogram -def _plot_rank( +def plot_rank( axes, length_plotters, rows, diff --git a/arviz/plots/backends/matplotlib/mpl_traceplot.py b/arviz/plots/backends/matplotlib/traceplot.py similarity index 99% rename from arviz/plots/backends/matplotlib/mpl_traceplot.py rename to arviz/plots/backends/matplotlib/traceplot.py index d5b03ab8e6..34711adc05 100644 --- a/arviz/plots/backends/matplotlib/mpl_traceplot.py +++ b/arviz/plots/backends/matplotlib/traceplot.py @@ -14,7 +14,7 @@ from ....rcparams import rcParams -def _plot_trace_mpl( +def plot_trace( data, var_names=None, coords=None, diff --git a/arviz/plots/backends/matplotlib/mpl_violinplot.py b/arviz/plots/backends/matplotlib/violinplot.py similarity index 99% rename from arviz/plots/backends/matplotlib/mpl_violinplot.py rename to arviz/plots/backends/matplotlib/violinplot.py index 725149d842..53bfc0765f 100644 --- a/arviz/plots/backends/matplotlib/mpl_violinplot.py +++ b/arviz/plots/backends/matplotlib/violinplot.py @@ -7,7 +7,7 @@ from ...plot_utils import get_bins, make_label, _create_axes_grid -def _plot_violin( +def plot_violin( ax, plotters, figsize, From d01638b4c41524551791db5412c399a385bdef72 Mon Sep 17 00:00:00 2001 From: Ravin Kumar Date: Wed, 11 Dec 2019 09:23:04 -0800 Subject: [PATCH 21/41] Ensure all bokeh tests are passing --- arviz/plots/autocorrplot.py | 11 ++++------- arviz/plots/backends/__init__.py | 3 --- arviz/plots/backends/bokeh/__init__.py | 10 +++++----- arviz/plots/backends/bokeh/hpdplot.py | 2 +- arviz/plots/backends/bokeh/ppcplot.py | 2 +- arviz/plots/backends/matplotlib/__init__.py | 5 ++++- arviz/plots/backends/matplotlib/hpdplot.py | 2 +- arviz/plots/compareplot.py | 12 ++++-------- arviz/plots/densityplot.py | 12 ++++-------- arviz/plots/distplot.py | 1 - arviz/plots/elpdplot.py | 13 ++++--------- arviz/plots/energyplot.py | 13 ++++--------- arviz/plots/essplot.py | 13 ++++--------- arviz/plots/forestplot.py | 14 ++++---------- arviz/plots/hpdplot.py | 13 ++++--------- arviz/plots/jointplot.py | 13 ++++--------- arviz/plots/kdeplot.py | 19 +++++++------------ arviz/plots/khatplot.py | 14 +++++--------- arviz/plots/loopitplot.py | 15 ++++++--------- arviz/plots/mcseplot.py | 13 ++++--------- arviz/plots/pairplot.py | 14 ++++---------- arviz/plots/parallelplot.py | 12 ++++-------- arviz/plots/plot_utils.py | 4 ++-- arviz/plots/posteriorplot.py | 12 ++++-------- arviz/plots/ppcplot.py | 12 ++++-------- arviz/plots/rankplot.py | 12 ++++-------- arviz/plots/violinplot.py | 13 ++++--------- 27 files changed, 96 insertions(+), 183 deletions(-) diff --git a/arviz/plots/autocorrplot.py b/arviz/plots/autocorrplot.py index 8d49a36598..656a6f486f 100644 --- a/arviz/plots/autocorrplot.py +++ b/arviz/plots/autocorrplot.py @@ -1,7 +1,6 @@ """Autocorrelation plot of data.""" import numpy as np -from .backends import check_bokeh_version from ..data import convert_to_dataset from .plot_utils import ( _scale_fig_size, @@ -9,6 +8,7 @@ xarray_var_iter, _create_axes_grid, filter_plotters_list, + get_plotting_method ) from ..utils import _var_names @@ -136,17 +136,14 @@ def plot_autocorr( ) if backend == "bokeh": - check_bokeh_version() - from .backends.bokeh.bokeh_autocorrplot import _plot_autocorr autocorr_plot_args.pop("xt_labelsize") autocorr_plot_args.pop("titlesize") autocorr_plot_args["line_width"] = autocorr_plot_args.pop("linewidth") autocorr_plot_args["show"] = show - axes = _plot_autocorr(**autocorr_plot_args) # pylint: disable=unexpected-keyword-arg - else: - from .backends.matplotlib.mpl_autocorrplot import _plot_autocorr - axes = _plot_autocorr(**autocorr_plot_args) + # TODO: Add backend kwargs + method = get_plotting_method("plot_autocorr", "autocorrplot", backend, {}) + axes = method(**autocorr_plot_args) return axes diff --git a/arviz/plots/backends/__init__.py b/arviz/plots/backends/__init__.py index d3df238c0a..2857dbde0c 100644 --- a/arviz/plots/backends/__init__.py +++ b/arviz/plots/backends/__init__.py @@ -1,4 +1 @@ """ArviZ plotting backends.""" - -# TODO: Get rid of this line once check_bokeh_version is removed -from .bokeh import check_bokeh_version diff --git a/arviz/plots/backends/bokeh/__init__.py b/arviz/plots/backends/bokeh/__init__.py index b886d1a0bc..234c835d8d 100644 --- a/arviz/plots/backends/bokeh/__init__.py +++ b/arviz/plots/backends/bokeh/__init__.py @@ -2,6 +2,9 @@ """Bokeh Plotting Backend.""" import packaging +# Set plot generic bokeh keyword arg defaults if none provided +BACKEND_KWARG_DEFAULTS = {"show": True} + from .autocorrplot import plot_autocorr from .compareplot import plot_compare from .densityplot import plot_density @@ -10,7 +13,7 @@ from .energyplot import plot_energy from .essplot import plot_ess from .forestplot import plot_forest -from .hpdplot import plot_hpdplot +from .hpdplot import plot_hpd from .jointplot import plot_joint from .kdeplot import plot_kde from .khatplot import plot_khat @@ -18,15 +21,12 @@ from .mcseplot import plot_mcse from .pairplot import plot_pair from .parallelplot import plot_parallel +from .ppcplot import plot_ppc from .posteriorplot import plot_posterior from .rankplot import plot_rank from .traceplot import plot_trace from .violinplot import plot_violin -# Set plot generic bokeh keyword arg defaults if none provided -BACKEND_KWARG_DEFAULTS = {"show": True} - - def output_notebook(*args, **kwargs): """Wrap bokeh.plotting.output_notebook.""" diff --git a/arviz/plots/backends/bokeh/hpdplot.py b/arviz/plots/backends/bokeh/hpdplot.py index e9485ea3b8..b515114e84 100644 --- a/arviz/plots/backends/bokeh/hpdplot.py +++ b/arviz/plots/backends/bokeh/hpdplot.py @@ -8,7 +8,7 @@ from ....rcparams import rcParams -def plot_hpdplot(ax, x_data, y_data, plot_kwargs, fill_kwargs, show): +def plot_hpd(ax, x_data, y_data, plot_kwargs, fill_kwargs, show): if ax is None: tools = rcParams["plot.bokeh.tools"] output_backend = rcParams["plot.bokeh.output_backend"] diff --git a/arviz/plots/backends/bokeh/ppcplot.py b/arviz/plots/backends/bokeh/ppcplot.py index b9cb59ced4..5bdb24889c 100644 --- a/arviz/plots/backends/bokeh/ppcplot.py +++ b/arviz/plots/backends/bokeh/ppcplot.py @@ -11,7 +11,7 @@ from ....stats.stats_utils import histogram -def _plot_ppc( +def plot_ppc( ax, length_plotters, rows, diff --git a/arviz/plots/backends/matplotlib/__init__.py b/arviz/plots/backends/matplotlib/__init__.py index adfbb1f0c3..c6be8d3c81 100644 --- a/arviz/plots/backends/matplotlib/__init__.py +++ b/arviz/plots/backends/matplotlib/__init__.py @@ -1,4 +1,6 @@ """Matplotlib Plotting Backend.""" +BACKEND_KWARG_DEFAULTS = {"show": True, "anotherkey":"test"} + from .autocorrplot import plot_autocorr from .compareplot import plot_compare from .densityplot import plot_density @@ -7,7 +9,7 @@ from .energyplot import plot_energy from .essplot import plot_ess from .forestplot import plot_forest -from .hpdplot import plot_hpdplot +from .hpdplot import plot_hpd from .jointplot import plot_joint from .kdeplot import plot_kde from .khatplot import plot_khat @@ -16,6 +18,7 @@ from .pairplot import plot_pair from .parallelplot import plot_parallel from .posteriorplot import plot_posterior +from .ppcplot import plot_ppc from .rankplot import plot_rank from .traceplot import plot_trace from .violinplot import plot_violin \ No newline at end of file diff --git a/arviz/plots/backends/matplotlib/hpdplot.py b/arviz/plots/backends/matplotlib/hpdplot.py index e6129a084c..e0761a9ad8 100644 --- a/arviz/plots/backends/matplotlib/hpdplot.py +++ b/arviz/plots/backends/matplotlib/hpdplot.py @@ -2,7 +2,7 @@ from matplotlib.pyplot import gca -def plot_hpdplot(ax, x_data, y_data, plot_kwargs, fill_kwargs): +def plot_hpd(ax, x_data, y_data, plot_kwargs, fill_kwargs): if ax is None: ax = gca() ax.plot(x_data, y_data, **plot_kwargs) diff --git a/arviz/plots/compareplot.py b/arviz/plots/compareplot.py index 2f4cfa9dc2..91b7416a77 100644 --- a/arviz/plots/compareplot.py +++ b/arviz/plots/compareplot.py @@ -1,8 +1,7 @@ """Summary plot for model comparison.""" import numpy as np -from .backends import check_bokeh_version -from .plot_utils import _scale_fig_size +from .plot_utils import _scale_fig_size, get_plotting_method def plot_compare( @@ -130,17 +129,14 @@ def plot_compare( ) if backend == "bokeh": - check_bokeh_version() - from .backends.bokeh.bokeh_compareplot import _compareplot compareplot_kwargs["line_width"] = compareplot_kwargs.pop("linewidth") compareplot_kwargs.pop("ax_labelsize") compareplot_kwargs.pop("xt_labelsize") compareplot_kwargs["show"] = show - ax = _compareplot(**compareplot_kwargs) # pylint: disable=unexpected-keyword-arg - else: - from .backends.matplotlib.mpl_compareplot import _compareplot - ax = _compareplot(**compareplot_kwargs) + # TODO: Add backend kwargs + method = get_plotting_method("plot_compare", "compareplot", backend, {}) + ax = method(**compareplot_kwargs) return ax diff --git a/arviz/plots/densityplot.py b/arviz/plots/densityplot.py index 214367a8b7..416551566a 100644 --- a/arviz/plots/densityplot.py +++ b/arviz/plots/densityplot.py @@ -4,7 +4,6 @@ import matplotlib.pyplot as plt -from .backends import check_bokeh_version from ..data import convert_to_dataset from .plot_utils import ( _scale_fig_size, @@ -12,6 +11,7 @@ xarray_var_iter, default_grid, _create_axes_grid, + get_plotting_method ) from ..utils import _var_names from ..rcparams import rcParams @@ -238,18 +238,14 @@ def plot_density( ) if backend == "bokeh": - check_bokeh_version() - from .backends.bokeh.bokeh_densityplot import _plot_density plot_density_kwargs["line_width"] = plot_density_kwargs.pop("linewidth") plot_density_kwargs.pop("titlesize") plot_density_kwargs.pop("xt_labelsize") plot_density_kwargs["show"] = show plot_density_kwargs.pop("n_data") - _plot_density(**plot_density_kwargs) # pylint: disable=unexpected-keyword-arg - else: - from .backends.matplotlib.mpl_densityplot import _plot_density - - _plot_density(**plot_density_kwargs) + # TODO: Add backend kwargs + method = get_plotting_method("plot_density", "densityplot", backend, {}) + ax = method(**plot_density_kwargs) return ax diff --git a/arviz/plots/distplot.py b/arviz/plots/distplot.py index a0667e909d..020cf0f408 100644 --- a/arviz/plots/distplot.py +++ b/arviz/plots/distplot.py @@ -1,6 +1,5 @@ # pylint: disable=unexpected-keyword-arg """Plot distribution as histogram or kernel density estimates.""" -from .backends import check_bokeh_version from .plot_utils import get_bins from .plot_utils import get_bins, get_plotting_method diff --git a/arviz/plots/elpdplot.py b/arviz/plots/elpdplot.py index 612fb04e68..867030e441 100644 --- a/arviz/plots/elpdplot.py +++ b/arviz/plots/elpdplot.py @@ -4,12 +4,12 @@ import matplotlib.cm as cm from matplotlib.lines import Line2D -from .backends import check_bokeh_version from ..data import convert_to_inference_data from .plot_utils import ( get_coords, format_coords_as_labels, color_from_dim, + get_plotting_method ) from ..stats import waic, loo, ELPDData from ..rcparams import rcParams @@ -207,17 +207,12 @@ def plot_elpd( ) if backend == "bokeh": - check_bokeh_version() - from .backends.bokeh.bokeh_elpdplot import _plot_elpd - elpd_plot_kwargs.pop("legend") elpd_plot_kwargs.pop("handles") elpd_plot_kwargs.pop("color") elpd_plot_kwargs["show"] = show - ax = _plot_elpd(**elpd_plot_kwargs) # pylint: disable=unexpected-keyword-arg - elif backend == "matplotlib": - from .backends.matplotlib.mpl_elpdplot import _plot_elpd - - ax = _plot_elpd(**elpd_plot_kwargs) + # TODO: Add backend kwargs + method = get_plotting_method("plot_elpd", "elpdplot", backend, {}) + ax = method(**elpd_plot_kwargs) return ax diff --git a/arviz/plots/energyplot.py b/arviz/plots/energyplot.py index 1c0639dba3..246661dc83 100644 --- a/arviz/plots/energyplot.py +++ b/arviz/plots/energyplot.py @@ -3,9 +3,8 @@ from matplotlib.pyplot import rcParams import numpy as np -from .backends import check_bokeh_version from ..data import convert_to_dataset -from .plot_utils import _scale_fig_size +from .plot_utils import _scale_fig_size, get_plotting_method def plot_energy( @@ -132,18 +131,14 @@ def plot_energy( ) if backend == "bokeh": - check_bokeh_version() - from .backends.bokeh.bokeh_energyplot import _plot_energy plot_energy_kwargs.pop("xt_labelsize") plot_energy_kwargs["line_width"] = plot_energy_kwargs.pop("linewidth") plot_energy_kwargs["show"] = show if kind in {"hist", "histogram"}: plot_energy_kwargs["legend"] = False - ax = _plot_energy(**plot_energy_kwargs) # pylint: disable=unexpected-keyword-arg - else: - from .backends.matplotlib.mpl_energyplot import _plot_energy - - ax = _plot_energy(**plot_energy_kwargs) + # TODO: Add backend kwargs + method = get_plotting_method("plot_energy", "energyplot", backend, {}) + ax = method(**plot_energy_kwargs) return ax diff --git a/arviz/plots/essplot.py b/arviz/plots/essplot.py index de96a838cf..985df10ed9 100644 --- a/arviz/plots/essplot.py +++ b/arviz/plots/essplot.py @@ -2,7 +2,6 @@ import numpy as np import xarray as xr -from .backends import check_bokeh_version from ..data import convert_to_dataset from ..stats import ess from .plot_utils import ( @@ -11,6 +10,7 @@ default_grid, get_coords, filter_plotters_list, + get_plotting_method ) from ..utils import _var_names @@ -317,14 +317,9 @@ def plot_ess( ) if backend == "bokeh": - check_bokeh_version() - from .backends.bokeh.bokeh_essplot import _plot_ess - essplot_kwargs["show"] = show - ax = _plot_ess(**essplot_kwargs) - else: - from .backends.matplotlib.mpl_essplot import _plot_ess - - ax = _plot_ess(**essplot_kwargs) + # TODO: Add backend kwargs + method = get_plotting_method("plot_ess", "essplot", backend, {}) + ax = method(**essplot_kwargs) return ax diff --git a/arviz/plots/forestplot.py b/arviz/plots/forestplot.py index 2866281c57..7d08ce60c5 100644 --- a/arviz/plots/forestplot.py +++ b/arviz/plots/forestplot.py @@ -1,7 +1,6 @@ """Forest plot.""" -from .backends import check_bokeh_version from ..data import convert_to_dataset -from .plot_utils import get_coords +from .plot_utils import get_coords, get_plotting_method from ..utils import _var_names @@ -176,14 +175,9 @@ def plot_forest( ) if backend == "bokeh": - check_bokeh_version() - from .backends.bokeh.bokeh_forestplot import _plot_forest - plot_forest_kwargs["show"] = show - axes = _plot_forest(**plot_forest_kwargs) # pylint: disable=unexpected-keyword-arg - else: - from .backends.matplotlib.mpl_forestplot import _plot_forest - - axes = _plot_forest(**plot_forest_kwargs) + # TODO: Add backend kwargs + method = get_plotting_method("plot_forest", "forestplot", backend, {}) + axes = method(**plot_forest_kwargs) return axes diff --git a/arviz/plots/hpdplot.py b/arviz/plots/hpdplot.py index ee8d26e625..9c85f04f92 100644 --- a/arviz/plots/hpdplot.py +++ b/arviz/plots/hpdplot.py @@ -3,8 +3,8 @@ from scipy.interpolate import griddata from scipy.signal import savgol_filter -from .backends import check_bokeh_version from ..stats import hpd +from .plot_utils import get_plotting_method def plot_hpd( @@ -104,14 +104,9 @@ def plot_hpd( ) if backend == "bokeh": - check_bokeh_version() - from .backends.bokeh.bokeh_hpdplot import _plot_hpdplot - hpdplot_kwargs["show"] = show - ax = _plot_hpdplot(**hpdplot_kwargs) - else: - from .backends.matplotlib.mpl_hpdplot import _plot_hpdplot - - ax = _plot_hpdplot(**hpdplot_kwargs) + # TODO: Add backend kwargs + method = get_plotting_method("plot_hpd", "hpdplot", backend, {}) + ax = method(**hpdplot_kwargs) return ax diff --git a/arviz/plots/jointplot.py b/arviz/plots/jointplot.py index b5d86a82f1..646c9b1378 100644 --- a/arviz/plots/jointplot.py +++ b/arviz/plots/jointplot.py @@ -1,7 +1,6 @@ """Joint scatter plot of two variables.""" -from .backends import check_bokeh_version from ..data import convert_to_dataset -from .plot_utils import _scale_fig_size, xarray_var_iter, get_coords +from .plot_utils import _scale_fig_size, xarray_var_iter, get_coords, get_plotting_method from ..utils import _var_names @@ -166,18 +165,14 @@ def plot_joint( ) if backend == "bokeh": - check_bokeh_version() - from .backends.bokeh.bokeh_jointplot import _plot_joint plot_joint_kwargs.pop("ax_labelsize") plot_joint_kwargs["marginal_kwargs"]["plot_kwargs"]["line_width"] = plot_joint_kwargs[ "marginal_kwargs" ]["plot_kwargs"].pop("linewidth") plot_joint_kwargs["show"] = show - axes = _plot_joint(**plot_joint_kwargs) # pylint: disable=unexpected-keyword-arg - else: - from .backends.matplotlib.mpl_jointplot import _plot_joint - - axes = _plot_joint(**plot_joint_kwargs) + # TODO: Add backend kwargs + method = get_plotting_method("plot_joint", "jointplot", backend, {}) + axes = method(**plot_joint_kwargs) return axes diff --git a/arviz/plots/kdeplot.py b/arviz/plots/kdeplot.py index 1284eeac67..3dbcf5afad 100644 --- a/arviz/plots/kdeplot.py +++ b/arviz/plots/kdeplot.py @@ -6,10 +6,10 @@ from scipy.sparse import coo_matrix import xarray as xr -from .backends import check_bokeh_version from ..data import InferenceData from ..utils import conditional_jit, _stack from ..stats.stats_utils import histogram +from .plot_utils import get_plotting_method def plot_kde( @@ -211,21 +211,16 @@ def plot_kde( legend=legend, **kwargs, ) - if backend is None or backend.lower() in ("mpl", "matplotlib"): - from .backends.matplotlib.mpl_kdeplot import _plot_kde_mpl - ax = _plot_kde_mpl(**kde_plot_args) - elif backend == "bokeh": - check_bokeh_version() - from .backends.bokeh.bokeh_kdeplot import _plot_kde_bokeh + if backend == "bokeh": kde_plot_args["show"] = show kde_plot_args.pop("textsize") - ax = _plot_kde_bokeh(**kde_plot_args) - else: - raise NotImplementedError( - 'Backend {} not implemented. Use {{"matplotlib", "bokeh"}}'.format(backend) - ) + + # TODO: Add backend kwargs + method = get_plotting_method("plot_kde", "kdeplot", backend, {}) + ax = method(**kde_plot_args) + return ax diff --git a/arviz/plots/khatplot.py b/arviz/plots/khatplot.py index c02d09f1a1..610c863bbf 100644 --- a/arviz/plots/khatplot.py +++ b/arviz/plots/khatplot.py @@ -5,12 +5,12 @@ import numpy as np from xarray import DataArray -from .backends import check_bokeh_version from .plot_utils import ( _scale_fig_size, get_coords, color_from_dim, format_coords_as_labels, + get_plotting_method ) from ..stats import ELPDData @@ -220,8 +220,6 @@ def plot_khat( ) if backend == "bokeh": - check_bokeh_version() - from .backends.bokeh.bokeh_khatplot import _plot_khat plot_khat_kwargs.pop("hover_label") plot_khat_kwargs.pop("hover_format") @@ -235,10 +233,8 @@ def plot_khat( plot_khat_kwargs.pop("cmap") plot_khat_kwargs.pop("color") plot_khat_kwargs["show"] = show - ax = _plot_khat(**plot_khat_kwargs) # pylint: disable=unexpected-keyword-arg - else: - from .backends.matplotlib.mpl_khatplot import _plot_khat - - ax = _plot_khat(**plot_khat_kwargs) - return ax + # TODO: Add backend kwargs + method = get_plotting_method("plot_khat", "khatplot", backend, {}) + axes = method(**plot_khat_kwargs) + return axes diff --git a/arviz/plots/loopitplot.py b/arviz/plots/loopitplot.py index 1c9c5e4988..065b670b93 100644 --- a/arviz/plots/loopitplot.py +++ b/arviz/plots/loopitplot.py @@ -4,9 +4,8 @@ from matplotlib.colors import to_rgb, rgb_to_hsv, hsv_to_rgb, to_hex from xarray import DataArray -from .backends import check_bokeh_version from ..stats import loo_pit as _loo_pit -from .plot_utils import _scale_fig_size +from .plot_utils import _scale_fig_size, get_plotting_method from .kdeplot import _fast_kde @@ -235,8 +234,6 @@ def plot_loo_pit( ) if backend == "bokeh": - check_bokeh_version() - from .backends.bokeh.bokeh_loopitplot import _plot_loo_pit if ( loo_pit_kwargs["hpd_kwargs"] is not None @@ -249,10 +246,10 @@ def plot_loo_pit( loo_pit_kwargs.pop("xt_labelsize") loo_pit_kwargs.pop("credible_interval") loo_pit_kwargs["show"] = show - ax = _plot_loo_pit(**loo_pit_kwargs) # pylint: disable=unexpected-keyword-arg - else: - from .backends.matplotlib.mpl_loopitplot import _plot_loo_pit - ax = _plot_loo_pit(**loo_pit_kwargs) - return ax + # TODO: Add backend kwargs + method = get_plotting_method("plot_loo_pit", "loopitplot", backend, {}) + axes = method(**loo_pit_kwargs) + + return axes diff --git a/arviz/plots/mcseplot.py b/arviz/plots/mcseplot.py index 60a853693f..0c659f86b1 100644 --- a/arviz/plots/mcseplot.py +++ b/arviz/plots/mcseplot.py @@ -2,7 +2,6 @@ import numpy as np import xarray as xr -from .backends import check_bokeh_version from ..data import convert_to_dataset from ..stats import mcse from .plot_utils import ( @@ -11,6 +10,7 @@ default_grid, get_coords, filter_plotters_list, + get_plotting_method ) from ..utils import _var_names @@ -182,9 +182,6 @@ def plot_mcse( ) if backend == "bokeh": - check_bokeh_version() - from .backends.bokeh.bokeh_mcseplot import _plot_mcse - mcse_kwargs.pop("kwargs") mcse_kwargs.pop("text_x") mcse_kwargs.pop("text_va") @@ -193,10 +190,8 @@ def plot_mcse( mcse_kwargs.pop("ax_labelsize") mcse_kwargs.pop("titlesize") mcse_kwargs["show"] = show - ax = _plot_mcse(**mcse_kwargs) # pylint: disable=unexpected-keyword-arg - else: - from .backends.matplotlib.mpl_mcseplot import _plot_mcse - - ax = _plot_mcse(**mcse_kwargs) + # TODO: Add backend kwargs + method = get_plotting_method("plot_mcse", "mcseplot", backend, {}) + ax = method(**mcse_kwargs) return ax diff --git a/arviz/plots/pairplot.py b/arviz/plots/pairplot.py index 00b56e3c7f..78e27745be 100644 --- a/arviz/plots/pairplot.py +++ b/arviz/plots/pairplot.py @@ -2,9 +2,8 @@ import warnings import numpy as np -from .backends import check_bokeh_version from ..data import convert_to_dataset, convert_to_inference_data -from .plot_utils import xarray_to_ndarray, get_coords +from .plot_utils import xarray_to_ndarray, get_coords, get_plotting_method from ..utils import _var_names @@ -192,18 +191,13 @@ def plot_pair( ) if backend == "bokeh": - check_bokeh_version() - from .backends.bokeh.bokeh_pairplot import _plot_pair - pairplot_kwargs.pop("gridsize", None) pairplot_kwargs.pop("colorbar", None) pairplot_kwargs.pop("divergences_kwargs", None) pairplot_kwargs.pop("hexbin_values", None) pairplot_kwargs["show"] = show - ax = _plot_pair(**pairplot_kwargs) # pylint: disable=unexpected-keyword-arg - else: - from .backends.matplotlib.mpl_pairplot import _plot_pair - - ax = _plot_pair(**pairplot_kwargs) + # TODO: Add backend kwargs + method = get_plotting_method("plot_pair", "pairplot", backend, {}) + ax = method(**pairplot_kwargs) return ax diff --git a/arviz/plots/parallelplot.py b/arviz/plots/parallelplot.py index 6e602189a8..5f0b176a71 100644 --- a/arviz/plots/parallelplot.py +++ b/arviz/plots/parallelplot.py @@ -2,9 +2,8 @@ import numpy as np from scipy.stats.mstats import rankdata -from .backends import check_bokeh_version from ..data import convert_to_dataset -from .plot_utils import _scale_fig_size, xarray_to_ndarray, get_coords +from .plot_utils import _scale_fig_size, xarray_to_ndarray, get_coords, get_plotting_method from ..utils import _var_names, _numba_var from ..stats.stats_utils import stats_variance_2d as svar @@ -139,8 +138,6 @@ def plot_parallel( ) if backend == "bokeh": - check_bokeh_version() - from .backends.bokeh.bokeh_parallelplot import _plot_parallel parallel_kwargs["show"] = show parallel_kwargs.pop("textsize") @@ -149,10 +146,9 @@ def plot_parallel( parallel_kwargs.pop("colord") parallel_kwargs.pop("colornd") parallel_kwargs.pop("shadend") - ax = _plot_parallel(**parallel_kwargs) # pylint: disable=unexpected-keyword-arg - else: - from .backends.matplotlib.mpl_parallelplot import _plot_parallel - ax = _plot_parallel(**parallel_kwargs) + # TODO: Add backend kwargs + method = get_plotting_method("plot_parallel", "parallelplot", backend, {}) + ax = method(**parallel_kwargs) return ax diff --git a/arviz/plots/plot_utils.py b/arviz/plots/plot_utils.py index baff4316d3..27f8d89fab 100644 --- a/arviz/plots/plot_utils.py +++ b/arviz/plots/plot_utils.py @@ -655,13 +655,13 @@ def get_plotting_method(plot_name, plot_module, backend, user_backend_kwargs): # Perform import of plotting method # TODO: Convert module import to top level for all plots module = importlib.import_module( - "arviz.plots.backends.{backend}.{backend}_{plot_module}".format( + "arviz.plots.backends.{backend}.{plot_module}".format( backend=backend, plot_module=plot_module ) ) plotting_method = getattr( - module, "_{plot_name}_{backend}".format(plot_name=plot_name, backend=backend) + module, plot_name ) return plotting_method diff --git a/arviz/plots/posteriorplot.py b/arviz/plots/posteriorplot.py index 4dced01f56..e5fc5195f4 100644 --- a/arviz/plots/posteriorplot.py +++ b/arviz/plots/posteriorplot.py @@ -1,7 +1,6 @@ """Plot posterior densities.""" from typing import Optional -from .backends import check_bokeh_version from ..data import convert_to_dataset from .plot_utils import ( xarray_var_iter, @@ -9,6 +8,7 @@ default_grid, get_coords, filter_plotters_list, + get_plotting_method ) from ..utils import _var_names @@ -211,16 +211,12 @@ def plot_posterior( ) if backend == "bokeh": - check_bokeh_version() - from .backends.bokeh.bokeh_posteriorplot import _plot_posterior posteriorplot_kwargs.pop("xt_labelsize") posteriorplot_kwargs.pop("titlesize") posteriorplot_kwargs["show"] = show - ax = _plot_posterior(**posteriorplot_kwargs) # pylint: disable=unexpected-keyword-arg - else: - from .backends.matplotlib.mpl_posteriorplot import _plot_posterior - - ax = _plot_posterior(**posteriorplot_kwargs) + # TODO: Add backend kwargs + method = get_plotting_method("plot_posterior", "posteriorplot", backend, {}) + ax = method(**posteriorplot_kwargs) return ax diff --git a/arviz/plots/ppcplot.py b/arviz/plots/ppcplot.py index 6e0ec70b72..b47d964eea 100644 --- a/arviz/plots/ppcplot.py +++ b/arviz/plots/ppcplot.py @@ -4,12 +4,12 @@ import logging import numpy as np -from .backends import check_bokeh_version from .plot_utils import ( xarray_var_iter, _scale_fig_size, default_grid, filter_plotters_list, + get_plotting_method ) from ..utils import _var_names @@ -291,8 +291,6 @@ def plot_ppc( ) if backend == "bokeh": - check_bokeh_version() - from .backends.bokeh.bokeh_ppcplot import _plot_ppc ppcplot_kwargs.pop("animated") ppcplot_kwargs.pop("animation_kwargs") @@ -300,10 +298,8 @@ def plot_ppc( ppcplot_kwargs.pop("xt_labelsize") ppcplot_kwargs.pop("ax_labelsize") ppcplot_kwargs["show"] = show - axes = _plot_ppc(**ppcplot_kwargs) # pylint: disable=unexpected-keyword-arg - else: - from .backends.matplotlib.mpl_ppcplot import _plot_ppc - - axes = _plot_ppc(**ppcplot_kwargs) + # TODO: Add backend kwargs + method = get_plotting_method("plot_ppc", "ppcplot", backend, {}) + axes = method(**ppcplot_kwargs) return axes diff --git a/arviz/plots/rankplot.py b/arviz/plots/rankplot.py index c9fb29d9e9..b115355e36 100644 --- a/arviz/plots/rankplot.py +++ b/arviz/plots/rankplot.py @@ -2,7 +2,6 @@ from itertools import cycle import matplotlib.pyplot as plt -from .backends import check_bokeh_version from ..data import convert_to_dataset from .plot_utils import ( _scale_fig_size, @@ -10,6 +9,7 @@ default_grid, filter_plotters_list, _sturges_formula, + get_plotting_method ) from ..utils import _var_names @@ -160,16 +160,12 @@ def plot_rank( ) if backend == "bokeh": - check_bokeh_version() - from .backends.bokeh.bokeh_rankplot import _plot_rank rankplot_kwargs.pop("ax_labelsize") rankplot_kwargs.pop("titlesize") rankplot_kwargs["show"] = show - axes = _plot_rank(**rankplot_kwargs) # pylint: disable=unexpected-keyword-arg - else: - from .backends.matplotlib.mpl_rankplot import _plot_rank - - axes = _plot_rank(**rankplot_kwargs) + # TODO: Add backend kwargs + method = get_plotting_method("plot_rankjplot", "rankplot", backend, {}) + axes = method(**rankplot_kwargs) return axes diff --git a/arviz/plots/violinplot.py b/arviz/plots/violinplot.py index a6d21c9a6b..bb41a0426a 100644 --- a/arviz/plots/violinplot.py +++ b/arviz/plots/violinplot.py @@ -1,7 +1,6 @@ """Plot posterior traces as violin plot.""" -from .backends import check_bokeh_version from ..data import convert_to_dataset -from .plot_utils import _scale_fig_size, xarray_var_iter, filter_plotters_list, default_grid +from .plot_utils import _scale_fig_size, xarray_var_iter, filter_plotters_list, default_grid, get_plotting_method from ..utils import _var_names @@ -100,16 +99,12 @@ def plot_violin( ) if backend == "bokeh": - check_bokeh_version() - from .backends.bokeh.bokeh_violinplot import _plot_violin violinplot_kwargs.pop("ax_labelsize") violinplot_kwargs.pop("xt_labelsize") violinplot_kwargs["show"] = show - ax = _plot_violin(**violinplot_kwargs) # pylint: disable=unexpected-keyword-arg - else: - from .backends.matplotlib.mpl_violinplot import _plot_violin - - ax = _plot_violin(**violinplot_kwargs) + # TODO: Add backend kwargs + method = get_plotting_method("plot_violin", "violinplot", backend, {}) + ax = method(**violinplot_kwargs) return ax From bf9d5a5b9a97da0d85f25b9150f199a652e23e98 Mon Sep 17 00:00:00 2001 From: Ravin Kumar Date: Wed, 11 Dec 2019 09:38:31 -0800 Subject: [PATCH 22/41] Change backend kwargs for matplotlib --- arviz/plots/backends/matplotlib/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/arviz/plots/backends/matplotlib/__init__.py b/arviz/plots/backends/matplotlib/__init__.py index c6be8d3c81..879d415d05 100644 --- a/arviz/plots/backends/matplotlib/__init__.py +++ b/arviz/plots/backends/matplotlib/__init__.py @@ -1,5 +1,5 @@ """Matplotlib Plotting Backend.""" -BACKEND_KWARG_DEFAULTS = {"show": True, "anotherkey":"test"} +BACKEND_KWARG_DEFAULTS = {} from .autocorrplot import plot_autocorr from .compareplot import plot_compare From af6e7d29fbbe122e8dcda067e0d095bde1a6c9d8 Mon Sep 17 00:00:00 2001 From: Ravin Kumar Date: Wed, 11 Dec 2019 09:48:44 -0800 Subject: [PATCH 23/41] WIP Commit for rebase --- arviz/plots/backends/matplotlib/traceplot.py | 128 ++++++++++--------- arviz/plots/traceplot.py | 6 +- 2 files changed, 69 insertions(+), 65 deletions(-) diff --git a/arviz/plots/backends/matplotlib/traceplot.py b/arviz/plots/backends/matplotlib/traceplot.py index 34711adc05..051f341ccf 100644 --- a/arviz/plots/backends/matplotlib/traceplot.py +++ b/arviz/plots/backends/matplotlib/traceplot.py @@ -5,6 +5,7 @@ import matplotlib.pyplot as plt from matplotlib.lines import Line2D import numpy as np +from . import BACKEND_KWARG_DEFAULTS from ....data import convert_to_dataset @@ -13,24 +14,28 @@ from ....utils import _var_names from ....rcparams import rcParams +# TODO: Change this to RcParams +BACKEND_KWARG_DEFAULTS["textsize"] = 10 def plot_trace( data, - var_names=None, - coords=None, - divergences="bottom", - figsize=None, - textsize=None, - rug=False, - lines=None, - compact=False, - combined=False, - legend=False, - plot_kwargs=None, - fill_kwargs=None, - rug_kwargs=None, - hist_kwargs=None, - trace_kwargs=None, + var_names, + divergences, + figsize, + rug, + lines, + combined, + legend, + plot_kwargs, + fill_kwargs, + rug_kwargs, + hist_kwargs, + trace_kwargs, + plotters, + divergence_data, + colors, + backend_kwargs + ): """Plot distribution (histogram or kernel density estimates) and sampled values. @@ -44,23 +49,16 @@ def plot_trace( Refer to documentation of az.convert_to_dataset for details var_names : string, or list of strings One or more variables to be plotted. - coords : mapping, optional - Coordinates of var_names to be plotted. Passed to `Dataset.sel` divergences : {"bottom", "top", None, False} Plot location of divergences on the traceplots. Options are "bottom", "top", or False-y. figsize : figure size tuple If None, size is (12, variables * 2) - textsize: float - Text size scaling factor for labels, titles and lines. If None it will be autoscaled based - on figsize. rug : bool If True adds a rugplot. Defaults to False. Ignored for 2D KDE. Only affects continuous variables. lines : tuple Tuple of (var_name, {'coord': selection}, [line, positions]) to be overplotted as vertical lines on the density and horizontal lines on the trace. - compact : bool - Plot multidimensional variables in a single plot. combined : bool Flag for combining multiple chains into a single line. If False (default), chains will be plotted separately. @@ -117,50 +115,54 @@ def plot_trace( >>> az.plot_trace(data, var_names=('theta_t', 'theta'), coords=coords, lines=lines) """ - - - - if compact: - skip_dims = set(data.dims) - {"chain", "draw"} - else: - skip_dims = set() - - plotters = list(xarray_var_iter(data, var_names=var_names, combined=True, skip_dims=skip_dims)) - max_plots = rcParams["plot.max_subplots"] - max_plots = len(plotters) if max_plots is None else max_plots - if len(plotters) > max_plots: - warnings.warn( - "rcParams['plot.max_subplots'] ({max_plots}) is smaller than the number " - "of variables to plot ({len_plotters}), generating only {max_plots} " - "plots".format(max_plots=max_plots, len_plotters=len(plotters)), - SyntaxWarning, - ) - plotters = plotters[:max_plots] - - if figsize is None: - figsize = (12, len(plotters) * 2) - - if trace_kwargs is None: - trace_kwargs = {} - - trace_kwargs.setdefault("alpha", 0.35) - - if hist_kwargs is None: - hist_kwargs = {} - if plot_kwargs is None: - plot_kwargs = {} - if fill_kwargs is None: - fill_kwargs = {} - if rug_kwargs is None: - rug_kwargs = {} - - hist_kwargs.setdefault("alpha", 0.35) + # Set plot default backend kwargs + if backend_kwargs is None: + backend_kwargs = {} + + backend_kwargs = {**BACKEND_KWARG_DEFAULTS, **backend_kwargs} + # if compact: + # skip_dims = set(data.dims) - {"chain", "draw"} + # else: + # skip_dims = set() + # + # plotters = list(xarray_var_iter(data, var_names=var_names, combined=True, skip_dims=skip_dims)) + # max_plots = rcParams["plot.max_subplots"] + # max_plots = len(plotters) if max_plots is None else max_plots + # if len(plotters) > max_plots: + # warnings.warn( + # "rcParams['plot.max_subplots'] ({max_plots}) is smaller than the number " + # "of variables to plot ({len_plotters}), generating only {max_plots} " + # "plots".format(max_plots=max_plots, len_plotters=len(plotters)), + # SyntaxWarning, + # ) + # plotters = plotters[:max_plots] + + # if figsize is None: + # figsize = (12, len(plotters) * 2) + # + # if trace_kwargs is None: + # trace_kwargs = {} + # + # trace_kwargs.setdefault("alpha", 0.35) + + # if hist_kwargs is None: + # hist_kwargs = {} + # if plot_kwargs is None: + # plot_kwargs = {} + # if fill_kwargs is None: + # fill_kwargs = {} + # if rug_kwargs is None: + # rug_kwargs = {} + # + # hist_kwargs.setdefault("alpha", 0.35) figsize, _, titlesize, xt_labelsize, linewidth, _ = _scale_fig_size( - figsize, textsize, rows=len(plotters), cols=2 + figsize, backend_kwargs["textsize"], rows=len(plotters), cols=2 ) - trace_kwargs.setdefault("linewidth", linewidth) - plot_kwargs.setdefault("linewidth", linewidth) + + # TODO: This is breaking plotting for some reason + # trace_kwargs.setdefault("linewidth", linewidth) + # plot_kwargs.setdefault("linewidth", linewidth) _, axes = plt.subplots( len(plotters), 2, squeeze=False, figsize=figsize, constrained_layout=True diff --git a/arviz/plots/traceplot.py b/arviz/plots/traceplot.py index 9bd740e297..5244119071 100644 --- a/arviz/plots/traceplot.py +++ b/arviz/plots/traceplot.py @@ -189,8 +189,10 @@ def plot_trace( rug_kwargs = {} figsize, _, _, _, linewidth, _ = _scale_fig_size(figsize, 10, rows=len(plotters), cols=2) - trace_kwargs.setdefault("line_width", linewidth) - plot_kwargs.setdefault("line_width", linewidth) + + # This is where the issue is + trace_kwargs.setdefault("linewidth", linewidth) + plot_kwargs.setdefault("linewidth", linewidth) # TODO: Check if this can be further simplified trace_plot_args = dict( From a4629d255cb79f13624cbbcfde504d561d6e5575 Mon Sep 17 00:00:00 2001 From: Ravin Kumar Date: Wed, 11 Dec 2019 09:55:48 -0800 Subject: [PATCH 24/41] Fix wrong name in rankplot --- arviz/plots/rankplot.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/arviz/plots/rankplot.py b/arviz/plots/rankplot.py index b115355e36..a4ab971ebb 100644 --- a/arviz/plots/rankplot.py +++ b/arviz/plots/rankplot.py @@ -166,6 +166,6 @@ def plot_rank( rankplot_kwargs["show"] = show # TODO: Add backend kwargs - method = get_plotting_method("plot_rankjplot", "rankplot", backend, {}) + method = get_plotting_method("plot_rank", "rankplot", backend, {}) axes = method(**rankplot_kwargs) return axes From 031b380b3072be7a1b5e3598ef28b63e31c7054d Mon Sep 17 00:00:00 2001 From: Ravin Kumar Date: Wed, 11 Dec 2019 09:56:03 -0800 Subject: [PATCH 25/41] Reverse linewidth to ensure bokeh tests pass --- arviz/plots/traceplot.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/arviz/plots/traceplot.py b/arviz/plots/traceplot.py index 5244119071..4c8a127c52 100644 --- a/arviz/plots/traceplot.py +++ b/arviz/plots/traceplot.py @@ -191,8 +191,8 @@ def plot_trace( figsize, _, _, _, linewidth, _ = _scale_fig_size(figsize, 10, rows=len(plotters), cols=2) # This is where the issue is - trace_kwargs.setdefault("linewidth", linewidth) - plot_kwargs.setdefault("linewidth", linewidth) + trace_kwargs.setdefault("line_width", linewidth) + plot_kwargs.setdefault("line_width", linewidth) # TODO: Check if this can be further simplified trace_plot_args = dict( From 5814b8445c477e14a7ba6f97d25d6a5b772982ce Mon Sep 17 00:00:00 2001 From: Ravin Kumar Date: Wed, 11 Dec 2019 20:57:42 -0800 Subject: [PATCH 26/41] Amazingly make all tests pass --- arviz/plots/backends/bokeh/traceplot.py | 8 +++- arviz/plots/backends/matplotlib/distplot.py | 45 +++++++++++--------- arviz/plots/backends/matplotlib/traceplot.py | 4 +- arviz/plots/traceplot.py | 4 -- 4 files changed, 33 insertions(+), 28 deletions(-) diff --git a/arviz/plots/backends/bokeh/traceplot.py b/arviz/plots/backends/bokeh/traceplot.py index cada0eec92..af4741ea3f 100644 --- a/arviz/plots/backends/bokeh/traceplot.py +++ b/arviz/plots/backends/bokeh/traceplot.py @@ -11,13 +11,12 @@ from . import BACKEND_KWARG_DEFAULTS from ...distplot import plot_dist -from ...plot_utils import xarray_var_iter, make_label +from ...plot_utils import xarray_var_iter, make_label, _scale_fig_size from ....rcparams import rcParams BACKEND_KWARG_DEFAULTS["tools"] = rcParams["plot.bokeh.tools"] BACKEND_KWARG_DEFAULTS["output_backend"] = rcParams["plot.bokeh.output_backend"] - def plot_trace( data, var_names, @@ -56,6 +55,11 @@ def plot_trace( # Used near end for whether to show plot or not, can't be passed to bkp.figure show = backend_kwargs.pop("show") + figsize, _, _, _, linewidth, _ = _scale_fig_size(figsize, 10, rows=len(plotters), cols=2) + + trace_kwargs.setdefault("line_width", linewidth) + plot_kwargs.setdefault("line_width", linewidth) + axes = [] for i in range(len(plotters)): if i != 0: diff --git a/arviz/plots/backends/matplotlib/distplot.py b/arviz/plots/backends/matplotlib/distplot.py index cf1f5aebad..8e6946a957 100644 --- a/arviz/plots/backends/matplotlib/distplot.py +++ b/arviz/plots/backends/matplotlib/distplot.py @@ -1,4 +1,5 @@ """Matplotlib distplot.""" +import warnings import matplotlib.pyplot as plt from ...kdeplot import plot_kde @@ -6,27 +7,31 @@ def plot_dist( values, - values2=None, - color="C0", - kind="auto", - cumulative=False, - label=None, - rotated=False, - rug=False, - bw=4.5, - quantiles=None, - contour=True, - fill_last=True, - textsize=None, - plot_kwargs=None, - fill_kwargs=None, - rug_kwargs=None, - contour_kwargs=None, - contourf_kwargs=None, - pcolormesh_kwargs=None, - hist_kwargs=None, - ax=None, + values2, + color, + kind, + cumulative, + label, + rotated, + rug, + bw, + quantiles, + contour, + fill_last, + textsize, + plot_kwargs, + fill_kwargs, + rug_kwargs, + contour_kwargs, + contourf_kwargs, + pcolormesh_kwargs, + hist_kwargs, + ax, + backend_kwargs ): + if backend_kwargs is not None: + warnings.warn(("Argument backend_kwargs has not effect in matplotlib.plot_dist" + "Supplied value won't be used")) if ax is None: ax = plt.gca() diff --git a/arviz/plots/backends/matplotlib/traceplot.py b/arviz/plots/backends/matplotlib/traceplot.py index 051f341ccf..d540caa472 100644 --- a/arviz/plots/backends/matplotlib/traceplot.py +++ b/arviz/plots/backends/matplotlib/traceplot.py @@ -161,8 +161,8 @@ def plot_trace( ) # TODO: This is breaking plotting for some reason - # trace_kwargs.setdefault("linewidth", linewidth) - # plot_kwargs.setdefault("linewidth", linewidth) + trace_kwargs.setdefault("linewidth", linewidth) + plot_kwargs.setdefault("linewidth", linewidth) _, axes = plt.subplots( len(plotters), 2, squeeze=False, figsize=figsize, constrained_layout=True diff --git a/arviz/plots/traceplot.py b/arviz/plots/traceplot.py index 4c8a127c52..ce2190b562 100644 --- a/arviz/plots/traceplot.py +++ b/arviz/plots/traceplot.py @@ -188,11 +188,7 @@ def plot_trace( if rug_kwargs is None: rug_kwargs = {} - figsize, _, _, _, linewidth, _ = _scale_fig_size(figsize, 10, rows=len(plotters), cols=2) - # This is where the issue is - trace_kwargs.setdefault("line_width", linewidth) - plot_kwargs.setdefault("line_width", linewidth) # TODO: Check if this can be further simplified trace_plot_args = dict( From 65186dce819eecdee3414d602a52f384ad33291e Mon Sep 17 00:00:00 2001 From: Ravin Kumar Date: Wed, 11 Dec 2019 21:01:29 -0800 Subject: [PATCH 27/41] Add black formatting --- arviz/plots/autocorrplot.py | 2 +- arviz/plots/backends/bokeh/jointplot.py | 3 +- arviz/plots/backends/bokeh/traceplot.py | 7 +++-- arviz/plots/backends/matplotlib/__init__.py | 2 +- arviz/plots/backends/matplotlib/distplot.py | 10 +++++-- arviz/plots/backends/matplotlib/traceplot.py | 4 +-- arviz/plots/densityplot.py | 2 +- arviz/plots/elpdplot.py | 7 +---- arviz/plots/essplot.py | 2 +- arviz/plots/khatplot.py | 2 +- arviz/plots/loopitplot.py | 1 - arviz/plots/mcseplot.py | 2 +- arviz/plots/plot_utils.py | 10 +++---- arviz/plots/posteriorplot.py | 2 +- arviz/plots/ppcplot.py | 2 +- arviz/plots/rankplot.py | 2 +- arviz/plots/traceplot.py | 30 +++++++++----------- arviz/plots/violinplot.py | 8 +++++- arviz/tests/test_plots_bokeh.py | 6 ++-- 19 files changed, 53 insertions(+), 51 deletions(-) diff --git a/arviz/plots/autocorrplot.py b/arviz/plots/autocorrplot.py index 656a6f486f..d15cc40199 100644 --- a/arviz/plots/autocorrplot.py +++ b/arviz/plots/autocorrplot.py @@ -8,7 +8,7 @@ xarray_var_iter, _create_axes_grid, filter_plotters_list, - get_plotting_method + get_plotting_method, ) from ..utils import _var_names diff --git a/arviz/plots/backends/bokeh/jointplot.py b/arviz/plots/backends/bokeh/jointplot.py index 40c9051181..d87276e947 100644 --- a/arviz/plots/backends/bokeh/jointplot.py +++ b/arviz/plots/backends/bokeh/jointplot.py @@ -92,9 +92,8 @@ def plot_joint( rotated=rotate, ax=ax_, backend="bokeh", - backend_kwargs={"show":False}, + backend_kwargs={"show": False}, **marginal_kwargs - ) if show: diff --git a/arviz/plots/backends/bokeh/traceplot.py b/arviz/plots/backends/bokeh/traceplot.py index af4741ea3f..483ab5df18 100644 --- a/arviz/plots/backends/bokeh/traceplot.py +++ b/arviz/plots/backends/bokeh/traceplot.py @@ -17,6 +17,7 @@ BACKEND_KWARG_DEFAULTS["tools"] = rcParams["plot.bokeh.tools"] BACKEND_KWARG_DEFAULTS["output_backend"] = rcParams["plot.bokeh.output_backend"] + def plot_trace( data, var_names, @@ -34,7 +35,7 @@ def plot_trace( plotters, divergence_data, colors, - backend_kwargs: [Dict] + backend_kwargs: [Dict], ): # If divergences are plotted they must be provided @@ -301,7 +302,7 @@ def _plot_chains_bokeh( fill_kwargs=fill_kwargs, rug_kwargs=rug_kwargs, backend="bokeh", - backend_kwargs={"show":False} + backend_kwargs={"show": False}, ) if combined: @@ -318,5 +319,5 @@ def _plot_chains_bokeh( fill_kwargs=fill_kwargs, rug_kwargs=rug_kwargs, backend="bokeh", - backend_kwargs={"show":False} + backend_kwargs={"show": False}, ) diff --git a/arviz/plots/backends/matplotlib/__init__.py b/arviz/plots/backends/matplotlib/__init__.py index 879d415d05..18fe8cff50 100644 --- a/arviz/plots/backends/matplotlib/__init__.py +++ b/arviz/plots/backends/matplotlib/__init__.py @@ -21,4 +21,4 @@ from .ppcplot import plot_ppc from .rankplot import plot_rank from .traceplot import plot_trace -from .violinplot import plot_violin \ No newline at end of file +from .violinplot import plot_violin diff --git a/arviz/plots/backends/matplotlib/distplot.py b/arviz/plots/backends/matplotlib/distplot.py index 8e6946a957..bc7ea0bc6c 100644 --- a/arviz/plots/backends/matplotlib/distplot.py +++ b/arviz/plots/backends/matplotlib/distplot.py @@ -27,11 +27,15 @@ def plot_dist( pcolormesh_kwargs, hist_kwargs, ax, - backend_kwargs + backend_kwargs, ): if backend_kwargs is not None: - warnings.warn(("Argument backend_kwargs has not effect in matplotlib.plot_dist" - "Supplied value won't be used")) + warnings.warn( + ( + "Argument backend_kwargs has not effect in matplotlib.plot_dist" + "Supplied value won't be used" + ) + ) if ax is None: ax = plt.gca() diff --git a/arviz/plots/backends/matplotlib/traceplot.py b/arviz/plots/backends/matplotlib/traceplot.py index d540caa472..bff5674429 100644 --- a/arviz/plots/backends/matplotlib/traceplot.py +++ b/arviz/plots/backends/matplotlib/traceplot.py @@ -17,6 +17,7 @@ # TODO: Change this to RcParams BACKEND_KWARG_DEFAULTS["textsize"] = 10 + def plot_trace( data, var_names, @@ -34,8 +35,7 @@ def plot_trace( plotters, divergence_data, colors, - backend_kwargs - + backend_kwargs, ): """Plot distribution (histogram or kernel density estimates) and sampled values. diff --git a/arviz/plots/densityplot.py b/arviz/plots/densityplot.py index 416551566a..1fd6b49574 100644 --- a/arviz/plots/densityplot.py +++ b/arviz/plots/densityplot.py @@ -11,7 +11,7 @@ xarray_var_iter, default_grid, _create_axes_grid, - get_plotting_method + get_plotting_method, ) from ..utils import _var_names from ..rcparams import rcParams diff --git a/arviz/plots/elpdplot.py b/arviz/plots/elpdplot.py index 867030e441..2b775abbee 100644 --- a/arviz/plots/elpdplot.py +++ b/arviz/plots/elpdplot.py @@ -5,12 +5,7 @@ from matplotlib.lines import Line2D from ..data import convert_to_inference_data -from .plot_utils import ( - get_coords, - format_coords_as_labels, - color_from_dim, - get_plotting_method -) +from .plot_utils import get_coords, format_coords_as_labels, color_from_dim, get_plotting_method from ..stats import waic, loo, ELPDData from ..rcparams import rcParams diff --git a/arviz/plots/essplot.py b/arviz/plots/essplot.py index 985df10ed9..371dde1534 100644 --- a/arviz/plots/essplot.py +++ b/arviz/plots/essplot.py @@ -10,7 +10,7 @@ default_grid, get_coords, filter_plotters_list, - get_plotting_method + get_plotting_method, ) from ..utils import _var_names diff --git a/arviz/plots/khatplot.py b/arviz/plots/khatplot.py index 610c863bbf..26b8425a20 100644 --- a/arviz/plots/khatplot.py +++ b/arviz/plots/khatplot.py @@ -10,7 +10,7 @@ get_coords, color_from_dim, format_coords_as_labels, - get_plotting_method + get_plotting_method, ) from ..stats import ELPDData diff --git a/arviz/plots/loopitplot.py b/arviz/plots/loopitplot.py index 065b670b93..42e5c84758 100644 --- a/arviz/plots/loopitplot.py +++ b/arviz/plots/loopitplot.py @@ -247,7 +247,6 @@ def plot_loo_pit( loo_pit_kwargs.pop("credible_interval") loo_pit_kwargs["show"] = show - # TODO: Add backend kwargs method = get_plotting_method("plot_loo_pit", "loopitplot", backend, {}) axes = method(**loo_pit_kwargs) diff --git a/arviz/plots/mcseplot.py b/arviz/plots/mcseplot.py index 0c659f86b1..2a4d95d040 100644 --- a/arviz/plots/mcseplot.py +++ b/arviz/plots/mcseplot.py @@ -10,7 +10,7 @@ default_grid, get_coords, filter_plotters_list, - get_plotting_method + get_plotting_method, ) from ..utils import _var_names diff --git a/arviz/plots/plot_utils.py b/arviz/plots/plot_utils.py index 27f8d89fab..5cccd88061 100644 --- a/arviz/plots/plot_utils.py +++ b/arviz/plots/plot_utils.py @@ -646,11 +646,13 @@ def get_plotting_method(plot_name, plot_module, backend, user_backend_kwargs): if backend == "bokeh": try: import bokeh + assert packaging.version.parse(bokeh.__version__) >= packaging.version.parse("1.4.0") except (ImportError, AssertionError): - raise ImportError("'bokeh' backend needs Bokeh (1.4.0+) installed." - " Please upgrade or install") + raise ImportError( + "'bokeh' backend needs Bokeh (1.4.0+) installed." " Please upgrade or install" + ) # Perform import of plotting method # TODO: Convert module import to top level for all plots @@ -660,8 +662,6 @@ def get_plotting_method(plot_name, plot_module, backend, user_backend_kwargs): ) ) - plotting_method = getattr( - module, plot_name - ) + plotting_method = getattr(module, plot_name) return plotting_method diff --git a/arviz/plots/posteriorplot.py b/arviz/plots/posteriorplot.py index e5fc5195f4..c97eb0507a 100644 --- a/arviz/plots/posteriorplot.py +++ b/arviz/plots/posteriorplot.py @@ -8,7 +8,7 @@ default_grid, get_coords, filter_plotters_list, - get_plotting_method + get_plotting_method, ) from ..utils import _var_names diff --git a/arviz/plots/ppcplot.py b/arviz/plots/ppcplot.py index b47d964eea..72d62e2a60 100644 --- a/arviz/plots/ppcplot.py +++ b/arviz/plots/ppcplot.py @@ -9,7 +9,7 @@ _scale_fig_size, default_grid, filter_plotters_list, - get_plotting_method + get_plotting_method, ) from ..utils import _var_names diff --git a/arviz/plots/rankplot.py b/arviz/plots/rankplot.py index a4ab971ebb..8a7e1c84b7 100644 --- a/arviz/plots/rankplot.py +++ b/arviz/plots/rankplot.py @@ -9,7 +9,7 @@ default_grid, filter_plotters_list, _sturges_formula, - get_plotting_method + get_plotting_method, ) from ..utils import _var_names diff --git a/arviz/plots/traceplot.py b/arviz/plots/traceplot.py index ce2190b562..5277eb7507 100644 --- a/arviz/plots/traceplot.py +++ b/arviz/plots/traceplot.py @@ -125,7 +125,7 @@ def plot_trace( try: divergence_data = convert_to_dataset(data, group="sample_stats").diverging except (ValueError, AttributeError): # No sample_stats, or no `.diverging` - divergences=False + divergences = False if coords is None: coords = {} @@ -188,37 +188,33 @@ def plot_trace( if rug_kwargs is None: rug_kwargs = {} - - # TODO: Check if this can be further simplified trace_plot_args = dict( # User Kwargs data=data, var_names=var_names, # coords = coords, - divergences = divergences, - figsize = figsize, + divergences=divergences, + figsize=figsize, rug=rug, - lines = lines, - plot_kwargs = plot_kwargs, - fill_kwargs = fill_kwargs, - rug_kwargs = rug_kwargs, - hist_kwargs = hist_kwargs, - trace_kwargs = trace_kwargs, + lines=lines, + plot_kwargs=plot_kwargs, + fill_kwargs=fill_kwargs, + rug_kwargs=rug_kwargs, + hist_kwargs=hist_kwargs, + trace_kwargs=trace_kwargs, # compact = compact, - combined = combined, - legend = legend, - + combined=combined, + legend=legend, # Generated kwargs - divergence_data = divergence_data, + divergence_data=divergence_data, # skip_dims=skip_dims, plotters=plotters, colors=colors, - backend_kwargs=backend_kwargs + backend_kwargs=backend_kwargs, ) method = get_plotting_method("plot_trace", "traceplot", backend, backend_kwargs) axes = method(**trace_plot_args) return axes - diff --git a/arviz/plots/violinplot.py b/arviz/plots/violinplot.py index bb41a0426a..9ecc062651 100644 --- a/arviz/plots/violinplot.py +++ b/arviz/plots/violinplot.py @@ -1,6 +1,12 @@ """Plot posterior traces as violin plot.""" from ..data import convert_to_dataset -from .plot_utils import _scale_fig_size, xarray_var_iter, filter_plotters_list, default_grid, get_plotting_method +from .plot_utils import ( + _scale_fig_size, + xarray_var_iter, + filter_plotters_list, + default_grid, + get_plotting_method, +) from ..utils import _var_names diff --git a/arviz/tests/test_plots_bokeh.py b/arviz/tests/test_plots_bokeh.py index 4743431de8..2b9ecd7746 100644 --- a/arviz/tests/test_plots_bokeh.py +++ b/arviz/tests/test_plots_bokeh.py @@ -76,7 +76,7 @@ def get_ax(): @pytest.fixture(scope="session") def backend_kwargs(): - return {"show":False} + return {"show": False} @pytest.mark.parametrize( @@ -189,7 +189,9 @@ def test_plot_kde_cumulative(continuous_model, kwargs): @pytest.mark.parametrize("kwargs", [{"kind": "hist"}, {"kind": "kde"}]) def test_plot_dist(continuous_model, kwargs, backend_kwargs): - axes = plot_dist(continuous_model["x"], backend="bokeh", backend_kwargs=backend_kwargs, **kwargs) + axes = plot_dist( + continuous_model["x"], backend="bokeh", backend_kwargs=backend_kwargs, **kwargs + ) assert axes From 0b3fabd6586d03c1914daa0f2c113439aefb6ebc Mon Sep 17 00:00:00 2001 From: Ravin Kumar Date: Wed, 11 Dec 2019 21:15:00 -0800 Subject: [PATCH 28/41] Fix broken tests --- arviz/tests/test_plot_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/arviz/tests/test_plot_utils.py b/arviz/tests/test_plot_utils.py index 1f887b78bf..f4c1d8c886 100644 --- a/arviz/tests/test_plot_utils.py +++ b/arviz/tests/test_plot_utils.py @@ -197,6 +197,6 @@ def test_bokeh_import(): """Tests that correct method is returned on bokeh import""" method = get_plotting_method("plot_dist", "distplot", "bokeh") - from arviz.plots.backends.bokeh.bokeh_distplot import _plot_dist_bokeh + from arviz.plots.backends.bokeh.distplot import plot_dist - assert method is _plot_dist_bokeh + assert method is plot_dist From 7b69a0be5a6fa1c30c03acc5010caa46b6ec4c1a Mon Sep 17 00:00:00 2001 From: Ravin Kumar Date: Wed, 11 Dec 2019 21:28:12 -0800 Subject: [PATCH 29/41] Various fixes to tests and linting --- arviz/plots/autocorrplot.py | 2 +- arviz/plots/backends/bokeh/__init__.py | 2 +- arviz/plots/backends/bokeh/distplot.py | 2 +- arviz/plots/backends/bokeh/posteriorplot.py | 1 - arviz/plots/backends/matplotlib/__init__.py | 1 + arviz/plots/backends/matplotlib/traceplot.py | 44 +------------------- arviz/plots/compareplot.py | 2 +- arviz/plots/densityplot.py | 2 +- arviz/plots/distplot.py | 2 +- arviz/plots/elpdplot.py | 2 +- arviz/plots/energyplot.py | 2 +- arviz/plots/essplot.py | 2 +- arviz/plots/forestplot.py | 2 +- arviz/plots/hpdplot.py | 2 +- arviz/plots/jointplot.py | 2 +- arviz/plots/kdeplot.py | 2 +- arviz/plots/khatplot.py | 2 +- arviz/plots/loopitplot.py | 2 +- arviz/plots/mcseplot.py | 2 +- arviz/plots/pairplot.py | 2 +- arviz/plots/parallelplot.py | 2 +- arviz/plots/plot_utils.py | 2 +- arviz/plots/posteriorplot.py | 2 +- arviz/plots/ppcplot.py | 2 +- arviz/plots/rankplot.py | 2 +- arviz/plots/traceplot.py | 7 +--- arviz/plots/violinplot.py | 2 +- arviz/tests/test_plots_bokeh.py | 4 +- 28 files changed, 29 insertions(+), 74 deletions(-) diff --git a/arviz/plots/autocorrplot.py b/arviz/plots/autocorrplot.py index d15cc40199..0fc2cceff5 100644 --- a/arviz/plots/autocorrplot.py +++ b/arviz/plots/autocorrplot.py @@ -143,7 +143,7 @@ def plot_autocorr( autocorr_plot_args["show"] = show # TODO: Add backend kwargs - method = get_plotting_method("plot_autocorr", "autocorrplot", backend, {}) + method = get_plotting_method("plot_autocorr", "autocorrplot", backend) axes = method(**autocorr_plot_args) return axes diff --git a/arviz/plots/backends/bokeh/__init__.py b/arviz/plots/backends/bokeh/__init__.py index 234c835d8d..d7abe8afc7 100644 --- a/arviz/plots/backends/bokeh/__init__.py +++ b/arviz/plots/backends/bokeh/__init__.py @@ -1,4 +1,4 @@ -# pylint: disable=no-member,invalid-name,redefined-outer-name +# pylint: disable=no-member,invalid-name,redefined-outer-name, wrong-import-position """Bokeh Plotting Backend.""" import packaging diff --git a/arviz/plots/backends/bokeh/distplot.py b/arviz/plots/backends/bokeh/distplot.py index 32ba626764..49d4dcec2d 100644 --- a/arviz/plots/backends/bokeh/distplot.py +++ b/arviz/plots/backends/bokeh/distplot.py @@ -30,7 +30,7 @@ def plot_dist( hist_kwargs=None, ax=None, backend_kwargs=None, - **kwargs + **kwargs # pylint: disable=unused-argument ): if backend_kwargs is None: diff --git a/arviz/plots/backends/bokeh/posteriorplot.py b/arviz/plots/backends/bokeh/posteriorplot.py index cc7963bb67..c7df820c7e 100644 --- a/arviz/plots/backends/bokeh/posteriorplot.py +++ b/arviz/plots/backends/bokeh/posteriorplot.py @@ -1,5 +1,4 @@ """Bokeh Plot posterior densities.""" -from typing import Optional from numbers import Number from typing import Optional diff --git a/arviz/plots/backends/matplotlib/__init__.py b/arviz/plots/backends/matplotlib/__init__.py index 18fe8cff50..0672d43631 100644 --- a/arviz/plots/backends/matplotlib/__init__.py +++ b/arviz/plots/backends/matplotlib/__init__.py @@ -1,3 +1,4 @@ +# pylint: disable= wrong-import-position """Matplotlib Plotting Backend.""" BACKEND_KWARG_DEFAULTS = {} diff --git a/arviz/plots/backends/matplotlib/traceplot.py b/arviz/plots/backends/matplotlib/traceplot.py index bff5674429..3b285b3104 100644 --- a/arviz/plots/backends/matplotlib/traceplot.py +++ b/arviz/plots/backends/matplotlib/traceplot.py @@ -1,6 +1,4 @@ """Matplotlib Traceplot.""" -from itertools import cycle -import warnings import matplotlib.pyplot as plt from matplotlib.lines import Line2D @@ -8,11 +6,8 @@ from . import BACKEND_KWARG_DEFAULTS -from ....data import convert_to_dataset from ...distplot import plot_dist -from ...plot_utils import _scale_fig_size, get_bins, xarray_var_iter, make_label, get_coords -from ....utils import _var_names -from ....rcparams import rcParams +from ...plot_utils import _scale_fig_size, get_bins, make_label # TODO: Change this to RcParams BACKEND_KWARG_DEFAULTS["textsize"] = 10 @@ -20,7 +15,7 @@ def plot_trace( data, - var_names, + var_names, # pylint: disable=unused-argument divergences, figsize, rug, @@ -120,41 +115,6 @@ def plot_trace( backend_kwargs = {} backend_kwargs = {**BACKEND_KWARG_DEFAULTS, **backend_kwargs} - # if compact: - # skip_dims = set(data.dims) - {"chain", "draw"} - # else: - # skip_dims = set() - # - # plotters = list(xarray_var_iter(data, var_names=var_names, combined=True, skip_dims=skip_dims)) - # max_plots = rcParams["plot.max_subplots"] - # max_plots = len(plotters) if max_plots is None else max_plots - # if len(plotters) > max_plots: - # warnings.warn( - # "rcParams['plot.max_subplots'] ({max_plots}) is smaller than the number " - # "of variables to plot ({len_plotters}), generating only {max_plots} " - # "plots".format(max_plots=max_plots, len_plotters=len(plotters)), - # SyntaxWarning, - # ) - # plotters = plotters[:max_plots] - - # if figsize is None: - # figsize = (12, len(plotters) * 2) - # - # if trace_kwargs is None: - # trace_kwargs = {} - # - # trace_kwargs.setdefault("alpha", 0.35) - - # if hist_kwargs is None: - # hist_kwargs = {} - # if plot_kwargs is None: - # plot_kwargs = {} - # if fill_kwargs is None: - # fill_kwargs = {} - # if rug_kwargs is None: - # rug_kwargs = {} - # - # hist_kwargs.setdefault("alpha", 0.35) figsize, _, titlesize, xt_labelsize, linewidth, _ = _scale_fig_size( figsize, backend_kwargs["textsize"], rows=len(plotters), cols=2 diff --git a/arviz/plots/compareplot.py b/arviz/plots/compareplot.py index 91b7416a77..4b04bb3119 100644 --- a/arviz/plots/compareplot.py +++ b/arviz/plots/compareplot.py @@ -136,7 +136,7 @@ def plot_compare( compareplot_kwargs["show"] = show # TODO: Add backend kwargs - method = get_plotting_method("plot_compare", "compareplot", backend, {}) + method = get_plotting_method("plot_compare", "compareplot", backend) ax = method(**compareplot_kwargs) return ax diff --git a/arviz/plots/densityplot.py b/arviz/plots/densityplot.py index 1fd6b49574..9e39ad1e3e 100644 --- a/arviz/plots/densityplot.py +++ b/arviz/plots/densityplot.py @@ -246,6 +246,6 @@ def plot_density( plot_density_kwargs.pop("n_data") # TODO: Add backend kwargs - method = get_plotting_method("plot_density", "densityplot", backend, {}) + method = get_plotting_method("plot_density", "densityplot", backend) ax = method(**plot_density_kwargs) return ax diff --git a/arviz/plots/distplot.py b/arviz/plots/distplot.py index 020cf0f408..f3abc46451 100644 --- a/arviz/plots/distplot.py +++ b/arviz/plots/distplot.py @@ -188,6 +188,6 @@ def plot_dist( **kwargs, ) - method = get_plotting_method("plot_dist", "distplot", backend, backend_kwargs) + method = get_plotting_method("plot_dist", "distplot", backend) ax = method(**dist_plot_args) return ax diff --git a/arviz/plots/elpdplot.py b/arviz/plots/elpdplot.py index 2b775abbee..1f946479de 100644 --- a/arviz/plots/elpdplot.py +++ b/arviz/plots/elpdplot.py @@ -208,6 +208,6 @@ def plot_elpd( elpd_plot_kwargs["show"] = show # TODO: Add backend kwargs - method = get_plotting_method("plot_elpd", "elpdplot", backend, {}) + method = get_plotting_method("plot_elpd", "elpdplot", backend) ax = method(**elpd_plot_kwargs) return ax diff --git a/arviz/plots/energyplot.py b/arviz/plots/energyplot.py index 246661dc83..9fa9558264 100644 --- a/arviz/plots/energyplot.py +++ b/arviz/plots/energyplot.py @@ -139,6 +139,6 @@ def plot_energy( plot_energy_kwargs["legend"] = False # TODO: Add backend kwargs - method = get_plotting_method("plot_energy", "energyplot", backend, {}) + method = get_plotting_method("plot_energy", "energyplot", backend) ax = method(**plot_energy_kwargs) return ax diff --git a/arviz/plots/essplot.py b/arviz/plots/essplot.py index 371dde1534..dd3f2953f7 100644 --- a/arviz/plots/essplot.py +++ b/arviz/plots/essplot.py @@ -320,6 +320,6 @@ def plot_ess( essplot_kwargs["show"] = show # TODO: Add backend kwargs - method = get_plotting_method("plot_ess", "essplot", backend, {}) + method = get_plotting_method("plot_ess", "essplot", backend) ax = method(**essplot_kwargs) return ax diff --git a/arviz/plots/forestplot.py b/arviz/plots/forestplot.py index 7d08ce60c5..89dcf77dd5 100644 --- a/arviz/plots/forestplot.py +++ b/arviz/plots/forestplot.py @@ -178,6 +178,6 @@ def plot_forest( plot_forest_kwargs["show"] = show # TODO: Add backend kwargs - method = get_plotting_method("plot_forest", "forestplot", backend, {}) + method = get_plotting_method("plot_forest", "forestplot", backend) axes = method(**plot_forest_kwargs) return axes diff --git a/arviz/plots/hpdplot.py b/arviz/plots/hpdplot.py index 9c85f04f92..e7eda33547 100644 --- a/arviz/plots/hpdplot.py +++ b/arviz/plots/hpdplot.py @@ -107,6 +107,6 @@ def plot_hpd( hpdplot_kwargs["show"] = show # TODO: Add backend kwargs - method = get_plotting_method("plot_hpd", "hpdplot", backend, {}) + method = get_plotting_method("plot_hpd", "hpdplot", backend) ax = method(**hpdplot_kwargs) return ax diff --git a/arviz/plots/jointplot.py b/arviz/plots/jointplot.py index 646c9b1378..879c2d30e9 100644 --- a/arviz/plots/jointplot.py +++ b/arviz/plots/jointplot.py @@ -173,6 +173,6 @@ def plot_joint( plot_joint_kwargs["show"] = show # TODO: Add backend kwargs - method = get_plotting_method("plot_joint", "jointplot", backend, {}) + method = get_plotting_method("plot_joint", "jointplot", backend) axes = method(**plot_joint_kwargs) return axes diff --git a/arviz/plots/kdeplot.py b/arviz/plots/kdeplot.py index 3dbcf5afad..80f8c36727 100644 --- a/arviz/plots/kdeplot.py +++ b/arviz/plots/kdeplot.py @@ -218,7 +218,7 @@ def plot_kde( kde_plot_args.pop("textsize") # TODO: Add backend kwargs - method = get_plotting_method("plot_kde", "kdeplot", backend, {}) + method = get_plotting_method("plot_kde", "kdeplot", backend) ax = method(**kde_plot_args) return ax diff --git a/arviz/plots/khatplot.py b/arviz/plots/khatplot.py index 26b8425a20..9e309881da 100644 --- a/arviz/plots/khatplot.py +++ b/arviz/plots/khatplot.py @@ -235,6 +235,6 @@ def plot_khat( plot_khat_kwargs["show"] = show # TODO: Add backend kwargs - method = get_plotting_method("plot_khat", "khatplot", backend, {}) + method = get_plotting_method("plot_khat", "khatplot", backend) axes = method(**plot_khat_kwargs) return axes diff --git a/arviz/plots/loopitplot.py b/arviz/plots/loopitplot.py index 42e5c84758..b282728cbf 100644 --- a/arviz/plots/loopitplot.py +++ b/arviz/plots/loopitplot.py @@ -248,7 +248,7 @@ def plot_loo_pit( loo_pit_kwargs["show"] = show # TODO: Add backend kwargs - method = get_plotting_method("plot_loo_pit", "loopitplot", backend, {}) + method = get_plotting_method("plot_loo_pit", "loopitplot", backend) axes = method(**loo_pit_kwargs) return axes diff --git a/arviz/plots/mcseplot.py b/arviz/plots/mcseplot.py index 2a4d95d040..46b11584b5 100644 --- a/arviz/plots/mcseplot.py +++ b/arviz/plots/mcseplot.py @@ -192,6 +192,6 @@ def plot_mcse( mcse_kwargs["show"] = show # TODO: Add backend kwargs - method = get_plotting_method("plot_mcse", "mcseplot", backend, {}) + method = get_plotting_method("plot_mcse", "mcseplot", backend) ax = method(**mcse_kwargs) return ax diff --git a/arviz/plots/pairplot.py b/arviz/plots/pairplot.py index 78e27745be..4ec8e8fc56 100644 --- a/arviz/plots/pairplot.py +++ b/arviz/plots/pairplot.py @@ -198,6 +198,6 @@ def plot_pair( pairplot_kwargs["show"] = show # TODO: Add backend kwargs - method = get_plotting_method("plot_pair", "pairplot", backend, {}) + method = get_plotting_method("plot_pair", "pairplot", backend) ax = method(**pairplot_kwargs) return ax diff --git a/arviz/plots/parallelplot.py b/arviz/plots/parallelplot.py index 5f0b176a71..2efb3f0f87 100644 --- a/arviz/plots/parallelplot.py +++ b/arviz/plots/parallelplot.py @@ -148,7 +148,7 @@ def plot_parallel( parallel_kwargs.pop("shadend") # TODO: Add backend kwargs - method = get_plotting_method("plot_parallel", "parallelplot", backend, {}) + method = get_plotting_method("plot_parallel", "parallelplot", backend) ax = method(**parallel_kwargs) return ax diff --git a/arviz/plots/plot_utils.py b/arviz/plots/plot_utils.py index 5cccd88061..ec089cbd4a 100644 --- a/arviz/plots/plot_utils.py +++ b/arviz/plots/plot_utils.py @@ -625,7 +625,7 @@ def filter_plotters_list(plotters, plot_kind): return plotters -def get_plotting_method(plot_name, plot_module, backend, user_backend_kwargs): +def get_plotting_method(plot_name, plot_module, backend): """Returns plotting function for correct backend""" _backend = { "mpl": "matplotlib", diff --git a/arviz/plots/posteriorplot.py b/arviz/plots/posteriorplot.py index c97eb0507a..3b1ee09cbf 100644 --- a/arviz/plots/posteriorplot.py +++ b/arviz/plots/posteriorplot.py @@ -217,6 +217,6 @@ def plot_posterior( posteriorplot_kwargs["show"] = show # TODO: Add backend kwargs - method = get_plotting_method("plot_posterior", "posteriorplot", backend, {}) + method = get_plotting_method("plot_posterior", "posteriorplot", backend) ax = method(**posteriorplot_kwargs) return ax diff --git a/arviz/plots/ppcplot.py b/arviz/plots/ppcplot.py index 72d62e2a60..29c490c6e1 100644 --- a/arviz/plots/ppcplot.py +++ b/arviz/plots/ppcplot.py @@ -300,6 +300,6 @@ def plot_ppc( ppcplot_kwargs["show"] = show # TODO: Add backend kwargs - method = get_plotting_method("plot_ppc", "ppcplot", backend, {}) + method = get_plotting_method("plot_ppc", "ppcplot", backend) axes = method(**ppcplot_kwargs) return axes diff --git a/arviz/plots/rankplot.py b/arviz/plots/rankplot.py index 8a7e1c84b7..1235611a3f 100644 --- a/arviz/plots/rankplot.py +++ b/arviz/plots/rankplot.py @@ -166,6 +166,6 @@ def plot_rank( rankplot_kwargs["show"] = show # TODO: Add backend kwargs - method = get_plotting_method("plot_rank", "rankplot", backend, {}) + method = get_plotting_method("plot_rank", "rankplot", backend) axes = method(**rankplot_kwargs) return axes diff --git a/arviz/plots/traceplot.py b/arviz/plots/traceplot.py index 5277eb7507..8841135288 100644 --- a/arviz/plots/traceplot.py +++ b/arviz/plots/traceplot.py @@ -16,7 +16,6 @@ def plot_trace( coords=None, divergences="bottom", figsize=None, - textsize=None, rug=False, lines=None, compact=False, @@ -29,7 +28,6 @@ def plot_trace( trace_kwargs=None, backend=None, backend_kwargs=None, - **kwargs ): """Plot distribution (histogram or kernel density estimates) and sampled values. @@ -49,9 +47,6 @@ def plot_trace( Plot location of divergences on the traceplots. Options are "bottom", "top", or False-y. figsize : figure size tuple If None, size is (12, variables * 2) - textsize: float - Text size scaling factor for labels, titles and lines. If None it will be autoscaled based - on figsize. Not implemented for bokeh backend. rug : bool If True adds a rugplot. Defaults to False. Ignored for 2D KDE. Only affects continuous variables. @@ -214,7 +209,7 @@ def plot_trace( backend_kwargs=backend_kwargs, ) - method = get_plotting_method("plot_trace", "traceplot", backend, backend_kwargs) + method = get_plotting_method("plot_trace", "traceplot", backend) axes = method(**trace_plot_args) return axes diff --git a/arviz/plots/violinplot.py b/arviz/plots/violinplot.py index 9ecc062651..3239d926d2 100644 --- a/arviz/plots/violinplot.py +++ b/arviz/plots/violinplot.py @@ -111,6 +111,6 @@ def plot_violin( violinplot_kwargs["show"] = show # TODO: Add backend kwargs - method = get_plotting_method("plot_violin", "violinplot", backend, {}) + method = get_plotting_method("plot_violin", "violinplot", backend) ax = method(**violinplot_kwargs) return ax diff --git a/arviz/tests/test_plots_bokeh.py b/arviz/tests/test_plots_bokeh.py index 2b9ecd7746..9ab53845c5 100644 --- a/arviz/tests/test_plots_bokeh.py +++ b/arviz/tests/test_plots_bokeh.py @@ -141,14 +141,14 @@ def test_plot_trace(models, kwargs, backend_kwargs): def test_plot_trace_discrete(discrete_model): - axes = plot_trace(discrete_model, backend="bokeh", show=False) + axes = plot_trace(discrete_model, backend="bokeh") assert axes.shape def test_plot_trace_max_subplots_warning(models): with pytest.warns(SyntaxWarning): with rc_context(rc={"plot.max_subplots": 1}): - axes = plot_trace(models.model_1, backend="bokeh", show=False) + axes = plot_trace(models.model_1, backend="bokeh") assert axes.shape From 6606d5cbec608b732dab3506d2eadb6b2d656f29 Mon Sep 17 00:00:00 2001 From: Ravin Kumar Date: Thu, 12 Dec 2019 20:49:19 -0800 Subject: [PATCH 30/41] Fix docstrings --- arviz/plots/backends/bokeh/autocorrplot.py | 1 + arviz/plots/backends/bokeh/compareplot.py | 2 +- arviz/plots/backends/bokeh/densityplot.py | 1 + arviz/plots/backends/bokeh/distplot.py | 1 + arviz/plots/backends/bokeh/elpdplot.py | 1 + arviz/plots/backends/bokeh/energyplot.py | 1 + arviz/plots/backends/bokeh/essplot.py | 1 + arviz/plots/backends/bokeh/forestplot.py | 1 + arviz/plots/backends/bokeh/hpdplot.py | 2 ++ arviz/plots/backends/bokeh/jointplot.py | 1 + arviz/plots/backends/bokeh/kdeplot.py | 1 + arviz/plots/backends/bokeh/khatplot.py | 1 + arviz/plots/backends/bokeh/loopitplot.py | 2 +- arviz/plots/backends/bokeh/mcseplot.py | 1 + arviz/plots/backends/bokeh/pairplot.py | 1 + arviz/plots/backends/bokeh/parallelplot.py | 1 + arviz/plots/backends/bokeh/posteriorplot.py | 1 + arviz/plots/backends/bokeh/ppcplot.py | 1 + arviz/plots/backends/bokeh/rankplot.py | 2 +- arviz/plots/backends/bokeh/traceplot.py | 2 +- arviz/plots/backends/bokeh/violinplot.py | 2 +- arviz/plots/backends/matplotlib/autocorrplot.py | 1 + arviz/plots/backends/matplotlib/compareplot.py | 1 + arviz/plots/backends/matplotlib/densityplot.py | 1 + arviz/plots/backends/matplotlib/distplot.py | 1 + arviz/plots/backends/matplotlib/elpdplot.py | 1 + arviz/plots/backends/matplotlib/energyplot.py | 1 + arviz/plots/backends/matplotlib/essplot.py | 1 + arviz/plots/backends/matplotlib/forestplot.py | 1 + arviz/plots/backends/matplotlib/hpdplot.py | 1 + arviz/plots/backends/matplotlib/jointplot.py | 1 + arviz/plots/backends/matplotlib/kdeplot.py | 1 + arviz/plots/backends/matplotlib/khatplot.py | 1 + arviz/plots/backends/matplotlib/loopitplot.py | 2 +- arviz/plots/backends/matplotlib/mcseplot.py | 1 + arviz/plots/backends/matplotlib/pairplot.py | 1 + arviz/plots/backends/matplotlib/parallelplot.py | 1 + arviz/plots/backends/matplotlib/posteriorplot.py | 1 + arviz/plots/backends/matplotlib/ppcplot.py | 1 + arviz/plots/backends/matplotlib/rankplot.py | 2 +- arviz/plots/backends/matplotlib/traceplot.py | 4 ++-- arviz/plots/backends/matplotlib/violinplot.py | 1 + arviz/plots/plot_utils.py | 2 +- arviz/plots/traceplot.py | 2 -- arviz/tests/test_plots_bokeh.py | 8 ++++---- 45 files changed, 49 insertions(+), 16 deletions(-) diff --git a/arviz/plots/backends/bokeh/autocorrplot.py b/arviz/plots/backends/bokeh/autocorrplot.py index 4c2acbb45b..2bec597404 100644 --- a/arviz/plots/backends/bokeh/autocorrplot.py +++ b/arviz/plots/backends/bokeh/autocorrplot.py @@ -11,6 +11,7 @@ def plot_autocorr( axes, plotters, max_lag, line_width, combined=False, show=True, ): + """Bokeh autocorrelation plot.""" for (var_name, selection, x), ax_ in zip(plotters, axes.flatten()): x_prime = x if combined: diff --git a/arviz/plots/backends/bokeh/compareplot.py b/arviz/plots/backends/bokeh/compareplot.py index 9d0ce2ff90..9925c0000f 100644 --- a/arviz/plots/backends/bokeh/compareplot.py +++ b/arviz/plots/backends/bokeh/compareplot.py @@ -20,7 +20,7 @@ def plot_compare( step, show, ): - + """Bokeh compareplot.""" if ax is None: tools = rcParams["plot.bokeh.tools"] output_backend = rcParams["plot.bokeh.output_backend"] diff --git a/arviz/plots/backends/bokeh/densityplot.py b/arviz/plots/backends/bokeh/densityplot.py index e1f579eabf..228d5beba7 100644 --- a/arviz/plots/backends/bokeh/densityplot.py +++ b/arviz/plots/backends/bokeh/densityplot.py @@ -26,6 +26,7 @@ def plot_density( data_labels, show, ): + """Bokeh density plot.""" axis_map = {label: ax_ for label, ax_ in zip(all_labels, ax.flatten())} if data_labels is None: data_labels = {} diff --git a/arviz/plots/backends/bokeh/distplot.py b/arviz/plots/backends/bokeh/distplot.py index 49d4dcec2d..07c6db66b1 100644 --- a/arviz/plots/backends/bokeh/distplot.py +++ b/arviz/plots/backends/bokeh/distplot.py @@ -32,6 +32,7 @@ def plot_dist( backend_kwargs=None, **kwargs # pylint: disable=unused-argument ): + """Bokeh distplot.""" if backend_kwargs is None: backend_kwargs = {} diff --git a/arviz/plots/backends/bokeh/elpdplot.py b/arviz/plots/backends/bokeh/elpdplot.py index e535c4235a..3387f47c42 100644 --- a/arviz/plots/backends/bokeh/elpdplot.py +++ b/arviz/plots/backends/bokeh/elpdplot.py @@ -25,6 +25,7 @@ def plot_elpd( threshold, show, ): + """Bokeh elpd plot.""" if numvars == 2: (figsize, _, _, _, _, markersize) = _scale_fig_size( figsize, textsize, numvars - 1, numvars - 1 diff --git a/arviz/plots/backends/bokeh/energyplot.py b/arviz/plots/backends/bokeh/energyplot.py index 3c62362ee5..bdb046254d 100644 --- a/arviz/plots/backends/bokeh/energyplot.py +++ b/arviz/plots/backends/bokeh/energyplot.py @@ -11,6 +11,7 @@ def plot_energy( ax, series, energy, kind, bfmi, figsize, line_width, fill_kwargs, plot_kwargs, bw, legend, show, ): + """Bokeh energy plot.""" if ax is None: tools = rcParams["plot.bokeh.tools"] output_backend = rcParams["plot.bokeh.output_backend"] diff --git a/arviz/plots/backends/bokeh/essplot.py b/arviz/plots/backends/bokeh/essplot.py index cd4a375412..cfba9b76b0 100644 --- a/arviz/plots/backends/bokeh/essplot.py +++ b/arviz/plots/backends/bokeh/essplot.py @@ -47,6 +47,7 @@ def plot_ess( hline_kwargs, show, ): + """Bokeh essplot.""" if ax is None: _, ax = _create_axes_grid( len(plotters), diff --git a/arviz/plots/backends/bokeh/forestplot.py b/arviz/plots/backends/bokeh/forestplot.py index 8d38930401..e1f37acfc0 100644 --- a/arviz/plots/backends/bokeh/forestplot.py +++ b/arviz/plots/backends/bokeh/forestplot.py @@ -51,6 +51,7 @@ def plot_forest( r_hat, show, ): + """Bokeh forest plot.""" plot_handler = PlotHandler( datasets, var_names=var_names, model_names=model_names, combined=combined, colors=colors ) diff --git a/arviz/plots/backends/bokeh/hpdplot.py b/arviz/plots/backends/bokeh/hpdplot.py index b515114e84..d5fd5eb8b5 100644 --- a/arviz/plots/backends/bokeh/hpdplot.py +++ b/arviz/plots/backends/bokeh/hpdplot.py @@ -9,6 +9,8 @@ def plot_hpd(ax, x_data, y_data, plot_kwargs, fill_kwargs, show): + """Bokeh hpd plot.""" + if ax is None: tools = rcParams["plot.bokeh.tools"] output_backend = rcParams["plot.bokeh.output_backend"] diff --git a/arviz/plots/backends/bokeh/jointplot.py b/arviz/plots/backends/bokeh/jointplot.py index d87276e947..0022951a21 100644 --- a/arviz/plots/backends/bokeh/jointplot.py +++ b/arviz/plots/backends/bokeh/jointplot.py @@ -22,6 +22,7 @@ def plot_joint( marginal_kwargs, show, ): + """Bokeh joint plot.""" if ax is None: tools = rcParams["plot.bokeh.tools"] output_backend = rcParams["plot.bokeh.output_backend"] diff --git a/arviz/plots/backends/bokeh/kdeplot.py b/arviz/plots/backends/bokeh/kdeplot.py index d40e1ca20d..33b6b7f95c 100644 --- a/arviz/plots/backends/bokeh/kdeplot.py +++ b/arviz/plots/backends/bokeh/kdeplot.py @@ -42,6 +42,7 @@ def plot_kde( legend=True, show=True, ): + """Bokeh kde plot.""" if ax is None: tools = rcParams["plot.bokeh.tools"] output_backend = rcParams["plot.bokeh.output_backend"] diff --git a/arviz/plots/backends/bokeh/khatplot.py b/arviz/plots/backends/bokeh/khatplot.py index 5560caf744..e44d966fc3 100644 --- a/arviz/plots/backends/bokeh/khatplot.py +++ b/arviz/plots/backends/bokeh/khatplot.py @@ -23,6 +23,7 @@ def plot_khat( bin_format, show, ): + """Bokeh khat plot.""" if ax is None: tools = rcParams["plot.bokeh.tools"] output_backend = rcParams["plot.bokeh.output_backend"] diff --git a/arviz/plots/backends/bokeh/loopitplot.py b/arviz/plots/backends/bokeh/loopitplot.py index e001a1a19b..33c789ad9c 100644 --- a/arviz/plots/backends/bokeh/loopitplot.py +++ b/arviz/plots/backends/bokeh/loopitplot.py @@ -29,7 +29,7 @@ def plot_loo_pit( plot_kwargs, show, ): - + """Bokeh loo pit plot.""" if ax is None: tools = rcParams["plot.bokeh.tools"] output_backend = rcParams["plot.bokeh.output_backend"] diff --git a/arviz/plots/backends/bokeh/mcseplot.py b/arviz/plots/backends/bokeh/mcseplot.py index b2d35be9b2..07f9fd30be 100644 --- a/arviz/plots/backends/bokeh/mcseplot.py +++ b/arviz/plots/backends/bokeh/mcseplot.py @@ -36,6 +36,7 @@ def plot_mcse( _linewidth, show, ): + """Bokeh mcse plot.""" if ax is None: _, ax = _create_axes_grid(length_plotters, rows, cols, figsize=figsize, backend="bokeh") diff --git a/arviz/plots/backends/bokeh/pairplot.py b/arviz/plots/backends/bokeh/pairplot.py index c52319b618..a9bc79d76c 100644 --- a/arviz/plots/backends/bokeh/pairplot.py +++ b/arviz/plots/backends/bokeh/pairplot.py @@ -28,6 +28,7 @@ def plot_pair( flat_var_names, show, ): + """Bokeh pair plot.""" if numvars == 2: (figsize, _, _, _, _, _) = _scale_fig_size(figsize, textsize, numvars - 1, numvars - 1) diff --git a/arviz/plots/backends/bokeh/parallelplot.py b/arviz/plots/backends/bokeh/parallelplot.py index e28133ad1a..d171e7b962 100644 --- a/arviz/plots/backends/bokeh/parallelplot.py +++ b/arviz/plots/backends/bokeh/parallelplot.py @@ -7,6 +7,7 @@ def plot_parallel(ax, diverging_mask, _posterior, var_names, figsize, show): + """Bokeh parallel plot.""" if ax is None: tools = rcParams["plot.bokeh.tools"] output_backend = rcParams["plot.bokeh.output_backend"] diff --git a/arviz/plots/backends/bokeh/posteriorplot.py b/arviz/plots/backends/bokeh/posteriorplot.py index c7df820c7e..109e1d4b29 100644 --- a/arviz/plots/backends/bokeh/posteriorplot.py +++ b/arviz/plots/backends/bokeh/posteriorplot.py @@ -39,6 +39,7 @@ def plot_posterior( kwargs, show, ): + """Bokeh posterior plot.""" if ax is None: _, ax = _create_axes_grid( length_plotters, rows, cols, figsize=figsize, squeeze=False, backend="bokeh" diff --git a/arviz/plots/backends/bokeh/ppcplot.py b/arviz/plots/backends/bokeh/ppcplot.py index 5bdb24889c..8e991510c6 100644 --- a/arviz/plots/backends/bokeh/ppcplot.py +++ b/arviz/plots/backends/bokeh/ppcplot.py @@ -31,6 +31,7 @@ def plot_ppc( show, num_pp_samples, ): + """Bokeh ppc plot.""" if ax is None: _, axes = _create_axes_grid(length_plotters, rows, cols, figsize=figsize, backend="bokeh") else: diff --git a/arviz/plots/backends/bokeh/rankplot.py b/arviz/plots/backends/bokeh/rankplot.py index 9fb5a5bc9a..214a9bacf5 100644 --- a/arviz/plots/backends/bokeh/rankplot.py +++ b/arviz/plots/backends/bokeh/rankplot.py @@ -28,7 +28,7 @@ def plot_rank( labels, show, ): - + """Bokeh rank plot.""" if axes is None: _, axes = _create_axes_grid( length_plotters, diff --git a/arviz/plots/backends/bokeh/traceplot.py b/arviz/plots/backends/bokeh/traceplot.py index 483ab5df18..abd0cbe83b 100644 --- a/arviz/plots/backends/bokeh/traceplot.py +++ b/arviz/plots/backends/bokeh/traceplot.py @@ -37,7 +37,7 @@ def plot_trace( colors, backend_kwargs: [Dict], ): - + """Bokeh traceplot.""" # If divergences are plotted they must be provided if divergences is not False: assert divergence_data is not None diff --git a/arviz/plots/backends/bokeh/violinplot.py b/arviz/plots/backends/bokeh/violinplot.py index 332206dada..ea8974d5cc 100644 --- a/arviz/plots/backends/bokeh/violinplot.py +++ b/arviz/plots/backends/bokeh/violinplot.py @@ -35,7 +35,7 @@ def plot_violin( squeeze=False, backend="bokeh", ) - + """Bokeh violin plot.""" ax = np.atleast_1d(ax) for (var_name, selection, x), ax_ in zip(plotters, ax.flatten()): diff --git a/arviz/plots/backends/matplotlib/autocorrplot.py b/arviz/plots/backends/matplotlib/autocorrplot.py index 1d77c37649..87d7a84fe3 100644 --- a/arviz/plots/backends/matplotlib/autocorrplot.py +++ b/arviz/plots/backends/matplotlib/autocorrplot.py @@ -8,6 +8,7 @@ def plot_autocorr( axes, plotters, max_lag, linewidth, titlesize, combined=False, xt_labelsize=None, ): + """Matplotlib autocorrplot.""" for (var_name, selection, x), ax_ in zip(plotters, axes.flatten()): x_prime = x if combined: diff --git a/arviz/plots/backends/matplotlib/compareplot.py b/arviz/plots/backends/matplotlib/compareplot.py index 8dd0f3f68f..89f407108f 100644 --- a/arviz/plots/backends/matplotlib/compareplot.py +++ b/arviz/plots/backends/matplotlib/compareplot.py @@ -18,6 +18,7 @@ def plot_compare( xt_labelsize, step, ): + """Matplotlib compare plot.""" if ax is None: _, ax = plt.subplots(figsize=figsize, constrained_layout=True) diff --git a/arviz/plots/backends/matplotlib/densityplot.py b/arviz/plots/backends/matplotlib/densityplot.py index 6d226ba39f..26690870a1 100644 --- a/arviz/plots/backends/matplotlib/densityplot.py +++ b/arviz/plots/backends/matplotlib/densityplot.py @@ -24,6 +24,7 @@ def plot_density( n_data, data_labels, ): + """Matplotlib densityplot.""" axis_map = {label: ax_ for label, ax_ in zip(all_labels, ax.flatten())} for m_idx, plotters in enumerate(to_plot): diff --git a/arviz/plots/backends/matplotlib/distplot.py b/arviz/plots/backends/matplotlib/distplot.py index bc7ea0bc6c..df6d98a2e1 100644 --- a/arviz/plots/backends/matplotlib/distplot.py +++ b/arviz/plots/backends/matplotlib/distplot.py @@ -29,6 +29,7 @@ def plot_dist( ax, backend_kwargs, ): + """Matplotlib distplot.""" if backend_kwargs is not None: warnings.warn( ( diff --git a/arviz/plots/backends/matplotlib/elpdplot.py b/arviz/plots/backends/matplotlib/elpdplot.py index 647cff6740..d4c2ccf89e 100644 --- a/arviz/plots/backends/matplotlib/elpdplot.py +++ b/arviz/plots/backends/matplotlib/elpdplot.py @@ -29,6 +29,7 @@ def plot_elpd( handles, color, ): + """Matplotlib elpd plot.""" if numvars == 2: (figsize, ax_labelsize, titlesize, xt_labelsize, _, markersize) = _scale_fig_size( diff --git a/arviz/plots/backends/matplotlib/energyplot.py b/arviz/plots/backends/matplotlib/energyplot.py index 504baf884f..21dc029fdb 100644 --- a/arviz/plots/backends/matplotlib/energyplot.py +++ b/arviz/plots/backends/matplotlib/energyplot.py @@ -19,6 +19,7 @@ def plot_energy( bw, legend, ): + """Matplotlib energy plot.""" if ax is None: _, ax = plt.subplots(figsize=figsize, constrained_layout=True) diff --git a/arviz/plots/backends/matplotlib/essplot.py b/arviz/plots/backends/matplotlib/essplot.py index 21fddccbec..bb4d1b971c 100644 --- a/arviz/plots/backends/matplotlib/essplot.py +++ b/arviz/plots/backends/matplotlib/essplot.py @@ -42,6 +42,7 @@ def plot_ess( rug_kwargs, hline_kwargs, ): + """Matplotlib ess plot.""" if ax is None: _, ax = _create_axes_grid( len(plotters), rows, cols, figsize=figsize, squeeze=False, constrained_layout=True diff --git a/arviz/plots/backends/matplotlib/forestplot.py b/arviz/plots/backends/matplotlib/forestplot.py index e8c780078a..b046b70282 100644 --- a/arviz/plots/backends/matplotlib/forestplot.py +++ b/arviz/plots/backends/matplotlib/forestplot.py @@ -44,6 +44,7 @@ def plot_forest( ess, r_hat, ): + """Matplotlib forest plot.""" plot_handler = PlotHandler( datasets, var_names=var_names, model_names=model_names, combined=combined, colors=colors ) diff --git a/arviz/plots/backends/matplotlib/hpdplot.py b/arviz/plots/backends/matplotlib/hpdplot.py index e0761a9ad8..5f68c59c91 100644 --- a/arviz/plots/backends/matplotlib/hpdplot.py +++ b/arviz/plots/backends/matplotlib/hpdplot.py @@ -3,6 +3,7 @@ def plot_hpd(ax, x_data, y_data, plot_kwargs, fill_kwargs): + """Matplotlib hpd plot.""" if ax is None: ax = gca() ax.plot(x_data, y_data, **plot_kwargs) diff --git a/arviz/plots/backends/matplotlib/jointplot.py b/arviz/plots/backends/matplotlib/jointplot.py index 089c669915..8cb03a4cf9 100644 --- a/arviz/plots/backends/matplotlib/jointplot.py +++ b/arviz/plots/backends/matplotlib/jointplot.py @@ -20,6 +20,7 @@ def plot_joint( gridsize, marginal_kwargs, ): + """Matplotlib joint plot.""" if ax is None: # Instantiate figure and grid fig, _ = plt.subplots(0, 0, figsize=figsize, constrained_layout=True) diff --git a/arviz/plots/backends/matplotlib/kdeplot.py b/arviz/plots/backends/matplotlib/kdeplot.py index d94627d247..959ad878d0 100644 --- a/arviz/plots/backends/matplotlib/kdeplot.py +++ b/arviz/plots/backends/matplotlib/kdeplot.py @@ -33,6 +33,7 @@ def plot_kde( ax=None, legend=True, ): + """Matplotlib kde plot.""" if ax is None: ax = plt.gca() diff --git a/arviz/plots/backends/matplotlib/khatplot.py b/arviz/plots/backends/matplotlib/khatplot.py index 93d9e4314b..b72efd5356 100644 --- a/arviz/plots/backends/matplotlib/khatplot.py +++ b/arviz/plots/backends/matplotlib/khatplot.py @@ -33,6 +33,7 @@ def plot_khat( n_data_points, bin_format, ): + """Matplotlib khat plot.""" if hover_label and mpl.get_backend() not in mpl.rcsetup.interactive_bk: hover_label = False warnings.warn( diff --git a/arviz/plots/backends/matplotlib/loopitplot.py b/arviz/plots/backends/matplotlib/loopitplot.py index 7a5f21dc9c..3775474ea4 100644 --- a/arviz/plots/backends/matplotlib/loopitplot.py +++ b/arviz/plots/backends/matplotlib/loopitplot.py @@ -30,7 +30,7 @@ def plot_loo_pit( credible_interval, plot_kwargs, ): - + """Matplotlib loo pit plot.""" if ax is None: _, ax = plt.subplots(1, 1, figsize=figsize, constrained_layout=True) diff --git a/arviz/plots/backends/matplotlib/mcseplot.py b/arviz/plots/backends/matplotlib/mcseplot.py index a04025bf98..05fe62cfa0 100644 --- a/arviz/plots/backends/matplotlib/mcseplot.py +++ b/arviz/plots/backends/matplotlib/mcseplot.py @@ -38,6 +38,7 @@ def plot_mcse( ax_labelsize, titlesize, ): + """Matplotlib mcseplot.""" if ax is None: _, ax = _create_axes_grid( length_plotters, rows, cols, figsize=figsize, squeeze=False, constrained_layout=True diff --git a/arviz/plots/backends/matplotlib/pairplot.py b/arviz/plots/backends/matplotlib/pairplot.py index 4980a33bf6..da2ba9cb2c 100644 --- a/arviz/plots/backends/matplotlib/pairplot.py +++ b/arviz/plots/backends/matplotlib/pairplot.py @@ -28,6 +28,7 @@ def plot_pair( divergences_kwargs, flat_var_names, ): + """Matplotlib pairplot.""" if numvars == 2: (figsize, ax_labelsize, _, xt_labelsize, _, _) = _scale_fig_size( figsize, textsize, numvars - 1, numvars - 1 diff --git a/arviz/plots/backends/matplotlib/parallelplot.py b/arviz/plots/backends/matplotlib/parallelplot.py index 0b74315170..f1c93c1b11 100644 --- a/arviz/plots/backends/matplotlib/parallelplot.py +++ b/arviz/plots/backends/matplotlib/parallelplot.py @@ -16,6 +16,7 @@ def plot_parallel( legend, figsize, ): + """Matplotlib parallel plot.""" if ax is None: _, ax = plt.subplots(figsize=figsize, constrained_layout=True) diff --git a/arviz/plots/backends/matplotlib/posteriorplot.py b/arviz/plots/backends/matplotlib/posteriorplot.py index 2c3bd2a382..7afd8af077 100644 --- a/arviz/plots/backends/matplotlib/posteriorplot.py +++ b/arviz/plots/backends/matplotlib/posteriorplot.py @@ -35,6 +35,7 @@ def plot_posterior( kwargs, titlesize, ): + """Matplotlib posterior plot.""" if ax is None: _, ax = _create_axes_grid( length_plotters, rows, cols, figsize=figsize, squeeze=False, constrained_layout=True diff --git a/arviz/plots/backends/matplotlib/ppcplot.py b/arviz/plots/backends/matplotlib/ppcplot.py index 12cd6d9341..d4924521ed 100644 --- a/arviz/plots/backends/matplotlib/ppcplot.py +++ b/arviz/plots/backends/matplotlib/ppcplot.py @@ -34,6 +34,7 @@ def plot_ppc( animation_kwargs, num_pp_samples, ): + """Matplotlib ppc plot.""" if ax is None: fig, axes = _create_axes_grid(length_plotters, rows, cols, figsize=figsize) else: diff --git a/arviz/plots/backends/matplotlib/rankplot.py b/arviz/plots/backends/matplotlib/rankplot.py index b132bfba36..0085ab4446 100644 --- a/arviz/plots/backends/matplotlib/rankplot.py +++ b/arviz/plots/backends/matplotlib/rankplot.py @@ -24,7 +24,7 @@ def plot_rank( ax_labelsize, titlesize, ): - + """Matplotlib rankplot..""" if axes is None: _, axes = _create_axes_grid(length_plotters, rows, cols, figsize=figsize, squeeze=False) diff --git a/arviz/plots/backends/matplotlib/traceplot.py b/arviz/plots/backends/matplotlib/traceplot.py index 3b285b3104..b24428691c 100644 --- a/arviz/plots/backends/matplotlib/traceplot.py +++ b/arviz/plots/backends/matplotlib/traceplot.py @@ -1,4 +1,4 @@ -"""Matplotlib Traceplot.""" +"""Matplotlib traceplot.""" import matplotlib.pyplot as plt from matplotlib.lines import Line2D @@ -15,7 +15,7 @@ def plot_trace( data, - var_names, # pylint: disable=unused-argument + var_names, # pylint: disable=unused-argument divergences, figsize, rug, diff --git a/arviz/plots/backends/matplotlib/violinplot.py b/arviz/plots/backends/matplotlib/violinplot.py index 53bfc0765f..ae06fe6071 100644 --- a/arviz/plots/backends/matplotlib/violinplot.py +++ b/arviz/plots/backends/matplotlib/violinplot.py @@ -23,6 +23,7 @@ def plot_violin( xt_labelsize, quartiles, ): + """Matplotlib violin plot.""" if ax is None: _, ax = _create_axes_grid( len(plotters), rows, cols, sharey=sharey, figsize=figsize, squeeze=False diff --git a/arviz/plots/plot_utils.py b/arviz/plots/plot_utils.py index ec089cbd4a..670716edae 100644 --- a/arviz/plots/plot_utils.py +++ b/arviz/plots/plot_utils.py @@ -626,7 +626,7 @@ def filter_plotters_list(plotters, plot_kind): def get_plotting_method(plot_name, plot_module, backend): - """Returns plotting function for correct backend""" + """Return plotting function for correct backend.""" _backend = { "mpl": "matplotlib", "bokeh": "bokeh", diff --git a/arviz/plots/traceplot.py b/arviz/plots/traceplot.py index 8841135288..249c80601b 100644 --- a/arviz/plots/traceplot.py +++ b/arviz/plots/traceplot.py @@ -114,8 +114,6 @@ def plot_trace( >>> az.plot_trace(data, var_names=('theta_t', 'theta'), coords=coords, lines=lines) """ - - # TODO: This can be simplified somehow I feel like if divergences: try: divergence_data = convert_to_dataset(data, group="sample_stats").diverging diff --git a/arviz/tests/test_plots_bokeh.py b/arviz/tests/test_plots_bokeh.py index 9ab53845c5..f5e7ebbf01 100644 --- a/arviz/tests/test_plots_bokeh.py +++ b/arviz/tests/test_plots_bokeh.py @@ -140,15 +140,15 @@ def test_plot_trace(models, kwargs, backend_kwargs): assert axes.shape -def test_plot_trace_discrete(discrete_model): - axes = plot_trace(discrete_model, backend="bokeh") +def test_plot_trace_discrete(discrete_mode, backend_kwargs): + axes = plot_trace(discrete_model, backend="bokeh", backend_kwargs=backend_kwargs) assert axes.shape -def test_plot_trace_max_subplots_warning(models): +def test_plot_trace_max_subplots_warning(models, backend_kwargs): with pytest.warns(SyntaxWarning): with rc_context(rc={"plot.max_subplots": 1}): - axes = plot_trace(models.model_1, backend="bokeh") + axes = plot_trace(models.model_1, backend="bokeh", backend_kwargs=backend_kwargs) assert axes.shape From 905b55f5f071bbbcf7d39ac5f1f9b5d9b2c74a68 Mon Sep 17 00:00:00 2001 From: Ravin Kumar Date: Thu, 12 Dec 2019 21:09:51 -0800 Subject: [PATCH 31/41] Add more doc fixes --- arviz/plots/backends/bokeh/distplot.py | 1 - arviz/plots/backends/bokeh/hpdplot.py | 1 - arviz/plots/backends/bokeh/violinplot.py | 2 +- arviz/plots/backends/matplotlib/compareplot.py | 1 - arviz/plots/backends/matplotlib/elpdplot.py | 1 - 5 files changed, 1 insertion(+), 5 deletions(-) diff --git a/arviz/plots/backends/bokeh/distplot.py b/arviz/plots/backends/bokeh/distplot.py index 07c6db66b1..b186c10f13 100644 --- a/arviz/plots/backends/bokeh/distplot.py +++ b/arviz/plots/backends/bokeh/distplot.py @@ -33,7 +33,6 @@ def plot_dist( **kwargs # pylint: disable=unused-argument ): """Bokeh distplot.""" - if backend_kwargs is None: backend_kwargs = {} diff --git a/arviz/plots/backends/bokeh/hpdplot.py b/arviz/plots/backends/bokeh/hpdplot.py index d5fd5eb8b5..5a2b49667b 100644 --- a/arviz/plots/backends/bokeh/hpdplot.py +++ b/arviz/plots/backends/bokeh/hpdplot.py @@ -10,7 +10,6 @@ def plot_hpd(ax, x_data, y_data, plot_kwargs, fill_kwargs, show): """Bokeh hpd plot.""" - if ax is None: tools = rcParams["plot.bokeh.tools"] output_backend = rcParams["plot.bokeh.output_backend"] diff --git a/arviz/plots/backends/bokeh/violinplot.py b/arviz/plots/backends/bokeh/violinplot.py index ea8974d5cc..dd4735ebbe 100644 --- a/arviz/plots/backends/bokeh/violinplot.py +++ b/arviz/plots/backends/bokeh/violinplot.py @@ -25,6 +25,7 @@ def plot_violin( quartiles, show, ): + """Bokeh violin plot.""" if ax is None: _, ax = _create_axes_grid( len(plotters), @@ -35,7 +36,6 @@ def plot_violin( squeeze=False, backend="bokeh", ) - """Bokeh violin plot.""" ax = np.atleast_1d(ax) for (var_name, selection, x), ax_ in zip(plotters, ax.flatten()): diff --git a/arviz/plots/backends/matplotlib/compareplot.py b/arviz/plots/backends/matplotlib/compareplot.py index 89f407108f..d6f0ac54de 100644 --- a/arviz/plots/backends/matplotlib/compareplot.py +++ b/arviz/plots/backends/matplotlib/compareplot.py @@ -19,7 +19,6 @@ def plot_compare( step, ): """Matplotlib compare plot.""" - if ax is None: _, ax = plt.subplots(figsize=figsize, constrained_layout=True) diff --git a/arviz/plots/backends/matplotlib/elpdplot.py b/arviz/plots/backends/matplotlib/elpdplot.py index d4c2ccf89e..0f372dc14f 100644 --- a/arviz/plots/backends/matplotlib/elpdplot.py +++ b/arviz/plots/backends/matplotlib/elpdplot.py @@ -30,7 +30,6 @@ def plot_elpd( color, ): """Matplotlib elpd plot.""" - if numvars == 2: (figsize, ax_labelsize, titlesize, xt_labelsize, _, markersize) = _scale_fig_size( figsize, textsize, numvars - 1, numvars - 1 From c0c6fd786babc6dda2dd2d447ec4abd2d0f64df4 Mon Sep 17 00:00:00 2001 From: Ravin Kumar Date: Fri, 13 Dec 2019 06:14:12 -0800 Subject: [PATCH 32/41] Fix additional linting errors --- .pylintrc | 5 ++++- arviz/data/io_pyro.py | 1 + arviz/plots/traceplot.py | 2 +- arviz/tests/test_plots_bokeh.py | 2 +- 4 files changed, 7 insertions(+), 3 deletions(-) diff --git a/.pylintrc b/.pylintrc index 7c1aa5fe1b..cf3cb0fc2c 100644 --- a/.pylintrc +++ b/.pylintrc @@ -67,7 +67,10 @@ disable=missing-docstring, import-outside-toplevel, no-else-continue, unnecessary-comprehension, - unsubscriptable-object + unsubscriptable-object, + + #TODO: Remove this + fixme # Enable the message, report, category or checker with the given id(s). You can diff --git a/arviz/data/io_pyro.py b/arviz/data/io_pyro.py index aec94d0ff6..0763df0b2f 100644 --- a/arviz/data/io_pyro.py +++ b/arviz/data/io_pyro.py @@ -1,3 +1,4 @@ +# pylint: disable=cyclic-import """Pyro-specific conversion code.""" import logging import numpy as np diff --git a/arviz/plots/traceplot.py b/arviz/plots/traceplot.py index 249c80601b..32360df267 100644 --- a/arviz/plots/traceplot.py +++ b/arviz/plots/traceplot.py @@ -4,7 +4,7 @@ import matplotlib.pyplot as plt -from .plot_utils import get_plotting_method, get_coords, xarray_var_iter, _scale_fig_size +from .plot_utils import get_plotting_method, get_coords, xarray_var_iter from ..data import convert_to_dataset from ..utils import _var_names from ..rcparams import rcParams diff --git a/arviz/tests/test_plots_bokeh.py b/arviz/tests/test_plots_bokeh.py index f5e7ebbf01..90d2a5145d 100644 --- a/arviz/tests/test_plots_bokeh.py +++ b/arviz/tests/test_plots_bokeh.py @@ -140,7 +140,7 @@ def test_plot_trace(models, kwargs, backend_kwargs): assert axes.shape -def test_plot_trace_discrete(discrete_mode, backend_kwargs): +def test_plot_trace_discrete(discrete_model, backend_kwargs): axes = plot_trace(discrete_model, backend="bokeh", backend_kwargs=backend_kwargs) assert axes.shape From 075379edd1a938832f0b5b1a2155622c98c29ca2 Mon Sep 17 00:00:00 2001 From: Ravin Kumar Date: Fri, 13 Dec 2019 06:29:28 -0800 Subject: [PATCH 33/41] Fix additional linting error --- arviz/plots/distplot.py | 1 - 1 file changed, 1 deletion(-) diff --git a/arviz/plots/distplot.py b/arviz/plots/distplot.py index f3abc46451..5e01f5613b 100644 --- a/arviz/plots/distplot.py +++ b/arviz/plots/distplot.py @@ -1,6 +1,5 @@ # pylint: disable=unexpected-keyword-arg """Plot distribution as histogram or kernel density estimates.""" -from .plot_utils import get_bins from .plot_utils import get_bins, get_plotting_method From 643fdf9f87919ad9d4b2bd3c5550911e6c19a5a2 Mon Sep 17 00:00:00 2001 From: Ravin Kumar Date: Fri, 13 Dec 2019 07:03:55 -0800 Subject: [PATCH 34/41] Try lint fix again --- arviz/data/io_numpyro.py | 1 + arviz/data/io_pyro.py | 1 - 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/arviz/data/io_numpyro.py b/arviz/data/io_numpyro.py index 5b0926f69f..43c83571af 100644 --- a/arviz/data/io_numpyro.py +++ b/arviz/data/io_numpyro.py @@ -1,3 +1,4 @@ +# pylint: disable=cyclic-import """NumPyro-specific conversion code.""" import logging import numpy as np diff --git a/arviz/data/io_pyro.py b/arviz/data/io_pyro.py index 0763df0b2f..aec94d0ff6 100644 --- a/arviz/data/io_pyro.py +++ b/arviz/data/io_pyro.py @@ -1,4 +1,3 @@ -# pylint: disable=cyclic-import """Pyro-specific conversion code.""" import logging import numpy as np From e1672a4ab479bd64117286cfb332c8d8169b3e73 Mon Sep 17 00:00:00 2001 From: Ravin Kumar Date: Fri, 13 Dec 2019 07:17:53 -0800 Subject: [PATCH 35/41] Add additional lint ignore --- arviz/data/io_pyro.py | 1 + 1 file changed, 1 insertion(+) diff --git a/arviz/data/io_pyro.py b/arviz/data/io_pyro.py index aec94d0ff6..0763df0b2f 100644 --- a/arviz/data/io_pyro.py +++ b/arviz/data/io_pyro.py @@ -1,3 +1,4 @@ +# pylint: disable=cyclic-import """Pyro-specific conversion code.""" import logging import numpy as np From 607bcdc4ae074e7cf8800edf6762119ec01e8398 Mon Sep 17 00:00:00 2001 From: Ravin Kumar Date: Fri, 13 Dec 2019 07:32:15 -0800 Subject: [PATCH 36/41] Remove cyclic import entirely --- .pylintrc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.pylintrc b/.pylintrc index cf3cb0fc2c..f897f96f12 100644 --- a/.pylintrc +++ b/.pylintrc @@ -68,8 +68,9 @@ disable=missing-docstring, no-else-continue, unnecessary-comprehension, unsubscriptable-object, + cyclic-import - #TODO: Remove this + #TODO: Remove this once todos are done fixme From 86cf36998f6e393edf47b3a68705cfbb68ff373b Mon Sep 17 00:00:00 2001 From: Ravin Kumar Date: Fri, 13 Dec 2019 07:55:44 -0800 Subject: [PATCH 37/41] Fix missing comma --- .pylintrc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.pylintrc b/.pylintrc index f897f96f12..0dff1b7c6f 100644 --- a/.pylintrc +++ b/.pylintrc @@ -68,7 +68,7 @@ disable=missing-docstring, no-else-continue, unnecessary-comprehension, unsubscriptable-object, - cyclic-import + cyclic-import, #TODO: Remove this once todos are done fixme From 89717536b91daad7ea0a0241e52f43c54b0fd395 Mon Sep 17 00:00:00 2001 From: Ravin Kumar Date: Fri, 13 Dec 2019 16:29:12 -0800 Subject: [PATCH 38/41] Fix examples --- examples/bokeh/bokeh_plot_dist.py | 4 ++-- examples/bokeh/bokeh_plot_trace.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/bokeh/bokeh_plot_dist.py b/examples/bokeh/bokeh_plot_dist.py index 1aa03750f4..e794c9a7aa 100644 --- a/examples/bokeh/bokeh_plot_dist.py +++ b/examples/bokeh/bokeh_plot_dist.py @@ -16,8 +16,8 @@ ax_poisson = bkp.figure(**figure_kwargs) ax_normal = bkp.figure(**figure_kwargs) -az.plot_dist(a, color="black", label="Poisson", ax=ax_poisson, backend="bokeh", show=False) -az.plot_dist(b, color="red", label="Gaussian", ax=ax_normal, backend="bokeh", show=False) +az.plot_dist(a, color="black", label="Poisson", ax=ax_poisson, backend="bokeh", backend_kwargs={"show":False}) +az.plot_dist(b, color="red", label="Gaussian", ax=ax_normal, backend="bokeh", backend_kwargs={"show":False}) ax = row(ax_poisson, ax_normal) bkp.show(ax) diff --git a/examples/bokeh/bokeh_plot_trace.py b/examples/bokeh/bokeh_plot_trace.py index 103a71cfc2..fd1937ac25 100644 --- a/examples/bokeh/bokeh_plot_trace.py +++ b/examples/bokeh/bokeh_plot_trace.py @@ -7,4 +7,4 @@ import arviz as az data = az.load_arviz_data("non_centered_eight") -ax = az.plot_trace(data, var_names=("tau", "mu"), backend="bokeh", show=True) +ax = az.plot_trace(data, var_names=("tau", "mu"), backend="bokeh", backend_kwargs={"show":True}) From 70e6f86522ecac4eb9ae0fb679df9c91d2c8e079 Mon Sep 17 00:00:00 2001 From: Ravin Kumar Date: Fri, 13 Dec 2019 16:29:30 -0800 Subject: [PATCH 39/41] Update container script to include shell with bind mounts --- scripts/container.sh | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/scripts/container.sh b/scripts/container.sh index 2eb3c1d15b..a55afa9e11 100755 --- a/scripts/container.sh +++ b/scripts/container.sh @@ -28,6 +28,11 @@ if [[ $* == *--clear-cache* ]]; then fi +if [[ $* == *--shell* ]]; then + echo "Starting Arviz Container Shell" + docker run -it --mount type=bind,source="$(pwd)",target=/opt/arviz --name arviz_shell --rm arviz:latest /bin/bash +fi + if [[ $* == *--sphinx-build* ]]; then echo "Build docs with sphinx" docker run --mount type=bind,source="$(pwd)",target=/opt/arviz --name arviz_sphinx --rm arviz:latest bash -c \ From ce044f44c21ff4a048318cd983c4941ed05f284d Mon Sep 17 00:00:00 2001 From: Ravin Kumar Date: Fri, 13 Dec 2019 16:35:01 -0800 Subject: [PATCH 40/41] Remove unneeded comment --- arviz/plots/backends/matplotlib/traceplot.py | 1 - 1 file changed, 1 deletion(-) diff --git a/arviz/plots/backends/matplotlib/traceplot.py b/arviz/plots/backends/matplotlib/traceplot.py index b24428691c..9282ed480f 100644 --- a/arviz/plots/backends/matplotlib/traceplot.py +++ b/arviz/plots/backends/matplotlib/traceplot.py @@ -120,7 +120,6 @@ def plot_trace( figsize, backend_kwargs["textsize"], rows=len(plotters), cols=2 ) - # TODO: This is breaking plotting for some reason trace_kwargs.setdefault("linewidth", linewidth) plot_kwargs.setdefault("linewidth", linewidth) From dea8432f35842a88df246500d3f2026b7643477d Mon Sep 17 00:00:00 2001 From: Ravin Kumar Date: Fri, 13 Dec 2019 16:47:45 -0800 Subject: [PATCH 41/41] Add black to examples --- examples/bokeh/bokeh_plot_dist.py | 13 +++++++++++-- examples/bokeh/bokeh_plot_trace.py | 2 +- 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/examples/bokeh/bokeh_plot_dist.py b/examples/bokeh/bokeh_plot_dist.py index e794c9a7aa..e428def85d 100644 --- a/examples/bokeh/bokeh_plot_dist.py +++ b/examples/bokeh/bokeh_plot_dist.py @@ -16,8 +16,17 @@ ax_poisson = bkp.figure(**figure_kwargs) ax_normal = bkp.figure(**figure_kwargs) -az.plot_dist(a, color="black", label="Poisson", ax=ax_poisson, backend="bokeh", backend_kwargs={"show":False}) -az.plot_dist(b, color="red", label="Gaussian", ax=ax_normal, backend="bokeh", backend_kwargs={"show":False}) +az.plot_dist( + a, + color="black", + label="Poisson", + ax=ax_poisson, + backend="bokeh", + backend_kwargs={"show": False}, +) +az.plot_dist( + b, color="red", label="Gaussian", ax=ax_normal, backend="bokeh", backend_kwargs={"show": False} +) ax = row(ax_poisson, ax_normal) bkp.show(ax) diff --git a/examples/bokeh/bokeh_plot_trace.py b/examples/bokeh/bokeh_plot_trace.py index fd1937ac25..aef11cd073 100644 --- a/examples/bokeh/bokeh_plot_trace.py +++ b/examples/bokeh/bokeh_plot_trace.py @@ -7,4 +7,4 @@ import arviz as az data = az.load_arviz_data("non_centered_eight") -ax = az.plot_trace(data, var_names=("tau", "mu"), backend="bokeh", backend_kwargs={"show":True}) +ax = az.plot_trace(data, var_names=("tau", "mu"), backend="bokeh", backend_kwargs={"show": True})