Skip to content

Commit

Permalink
Merge pull request #186 from ColCarroll/update-convert
Browse files Browse the repository at this point in the history
Update object conversion api
  • Loading branch information
ColCarroll committed Sep 2, 2018
2 parents 9e3b634 + 23211c7 commit 46dd9ed
Show file tree
Hide file tree
Showing 13 changed files with 449 additions and 645 deletions.
4 changes: 1 addition & 3 deletions arviz/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,7 @@
__version__ = '0.1.0'
from matplotlib.pyplot import style

config = {'default_data_directory': '.arviz_data'}

from .inference_data import InferenceData
from .plots import *
from .stats import *
from .utils import trace_to_dataframe, load_data, convert_to_netcdf, load_arviz_data
from .utils import *
44 changes: 26 additions & 18 deletions arviz/inference_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
class InferenceData():
"""Container for accessing netCDF files using xarray."""

def __init__(self, filename):
def __init__(self, *_, **kwargs):
"""Attach to a netcdf file.
This will inspect the netcdf for the available groups, so that they can be
Expand All @@ -16,25 +16,33 @@ def __init__(self, filename):
filename : str
netcdf4 file that contains groups for accessing with xarray.
"""
if filename == '': # netcdf freezes in this case
raise FileNotFoundError("No such file b''")
self._filename = filename
self._nc_dataset = nc.Dataset(filename, mode='r')
self._groups = self._nc_dataset.groups
self._groups = []
for key, dataset in kwargs.items():
if dataset is None:
continue
elif not isinstance(dataset, xr.Dataset):
raise ValueError('Arguments to InferenceData must be xarray Datasets '
'(argument "{}" was type "{}")'.format(key, type(dataset)))
setattr(self, key, dataset)
self._groups.append(key)

