Skip to content

Commit

Permalink
Adding maximum posterior points for v0.26.0
Browse files Browse the repository at this point in the history
  • Loading branch information
Samreay committed Apr 26, 2018
1 parent a33baa9 commit 6cbd912
Show file tree
Hide file tree
Showing 15 changed files with 674 additions and 163 deletions.
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,11 @@ post, it can be solved by explicitly install the `matplotlib` dependency `dvipng

### Update History

##### 0.26.0
* Adding ability to pass in a power to raise the surface to for each chain.
* Adding methods to retrieve the maximum posterior point: `Analysis.get_max_posteriors`
* Adding ability to plot maximum posterior points. Can control `marker_size`, `marker_style`, `marker_alpha`, and whether to plot contours, points or both.
* Finishing migration of configuration options you can specify when adding chains rather than configuring all chains with `configure`.
##### 0.25.2
* (Attempting to) enable fully automated releases to Github, PyPI, Zenodo and conda.

Expand Down
53 changes: 53 additions & 0 deletions chainconsumer/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@


class Analysis(object):

summaries = ["max", "mean", "cumulative", "max_symmetric", "max_shortest", "max_central"]

def __init__(self, parent):
self.parent = parent
self._logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -143,6 +146,54 @@ def get_summary(self, squeeze=True, parameters=None, chains=None):
return results[0]
return results

def get_max_posteriors(self, parameters=None, squeeze=True, chains=None):
""" Gets the maximum posterior point in parameter space from the passed parameters.
Requires the chains to have set `posterior` values.
Parameters
----------
parameters : str|list[str]
The parameters to find
squeeze : bool, optional
Squeeze the summaries. If you only have one chain, squeeze will not return
a length one list, just the single summary. If this is false, you will
get a length one list.
chains : list[int|str], optional
A list of the chains to get a summary of.
Returns
-------
list of two-tuples
One entry per chain, two-tuple represents the max-likelihood coordinate
"""

results = []
if chains is None:
chains = self.parent.chains
else:
if isinstance(chains, (int, str)):
chains = [chains]
chains = [self.parent.chains[self.parent._get_chain(c)] for c in chains]

if isinstance(parameters, str):
parameters = [parameters]

for chain in chains:
if chain.posterior_max_index is None:
results.append(None)
continue
res = {}
params_to_find = parameters if parameters is not None else chain.parameters
for p in params_to_find:
if p in chain.parameters:
res[p] = chain.posterior_max_params[p]
results.append(res)

if squeeze and len(results) == 1:
return results[0]
return results

def get_parameter_summary(self, chain, parameter):
# Ensure config has been called so we get the statistics set in config
if not self.parent._configured:
Expand Down Expand Up @@ -262,6 +313,8 @@ def _get_smoothed_histogram(self, chain, parameter):
bins, smooth = get_smoothed_bins(smooth, bins, data, chain.weights)

hist, edges = np.histogram(data, bins=bins, normed=True, weights=chain.weights)
if chain.power is not None:
hist = hist ** chain.power
edge_centers = 0.5 * (edges[1:] + edges[:-1])
xs = np.linspace(edge_centers[0], edge_centers[-1], 10000)

Expand Down
87 changes: 80 additions & 7 deletions chainconsumer/chain.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,46 @@
import logging
import numpy as np

from .colors import Colors
from .analysis import Analysis


class Chain(object):

colors = Colors() # Static colors object to do color mapping

def __init__(self, chain, parameters, name, weights=None, posterior=None, walkers=None,
grid=False, num_free_params=None, num_eff_data_points=None, color=None, linewidth=None,
linestyle=None, kde=None, shade_alpha=None):
grid=False, num_free_params=None, num_eff_data_points=None, power=None,
statistics="max", color=None, linestyle=None, linewidth=None, cloud=None,
shade=None, shade_alpha=None, shade_gradient=None, bar_shade=None,
bins=None, kde=None, smooth=None, color_params=None, plot_color_params=None,
cmap=None, num_cloud=None, plot_contour=True, plot_point=False, marker_style=None,
marker_size=None, marker_alpha=None):
self.chain = chain
self.parameters = parameters
self.name = name

self.posterior_max_index = None
self.posterior_max_params = {}

if weights is None:
weights = np.ones(chain.shape[0])
weights = weights.squeeze()

if posterior is not None:
posterior = posterior.squeeze()
self.posterior_max_index = np.argmax(posterior)
for i, p in enumerate(parameters):
self.posterior_max_params[p] = chain[self.posterior_max_index, i]

self.weights = weights
self.posterior = posterior
self.walkers = walkers
self.grid = grid
self.num_free_params = num_free_params
self.num_eff_data_points = num_eff_data_points
self.power = power

self._logger = logging.getLevelName(self.__class__.__name__)

# Storing config overrides
Expand All @@ -35,9 +53,64 @@ def __init__(self, chain, parameters, name, weights=None, posterior=None, walker
self.summaries = {}
self.config = {}

self.configure(statistics=statistics, color=color, linestyle=linestyle,
linewidth=linewidth, cloud=cloud, shade=shade, shade_alpha=shade_alpha,
shade_gradient=shade_gradient, bar_shade=bar_shade, bins=bins,
kde=kde, smooth=smooth, color_params=color_params,
plot_color_params=plot_color_params, cmap=cmap, num_cloud=num_cloud,
plot_contour=plot_contour, plot_point=plot_point, marker_style=marker_style,
marker_size=marker_size, marker_alpha=marker_alpha)
self.validate_chain()
self.validated_params = set()

def configure(self, statistics=None, color=None, linestyle=None, linewidth=None, cloud=None,
shade=None, shade_alpha=None, shade_gradient=None, bar_shade=None,
bins=None, kde=None, smooth=None, color_params=None, plot_color_params=None,
cmap=None, num_cloud=None, marker_style=None, marker_size=None, marker_alpha=None,
plot_contour=True, plot_point=False):

if statistics is not None:
assert isinstance(statistics, str), "statistics should be a string"
assert statistics in list(Analysis.summaries), \
"statistics %s not recognised. Should be in %s" % (statistics, Analysis.summaries)
self.config["statistics"] = statistics

if color is not None:
color = self.colors.format(color)
self.config["color"] = color

# See I wish I didnt have to do this, but I get too many issues raised when people
# pass in the weirdest stuff and expect it to work.
self._validate_config("linestyle", linestyle, str)
self._validate_config("linewidth", linewidth, int, float)
self._validate_config("cloud", cloud, bool)
self._validate_config("shade", shade, bool)
self._validate_config("shade_alpha", shade_alpha, int, float)
self._validate_config("shade_gradient", shade_gradient, int, float)
self._validate_config("bar_shade", bar_shade, bool)
self._validate_config("bins", bins, int, float)
self._validate_config("kde", kde, int, float, bool)
self._validate_config("smooth", smooth, int, float, bool)
self._validate_config("color_params", color_params, str)
self._validate_config("plot_color_params", plot_color_params, bool)
self._validate_config("cmap", cmap, str)
self._validate_config("num_cloud", num_cloud, int, float)
self._validate_config("marker_style", marker_style, str)
self._validate_config("marker_size", marker_size, int, float)
self._validate_config("marker_alpha", marker_alpha, int, float)
self._validate_config("plot_contour", plot_contour, bool)
self._validate_config("plot_point", plot_point, bool)

def update_unset_config(self, name, value):
if self.config.get(name) is None:
self.config[name] = value

def _validate_config(self, name, value, *types):
if value is not None:
assert isinstance(value, tuple(types)), \
"%s, which is %s, should be type of: %s" % (name, value, " or ".join([t.__name__ for t in types]))
self.config[name] = value

def validate_chain(self):
# So many people request help when the pass in junk data without realising it.
# Let's try and flag this as quickly as we can.
Expand All @@ -62,7 +135,7 @@ def validate_chain(self):
if self.posterior is not None:
assert len(self.posterior.shape) == 1, "posterior should be a 1D array, have instead %s" % str(self.posterior.shape)
assert self.posterior.size == self.chain.shape[0], "Chain %s has %d steps but %d log-posterior values" % \
(self.name, self.posterior.size, self.chain.shape[0])
(self.name, self.chain.shape[0], self.posterior.size)
assert np.all(np.isfinite(self.posterior)), "Chain %s has NaN or inf in the log-posterior" % self.name
if self.num_free_params is not None:
assert isinstance(self.num_free_params, (int, float)), \
Expand All @@ -75,10 +148,10 @@ def validate_chain(self):
assert np.isfinite(self.num_eff_data_points), "num_eff_data_points is either infinite or NaN"
assert self.num_eff_data_points > 0, "num_eff_data_points must be positive"

def reset_config(self):
self.config = {}
self.summaries = {}
self.validated_params = set()
# def reset_config(self):
# self.config = {}
# self.summaries = {}
# self.validated_params = set()

def get_summary(self, param, callback):
if param in self.summaries:
Expand Down

0 comments on commit 6cbd912

Please sign in to comment.