Skip to content

Commit

Permalink
Merge pull request #327 from CUQI-DTU/allow_lazy_import_of_arviz
Browse files Browse the repository at this point in the history
Allow lazy import of arviz
  • Loading branch information
nabriis committed Dec 7, 2023
2 parents f9fe3a6 + ef42e8f commit 2a8abe2
Showing 1 changed file with 15 additions and 1 deletion.
16 changes: 15 additions & 1 deletion cuqi/samples/_samples.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,17 @@
from cuqi.array import CUQIarray
from cuqi.utilities import force_ndarray
from copy import copy
import arviz # Plotting tool
from numbers import Number

try:
import arviz # Plotting tool
except ImportError:
arviz = None

def _check_for_arviz():
if arviz is None:
raise ImportError("The arviz package is required for this functionality. Please install arviz using `pip install arviz`.")


class Samples(object):
"""
Expand Down Expand Up @@ -598,6 +606,7 @@ def plot_autocorrelation(self, variable_indices=None, max_lag=None, combined=Tru
datadict = self.to_arviz_inferencedata(variable_indices)

# Plot autocorrelation using arviz
_check_for_arviz()
axis = arviz.plot_autocorr(datadict, max_lag=max_lag, combined=combined, **kwargs)

return axis
Expand Down Expand Up @@ -664,6 +673,7 @@ def plot_trace(self, variable_indices=None, exact=None, combined=True, tight_lay
kwargs["lines"] = tuple([(par_names[i], {}, exact[i]) for i in range(len(par_names))])

# Plot using arviz
_check_for_arviz()
ax = arviz.plot_trace(datadict, combined=combined, **kwargs)

# Improves subplot spacing
Expand Down Expand Up @@ -704,6 +714,7 @@ def plot_pair(self, variable_indices=None, kind="scatter", marginals=False, **kw
# Convert to arviz InferenceData object
datadict = self.to_arviz_inferencedata(variable_indices)

_check_for_arviz()
ax = arviz.plot_pair(datadict, kind=kind, marginals=marginals, **kwargs)

return ax
Expand Down Expand Up @@ -751,6 +762,7 @@ def compute_ess(self, **kwargs):
-------
Numpy array with effective sample size for each variable.
"""
_check_for_arviz()
ESS_xarray = arviz.ess(self.to_arviz_inferencedata(), **kwargs)
ESS_items = ESS_xarray.items()
ESS = np.empty(len(ESS_items))
Expand Down Expand Up @@ -806,6 +818,7 @@ def compute_rhat(self, chains, **kwargs):
datadict = dict(zip(variables,samples))

# Compute rhat
_check_for_arviz()
RHAT_xarray = arviz.rhat(datadict, **kwargs)

# Convert to numpy array
Expand Down Expand Up @@ -844,6 +857,7 @@ def plot_violin(self, variable_indices=None, **kwargs):
datadict = self.to_arviz_inferencedata(variable_indices)

# Plot using arviz
_check_for_arviz()
ax = arviz.plot_violin(datadict, **kwargs)

return ax

0 comments on commit 2a8abe2

Please sign in to comment.