def __repr__(self):
return 'Inference data from "{filename}" with groups:\n\t> {options}'.format(
filename=self._filename,
return 'Inference data with groups:\n\t> {options}'.format(
options='\n\t> '.join(self._groups)
)

def __getattr__(self, name):
"""Lazy load xarray DataSets when they are requested"""
if name in self._groups:
setattr(self, name, xr.open_dataset(self._filename, group=name))
return getattr(self, name)
return self.__getattribute__(name)

def __dir__(self):
"""Enable tab-completion in iPython and Jupyter environments"""
return super(InferenceData, self).__dir__() + list(self._groups.keys())
@staticmethod
def from_netcdf(filename):
groups = {}
for group in nc.Dataset(filename, mode='r').groups:
groups[group] = xr.open_dataset(filename, group=group)
return InferenceData(**groups)

def to_netcdf(self, filename):
mode = 'w' # overwrite first, then append
for group in self._groups:
data = getattr(self, group)
data.to_netcdf(filename, mode=mode, group=group)
data.close()
mode = 'a'
return filename
4 changes: 2 additions & 2 deletions arviz/plots/autocorrplot.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import numpy as np

from .plot_utils import _scale_text, default_grid, make_label, xarray_var_iter, _create_axes_grid
from ..utils import convert_to_netcdf
from ..utils import convert_to_dataset
from ..stats.diagnostics import autocorr


Expand Down Expand Up @@ -32,7 +32,7 @@ def autocorrplot(data, var_names=None, max_lag=100, combined=False,
-------
axes : matplotlib axes
"""
data = convert_to_netcdf(data).posterior
data = convert_to_dataset(data, 'posterior')

plotters = list(xarray_var_iter(data, var_names, combined))
length_plotters = len(plotters)
Expand Down
7 changes: 3 additions & 4 deletions arviz/plots/densityplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from .kdeplot import fast_kde
from ..stats import hpd
from ..utils import convert_to_netcdf
from ..utils import convert_to_dataset
from .plot_utils import _scale_text, make_label, xarray_var_iter


Expand Down Expand Up @@ -59,10 +59,9 @@ def densityplot(data, data_labels=None, var_names=None, credible_interval=0.94,
"""
if not isinstance(data, (list, tuple)):
datasets = [convert_to_netcdf(data)]
datasets = [convert_to_dataset(data, 'posterior')]
else:
datasets = [convert_to_netcdf(d) for d in data]
datasets = [data.posterior for data in datasets]
datasets = [convert_to_dataset(d, 'posterior') for d in data]

if point_estimate not in ('mean', 'median', None):
raise ValueError(
Expand Down
4 changes: 2 additions & 2 deletions arviz/plots/forestplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from ..stats.diagnostics import _get_neff, _get_rhat
from ..stats import hpd
from .plot_utils import _scale_text, xarray_var_iter, make_label
from ..utils import convert_to_netcdf
from ..utils import convert_to_dataset
from .kdeplot import fast_kde


Expand Down Expand Up @@ -158,7 +158,7 @@ def __init__(self, data, var_names, model_names, combined, colors):
data = [data]

# y-values upside down
self.data = [convert_to_netcdf(datum).posterior for datum in reversed(data)]
self.data = [convert_to_dataset(datum, 'posterior') for datum in reversed(data)]

if model_names is None:
if len(self.data) > 1:
Expand Down
4 changes: 2 additions & 2 deletions arviz/plots/jointplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from matplotlib.ticker import NullFormatter

from .kdeplot import kdeplot
from ..utils import convert_to_netcdf
from ..utils import convert_to_dataset
from .plot_utils import _scale_text, get_bins, xarray_var_iter, make_label


Expand Down Expand Up @@ -41,7 +41,7 @@ def jointplot(data, var_names=None, coords=None, figsize=None, textsize=None, ki
ax_hist_y : matplotlib axes, y (right) distribution
"""

data = convert_to_netcdf(data).posterior
data = convert_to_dataset(data, 'posterior')
if coords is None:
coords = {}

Expand Down
4 changes: 2 additions & 2 deletions arviz/plots/posteriorplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from scipy.stats import mode
from .kdeplot import kdeplot, fast_kde
from ..stats import hpd
from ..utils import convert_to_netcdf
from ..utils import convert_to_dataset
from .plot_utils import xarray_var_iter, _scale_text, make_label, default_grid, _create_axes_grid


Expand Down Expand Up @@ -128,7 +128,7 @@ def posteriorplot(data, var_names=None, coords=None, figsize=None, textsize=None
>>> az.posteriorplot(non_centered, var_names=('mu', 'theta_tilde',), credible_interval=.94)
"""
data = convert_to_netcdf(data).posterior
data = convert_to_dataset(data, 'posterior')

if coords is None:
coords = {}
Expand Down
4 changes: 2 additions & 2 deletions arviz/plots/traceplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import numpy as np

from .kdeplot import kdeplot
from ..utils import convert_to_netcdf
from ..utils import convert_to_dataset
from .plot_utils import _scale_text, get_bins, xarray_var_iter, make_label


Expand Down Expand Up @@ -38,7 +38,7 @@ def traceplot(data, var_names=None, coords=None, figsize=None, textsize=None, li
-------
axes : matplotlib axes
"""
data = convert_to_netcdf(data).posterior
data = convert_to_dataset(data, 'posterior')

if coords is None:
coords = {}
Expand Down
4 changes: 2 additions & 2 deletions arviz/plots/violintraceplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from .kdeplot import fast_kde
from .plot_utils import get_bins, _scale_text, xarray_var_iter, make_label
from ..stats import hpd
from ..utils import convert_to_netcdf
from ..utils import convert_to_dataset


def violintraceplot(data, var_names=None, quartiles=True, credible_interval=0.94, shade=0.35,
Expand Down Expand Up @@ -50,7 +50,7 @@ def violintraceplot(data, var_names=None, quartiles=True, credible_interval=0.94
"""

data = convert_to_netcdf(data).posterior
data = convert_to_dataset(data, 'posterior')
plotters = list(xarray_var_iter(data, var_names=var_names, combined=True))

if kwargs_shade is None:
Expand Down
12 changes: 3 additions & 9 deletions arviz/tests/helpers.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,22 @@
import os
import pickle
import shutil
import sys
import tempfile

import matplotlib.pyplot as plt
import numpy as np
import pymc3 as pm
import pystan

import arviz


class BaseArvizTest():

@classmethod
def setup_class(cls):
cls.default_data_directory = arviz.config['default_data_directory']
cls.tempdir = tempfile.mkdtemp()
arviz.config['default_data_directory'] = cls.tempdir
pass

@classmethod
def teardown_class(cls):
arviz.config['default_data_directory'] = cls.default_data_directory
shutil.rmtree(cls.tempdir)
pass

def setup_method(self):
np.random.seed(1)
Expand Down

0 comments on commit 46dd9ed

Please sign in to comment.