In [34]:
from arviz.data.base import generate_dims_coords
from arviz.data import load_arviz_data
from arviz.data.converters import convert_to_inference_data
from line_profiler import LineProfiler
import numpy as np
from numpy import array, average, dot
import numba
from copy import deepcopy
import datetime
import warnings
from arviz.plots.kdeplot import _fast_kde_2d as f2
import numpy as np
import pkg_resources
import xarray as xr
import timeit
from scipy.signal import gaussian, convolve, convolve2d  # pylint: disable=no-name-in-module
from scipy.sparse import coo_matrix
from collections import OrderedDict
from collections.abc import Sequence
from copy import copy as ccopy, deepcopy
from datetime import datetime
import netCDF4 as nc
import numpy as np
import xarray as xr


In [2]:
lp = LineProfiler()
wrapper = lp(generate_dims_coords)
wrapper((500,600,80), 'x')
lp.print_stats()

Timer unit: 1e-06 s

Total time: 0.000579 s
File: /home/banzee/Desktop/arviz/arviz/data/base.py
Function: generate_dims_coords at line 30

Line #      Hits         Time  Per Hit   % Time  Line Contents
    30                                           def generate_dims_coords(shape, var_name, dims=None, coords=None, default_dims=None):
    31                                               """Generate default dimensions and coordinates for a variable.
    32                                           
    33                                               Parameters
    34                                               ----------
    35                                               shape : tuple[int]
    36                                                   Shape of the variable
    37                                               var_name : str
    38                                                   Name of the variable. Used in the default name, if necessary
    39                          

In [3]:
@numba.njit
def range_(x):
    return np.arange(x)


def range_jit(x):
    return np.arange(x)

In [4]:
%timeit range_(100)

The slowest run took 16.34 times longer than the fastest. This could mean that an intermediate result is being cached.
3.92 µs ± 5.66 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [5]:
%timeit range_jit(100)

1.46 µs ± 96.1 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)


In [6]:
def generate_dims_coords(shape, var_name, dims=None, coords=None, default_dims=None):
    if default_dims is None:
        default_dims = []
    if dims is None:
        dims = []
    if len([dim for dim in dims if dim not in default_dims]) > len(shape):
        warnings.warn(
            (
                "In variable {var_name}, there are "
                + "more dims ({dims_len}) given than exist ({shape_len}). "
                + "Passed array should have shape (chains, draws, *shape)"
            ).format(var_name=var_name, dims_len=len(dims), shape_len=len(shape)),
            SyntaxWarning,
        )
    if coords is None:
        coords = {}

    coords = deepcopy(coords)
    dims = deepcopy(dims)

    for idx, dim_len in enumerate(shape):
        if (len(dims) < idx + 1) or (dims[idx] is None):
            dim_name = "{var_name}_dim_{idx}".format(var_name=var_name, idx=idx)
            if len(dims) < idx + 1:
                dims.append(dim_name)
            else:
                dims[idx] = dim_name
        dim_name = dims[idx]
        if dim_name not in coords:
            coords[dim_name] = np.arange(dim_len)
    coords = {key: coord for key, coord in coords.items() if any(key == dim for dim in dims)}
    return dims, coords



def generate_dims_coords_jit(shape, var_name, dims=None, coords=None, default_dims=None):
    if default_dims is None:
        default_dims = []
    if dims is None:
        dims = []
    if len([dim for dim in dims if dim not in default_dims]) > len(shape):
        warnings.warn(
            (
                "In variable {var_name}, there are "
                + "more dims ({dims_len}) given than exist ({shape_len}). "
                + "Passed array should have shape (chains, draws, *shape)"
            ).format(var_name=var_name, dims_len=len(dims), shape_len=len(shape)),
            SyntaxWarning,
        )
    if coords is None:
        coords = {}

    coords = deepcopy(coords)
    dims = deepcopy(dims)

    for idx, dim_len in enumerate(shape):
        if (len(dims) < idx + 1) or (dims[idx] is None):
            dim_name = "{var_name}_dim_{idx}".format(var_name=var_name, idx=idx)
            if len(dims) < idx + 1:
                dims.append(dim_name)
            else:
                dims[idx] = dim_name
        dim_name = dims[idx]
        if dim_name not in coords:
            coords[dim_name] = range_(dim_len)
    coords = {key: coord for key, coord in coords.items() if any(key == dim for dim in dims)}
    return dims, coords


In [7]:
%timeit generate_dims_coords((10000,10000), 'x')

55.6 µs ± 3.2 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)


In [8]:
%timeit generate_dims_coords_jit((10000,10000), 'x')

35.8 µs ± 3.51 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)


In [9]:
%timeit generate_dims_coords((10,190), 'x')

20 µs ± 1.33 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)


In [10]:
%timeit generate_dims_coords_jit((10,190), 'x')

20 µs ± 1.04 µs per loop (mean ± std. dev. of 7 runs, 100000 loops each)


In [11]:
def numpy_to_data_array(ary, *, var_name="data", coords=None, dims=None):
    # manage and transform copies
    default_dims = ["chain", "draw"]
    ary = np.atleast_2d(ary)
    n_chains, n_samples, *shape = ary.shape
    if n_chains > n_samples:
        warnings.warn(
            "More chains ({n_chains}) than draws ({n_samples}). "
            "Passed array should have shape (chains, draws, *shape)".format(
                n_chains=n_chains, n_samples=n_samples
            ),
            SyntaxWarning,
        )

    dims, coords = generate_dims_coords(
        shape, var_name, dims=dims, coords=coords, default_dims=default_dims
    )

    # reversed order for default dims: 'chain', 'draw'
    if "draw" not in dims:
        dims = ["draw"] + dims
    if "chain" not in dims:
        dims = ["chain"] + dims

    if "chain" not in coords:
        coords["chain"] = np.arange(n_chains)
    if "draw" not in coords:
        coords["draw"] = np.arange(n_samples)

    # filter coords based on the dims
    coords = {key: xr.IndexVariable((key,), data=coords[key]) for key in dims}
    return xr.DataArray(ary, coords=coords, dims=dims)


def numpy_to_data_array_jit(ary, *, var_name="data", coords=None, dims=None):
    # manage and transform copies
    default_dims = ["chain", "draw"]
    ary = np.atleast_2d(ary)
    n_chains, n_samples, *shape = ary.shape
    if n_chains > n_samples:
        warnings.warn(
            "More chains ({n_chains}) than draws ({n_samples}). "
            "Passed array should have shape (chains, draws, *shape)".format(
                n_chains=n_chains, n_samples=n_samples
            ),
            SyntaxWarning,
        )

    dims, coords = generate_dims_coords_jit(
        shape, var_name, dims=dims, coords=coords, default_dims=default_dims
    )

    # reversed order for default dims: 'chain', 'draw'
    if "draw" not in dims:
        dims = ["draw"] + dims
    if "chain" not in dims:
        dims = ["chain"] + dims

    if "chain" not in coords:
        coords["chain"] = range_(n_chains)
    if "draw" not in coords:
        coords["draw"] = range_(n_samples)

    # filter coords based on the dims
    coords = {key: xr.IndexVariable((key,), data=coords[key]) for key in dims}
    return xr.DataArray(ary, coords=coords, dims=dims)

In [12]:
data = np.random.randn(10000,100)
linear = np.random.randn(1000000)
small = np.random.randn(100,100)

In [13]:
%timeit numpy_to_data_array(data)

  if sys.path[0] == '':


472 µs ± 81.4 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [14]:
%timeit numpy_to_data_array_jit(data)



409 µs ± 50 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [15]:
%timeit numpy_to_data_array(linear)

2.84 ms ± 256 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [16]:
%timeit numpy_to_data_array_jit(linear)

3.03 ms ± 123 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [17]:
%timeit numpy_to_data_array(small)

414 µs ± 67.9 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [18]:
%timeit numpy_to_data_array_jit(small)

362 µs ± 35 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [19]:
'Very Similar Performance. Up for reconsideration'

'Very Similar Performance. Up for reconsideration'

In [20]:
# Dict to dataset bottleneck ---> numpy_to_ndarray

In [21]:
""""""""""""""""""""""""""""""""""""""""""""""Converters"""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""

