Skip to content

Commit

Permalink
Add altair traceplot (#92)
Browse files Browse the repository at this point in the history
  • Loading branch information
ColCarroll authored and aloctavodia committed May 22, 2018
1 parent d5ff245 commit d607c7d
Show file tree
Hide file tree
Showing 4 changed files with 103 additions and 10 deletions.
55 changes: 52 additions & 3 deletions arviz/plots/traceplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,56 @@

from .kdeplot import fast_kde
from .plot_utils import make_2d, get_bins, _scale_text
from ..utils import get_varnames, trace_to_dataframe
from ..utils import get_varnames, trace_to_dataframe, untransform_varnames
from ..compat import altair as alt


def _var_to_traceplot(dataframe, varname, brush):
df = dataframe.reset_index().melt(id_vars='index')

trace = alt.Chart().mark_line().encode(
alt.X('index:Q', title='Sample'),
alt.Y('value:Q', title=varname),
color=alt.Color('variable:N'),
opacity=alt.value(0.4 + 0.6 / len(df.variable.unique())),
).properties(
selection=brush,
width=600,
height=200
)

if all(np.issubdtype(dtype, np.dtype('int')) for dtype in dataframe.dtypes.values):
base = alt.Chart().mark_bar()
else:
base = alt.Chart().mark_line()

kde = base.encode(
x=alt.X('value:Q', bin=alt.Bin(maxbins=100), title=varname),
y=alt.Y('count():Q', title='Number of Samples'),
color=alt.Color('variable:N'),
).transform_filter(
brush.ref()
).properties(
height=200
)

return alt.hconcat(kde, trace, data=df)


def traceplot_altair(dataframe):
"""Interactive traceplot using Altair
"""
all_vars, _ = untransform_varnames(dataframe.columns)
brush = alt.selection_interval(encodings=['x'])
charts = []
for base_name, varnames in all_vars.items():
charts.append(_var_to_traceplot(dataframe.loc[:, varnames], base_name, brush))
return alt.vconcat(*charts)


def traceplot(trace, varnames=None, figsize=None, textsize=None, lines=None, combined=False,
grid=True, shade=0.35, priors=None, prior_shade=1, prior_style='--', bw=4.5,
skip_first=0, ax=None):
skip_first=0, ax=None, altair=False):
"""Plot samples histograms and values.
Parameters
Expand Down Expand Up @@ -52,16 +96,21 @@ def traceplot(trace, varnames=None, figsize=None, textsize=None, lines=None, com
>>> pymc3.traceplot(trace, ax=axs)
Creates own axes by default.
altair : bool
Should returned plot be an altair chart.
Returns
-------
ax : matplotlib axes
"""
trace = trace_to_dataframe(trace[skip_first:], combined)
trace = trace_to_dataframe(trace[skip_first:], combined=combined)
varnames = get_varnames(trace, varnames)

if altair:
return traceplot_altair(trace.loc[:, varnames])

if figsize is None:
figsize = (12, len(varnames) * 2)

Expand Down
3 changes: 2 additions & 1 deletion arviz/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .utils import (trace_to_dataframe, get_stats, expand_variable_names, get_varnames,
_create_flat_names, log_post_trace, save_trace, load_trace)
_create_flat_names, log_post_trace, save_trace, load_trace,
untransform_varnames)
53 changes: 48 additions & 5 deletions arviz/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,55 @@
import gzip
import lzma
import os
import re

import numpy as np
import pandas as pd

__all__ = ['expand_variable_names', 'get_stats', 'get_varnames', 'log_post_trace',
'trace_to_dataframe', 'save_trace', 'load_trace']

def untransform_varnames(varnames):
"""Map transformed variable names back to their originals.
Mainly useful for dealing with PyMC3 traces.
Example
-------
untransform_varnames(['eta__0', 'eta__1', 'theta', 'theta_log__'])
{'eta': {'eta__0', 'eta_1'}, 'theta': {'theta'}}, {'theta': {'theta_log__'}}
Parameters
----------
varnames : iterable of strings
All the varnames from a trace
Returns
-------
(dict, dict)
A dictionary of names to vector names, and names to transformed names
"""
# Captures tau_log____0 or tau_log__, but not tau__0
transformed_vec_ptrn = re.compile(r'^(.*)__(?:__\d+)$')
# Captures tau__0 and tau_log____0, so use after the above
vec_ptrn = re.compile(r'^(.*)__\d+$')

varname_map = {}
transformed = {}
for varname in varnames:
has_match = False
for ptrn, mapper in ((transformed_vec_ptrn, transformed), (vec_ptrn, varname_map)):
match = ptrn.match(varname)
if match:
base_name = match.group(1)
if base_name not in mapper:
mapper[base_name] = set()
mapper[base_name].add(varname)
has_match = True
if not has_match:
if varname not in varname_map:
varname_map[varname] = set()
varname_map[varname].add(varname)
return varname_map, transformed


def expand_variable_names(trace, varnames):
Expand All @@ -17,7 +60,7 @@ def expand_variable_names(trace, varnames):
tmp = []
for vtrace in pd.unique(trace.columns):
for varname in varnames:
if '{}__'.format(varname) in vtrace or varname in vtrace:
if vtrace == varname or vtrace.startswith('{}__'.format(varname)):
tmp.append(vtrace)
return np.unique(tmp)

Expand Down Expand Up @@ -161,7 +204,7 @@ def _create_flat_names(varname, shape):

def _is_transformed_name(name):
"""
Quickly check if a name was transformed with `get_transormed_name`
Quickly check if a name was transformed with `get_transformed_name`
Parameters
----------
Expand All @@ -171,7 +214,7 @@ def _is_transformed_name(name):
Returns
-------
bool
Boolean, whether the string could have been produced by `get_transormed_name`
Boolean, whether the string could have been produced by `get_transformed_name`
"""
return name.endswith('__') and name.count('_') >= 3

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
REQUIREMENTS_FILE = os.path.join(PROJECT_ROOT, 'requirements.txt')

with open(REQUIREMENTS_FILE) as buff:
install_reqs = buff.read.splitlines()
install_reqs = buff.read().splitlines()


def copy_styles():
Expand Down

0 comments on commit d607c7d

Please sign in to comment.