Skip to content

Commit

Permalink
Merge pull request #173 from ColCarroll/move_to_netcdf
Browse files Browse the repository at this point in the history
Add an `InferenceData` object
  • Loading branch information
ColCarroll committed Aug 26, 2018
2 parents c301673 + 8c60516 commit affe364
Show file tree
Hide file tree
Showing 25 changed files with 252 additions and 144 deletions.
2 changes: 1 addition & 1 deletion .pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,7 @@ ignore-on-opaque-inference=yes
# List of class names for which member attributes should not be checked (useful
# for classes with dynamically set attributes). This supports the use of
# qualified names.
ignored-classes=optparse.Values,thread._local,_thread._local
ignored-classes=optparse.Values,thread._local,_thread._local,netCDF4

# List of module names for which member attributes should not be checked
# (useful for modules/projects where namespaces are manipulated during runtime
Expand Down
7 changes: 5 additions & 2 deletions arviz/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
# pylint: disable=wildcard-import
# pylint: disable=wildcard-import,invalid-name,wrong-import-position
__version__ = '0.1.0'
from matplotlib.pyplot import style

config = {'default_data_directory': '.arviz_data'}

from .inference_data import InferenceData
from .plots import *
from .stats import *
from .utils import trace_to_dataframe, save_data, load_data, convert_to_xarray, load_arviz_data
from .utils import trace_to_dataframe, load_data, convert_to_netcdf, load_arviz_data
40 changes: 40 additions & 0 deletions arviz/inference_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import netCDF4 as nc
import xarray as xr


class InferenceData():
"""Container for accessing netCDF files using xarray."""

def __init__(self, filename):
"""Attach to a netcdf file.
This will inspect the netcdf for the available groups, so that they can be
later loaded into memory.
Parameters:
-----------
filename : str
netcdf4 file that contains groups for accessing with xarray.
"""
if filename == '': # netcdf freezes in this case
raise FileNotFoundError("No such file b''")
self._filename = filename
self._nc_dataset = nc.Dataset(filename, mode='r')
self._groups = self._nc_dataset.groups

def __repr__(self):
return 'Inference data from "{filename}" with groups:\n\t> {options}'.format(
filename=self._filename,
options='\n\t> '.join(self._groups)
)

def __getattr__(self, name):
"""Lazy load xarray DataSets when they are requested"""
if name in self._groups:
setattr(self, name, xr.open_dataset(self._filename, group=name))
return getattr(self, name)
return self.__getattribute__(name)

def __dir__(self):
"""Enable tab-completion in iPython and Jupyter environments"""
return super(InferenceData, self).__dir__() + list(self._groups.keys())
10 changes: 5 additions & 5 deletions arviz/plots/autocorrplot.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
import numpy as np

from .plot_utils import _scale_text, default_grid, make_label, xarray_var_iter, _create_axes_grid
from ..utils import convert_to_xarray
from ..utils import convert_to_netcdf
from ..stats.diagnostics import autocorr


