Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/examples/plot_5_emcee_arviz_numpyro.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
# Using external samples easily
# Using external samples

`emcee`, `arviz`, and `numpyro` are all popular MCMC packages. ChainConsumer
provides class methods to turn results from these packages into chains efficiently.
Expand Down
79 changes: 79 additions & 0 deletions docs/examples/plot_7_multimodal_chains.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
"""
# Multimodal distributions

`ChainConsumer` can handle cases where the distributions of your chains are multimodal.
"""

import numpy as np
import pandas as pd

from chainconsumer import Chain, ChainConsumer
from chainconsumer.statistics import SummaryStatistic

# %%
# First, let's build some dummy data

rng = np.random.default_rng(42)
size = 60_000

eta = rng.normal(loc=0.0, scale=0.8, size=size)

phi = np.asarray(
[rng.gamma(shape=2.5, scale=0.4, size=size // 2) - 3.0, 3.0 - rng.gamma(shape=5.0, scale=0.35, size=(size // 2))]
).flatten()

rng.shuffle(phi)

df = pd.DataFrame({"eta": eta, "phi": phi})

# %%
# To build a multimodal chain, you simply have to pass `multimodal=True` when building the chain. To work, it requires
# you to specify `SummaryStatistic.HDI` as the summary statistic.

chain_multimodal = Chain(
samples=df.copy(),
name="posterior-multimodal",
statistics=SummaryStatistic.HDI,
multimodal=True, # <- Here
)

# %%
# Now, if you add this `Chain` to a plotter, it will try to look for sub-intervals and display them.

cc = ChainConsumer()
cc.add_chain(chain_multimodal)
fig = cc.plotter.plot()

# %%
# Let's compare with what would happen if you don't use a multimodal chain. We use the same data as before but don't
# tell `ChainConsumer` that we expect the chains to be multimodal.

chain_unimodal = Chain(samples=df.copy(), name="posterior-unimodal", statistics=SummaryStatistic.HDI, multimodal=False)

cc.add_chain(chain_unimodal)
fig = cc.plotter.plot()

# %%
# Let's try with even more modes.

eta = np.asarray(
[
rng.normal(loc=-3, scale=0.8, size=size // 3),
rng.normal(loc=0.0, scale=0.8, size=size // 3),
rng.normal(loc=+3, scale=0.8, size=size // 3),
]
).flatten()


rng.shuffle(eta)

df = pd.DataFrame({"eta": eta, "phi": phi})

chain_multimodal = Chain(
samples=df.copy(), name="posterior-multimodal", statistics=SummaryStatistic.HDI, multimodal=True
)

cc = ChainConsumer()
cc.add_chain(chain_multimodal)
fig = cc.plotter.plot()
fig.tight_layout()
81 changes: 81 additions & 0 deletions docs/resources/generate_stats.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from matplotlib import rc
from scipy.stats import gamma

from chainconsumer import Chain, ChainConsumer
from chainconsumer.statistics import SummaryStatistic

# Activate latex text rendering
rc("font", family="serif", serif=["Computer Modern Roman"], size=13)
rc("text", usetex=True)
matplotlib.rcParams["text.latex.preamble"] = r"\usepackage{amsmath}"

x = np.linspace(0, 5, 100)

loc = 4
scale = 0.45

fig, axs = plt.subplots(nrows=2, ncols=1, sharex=True, height_ratios=[0.5, 0.5], figsize=(5, 5))
axs[0].plot(x, gamma.pdf(x, a=loc, scale=scale), color="black")
axs[1].plot(x, gamma.cdf(x, a=loc, scale=scale), color="black")


axs[1].set_xlabel("$x$")
axs[0].set_ylabel("$P(x)$")
axs[1].set_ylabel("$C(x)$")
axs[0].set_xlim(0, 5.0)
axs[0].set_ylim(0, 0.6)
axs[1].set_ylim(0, 1)

samples = pd.DataFrame.from_dict({"gamma": gamma.rvs(size=10_000_000, a=loc, scale=scale)})

summary_list = [
(SummaryStatistic.MAX, "MAX"),
(SummaryStatistic.CUMULATIVE, "CUMULATIVE"),
(SummaryStatistic.MEAN, "MEAN"),
(SummaryStatistic.HDI, "HDI"),
]

chains = []

for summary, name in summary_list:
chains.append(Chain(samples=samples, statistics=summary, name=name))

cc = ChainConsumer()

summary_result = cc.analysis.get_summary(chains=chains, columns=["gamma"])

for (_summary, name), color, linestyle, marker_style in zip(
summary_list,
["r", "g", "b", "y"],
[":", "--", "-", "-."],
["o", "^", "s", "*"],
strict=False,
):
bound = summary_result[name]["gamma"]

x_min, x_mid, x_max = bound.lower, bound.center, bound.upper

axs[0].scatter(x_mid, gamma.pdf(x_mid, a=loc, scale=scale), label=name, zorder=10, color=color, marker=marker_style)
axs[1].scatter(x_mid, gamma.cdf(x_mid, a=loc, scale=scale), zorder=10, color=color, marker=marker_style)

axs[0].vlines(
x=x_min, ymin=0, ymax=gamma.pdf(x_min, a=loc, scale=scale), color=color, linestyle=linestyle, alpha=0.5
)
axs[0].vlines(
x=x_max, ymin=0, ymax=gamma.pdf(x_max, a=loc, scale=scale), color=color, linestyle=linestyle, alpha=0.5
)

axs[1].hlines(
xmin=0, xmax=x_min, y=gamma.cdf(x_min, a=loc, scale=scale), color=color, linestyle=linestyle, alpha=0.5
)
axs[1].hlines(
xmin=0, xmax=x_max, y=gamma.cdf(x_max, a=loc, scale=scale), color=color, linestyle=linestyle, alpha=0.5
)

axs[0].legend(fontsize=8)
plt.tight_layout()
plt.savefig("stats.png", bbox_inches="tight")
Binary file modified docs/resources/stats.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 1 addition & 1 deletion docs/usage.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ In general, this is the flow:

## Statistics

When summarising chains, ChainConsumer offers several different methods. The below image shows the upper and lower bounds and central points for the "MEAN", "CUMULATIVE", and "MAX" methods respectively. The "MAX_CENTRAL" method is the blue central value and the red bounds.
When summarising chains, ChainConsumer offers several different methods. The below image shows the upper and lower bounds and central points for the `MAX`, `CUMULATIVE`, `MEAN` and `HDI` methods respectively, with their associated bounds.

![](resources/stats.png)

Expand Down
Loading