'"Converters'

In [22]:
data = np.random.randn(10000,100)

In [23]:
lp = LineProfiler()
wrapper = lp(convert_to_inference_data)
wrapper(data)
lp.print_stats()

Timer unit: 1e-06 s

Total time: 0.004205 s
File: /home/banzee/Desktop/arviz/arviz/data/converters.py
Function: convert_to_inference_data at line 16

Line #      Hits         Time  Per Hit   % Time  Line Contents
    16                                           def convert_to_inference_data(obj, *, group="posterior", coords=None, dims=None, **kwargs):
    17                                               r"""Convert a supported object to an InferenceData object.
    18                                           
    19                                               This function sends `obj` to the right conversion function. It is idempotent,
    20                                               in that it will return arviz.InferenceData objects unchanged.
    21                                           
    22                                               Parameters
    23                                               ----------
    24                                               obj : d



In [24]:
# Bottleneck is dict to dataset. Refer above

In [25]:
""""""""""""""""""""""""""""""""""""""""""""""""""""DATASETS.PY"""""""""""""""""""""""""""""""""""""""""""""""""""""""""""

'"DATASETS.PY'

In [26]:
# Nothing to improve here

In [27]:
"""""""""""""""""""""""""""""""""""""""""""""""""""""io_dict"""""""""""""""""""""""""""""""""""""""""""""""""""""""""'""'

'""io_dict""'

In [28]:
# Bottleneck is dict to dataset. Refer above

In [29]:
"""""""""""""""""""""""""""""""""""""""""io_netcdf"""""""""""""""""""""""""""""""""""""""""""""""""""""""""

'""io_netcdf'

In [30]:
#Bottlenecks---->Inference data and convert_to_inference_data

In [31]:
"""""""""""""""""""""""""""""""""""""""""""""Inference_data"""""""""""""""""""""""""""""""""""""""""""""""""""""""""

'Inference_data'

In [48]:
class InferenceData:
    """Container for accessing netCDF files using xarray."""

    def __init__(self, **kwargs):
        """Initialize InferenceData object from keyword xarray datasets.

        Examples
        --------
        InferenceData(posterior=posterior, prior=prior)

        Parameters
        ----------
        kwargs :
            Keyword arguments of xarray datasets
        """
        self._groups = []
        for key, dataset in kwargs.items():
            if dataset is None:
                continue
            elif not isinstance(dataset, xr.Dataset):
                raise ValueError(
                    "Arguments to InferenceData must be xarray Datasets "
                    '(argument "{}" was type "{}")'.format(key, type(dataset))
                )
            setattr(self, key, dataset)
            self._groups.append(key)

    def __repr__(self):
        """Make string representation of object."""
        return "Inference data with groups:\n\t> {options}".format(
            options="\n\t> ".join(self._groups)
        )

    @staticmethod
    def from_netcdf(filename):
        """Initialize object from a netcdf file.

        Expects that the file will have groups, each of which can be loaded by xarray.

        Parameters
        ----------
        filename : str
            location of netcdf file

        Returns
        -------
        InferenceData object
        """
        groups = {}
        with nc.Dataset(filename, mode="r") as data:
            data_groups = list(data.groups)

        for group in data_groups:
            with xr.open_dataset(filename, group=group) as data:
                groups[group] = data
        return InferenceData(**groups)

    def to_netcdf(self, filename, compress=True):
        """Write InferenceData to file using netcdf4.

        Parameters
        ----------
        filename : str
            Location to write to
        compress : bool
            Whether to compress result. Note this saves disk space, but may make
            saving and loading somewhat slower (default: True).

        Returns
        -------
        str
            Location of netcdf file
        """
        mode = "w"  # overwrite first, then append
        if self._groups:  # check's whether a group is present or not.
            for group in self._groups:
                data = getattr(self, group)
                kwargs = {}
                if compress:
                    kwargs["encoding"] = {var_name: {"zlib": True} for var_name in data.variables}
                data.to_netcdf(filename, mode=mode, group=group, **kwargs)
                data.close()
                mode = "a"
        else:  # creates a netcdf file for an empty InferenceData object.
            empty_netcdf_file = nc.Dataset(filename, mode="w", format="NETCDF4")
            empty_netcdf_file.close()
        return filename

    def __add__(self, other):
        """Concatenate two InferenceData objects."""
        return concat(self, other, copy=True, inplace=False)

    def sel(self, inplace=True, **kwargs):
        """Perform an xarray selection on all groups.

        Loops over all groups to perform Dataset.sel(key=item)
        for every kwarg if key is a dimension of the dataset.
        The selection is performed inplace.

        Parameters
        ----------
        inplace : bool
            If True, modify the InferenceData object inplace, otherwise, return the modified copy.
        **kwargs : mapping
            It must be accepted by Dataset.sel()
        """
        out = self if inplace else deepcopy(self)
        for group in self._groups:
            dataset = getattr(self, group)
            valid_keys = set(kwargs.keys()).intersection(dataset.dims)
            dataset = dataset.sel(**{key: kwargs[key] for key in valid_keys})
            setattr(out, group, dataset)
        if inplace:
            return None
        else:
            return out


@numba.jit(forceobj=True)
def concat(*args, dim=None, copy=True, inplace=False, reset_dim=True):
    """Concatenate InferenceData objects.

    Concatenates over `group`, `chain` or `draw`.
    By default concatenates over unique groups.
    To concatenate over `chain` or `draw` function
    needs identical groups and variables.

    The `variables` in the `data` -group are merged if `dim` are not found.


    Parameters
    ----------
    *args : InferenceData
        Variable length InferenceData list or
        Sequence of InferenceData.
    dim : str, optional
        If defined, concatenated over the defined dimension.
        Dimension which is concatenated. If None, concatenates over
        unique groups.
    copy : bool
        If True, groups are copied to the new InferenceData object.
        Used only if `dim` is None.
    inplace : bool
        If True, merge args to first object.
    reset_dim : bool
        Valid only if dim is not None.

    Returns
    -------
    InferenceData
        A new InferenceData object by default.
        When `inplace==True` merge args to first arg and return `None`
    """
    # pylint: disable=undefined-loop-variable, too-many-nested-blocks
    if len(args) == 0:
        if inplace:
            return
        return InferenceData()

    if len(args) == 1 and isinstance(args[0], Sequence):
        args = args[0]

    # assert that all args are InferenceData
    for i, arg in enumerate(args):
        if not isinstance(arg, InferenceData):
            raise TypeError(
                "Concatenating is supported only"
                "between InferenceData objects. Input arg {} is {}".format(i, type(arg))
            )

    if dim is not None and dim.lower() not in {"group", "chain", "draw"}:
        msg = "Invalid `dim`: {}. Valid `dim` are {}".format(dim, '{"group", "chain", "draw"}')
        raise TypeError(msg)
    dim = dim.lower() if dim is not None else dim

    if len(args) == 1 and isinstance(args[0], InferenceData):
        if inplace:
            return None
        else:
            if copy:
                return deepcopy(args[0])
            else:
                return args[0]

    current_time = str(datetime.now())

    if not inplace:
        # Keep order for python 3.5
        inference_data_dict = OrderedDict()

    if dim is None:
        arg0 = args[0]
        arg0_groups = ccopy(arg0._groups)
        args_groups = dict()
        # check if groups are independent
        # Concat over unique groups
        for arg in args[1:]:
            for group in arg._groups:
                if group in args_groups or group in arg0_groups:
                    msg = (
                        "Concatenating overlapping groups is not supported unless `dim` is defined."
                    )
                    msg += " Valid dimensions are `chain` and `draw`."
                    raise TypeError(msg)
            group_data = getattr(arg, group)
            args_groups[group] = deepcopy(group_data) if copy else group_data
        # add arg0 to args_groups if inplace is False
        if not inplace:
            for group in arg0_groups:
                group_data = getattr(arg0, group)
                args_groups[group] = deepcopy(group_data) if copy else group_data

        basic_order = [
            "posterior",
            "posterior_predictive",
            "sample_stats",
            "prior",
            "prior_predictive",
            "sample_stats_prior",
            "observed_data",
        ]
        other_groups = [group for group in args_groups if group not in basic_order]

        for group in basic_order + other_groups:
            if group not in args_groups:
                continue
            if inplace:
                arg0._groups.append(group)
                setattr(arg0, group, args_groups[group])
            else:
                inference_data_dict[group] = args_groups[group]
        if inplace:
            other_groups = [
                group for group in arg0_groups if group not in basic_order
            ] + other_groups
            sorted_groups = [group for group in basic_order + other_groups if group in arg0._groups]
            setattr(arg0, "_groups", sorted_groups)
    else:
        arg0 = args[0]
        arg0_groups = arg0._groups
        for arg in args[1:]:
            for group0 in arg0_groups:
                if group0 not in arg._groups:
                    if group0 == "observed_data":
                        continue
                    msg = "Mismatch between the groups."
                    raise TypeError(msg)
            for group in arg._groups:
                if group != "observed_data":
                    # assert that groups are equal
                    if group not in arg0_groups:
                        msg = "Mismatch between the groups."
                        raise TypeError(msg)

                    # assert that variables are equal
                    group_data = getattr(arg, group)
                    group_vars = group_data.data_vars

                    if not inplace and group in inference_data_dict:
                        group0_data = inference_data_dict[group]
                    else:
                        group0_data = getattr(arg0, group)
                    group0_vars = group0_data.data_vars

                    for var in group0_vars:
                        if var not in group_vars:
                            msg = "Mismatch between the variables."
                            raise TypeError(msg)

                    for var in group_vars:
                        if var not in group0_vars:
                            msg = "Mismatch between the variables."
                            raise TypeError(msg)
                        var_dims = getattr(group_data, var).dims
                        var0_dims = getattr(group0_data, var).dims
                        if var_dims != var0_dims:
                            msg = "Mismatch between the dimensions."
                            raise TypeError(msg)

                        if dim not in var_dims or dim not in var0_dims:
                            msg = "Dimension {} missing.".format(dim)
                            raise TypeError(msg)

                    # xr.concat
                    concatenated_group = xr.concat((group_data, group0_data), dim=dim)
                    if reset_dim:
                        concatenated_group[dim] = range(concatenated_group[dim].size)

                    # handle attrs
                    if hasattr(group0_data, "attrs"):
                        group0_attrs = deepcopy(getattr(group0_data, "attrs"))
                    else:
                        group0_attrs = OrderedDict()

                    if hasattr(group_data, "attrs"):
                        group_attrs = getattr(group_data, "attrs")
                    else:
                        group_attrs = dict()

                    # gather attrs results to group0_attrs
                    for attr_key, attr_values in group_attrs.items():
                        group0_attr_values = group0_attrs.get(attr_key, None)
                        equality = attr_values == group0_attr_values
                        if hasattr(equality, "__iter__"):
                            equality = np.all(equality)
                        if equality:
                            continue
                        # handle special cases:
                        if attr_key in ("created_at", "previous_created_at"):
                            # check the defaults
                            if not hasattr(group0_attrs, "previous_created_at"):
                                group0_attrs["previous_created_at"] = []
                                if group0_attr_values is not None:
                                    group0_attrs["previous_created_at"].append(group0_attr_values)
                            # check previous values
                            if attr_key == "previous_created_at":
                                if not isinstance(attr_values, list):
                                    attr_values = [attr_values]
                                group0_attrs["previous_created_at"].extend(attr_values)
                                continue
                            # update "created_at"
                            if group0_attr_values != current_time:
                                group0_attrs[attr_key] = current_time
                            group0_attrs["previous_created_at"].append(attr_values)

                        elif attr_key in group0_attrs:
                            combined_key = "combined_{}".format(attr_key)
                            if combined_key not in group0_attrs:
                                group0_attrs[combined_key] = [group0_attr_values]
                            group0_attrs[combined_key].append(attr_values)
                        else:
                            group0_attrs[attr_key] = attr_values
                    # update attrs
                    setattr(concatenated_group, "attrs", group0_attrs)

                    if inplace:
                        setattr(arg0, group, concatenated_group)
                    else:
                        inference_data_dict[group] = concatenated_group
                else:
                    # observed_data
                    if group not in arg0_groups:
                        setattr(arg0, group, deepcopy(group_data) if copy else group_data)
                        arg0._groups.append(group)
                        continue

                    # assert that variables are equal
                    group_data = getattr(arg, group)
                    group_vars = group_data.data_vars

                    group0_data = getattr(arg0, group)
                    if not inplace:
                        group0_data = deepcopy(group0_data)
                    group0_vars = group0_data.data_vars

                    for var in group_vars:
                        if var not in group0_vars:
                            var_data = getattr(group_data, var)
                            arg0.observed_data[var] = var_data
                        else:
                            var_data = getattr(group_data, var)
                            var0_data = getattr(group0_data, var)
                            if dim in var_data.dims and dim in var0_data.dims:
                                concatenated_var = xr.concat((group_data, group0_data), dim=dim)
                                group0_data[var] = concatenated_var

                    # handle attrs
                    if hasattr(group0_data, "attrs"):
                        group0_attrs = getattr(group0_data, "attrs")
                    else:
                        group0_attrs = OrderedDict()

                    if hasattr(group_data, "attrs"):
                        group_attrs = getattr(group_data, "attrs")
                    else:
                        group_attrs = dict()

                    # gather attrs results to group0_attrs
                    for attr_key, attr_values in group_attrs.items():
                        group0_attr_values = group0_attrs.get(attr_key, None)
                        equality = attr_values == group0_attr_values
                        if hasattr(equality, "__iter__"):
                            equality = np.all(equality)
                        if equality:
                            continue
                        # handle special cases:
                        if attr_key in ("created_at", "previous_created_at"):
                            # check the defaults
                            if not hasattr(group0_attrs, "previous_created_at"):
                                group0_attrs["previous_created_at"] = []
                                if group0_attr_values is not None:
                                    group0_attrs["previous_created_at"].append(group0_attr_values)
                            # check previous values
                            if attr_key == "previous_created_at":
                                if not isinstance(attr_values, list):
                                    attr_values = [attr_values]
                                group0_attrs["previous_created_at"].extend(attr_values)
                                continue
                            # update "created_at"
                            if group0_attr_values != current_time:
                                group0_attrs[attr_key] = current_time
                            group0_attrs["previous_created_at"].append(attr_values)

                        elif attr_key in group0_attrs:
                            combined_key = "combined_{}".format(attr_key)
                            if combined_key not in group0_attrs:
                                group0_attrs[combined_key] = [group0_attr_values]
                            group0_attrs[combined_key].append(attr_values)

                        else:
                            group0_attrs[attr_key] = attr_values
                    # update attrs
                    setattr(group0_data, "attrs", group0_attrs)

                    if inplace:
                        setattr(arg0, group, group0_data)
                    else:
                        inference_data_dict[group] = group0_data

    return None if inplace else InferenceData(**inference_data_dict)