Skip to content

Commit

Permalink
Hexplot Example and other fixes (#217)
Browse files Browse the repository at this point in the history
* Add more verbose exception and remove padding

Add hexplot gallery example

* Fix end of file spacing for examples

* Fix linting errors
  • Loading branch information
canyon289 authored Sep 8, 2018
1 parent 2e0bc44 commit 9534b19
Show file tree
Hide file tree
Showing 6 changed files with 23 additions and 11 deletions.
15 changes: 9 additions & 6 deletions arviz/plots/pairplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,14 +54,17 @@ def pairplot(data, var_names=None, coords=None, figsize=None, textsize=None, kin
gs : matplotlib gridspec
"""
if kind not in ['scatter', 'hexbin']:
raise ValueError('Plot type {} not recognized.'.format(kind))
valid_kinds = ['scatter', 'hexbin']
if kind not in valid_kinds:
raise ValueError(('Plot type {} not recognized.'
'Plot type must be in {}').format(kind, valid_kinds))

if coords is None:
coords = {}

if plot_kwargs is None:
plot_kwargs = {}

# Get posterior draws and combine chains
posterior_data = convert_to_dataset(data, group='posterior')
_var_names, _posterior = xarray_to_nparray(posterior_data.sel(**coords),
Expand Down Expand Up @@ -129,18 +132,18 @@ def pairplot(data, var_names=None, coords=None, figsize=None, textsize=None, kin
if i == j == 0 and colorbar:
hexbin = ax.hexbin(var1, var2, mincnt=1, gridsize=gridsize, **plot_kwargs)
divider = make_axes_locatable(ax)
cax = divider.append_axes('right', size='7%', pad=0.1)
cax = divider.append_axes('right', size='7%')
cbar = plt.colorbar(hexbin,
ticks=[hexbin.norm.vmin, hexbin.norm.vmax],
cax=cax)
cbar.ax.set_yticklabels(['low', 'high'], fontsize=textsize)
divider.append_axes('top', size='7%', pad=0.1).set_axis_off()
divider.append_axes('top', size='7%').set_axis_off()

else:
ax.hexbin(var1, var2, mincnt=1, gridsize=gridsize, **plot_kwargs)
divider = make_axes_locatable(ax)
divider.append_axes('right', size='7%', pad=0.1).set_axis_off()
divider.append_axes('top', size='7%', pad=0.1).set_axis_off()
divider.append_axes('right', size='7%').set_axis_off()
divider.append_axes('top', size='7%').set_axis_off()

if divergences:
ax.scatter(var1[diverging_mask], var2[diverging_mask],
Expand Down
13 changes: 13 additions & 0 deletions examples/hexpairplot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
"""
Hexbin PairPlot
===============
_thumb: .2, .5
"""
import arviz as az

az.style.use('arviz-darkgrid')

centered = az.load_arviz_data('centered_eight')

az.pairplot(centered, var_names=['theta', "mu"], kind='hexbin', colorbar=True, divergences=True)
2 changes: 0 additions & 2 deletions examples/pairplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,3 @@

coords = {'school': ['Choate', 'Deerfield']}
az.pairplot(centered, var_names=['theta', 'mu', 'tau'], coords=coords, divergences=True)


1 change: 0 additions & 1 deletion examples/parallelplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,3 @@
centered_eight_trace = pm.sample()

az.parallelplot(centered_eight_trace, var_names=['theta', 'tau', 'mu'])

2 changes: 1 addition & 1 deletion examples/ridgeplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,4 @@
colors='white',
r_hat=False,
n_eff=False)
axes[0].set_title('Estimated theta for eight schools model', fontsize=11)
axes[0].set_title('Estimated theta for eight schools model', fontsize=11)
1 change: 0 additions & 1 deletion examples/violinplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,3 @@

non_centered = az.load_arviz_data('non_centered_eight')
az.violintraceplot(non_centered, var_names=["mu", "tau"], textsize=8)

0 comments on commit 9534b19

Please sign in to comment.