def autocorrplot(posterior, var_names=None, max_lag=100, combined=False,
def autocorrplot(data, var_names=None, max_lag=100, combined=False,
figsize=None, textsize=None):
"""
Bar plot of the autocorrelation function for a posterior.
Parameters
----------
posterior : xarray, or object that can be converted (pystan or pymc3 draws)
Posterior samples
data : inference_data, or object that can be converted (pystan or pymc3 draws)
Must contain posterior data
var_names : list of variable names, optional
Variables to be plotted, if None all variable are plotted.
Vector-value stochastics are handled automatically.
Expand All @@ -32,7 +32,7 @@ def autocorrplot(posterior, var_names=None, max_lag=100, combined=False,
-------
axes : matplotlib axes
"""
data = convert_to_xarray(posterior)
data = convert_to_netcdf(data).posterior

plotters = list(xarray_var_iter(data, var_names, combined))
length_plotters = len(plotters)
Expand Down
12 changes: 5 additions & 7 deletions arviz/plots/densityplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@

from .kdeplot import fast_kde
from ..stats import hpd
from ..utils import convert_to_xarray
from ..utils import convert_to_netcdf
from .plot_utils import _scale_text, make_label, xarray_var_iter


def densityplot(data, data_labels=None, var_names=None, credible_interval=0.94,
point_estimate='mean', colors='cycle', outline=True, hpd_markers='', shade=0.,
bw=4.5, figsize=None, textsize=None, skip_first=0):
bw=4.5, figsize=None, textsize=None):
"""
Generates KDE plots for continuous variables and histograms for discretes ones.
Plots are truncated at their 100*(1-alpha)% credible intervals. Plots are grouped per variable
Expand Down Expand Up @@ -51,8 +51,6 @@ def densityplot(data, data_labels=None, var_names=None, credible_interval=0.94,
Figure size. If None, size is (6, number of variables * 2)
textsize: int
Text size for labels and legend. If None it will be autoscaled based on figsize.
skip_first : int
Number of first samples not shown in plots (burn-in).
Returns
-------
Expand All @@ -61,10 +59,10 @@ def densityplot(data, data_labels=None, var_names=None, credible_interval=0.94,
"""
if not isinstance(data, (list, tuple)):
datasets = [convert_to_xarray(data)]
datasets = [convert_to_netcdf(data)]
else:
datasets = [convert_to_xarray(d) for d in data]
datasets = [data.where(data.draw >= skip_first).dropna('draw') for data in datasets]
datasets = [convert_to_netcdf(d) for d in data]
datasets = [data.posterior for data in datasets]

