In [118]:
from arviz.data.base import generate_dims_coords
from arviz.data import load_arviz_data
from arviz.data.converters import convert_to_inference_data
from line_profiler import LineProfiler
import numpy as np
from numpy import array, average, dot
import numba
import pandas as pd
from copy import deepcopy
import datetime
import warnings
from arviz.plots.kdeplot import _fast_kde_2d as f2
import numpy as np
import pkg_resources
import xarray as xr
import timeit
from scipy.signal import gaussian, convolve, convolve2d  # pylint: disable=no-name-in-module
from scipy.sparse import coo_matrix
from collections import OrderedDict,defaultdict
from collections.abc import Sequence
from copy import copy as ccopy, deepcopy
from datetime import datetime
import netCDF4 as nc
import numpy as np
import xarray as xr
import re


In [2]:
lp = LineProfiler()
wrapper = lp(generate_dims_coords)
wrapper((500,600,80), 'x')
lp.print_stats()

Timer unit: 1e-06 s

Total time: 0.000435 s
File: /home/banzee/Desktop/arviz/arviz/data/base.py
Function: generate_dims_coords at line 30

Line #      Hits         Time  Per Hit   % Time  Line Contents
    30                                           def generate_dims_coords(shape, var_name, dims=None, coords=None, default_dims=None):
    31                                               """Generate default dimensions and coordinates for a variable.
    32                                           
    33                                               Parameters
    34                                               ----------
    35                                               shape : tuple[int]
    36                                                   Shape of the variable
    37                                               var_name : str
    38                                                   Name of the variable. Used in the default name, if necessary
    39                          

In [3]:
@numba.njit
def range_(x):
    return np.arange(x)


def range_jit(x):
    return np.arange(x)

In [4]:
%timeit range_(100)

The slowest run took 15.71 times longer than the fastest. This could mean that an intermediate result is being cached.
3.15 µs ± 4.96 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [5]:
%timeit range_jit(100)

1.4 µs ± 165 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)


In [6]:
def generate_dims_coords(shape, var_name, dims=None, coords=None, default_dims=None):
    if default_dims is None:
        default_dims = []
    if dims is None:
        dims = []
    if len([dim for dim in dims if dim not in default_dims]) > len(shape):
        warnings.warn(
            (
                "In variable {var_name}, there are "
                + "more dims ({dims_len}) given than exist ({shape_len}). "
                + "Passed array should have shape (chains, draws, *shape)"
            ).format(var_name=var_name, dims_len=len(dims), shape_len=len(shape)),
            SyntaxWarning,
        )
    if coords is None:
        coords = {}

    coords = deepcopy(coords)
    dims = deepcopy(dims)

    for idx, dim_len in enumerate(shape):
        if (len(dims) < idx + 1) or (dims[idx] is None):
            dim_name = "{var_name}_dim_{idx}".format(var_name=var_name, idx=idx)
            if len(dims) < idx + 1:
                dims.append(dim_name)
            else:
                dims[idx] = dim_name
        dim_name = dims[idx]
        if dim_name not in coords:
            coords[dim_name] = np.arange(dim_len)
    coords = {key: coord for key, coord in coords.items() if any(key == dim for dim in dims)}
    return dims, coords



def generate_dims_coords_jit(shape, var_name, dims=None, coords=None, default_dims=None):
    if default_dims is None:
        default_dims = []
    if dims is None:
        dims = []
    if len([dim for dim in dims if dim not in default_dims]) > len(shape):
        warnings.warn(
            (
                "In variable {var_name}, there are "
                + "more dims ({dims_len}) given than exist ({shape_len}). "
                + "Passed array should have shape (chains, draws, *shape)"
            ).format(var_name=var_name, dims_len=len(dims), shape_len=len(shape)),
            SyntaxWarning,
        )
    if coords is None:
        coords = {}

    coords = deepcopy(coords)
    dims = deepcopy(dims)
    
    for idx, dim_len in enumerate(shape):
        if (len(dims) < idx + 1) or (dims[idx] is None):
            dim_name = "{var_name}_dim_{idx}".format(var_name=var_name, idx=idx)
            if len(dims) < idx + 1:
                dims.append(dim_name)
            else:
                dims[idx] = dim_name
        dim_name = dims[idx]
        if dim_name not in coords:
            coords[dim_name] = range_(dim_len)
    coords = {key: coord for key, coord in coords.items() if any(key == dim for dim in dims)}
    return dims, coords


In [7]:
@numba.jit(forceobj=True)
def jit_loop(dims,coords,shape,var_name):
    for idx,dim_len in enumerate(shape):
        if (len(dims)<idx+1) or (dims[idx] is None):
            dim_name = "{var_name}_dim_{idx}".format(var_name=var_name, idx=idx)
            if len(dims)<idx+1:
                dims.append(dim_name)
            else:
                dims[idx] = dim_name
        dim_name = dims[idx]
        if dim_name not in coords:
            coords[dim_name] = range_(dim_len)
        return coords, dims
            

In [8]:
%timeit generate_dims_coords((10000,100,10,5), 'x')

48.7 µs ± 4.57 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)


In [9]:
%timeit generate_dims_coords_jit((10000,100,10,5), 'x')

36.7 µs ± 4.51 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)


In [10]:
%timeit generate_dims_coords((10,), 'x')

14.5 µs ± 2.3 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)


In [11]:
%timeit generate_dims_coords_jit((10,), 'x')

13 µs ± 944 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)


In [12]:
@numba.njit
def one_de(x):
    return np.atleast_1d(x)

@numba.njit
def two_de(x):
    return np.atleast_2d(x)

In [13]:
def numpy_to_data_array(ary, *, var_name="data", coords=None, dims=None):
    # manage and transform copies
    default_dims = ["chain", "draw"]
    ary = np.atleast_2d(ary)
    n_chains, n_samples, *shape = ary.shape
    if n_chains > n_samples:
        warnings.warn(
            "More chains ({n_chains}) than draws ({n_samples}). "
            "Passed array should have shape (chains, draws, *shape)".format(
                n_chains=n_chains, n_samples=n_samples
            ),
            SyntaxWarning,
        )

    dims, coords = generate_dims_coords(
        shape, var_name, dims=dims, coords=coords, default_dims=default_dims
    )

    # reversed order for default dims: 'chain', 'draw'
    if "draw" not in dims:
        dims = ["draw"] + dims
    if "chain" not in dims:
        dims = ["chain"] + dims

    if "chain" not in coords:
        coords["chain"] = np.arange(n_chains)
    if "draw" not in coords:
        coords["draw"] = np.arange(n_samples)

    # filter coords based on the dims
    coords = {key: xr.IndexVariable((key,), data=coords[key]) for key in dims}
    return xr.DataArray(ary, coords=coords, dims=dims)


def numpy_to_data_array_jit(ary, *, var_name="data", coords=None, dims=None):
    # manage and transform copies
    default_dims = ["chain", "draw"]
    ary = two_de(ary)
    n_chains, n_samples, *shape = ary.shape
    if n_chains > n_samples:
        warnings.warn(
            "More chains ({n_chains}) than draws ({n_samples}). "
            "Passed array should have shape (chains, draws, *shape)".format(
                n_chains=n_chains, n_samples=n_samples
            ),
            SyntaxWarning,
        )

    dims, coords = generate_dims_coords(
        shape, var_name, dims=dims, coords=coords, default_dims=default_dims
    )

    # reversed order for default dims: 'chain', 'draw'
    if "draw" not in dims:
        dims = ["draw"] + dims
    if "chain" not in dims:
        dims = ["chain"] + dims

    if "chain" not in coords:
        coords["chain"] = range_(n_chains)
    if "draw" not in coords:
        coords["draw"] = range_(n_samples)

    # filter coords based on the dims
    coords = {key: xr.IndexVariable((key,), data=coords[key]) for key in dims}
    return xr.DataArray(ary, coords=coords, dims=dims)

In [14]:
data = np.random.randn(10000,100)
linear = np.random.randn(1000000)
small = np.random.randn(100,100)
small_linear = np.random.randn(100)

In [15]:
%timeit numpy_to_data_array(data)

  if sys.path[0] == '':


497 µs ± 82.3 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [16]:
%timeit numpy_to_data_array_jit(data)



448 µs ± 62.1 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [17]:
%timeit numpy_to_data_array(linear)

2.69 ms ± 396 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [18]:
%timeit numpy_to_data_array_jit(linear)

2.09 ms ± 115 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [19]:
%timeit numpy_to_data_array(small)

289 µs ± 5.65 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [20]:
%timeit numpy_to_data_array_jit(small)

316 µs ± 36.7 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [21]:
%timeit numpy_to_data_array(small_linear)

320 µs ± 37.5 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [22]:
%timeit numpy_to_data_array_jit(small_linear)

297 µs ± 10.2 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [23]:
'Very Similar Performance. Up for reconsideration'

'Very Similar Performance. Up for reconsideration'

In [24]:
# Dict to dataset bottleneck ---> numpy_to_ndarray

In [25]:
""""""""""""""""""""""""""""""""""""""""""""""Converters"""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""

