Skip to content

Commit

Permalink
Merge pull request #88 from aloctavodia/figsize
Browse files Browse the repository at this point in the history
Change default figsize values
  • Loading branch information
ColCarroll committed May 20, 2018
2 parents 4930832 + c706e0c commit b7b8a39
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 38 deletions.
5 changes: 3 additions & 2 deletions arviz/plots/pairplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,10 +68,11 @@ def pairplot(trace, varnames=None, figsize=None, textsize=None, kind='scatter',
numvars = len(varnames)

if figsize is None:
figsize = (8 + numvars, 8 + numvars)
figsize = (2 * numvars, 2 * numvars)

if textsize is None:
textsize, _, markersize = _scale_text(figsize, textsize=textsize, scale_ratio=1.5)
scale_ratio = (6 / numvars) ** 0.75
textsize, _, markersize = _scale_text(figsize, textsize=textsize, scale_ratio=scale_ratio)

if numvars < 2:
raise Exception('Number of variables to be plotted must be 2 or greater.')
Expand Down
33 changes: 18 additions & 15 deletions arviz/plots/plot_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,14 +80,16 @@ def get_bins(ary, max_bins=50, fenceposts=2):
return bins


def _create_axes_grid(figsize, trace):
def _create_axes_grid(trace, figsize, ax):
"""
Parameters
----------
figsize : tuple
Figure size.
trace : dict or DataFrame
dictionary with ppc samples of DataFrame with posterior samples
figsize : tuple
figure size
ax : matplotlib axes
Returns
-------
fig : matplotlib figure
Expand All @@ -97,15 +99,16 @@ def _create_axes_grid(figsize, trace):
l_trace = len(trace)
else:
l_trace = trace.shape[1]
if l_trace == 1:
fig, ax = plt.subplots(figsize=figsize)
else:
n_rows = np.ceil(l_trace / 2.0).astype(int)
if figsize is None:
figsize = (12, n_rows * 2.5)
fig, ax = plt.subplots(n_rows, 2, figsize=figsize)
ax = ax.reshape(2 * n_rows)
if l_trace % 2 == 1:
ax[-1].set_axis_off()
ax = ax[:-1]
return fig, ax
if figsize is None:
figsize = (8, 2 + l_trace + (l_trace % 2))
if ax is None:
if l_trace == 1:
_, ax = plt.subplots(figsize=figsize)
else:
n_rows = np.ceil(l_trace / 2.0).astype(int)
_, ax = plt.subplots(n_rows, 2, figsize=figsize)
ax = ax.reshape(2 * n_rows)
if l_trace % 2 == 1:
ax[-1].set_axis_off()
ax = ax[:-1]
return ax, figsize
28 changes: 13 additions & 15 deletions arviz/plots/posteriorplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,9 @@ def posteriorplot(trace, varnames=None, figsize=None, textsize=None, alpha=0.05,
Controls formatting for floating point numbers
point_estimate: str
Must be in ('mode', 'mean', 'median')
rope: list or numpy array
Lower and upper values of the Region Of Practical Equivalence
rope: tuple of list of tuples
Lower and upper values of the Region Of Practical Equivalence. If a list is provided, its
length should match the number of variables.
ref_val: float or list-like
display the percentage below and above the values in ref_val. If a list is provided, its
length should match the number of variables.
Expand Down Expand Up @@ -66,11 +67,7 @@ def posteriorplot(trace, varnames=None, figsize=None, textsize=None, alpha=0.05,
varnames = expand_variable_names(trace, varnames)
trace = trace[varnames]

if figsize is None:
figsize = (8, 8)

if ax is None:
_, ax = _create_axes_grid(figsize, trace)
ax, figsize = _create_axes_grid(trace, figsize, ax)

textsize, linewidth, _ = _scale_text(figsize, textsize, 1.5)

Expand Down Expand Up @@ -109,16 +106,16 @@ def display_ref_val(ref_val):
ref_in_posterior = "{} <{:g}< {}".format(format_as_percent(less_than_ref_probability, 1),
ref_val,
format_as_percent(greater_than_ref_probability, 1))
ax.axvline(ref_val, ymin=0.02, ymax=.75, color='C1', lw=linewidth, alpha=0.65)
ax.axvline(ref_val, ymin=0.05, ymax=.75, color='C1', lw=linewidth, alpha=0.65)
ax.text(trace_values.mean(), plot_height * 0.6, ref_in_posterior, size=textsize,
color='C1', horizontalalignment='center')
color='C1', weight='semibold', horizontalalignment='center')

def display_rope(rope):
ax.plot(rope, (plot_height * 0.02, plot_height * 0.02), lw=linewidth*5, color='C2',
alpha=0.75)
solid_capstyle='round')
text_props = dict(size=textsize, horizontalalignment='center', color='C2')
ax.text(rope[0], plot_height * 0.14, rope[0], **text_props)
ax.text(rope[1], plot_height * 0.14, rope[1], **text_props)
ax.text(rope[0], plot_height * 0.2, rope[0], weight='semibold', **text_props)
ax.text(rope[1], plot_height * 0.2, rope[1], weight='semibold', **text_props)

def display_point_estimate():
if not point_estimate:
Expand All @@ -144,14 +141,15 @@ def display_point_estimate():

def display_hpd():
hpd_intervals = hpd(trace_values, alpha=alpha)
ax.plot(hpd_intervals, (plot_height * 0.02, plot_height * 0.02), lw=linewidth, color='k')
ax.plot(hpd_intervals, (plot_height * 0.02, plot_height * 0.02), lw=linewidth*2, color='k',
solid_capstyle='round')
ax.text(hpd_intervals[0], plot_height * 0.07,
hpd_intervals[0].round(round_to),
size=textsize, horizontalalignment='center')
ax.text(hpd_intervals[1], plot_height * 0.07,
hpd_intervals[1].round(round_to),
size=textsize, horizontalalignment='center')
ax.text((hpd_intervals[0] + hpd_intervals[1]) / 2, plot_height * 0.2,
ax.text((hpd_intervals[0] + hpd_intervals[1]) / 2, plot_height * 0.3,
format_as_percent(1 - alpha) + ' HPD',
size=textsize, horizontalalignment='center')

Expand All @@ -168,7 +166,7 @@ def format_axes():

if kind == 'kde' and isinstance(trace_values.iloc[0], float):
kdeplot(trace_values,
fill_alpha=kwargs.pop('alpha', 1),
fill_alpha=kwargs.pop('fill_alpha', 0.35),
bw=bw,
ax=ax,
lw=linewidth,
Expand Down
8 changes: 2 additions & 6 deletions arviz/plots/ppcplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,9 @@ def ppcplot(data, ppc_sample, kind='kde', mean=True, figsize=None, textsize=None
-------
ax : matplotlib axes
"""
if figsize is None:
figsize = (6, 5)
ax, figsize = _create_axes_grid(ppc_sample, figsize, ax)

if ax is None:
_, ax = _create_axes_grid(figsize, ppc_sample)

textsize, linewidth, _ = _scale_text(figsize, textsize)
textsize, linewidth, _ = _scale_text(figsize, textsize, 1.5)

for ax_, (var, ppss) in zip(np.atleast_1d(ax), ppc_sample.items()):
if kind == 'kde':
Expand Down

0 comments on commit b7b8a39

Please sign in to comment.