if point_estimate not in ('mean', 'median', None):
raise ValueError(
Expand Down
5 changes: 3 additions & 2 deletions arviz/plots/forestplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from ..stats.diagnostics import _get_neff, _get_rhat
from ..stats import hpd
from .plot_utils import _scale_text, xarray_var_iter, make_label
from ..utils import convert_to_xarray
from ..utils import convert_to_netcdf
from .kdeplot import fast_kde


Expand Down Expand Up @@ -157,7 +157,8 @@ def __init__(self, data, var_names, model_names, combined, colors):
if not isinstance(data, (list, tuple)):
data = [data]

self.data = [convert_to_xarray(datum) for datum in reversed(data)] # y-values upside down
# y-values upside down
self.data = [convert_to_netcdf(datum).posterior for datum in reversed(data)]

if model_names is None:
if len(self.data) > 1:
Expand Down
4 changes: 2 additions & 2 deletions arviz/plots/jointplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from matplotlib.ticker import NullFormatter

from .kdeplot import kdeplot
from ..utils import convert_to_xarray
from ..utils import convert_to_netcdf
from .plot_utils import _scale_text, get_bins, xarray_var_iter, make_label


Expand Down Expand Up @@ -41,7 +41,7 @@ def jointplot(data, var_names=None, coords=None, figsize=None, textsize=None, ki
ax_hist_y : matplotlib axes, y (right) distribution
"""

data = convert_to_xarray(data)
data = convert_to_netcdf(data).posterior
if coords is None:
coords = {}

Expand Down
6 changes: 3 additions & 3 deletions arviz/plots/posteriorplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from scipy.stats import mode
from .kdeplot import kdeplot, fast_kde
from ..stats import hpd
from ..utils import convert_to_xarray
from ..utils import convert_to_netcdf
from .plot_utils import xarray_var_iter, _scale_text, make_label, default_grid, _create_axes_grid


Expand Down Expand Up @@ -89,7 +89,7 @@ def posteriorplot(data, var_names=None, coords=None, figsize=None, textsize=None
.. plot::
:context: close-figs
>>> az.posteriorplot(non_centered, var_names=('mu', 'theta_tilde',), rope=(-1, 1))
>>> az.posteriorplot(non_centered, var_names=("mu", 'theta_tilde',), rope=(-1, 1))
Plot Region of Practical Equivalence for selected distributions
Expand Down Expand Up @@ -128,7 +128,7 @@ def posteriorplot(data, var_names=None, coords=None, figsize=None, textsize=None
>>> az.posteriorplot(non_centered, var_names=('mu', 'theta_tilde',), credible_interval=.94)
"""
data = convert_to_xarray(data)
data = convert_to_netcdf(data).posterior

if coords is None:
coords = {}
Expand Down
4 changes: 2 additions & 2 deletions arviz/plots/traceplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import numpy as np

from .kdeplot import kdeplot
from ..utils import convert_to_xarray
from ..utils import convert_to_netcdf
from .plot_utils import _scale_text, get_bins, xarray_var_iter, make_label


Expand Down Expand Up @@ -38,7 +38,7 @@ def traceplot(data, var_names=None, coords=None, figsize=None, textsize=None, li
-------
axes : matplotlib axes
"""
data = convert_to_xarray(data)
data = convert_to_netcdf(data).posterior

if coords is None:
coords = {}
Expand Down
8 changes: 4 additions & 4 deletions arviz/plots/violintraceplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from .kdeplot import fast_kde
from .plot_utils import get_bins, _scale_text, xarray_var_iter, make_label
from ..stats import hpd
from ..utils import convert_to_xarray
from ..utils import convert_to_netcdf


def violintraceplot(data, var_names=None, quartiles=True, credible_interval=0.94, shade=0.35,
Expand All @@ -18,7 +18,7 @@ def violintraceplot(data, var_names=None, quartiles=True, credible_interval=0.94
Parameters
----------
data : xarray, or object that can be converted (pystan or pymc3 draws)
data : InferenceData, or object that can be converted (pystan or pymc3 draws)
Posterior samples
var_names: list, optional
List of variables to plot (defaults to None, which results in all variables plotted)
Expand Down Expand Up @@ -50,7 +50,7 @@ def violintraceplot(data, var_names=None, quartiles=True, credible_interval=0.94
"""

data = convert_to_xarray(data)
data = convert_to_netcdf(data).posterior
plotters = list(xarray_var_iter(data, var_names=var_names, combined=True))

if kwargs_shade is None:
Expand Down Expand Up @@ -94,7 +94,7 @@ def violintraceplot(data, var_names=None, quartiles=True, credible_interval=0.94

def _violinplot(val, shade, bw, ax, **kwargs_shade):
"""
Auxiliar function to plot violinplots
Auxiliary function to plot violinplots
"""
density, low_b, up_b = fast_kde(val, bw=bw)
x = np.linspace(low_b, up_b, len(density))
Expand Down
24 changes: 24 additions & 0 deletions arviz/tests/helpers.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,35 @@
import os
import pickle
import shutil
import sys
import tempfile

import matplotlib.pyplot as plt
import numpy as np
import pymc3 as pm
import pystan

import arviz


class BaseArvizTest():
@classmethod
def setup_class(cls):
cls.default_data_directory = arviz.config['default_data_directory']
cls.tempdir = tempfile.mkdtemp()
arviz.config['default_data_directory'] = cls.tempdir

@classmethod
def teardown_class(cls):
arviz.config['default_data_directory'] = cls.default_data_directory
shutil.rmtree(cls.tempdir)

def setup_method(self):
np.random.seed(1)

def teardown_method(self):
plt.close('all')


def eight_schools_params():
"""Share setup for eight schools"""
Expand Down
6 changes: 4 additions & 2 deletions arviz/tests/test_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,16 @@
import pymc3 as pm
import pytest

from .helpers import eight_schools_params, load_cached_models
from .helpers import eight_schools_params, load_cached_models, BaseArvizTest
from ..plots import (densityplot, traceplot, energyplot, posteriorplot, autocorrplot, forestplot,
parallelplot, pairplot, jointplot, ppcplot, violintraceplot)


class SetupPlots():
class SetupPlots(BaseArvizTest):

@classmethod
def setup_class(cls):
super().setup_class()
cls.data = eight_schools_params()
models = load_cached_models(draws=500, chains=2)
model, cls.short_trace = models['pymc3']
Expand All @@ -22,6 +23,7 @@ def setup_class(cls):
cls.df_trace = DataFrame({'a': np.random.poisson(2.3, 100)})

def teardown_method(self):
super().teardown_method()
plt.close('all')


Expand Down

0 comments on commit affe364

Please sign in to comment.