Skip to content

Commit

Permalink
Merge pull request #90 from aloctavodia/autocorrplot
Browse files Browse the repository at this point in the history
autocorrplot: share x-y labels
  • Loading branch information
ColCarroll committed May 21, 2018
2 parents 461d07f + 8a99fe2 commit d5ff245
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 34 deletions.
22 changes: 11 additions & 11 deletions arviz/plots/autocorrplot.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import matplotlib.pyplot as plt

from .plot_utils import get_axis, _scale_text
from .plot_utils import _scale_text
from ..utils import get_varnames, trace_to_dataframe


Expand Down Expand Up @@ -41,13 +41,13 @@ def autocorrplot(trace, varnames=None, max_lag=100, symmetric_plot=False, combin
varnames = get_varnames(trace, varnames)

if figsize is None:
figsize = (6, len(varnames) * 2)
figsize = (12, len(varnames) * 2)

textsize, linewidth, _ = _scale_text(figsize, textsize=textsize)
textsize, linewidth, _ = _scale_text(figsize, textsize, 1)

nchains = trace.columns.value_counts()[0]
ax = get_axis(ax, len(varnames), nchains, squeeze=False, sharex=True, sharey=True,
figsize=figsize)
fig, ax = plt.subplots(len(varnames), nchains, squeeze=False, sharex=True, sharey=True,
figsize=figsize)

max_lag = min(len(trace) - 1, max_lag)

Expand All @@ -59,12 +59,6 @@ def autocorrplot(trace, varnames=None, max_lag=100, symmetric_plot=False, combin
data = trace[varname].values[:, j]
ax[i, j].acorr(data, detrend=plt.mlab.detrend_mean, maxlags=max_lag, lw=linewidth)

if j == 0:
ax[i, j].set_ylabel("correlation", fontsize=textsize)

if i == len(varnames) - 1:
ax[i, j].set_xlabel("lag", fontsize=textsize)

if not symmetric_plot:
ax[i, j].set_xlim(0, max_lag)

Expand All @@ -73,4 +67,10 @@ def autocorrplot(trace, varnames=None, max_lag=100, symmetric_plot=False, combin
else:
ax[i, j].set_title(varname, fontsize=textsize)
ax[i, j].tick_params(labelsize=textsize)

fig.add_subplot(111, frameon=False)
plt.tick_params(labelcolor='none', top='off', bottom='off', left='off', right='off')
plt.grid(False)
plt.xlabel("Lag", fontsize=textsize)
plt.ylabel("Correlation", fontsize=textsize)
return ax
21 changes: 0 additions & 21 deletions arviz/plots/plot_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,27 +2,6 @@
import matplotlib.pyplot as plt


def get_axis(ax, default_rows, default_columns, **default_kwargs):
"""Verifies the provided axis is of the correct shape, and creates one if needed.
Args:
ax: matplotlib axis or None
default_rows: int, expected rows in axis
default_columns: int, expected columns in axis
**default_kwargs: keyword arguments to pass to plt.subplot
Returns:
axis, or raises an error
"""

default_shape = (default_rows, default_columns)
if ax is None:
_, ax = plt.subplots(*default_shape, **default_kwargs)
elif ax.shape != default_shape:
raise ValueError('Subplots with shape %r required' % (default_shape,))
return ax


def make_2d(ary):
"""Convert any array into a 2d numpy array.
Expand Down
4 changes: 2 additions & 2 deletions arviz/plots/traceplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import matplotlib.pyplot as plt

from .kdeplot import fast_kde
from .plot_utils import get_axis, make_2d, get_bins, _scale_text
from .plot_utils import make_2d, get_bins, _scale_text
from ..utils import get_varnames, trace_to_dataframe


Expand Down Expand Up @@ -67,7 +67,7 @@ def traceplot(trace, varnames=None, figsize=None, textsize=None, lines=None, com

textsize, linewidth, _ = _scale_text(figsize, textsize=textsize, scale_ratio=1)

ax = get_axis(ax, len(varnames), 2, squeeze=False, figsize=figsize)
_, ax = plt.subplots(len(varnames), 2, squeeze=False, figsize=figsize)

for i, varname in enumerate(varnames):
if priors is not None:
Expand Down

0 comments on commit d5ff245

Please sign in to comment.