'"Converters'

In [26]:
data = np.random.randn(10000,100)

In [27]:
lp = LineProfiler()
wrapper = lp(convert_to_inference_data)
wrapper(data)
lp.print_stats()

Timer unit: 1e-06 s

Total time: 0.002719 s
File: /home/banzee/Desktop/arviz/arviz/data/converters.py
Function: convert_to_inference_data at line 16

Line #      Hits         Time  Per Hit   % Time  Line Contents
    16                                           def convert_to_inference_data(obj, *, group="posterior", coords=None, dims=None, **kwargs):
    17                                               r"""Convert a supported object to an InferenceData object.
    18                                           
    19                                               This function sends `obj` to the right conversion function. It is idempotent,
    20                                               in that it will return arviz.InferenceData objects unchanged.
    21                                           
    22                                               Parameters
    23                                               ----------
    24                                               obj : d



In [28]:
# Bottleneck is dict to dataset. Refer above

In [29]:
""""""""""""""""""""""""""""""""""""""""""""""""""""DATASETS.PY"""""""""""""""""""""""""""""""""""""""""""""""""""""""""""

'"DATASETS.PY'

In [30]:
# Nothing to improve here

In [31]:
"""""""""""""""""""""""""""""""""""""""""""""""""""""io_dict"""""""""""""""""""""""""""""""""""""""""""""""""""""""""'""'

'""io_dict""'

In [32]:
# Bottleneck is dict to dataset. Refer above

In [33]:
"""""""""""""""""""""""""""""""""""""""""io_netcdf"""""""""""""""""""""""""""""""""""""""""""""""""""""""""

'""io_netcdf'

In [34]:
#Bottlenecks---->Inference data and convert_to_inference_data

In [35]:
"""""""""""""""""""""""""""""""""""""""""""""Inference_data"""""""""""""""""""""""""""""""""""""""""""""""""""""""""

'Inference_data'

In [36]:
class InferenceData:
    """Container for accessing netCDF files using xarray."""

    def __init__(self, **kwargs):
        """Initialize InferenceData object from keyword xarray datasets.

        Examples
        --------
        InferenceData(posterior=posterior, prior=prior)

        Parameters
        ----------
        kwargs :
            Keyword arguments of xarray datasets
        """
        self._groups = []
        for key, dataset in kwargs.items():
            if dataset is None:
                continue
            elif not isinstance(dataset, xr.Dataset):
                raise ValueError(
                    "Arguments to InferenceData must be xarray Datasets "
                    '(argument "{}" was type "{}")'.format(key, type(dataset))
                )
            setattr(self, key, dataset)
            self._groups.append(key)

    def __repr__(self):
        """Make string representation of object."""
        return "Inference data with groups:\n\t> {options}".format(
            options="\n\t> ".join(self._groups)
        )

    @staticmethod
    def from_netcdf(filename):
        """Initialize object from a netcdf file.

        Expects that the file will have groups, each of which can be loaded by xarray.

        Parameters
        ----------
        filename : str
            location of netcdf file

        Returns
        -------
        InferenceData object
        """
        groups = {}
        with nc.Dataset(filename, mode="r") as data:
            data_groups = list(data.groups)

        for group in data_groups:
            with xr.open_dataset(filename, group=group) as data:
                groups[group] = data
        return InferenceData(**groups)

    def to_netcdf(self, filename, compress=True):
        """Write InferenceData to file using netcdf4.

        Parameters
        ----------
        filename : str
            Location to write to
        compress : bool
            Whether to compress result. Note this saves disk space, but may make
            saving and loading somewhat slower (default: True).

        Returns
        -------
        str
            Location of netcdf file
        """
        mode = "w"  # overwrite first, then append
        if self._groups:  # check's whether a group is present or not.
            for group in self._groups:
                data = getattr(self, group)
                kwargs = {}
                if compress:
                    kwargs["encoding"] = {var_name: {"zlib": True} for var_name in data.variables}
                data.to_netcdf(filename, mode=mode, group=group, **kwargs)
                data.close()
                mode = "a"
        else:  # creates a netcdf file for an empty InferenceData object.
            empty_netcdf_file = nc.Dataset(filename, mode="w", format="NETCDF4")
            empty_netcdf_file.close()
        return filename

    def __add__(self, other):
        """Concatenate two InferenceData objects."""
        return concat(self, other, copy=True, inplace=False)

    def sel(self, inplace=True, **kwargs):
        """Perform an xarray selection on all groups.

        Loops over all groups to perform Dataset.sel(key=item)
        for every kwarg if key is a dimension of the dataset.
        The selection is performed inplace.

        Parameters
        ----------
        inplace : bool
            If True, modify the InferenceData object inplace, otherwise, return the modified copy.
        **kwargs : mapping
            It must be accepted by Dataset.sel()
        """
        out = self if inplace else deepcopy(self)
        for group in self._groups:
            dataset = getattr(self, group)
            valid_keys = set(kwargs.keys()).intersection(dataset.dims)
            dataset = dataset.sel(**{key: kwargs[key] for key in valid_keys})
            setattr(out, group, dataset)
        if inplace:
            return None
        else:
            return out


@numba.jit(forceobj=True)
def concat(*args, dim=None, copy=True, inplace=False, reset_dim=True):
    """Concatenate InferenceData objects.

    Concatenates over `group`, `chain` or `draw`.
    By default concatenates over unique groups.
    To concatenate over `chain` or `draw` function
    needs identical groups and variables.

    The `variables` in the `data` -group are merged if `dim` are not found.


    Parameters
    ----------
    *args : InferenceData
        Variable length InferenceData list or
        Sequence of InferenceData.
    dim : str, optional
        If defined, concatenated over the defined dimension.
        Dimension which is concatenated. If None, concatenates over
        unique groups.
    copy : bool
        If True, groups are copied to the new InferenceData object.
        Used only if `dim` is None.
    inplace : bool
        If True, merge args to first object.
    reset_dim : bool
        Valid only if dim is not None.

    Returns
    -------
    InferenceData
        A new InferenceData object by default.
        When `inplace==True` merge args to first arg and return `None`
    """
    # pylint: disable=undefined-loop-variable, too-many-nested-blocks
    if len(args) == 0:
        if inplace:
            return
        return InferenceData()

    if len(args) == 1 and isinstance(args[0], Sequence):
        args = args[0]

    # assert that all args are InferenceData
    for i, arg in enumerate(args):
        if not isinstance(arg, InferenceData):
            raise TypeError(
                "Concatenating is supported only"
                "between InferenceData objects. Input arg {} is {}".format(i, type(arg))
            )

    if dim is not None and dim.lower() not in {"group", "chain", "draw"}:
        msg = "Invalid `dim`: {}. Valid `dim` are {}".format(dim, '{"group", "chain", "draw"}')
        raise TypeError(msg)
    dim = dim.lower() if dim is not None else dim

    if len(args) == 1 and isinstance(args[0], InferenceData):
        if inplace:
            return None
        else:
            if copy:
                return deepcopy(args[0])
            else:
                return args[0]

    current_time = str(datetime.now())

    if not inplace:
        # Keep order for python 3.5
        inference_data_dict = OrderedDict()

    if dim is None:
        arg0 = args[0]
        arg0_groups = ccopy(arg0._groups)
        args_groups = dict()
        # check if groups are independent
        # Concat over unique groups
        for arg in args[1:]:
            for group in arg._groups:
                if group in args_groups or group in arg0_groups:
                    msg = (
                        "Concatenating overlapping groups is not supported unless `dim` is defined."
                    )
                    msg += " Valid dimensions are `chain` and `draw`."
                    raise TypeError(msg)
            group_data = getattr(arg, group)
            args_groups[group] = deepcopy(group_data) if copy else group_data
        # add arg0 to args_groups if inplace is False
        if not inplace:
            for group in arg0_groups:
                group_data = getattr(arg0, group)
                args_groups[group] = deepcopy(group_data) if copy else group_data

        basic_order = [
            "posterior",
            "posterior_predictive",
            "sample_stats",
            "prior",
            "prior_predictive",
            "sample_stats_prior",
            "observed_data",
        ]
        other_groups = [group for group in args_groups if group not in basic_order]

        for group in basic_order + other_groups:
            if group not in args_groups:
                continue
            if inplace:
                arg0._groups.append(group)
                setattr(arg0, group, args_groups[group])
            else:
                inference_data_dict[group] = args_groups[group]
        if inplace:
            other_groups = [
                group for group in arg0_groups if group not in basic_order
            ] + other_groups
            sorted_groups = [group for group in basic_order + other_groups if group in arg0._groups]
            setattr(arg0, "_groups", sorted_groups)
    else:
        arg0 = args[0]
        arg0_groups = arg0._groups
        for arg in args[1:]:
            for group0 in arg0_groups:
                if group0 not in arg._groups:
                    if group0 == "observed_data":
                        continue
                    msg = "Mismatch between the groups."
                    raise TypeError(msg)
            for group in arg._groups:
                if group != "observed_data":
                    # assert that groups are equal
                    if group not in arg0_groups:
                        msg = "Mismatch between the groups."
                        raise TypeError(msg)

                    # assert that variables are equal
                    group_data = getattr(arg, group)
                    group_vars = group_data.data_vars

                    if not inplace and group in inference_data_dict:
                        group0_data = inference_data_dict[group]
                    else:
                        group0_data = getattr(arg0, group)
                    group0_vars = group0_data.data_vars

                    for var in group0_vars:
                        if var not in group_vars:
                            msg = "Mismatch between the variables."
                            raise TypeError(msg)

                    for var in group_vars:
                        if var not in group0_vars:
                            msg = "Mismatch between the variables."
                            raise TypeError(msg)
                        var_dims = getattr(group_data, var).dims
                        var0_dims = getattr(group0_data, var).dims
                        if var_dims != var0_dims:
                            msg = "Mismatch between the dimensions."
                            raise TypeError(msg)

                        if dim not in var_dims or dim not in var0_dims:
                            msg = "Dimension {} missing.".format(dim)
                            raise TypeError(msg)

                    # xr.concat
                    concatenated_group = xr.concat((group_data, group0_data), dim=dim)
                    if reset_dim:
                        concatenated_group[dim] = range(concatenated_group[dim].size)

                    # handle attrs
                    if hasattr(group0_data, "attrs"):
                        group0_attrs = deepcopy(getattr(group0_data, "attrs"))
                    else:
                        group0_attrs = OrderedDict()

                    if hasattr(group_data, "attrs"):
                        group_attrs = getattr(group_data, "attrs")
                    else:
                        group_attrs = dict()

                    # gather attrs results to group0_attrs
                    for attr_key, attr_values in group_attrs.items():
                        group0_attr_values = group0_attrs.get(attr_key, None)
                        equality = attr_values == group0_attr_values
                        if hasattr(equality, "__iter__"):
                            equality = np.all(equality)
                        if equality:
                            continue
                        # handle special cases:
                        if attr_key in ("created_at", "previous_created_at"):
                            # check the defaults
                            if not hasattr(group0_attrs, "previous_created_at"):
                                group0_attrs["previous_created_at"] = []
                                if group0_attr_values is not None:
                                    group0_attrs["previous_created_at"].append(group0_attr_values)
                            # check previous values
                            if attr_key == "previous_created_at":
                                if not isinstance(attr_values, list):
                                    attr_values = [attr_values]
                                group0_attrs["previous_created_at"].extend(attr_values)
                                continue
                            # update "created_at"
                            if group0_attr_values != current_time:
                                group0_attrs[attr_key] = current_time
                            group0_attrs["previous_created_at"].append(attr_values)

                        elif attr_key in group0_attrs:
                            combined_key = "combined_{}".format(attr_key)
                            if combined_key not in group0_attrs:
                                group0_attrs[combined_key] = [group0_attr_values]
                            group0_attrs[combined_key].append(attr_values)
                        else:
                            group0_attrs[attr_key] = attr_values
                    # update attrs
                    setattr(concatenated_group, "attrs", group0_attrs)

                    if inplace:
                        setattr(arg0, group, concatenated_group)
                    else:
                        inference_data_dict[group] = concatenated_group
                else:
                    # observed_data
                    if group not in arg0_groups:
                        setattr(arg0, group, deepcopy(group_data) if copy else group_data)
                        arg0._groups.append(group)
                        continue

                    # assert that variables are equal
                    group_data = getattr(arg, group)
                    group_vars = group_data.data_vars

                    group0_data = getattr(arg0, group)
                    if not inplace:
                        group0_data = deepcopy(group0_data)
                    group0_vars = group0_data.data_vars

                    for var in group_vars:
                        if var not in group0_vars:
                            var_data = getattr(group_data, var)
                            arg0.observed_data[var] = var_data
                        else:
                            var_data = getattr(group_data, var)
                            var0_data = getattr(group0_data, var)
                            if dim in var_data.dims and dim in var0_data.dims:
                                concatenated_var = xr.concat((group_data, group0_data), dim=dim)
                                group0_data[var] = concatenated_var

                    # handle attrs
                    if hasattr(group0_data, "attrs"):
                        group0_attrs = getattr(group0_data, "attrs")
                    else:
                        group0_attrs = OrderedDict()

                    if hasattr(group_data, "attrs"):
                        group_attrs = getattr(group_data, "attrs")
                    else:
                        group_attrs = dict()

                    # gather attrs results to group0_attrs
                    for attr_key, attr_values in group_attrs.items():
                        group0_attr_values = group0_attrs.get(attr_key, None)
                        equality = attr_values == group0_attr_values
                        if hasattr(equality, "__iter__"):
                            equality = np.all(equality)
                        if equality:
                            continue
                        # handle special cases:
                        if attr_key in ("created_at", "previous_created_at"):
                            # check the defaults
                            if not hasattr(group0_attrs, "previous_created_at"):
                                group0_attrs["previous_created_at"] = []
                                if group0_attr_values is not None:
                                    group0_attrs["previous_created_at"].append(group0_attr_values)
                            # check previous values
                            if attr_key == "previous_created_at":
                                if not isinstance(attr_values, list):
                                    attr_values = [attr_values]
                                group0_attrs["previous_created_at"].extend(attr_values)
                                continue
                            # update "created_at"
                            if group0_attr_values != current_time:
                                group0_attrs[attr_key] = current_time
                            group0_attrs["previous_created_at"].append(attr_values)

                        elif attr_key in group0_attrs:
                            combined_key = "combined_{}".format(attr_key)
                            if combined_key not in group0_attrs:
                                group0_attrs[combined_key] = [group0_attr_values]
                            group0_attrs[combined_key].append(attr_values)

                        else:
                            group0_attrs[attr_key] = attr_values
                    # update attrs
                    setattr(group0_data, "attrs", group0_attrs)

                    if inplace:
                        setattr(arg0, group, group0_data)
                    else:
                        inference_data_dict[group] = group0_data

    return None if inplace else InferenceData(**inference_data_dict)

In [37]:
# from emcee jit the loop and stack b/w lines 180 and 200

In [38]:
# io pyro jit the loops in var names

In [39]:
c = np.random.randn(10)

In [40]:
@numba.njit
def ex_jit(x):
    return np.expand_dims(x, axis=0)


In [41]:
%timeit  ex_jit(c)

1.06 µs ± 152 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)


In [42]:
%timeit np.expand_dims(c, axis=0)

2.1 µs ± 114 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)


In [43]:
# njit expand dims in io_tfp and everywhere else

In [44]:
data = np.random.randn(10000,1000)

In [45]:
%timeit np.atleast_1d(data)

1.44 µs ± 229 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)


In [46]:
%timeit one_de(data)

473 ns ± 18.5 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)


In [47]:
%timeit np.atleast_2d(data)

1.3 µs ± 15 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)


In [48]:
%timeit two_de(data)

505 ns ± 40.6 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)


In [49]:
# for io_pymc3 jit the expand-dims and at-least one d methods. Loop jitting can also be useful

In [50]:
# for frompystan work with the last 4 methods

In [51]:
from pystan import stan

In [52]:
model_code = '''
... parameters {
...   real y[2];
... }
... model {
...   y[1] ~ normal(0, 1);
...   y[2] ~ double_exponential(0, 2);
... }'''
fit1 = stan(model_code=model_code, iter=100)

INFO:pystan:COMPILING THE C++ CODE FOR MODEL anon_model_42c0ed8065b95a4b695f212f17d17458 NOW.


In [53]:
fit1

Inference for Stan model: anon_model_42c0ed8065b95a4b695f212f17d17458.
4 chains, each with iter=100; warmup=50; thin=1; 
post-warmup draws per chain=50, total post-warmup draws=200.

       mean se_mean     sd   2.5%    25%    50%    75%  97.5%  n_eff   Rhat
y[0]  -0.02    0.07    1.0  -2.06  -0.75  -0.09   0.66   2.17    200   0.99
y[1]  -0.39    0.37   3.02  -8.11  -1.95  -0.12   1.45   5.08     65    1.0
lp__  -1.63     0.2   1.25  -4.53  -2.36  -1.29  -0.71  -0.13     41   1.06

Samples were drawn using NUTS at Sun Jul 28 13:04:36 2019.
For each parameter, n_eff is a crude measure of effective sample size,
and Rhat is the potential scale reduction factor on split chains (at 
convergence, Rhat=1).

In [99]:
@numba.njit
def check_dim(x):
    return np.prod(np.asarray(x))

In [100]:
def get_draws(fit, variables=None, ignore=None):
    """Extract draws from PyStan fit."""
    if ignore is None:
        ignore = []
    if fit.mode == 1:
        msg = "Model in mode 'test_grad'. Sampling is not conducted."
        raise AttributeError(msg)

    if fit.mode == 2 or fit.sim.get("samples") is None:
        msg = "Fit doesn't contain samples."
        raise AttributeError(msg)

    dtypes = infer_dtypes(fit)

    if variables is None:
        variables = fit.sim["pars_oi"]
    elif isinstance(variables, str):
        variables = [variables]
    variables = list(variables)

    for var, dim in zip(fit.sim["pars_oi"], fit.sim["dims_oi"]):
        if var in variables and np.prod(dim) == 0:
            del variables[variables.index(var)]

    ndraws = [s - w for s, w in zip(fit.sim["n_save"], fit.sim["warmup2"])]
    nchain = len(fit.sim["samples"])

    # check if the values are in 0-based (<=2.17) or 1-based indexing (>=2.18)
    shift = 1
    if any(dim and np.prod(dim) != 0 for dim in fit.sim["dims_oi"]):
        # choose variable with lowest number of dims > 1
        par_idx = min(
            (dim, i) for i, dim in enumerate(fit.sim["dims_oi"]) if (dim and np.prod(dim) != 0)
        )[1]
        offset = int(sum(map(np.product, fit.sim["dims_oi"][:par_idx])))
        par_offset = int(np.product(fit.sim["dims_oi"][par_idx]))
        par_keys = fit.sim["fnames_oi"][offset : offset + par_offset]
        shift = len(par_keys)
        for item in par_keys:
            _, shape = item.replace("]", "").split("[")
            shape_idx_min = min(int(shape_value) for shape_value in shape.split(","))
            if shape_idx_min < shift:
                shift = shape_idx_min
        # If shift is higher than 1, this will probably mean that Stan
        # has implemented sparse structure (saves only non-zero parts),
        # but let's hope that dims are still corresponding the full shape
        shift = int(min(shift, 1))

    var_keys = OrderedDict((var, []) for var in fit.sim["pars_oi"])
    for key in fit.sim["fnames_oi"]:
        var, *tails = key.split("[")
        loc = [Ellipsis]
        for tail in tails:
            loc = []
            for i in tail[:-1].split(","):
                loc.append(int(i) - shift)
        var_keys[var].append((key, loc))

    shapes = dict(zip(fit.sim["pars_oi"], fit.sim["dims_oi"]))

    variables = [var for var in variables if var not in ignore]

    data = OrderedDict()

    for var in variables:
        if var in data:
            continue
        keys_locs = var_keys.get(var, [(var, [Ellipsis])])
        shape = shapes.get(var, [])
        dtype = dtypes.get(var)

        ndraw = max(ndraws)
        ary_shape = [nchain, ndraw] + shape
        ary = np.empty(ary_shape, dtype=dtype, order="F")
        for chain, (pyholder, ndraw) in enumerate(zip(fit.sim["samples"], ndraws)):
            axes = [chain, slice(None)]
            for key, loc in keys_locs:
                ary_slice = tuple(axes + loc)
                ary[ary_slice] = pyholder.chains[key][-ndraw:]
        data[var] = ary
    return data

def get_draws_jit(fit, variables=None, ignore=None):
    """Extract draws from PyStan fit."""
    if ignore is None:
        ignore = []
    if fit.mode == 1:
        msg = "Model in mode 'test_grad'. Sampling is not conducted."
        raise AttributeError(msg)

    if fit.mode == 2 or fit.sim.get("samples") is None:
        msg = "Fit doesn't contain samples."
        raise AttributeError(msg)

    dtypes = infer_dtypes(fit)

    if variables is None:
        variables = fit.sim["pars_oi"]
    elif isinstance(variables, str):
        variables = [variables]
    variables = list(variables)

    for var, dim in zip(fit.sim["pars_oi"], fit.sim["dims_oi"]):
        if var in variables and np.prod(dim) == 0:
            del variables[variables.index(var)]

    ndraws = [s - w for s, w in zip(fit.sim["n_save"], fit.sim["warmup2"])]
    nchain = len(fit.sim["samples"])

    # check if the values are in 0-based (<=2.17) or 1-based indexing (>=2.18)
    shift = 1
    if any(dim and np.prod(dim) != 0 for dim in fit.sim["dims_oi"]):
        # choose variable with lowest number of dims > 1
        par_idx = min(
            (dim, i) for i, dim in enumerate(fit.sim["dims_oi"]) if (dim and np.prod(dim) != 0)
        )[1]
        offset = int(sum(map(np.product, fit.sim["dims_oi"][:par_idx])))
        par_offset = int(np.product(fit.sim["dims_oi"][par_idx]))
        par_keys = fit.sim["fnames_oi"][offset : offset + par_offset]
        shift = len(par_keys)
        for item in par_keys:
            _, shape = item.replace("]", "").split("[")
            shape_idx_min = min(int(shape_value) for shape_value in shape.split(","))
            if shape_idx_min < shift:
                shift = shape_idx_min
        # If shift is higher than 1, this will probably mean that Stan
        # has implemented sparse structure (saves only non-zero parts),
        # but let's hope that dims are still corresponding the full shape
        shift = int(min(shift, 1))

    var_keys = OrderedDict((var, []) for var in fit.sim["pars_oi"])
    for key in fit.sim["fnames_oi"]:
        var, *tails = key.split("[")
        loc = [Ellipsis]
        for tail in tails:
            loc = []
            for i in tail[:-1].split(","):
                loc.append(int(i) - shift)
        var_keys[var].append((key, loc))

    shapes = dict(zip(fit.sim["pars_oi"], fit.sim["dims_oi"]))

    variables = [var for var in variables if var not in ignore]

    data = OrderedDict()
    for var in variables:
        if var in data:
            continue
        keys_locs = var_keys.get(var, [(var, [Ellipsis])])
        shape = shapes.get(var, [])
        dtype = dtypes.get(var)

        ndraw = max(ndraws)
        ary_shape = [nchain, ndraw] + shape
        ary = np.empty(ary_shape, dtype=dtype, order="F")
        for chain, (pyholder, ndraw) in enumerate(zip(fit.sim["samples"], ndraws)):
            axes = [chain, slice(None)]
            for key, loc in keys_locs:
                ary_slice = tuple(axes + loc)
                ary[ary_slice] = pyholder.chains[key][-ndraw:]
        data[var] = ary
    return data

@numba.jit(forceobj=True)
def loop_lifter(data,variables,var_keys,shapes,dtypes,ndraws,fit,nchain):
    for var in variables:
        if var in data:
            continue
        keys_locs = var_keys.get(var, [(var, [Ellipsis])])
        shape = shapes.get(var, [])
        dtype = dtypes.get(var)

        ndraw = max(ndraws)
        ary_shape = [nchain, ndraw] + shape
        ary = np.empty(ary_shape, dtype=dtype, order="F")
        for chain, (pyholder, ndraw) in enumerate(zip(fit.sim["samples"], ndraws)):
            axes = [chain, slice(None)]
            for key, loc in keys_locs:
                ary_slice = tuple(axes + loc)
                ary[ary_slice] = pyholder.chains[key][-ndraw:]
        data[var] = ary
    return data
    
def infer_dtypes(fit, model=None):
    """Infer dtypes from Stan model code.

    Function strips out generated quantities block and searchs for `int`
    dtypes after stripping out comments inside the block.
    """
    pattern_remove_comments = re.compile(
        r'//.*?$|/\*.*?\*/|\'(?:\\.|[^\\\'])*\'|"(?:\\.|[^\\"])*"', re.DOTALL | re.MULTILINE
    )
    stan_integer = r"int"
    stan_limits = r"(?:\<[^\>]+\>)*"  # ignore group: 0 or more <....>
    stan_param = r"([^;=\s\[]+)"  # capture group: ends= ";", "=", "[" or whitespace
    stan_ws = r"\s*"  # 0 or more whitespace
    pattern_int = re.compile(
        "".join((stan_integer, stan_ws, stan_limits, stan_ws, stan_param)), re.IGNORECASE
    )
    if model is None:
        stan_code = fit.get_stancode()
        model_pars = fit.model_pars
    else:
        stan_code = model.program_code
        model_pars = fit.param_names
    # remove deprecated comments
    stan_code = "\n".join(
        line if "#" not in line else line[: line.find("#")] for line in stan_code.splitlines()
    )
    stan_code = re.sub(pattern_remove_comments, "", stan_code)
    stan_code = stan_code.split("generated quantities")[-1]
    dtypes = re.findall(pattern_int, stan_code)
    dtypes = {item.strip(): "int" for item in dtypes if item.strip() in model_pars}
    return dtypes



In [101]:
%timeit get_draws(fit=fit1)

210 µs ± 22.4 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [110]:
#no point jitting

In [56]:
lp = LineProfiler()
wrapper = lp(get_draws)
wrapper(fit1)
lp.print_stats()

Timer unit: 1e-06 s

Total time: 0.099299 s
File: <ipython-input-54-3c8d3439d7c5>
Function: get_draws at line 1

Line #      Hits         Time  Per Hit   % Time  Line Contents
     1                                           def get_draws(fit, variables=None, ignore=None):
     2                                               """Extract draws from PyStan fit."""
     3         1          8.0      8.0      0.0      if ignore is None:
     4         1          6.0      6.0      0.0          ignore = []
     5         1          7.0      7.0      0.0      if fit.mode == 1:
     6                                                   msg = "Model in mode 'test_grad'. Sampling is not conducted."
     7                                                   raise AttributeError(msg)
     8                                           
     9         1          8.0      8.0      0.0      if fit.mode == 2 or fit.sim.get("samples") is None:
    10                                                   msg = "Fit

In [104]:
@numba.njit
def check_dim(x,y):
    return np.product(x,y)

In [107]:
x=np.random.randn(1000000)
y=np.random.randn(1000000)

In [111]:
######################################cmdstan########################################

In [287]:
def _unpack_dataframes(dfs):
    col_groups = defaultdict(list)
    columns = dfs[0].columns
    for col in columns:
        key, *loc = col.split(".")
        loc = tuple(int(i) - 1 for i in loc)
        col_groups[key].append((col, loc))

    chains = len(dfs)
    draws = len(dfs[0])
    sample = {}
    for key, cols_locs in col_groups.items():
        ndim = np.array([loc for _, loc in cols_locs]).max(0) + 1
        sample[key] = np.full((chains, draws, *ndim), np.nan)
        for col, loc in cols_locs:
            for chain_id, df in enumerate(dfs):
                draw = df[col].values
                if loc == ():
                    sample[key][chain_id, :] = draw
                else:
                    axis1_all = range(sample[key].shape[1])
                    slicer = (chain_id, axis1_all, *loc)
                    sample[key][slicer] = draw
    return sample

def _unpack_dataframes_jit(dfs):
    col_groups = defaultdict(list)
    columns = dfs[0].columns
    for col in columns:
        key, *loc = col.split(".")
        loc = tuple(int(i) - 1 for i in loc)
        col_groups[key].append((col, loc))
    sample = looper_jit(col_groups,columns,key,loc,)
    return sample

@numba.jit(nopython=False,forceobj=True)
def looper_jit(col_groups,columns,key,loc):
    chains = len(dfs)
    draws = len(dfs[0])
    sample = {}

    for key, cols_locs in col_groups.items():
        locs = []
        for _,loc in cols_locs:
            locs.append(loc)
        
        ndim = np.array(locs).max(0) + 1
        sample[key] = np.full((chains, draws, *ndim),np.nan)
        for col, loc in cols_locs:
            for chain_id, df in enumerate(dfs):
                draw = df[col].values
                if loc == ():
                    sample[key][chain_id, :] = draw
                else:
                    axis1_all = range(sample[key].shape[1])
                    slicer = (chain_id, axis1_all, *loc)
                    sample[key][slicer] = draw
    return sample

@numba.njit
def full(shape):
    return np.full(shape,np.nan)

In [288]:
x = pd.DataFrame({'1234': np.random.randn(10000),
                 })

In [289]:
dfs = [x]

In [290]:
%timeit _unpack_dataframes(dfs=dfs)

88 µs ± 2.38 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)


In [291]:
_unpack_dataframes_jit(dfs)

ValueError: operands could not be broadcast together with shapes (2,) (0,) 

In [284]:
# jitting is throwing erros

In [293]:
def _unpack_frame(data, columns, valid_cols):
    draws, chains, *_ = data.shape

    column_groups = defaultdict(list)
    column_locs = defaultdict(list)
    for i, col in enumerate(columns):
        col_base, *col_tail = col.split(".")
        if len(col_tail):
            column_groups[col_base].append(tuple(map(int, col_tail)))
        column_locs[col_base].append(i)
    dims = {}
    for colname, col_dims in column_groups.items():
        dims[colname] = tuple(np.array(col_dims).max(0))
    sample = {}
    valid_base_cols = []
    for col in valid_cols:
        base_col, *_ = col.split(".")
        if base_col not in valid_base_cols:
            valid_base_cols.append(base_col)

    # extract each wanted parameter to ndarray with correct shape
    for key in valid_base_cols:
        ndim = dims.get(key, None)
        shape_location = column_groups.get(key, None)
        if ndim is not None:
            sample[key] = np.full((chains, draws, *ndim), np.nan)
        if shape_location is None:
            # reorder draw, chain -> chain, draw
            i, = column_locs[key]
            sample[key] = np.swapaxes(data[..., i], 0, 1)
        else:
            for i, shape_loc in zip(column_locs[key], shape_location):
                # location to insert extracted array
                shape_loc = tuple([Ellipsis] + [j - 1 for j in shape_loc])
                # reorder draw, chain -> chain, draw and insert to ndarray
                sample[key][shape_loc] = np.swapaxes(data[..., i], 0, 1)
    return sample

In [None]:
#nothing  to see here