In [1]:
from pathlib import Path
import pandas as pd
from bokeh.plotting import figure, ColumnDataSource, show
from bokeh.models import HoverTool
from bokeh.models.mappers import LinearColorMapper
from bokeh.palettes import brewer
from bokeh.io import output_notebook
from math import sqrt, log, exp

output_notebook()

In [29]:
def comet_chart(df, columns=None, **kwargs):
    """
    Generates comet-charts as originally described by Zan Armstrong with bokeh.figure.patches().
    See https://www.zanarmstrong.com/#/infovisresearch/ for details.
    
    A comet chart compares two scenarios: (weight_start, value_start) vs (weight_end, value_end).
    Weight_start and weight_end denotes the size of the population; and value_start and value_end denotes the metric.
    
    df:     Pandas dataframe which contains at least 4 columns as defined in value_columns as input for comet chart.
            Preceding columns df.iloc[:,:-4] may contain hierarchy of subpopulations or segments
            which will be shown in hover tooltip.
    columns:List of length 4 containing ['weight_start', 'weight_end', 'value_start', 'value_end'] data, in that order.
            When none, defaults to the last four columns, i.e. df.iloc[:,-4:].
            ,weight_start, weight_end, value_start, value_end datapoints for each record.
                
    """
    
    def values_to_points(_id, weight_start, weight_end, value_start, value_end):
        """Returns dict with xs, ys, delta_weight for single comet"""
        a = weight_end - weight_start
        b = value_end - value_start
        dist = sqrt(a**2 + b**2)
        halfwidth = dist/16
        comet = {
            '_ids': _id,
            '_delta_weight': a,
            '_xs': [weight_start, (halfwidth / dist) * b + weight_end, (-halfwidth / dist) * b + weight_end],
            '_ys': [value_start, (-halfwidth / dist) * a + value_end, (halfwidth / dist) * a + value_end]}
        return comet
    
    # check correct input: reset index for joining
    _df = df.reset_index()
    
    # parse data into dataframe of comets
    comets = []
    
    # TO DO: rewrite i) use columns if not None, ii) else last four columns in order
    # NB: order of columns is not fixed yet!!
    for row in _df.itertuples():
        comets.append(values_to_points(row[0], row[-2], row[-1], row[-4], row[-3]))
    cdf = pd.DataFrame(comets)
    jdf = pd.concat([_df, cdf], axis=1, join='inner')
    source = ColumnDataSource(jdf)
    
    # configure plot
    hover = HoverTool(tooltips= [("index", "$index"), ("(x,y)", "($x, $y)")] 
                      + [(x, '@'+x) for x in _df.columns])
    
    plot = figure(tools=[hover, 'box_zoom', 'reset'])
    
    # TO D): add option to choose palette
    color_mapper = LinearColorMapper(palette=brewer['RdBu'][11],
                                     high=source.data['_delta_weight'].max(),
                                     low=source.data['_delta_weight'].min(),
                                     )
    plot.patches('_xs', '_ys', source=source, 
                 fill_color={'field': '_delta_weight', 'transform': color_mapper},
                 fill_alpha=0.7,
                 line_color={'field': '_delta_weight', 'transform': color_mapper},
                )
    
    # customize plot
    
    return (source, plot)
        

In [30]:
# test using CDC wonder dataset
data = Path.cwd() / 'data.csv'
df = pd.DataFrame.from_csv(data)

# calculate log values
for col in df.columns[-4:]:
    df['log_' + col] = df[col].map(lambda x: log(x))
source, plot = comet_chart(df)
show(plot)

In [4]:
#TO DO: add hoover tool with population/segment characteristics, _weight_value
#TO DO: add brushed sortable bar chart

In [10]:
df1.head()

Unnamed: 0,state,birthweight,startvalue,endvalue,startweight,endweight,log_startvalue,log_endvalue,log_startweight,log_endweight,_delta_weight,_ids,_xs,_ys
0,Ohio,2500 - 2999 grams,5.53,4.88,101227,109151,1.710188,1.585145,11.525121,11.600488,0.075367,0,"[11.5251207987, 11.5926723614, 11.6083026859]","[1.71018781553, 1.58043479955, 1.58985564018]"
1,Ohio,1500 - 1999 grams,29.19,26.96,9078,9904,3.373826,3.294354,9.113609,9.200694,0.087085,1,"[9.11360918302, 9.19572700125, 9.2056609886]","[3.37382618487, 3.2889114853, 3.29979708679]"
2,Ohio,1000 - 1499 grams,67.39,59.15,4526,5038,4.210497,4.080077,8.417594,8.524764,0.107171,2,"[8.41759382619, 8.51661320388, 8.53291570995]","[4.21049663896, 4.073378426, 4.08677475484]"
3,Ohio,4000+ grams,1.96,1.24,65199,46140,0.672944,0.215111,11.085199,10.739436,-0.345764,3,"[11.0851994104, 10.7108209634, 10.7680501001]","[0.672944473242, 0.236721622032, 0.193501137202]"
4,Ohio,500 - 999 grams,336.39,287.95,3377,3445,5.818271,5.662787,8.124743,8.144679,0.019936,4,"[8.12474302039, 8.13496141175, 8.15439695515]","[5.81827120114, 5.66154084376, 5.66403286414]"


In [11]:
tooltips = [("index", "$index"), ("(x,y)", "($x, $y)")] + [(x, '$'+x) for x in df1.columns]
tooltips

[('index', '$index'),
 ('(x,y)', '($x, $y)'),
 ('state', '$state'),
 ('birthweight', '$birthweight'),
 ('startvalue', '$startvalue'),
 ('endvalue', '$endvalue'),
 ('startweight', '$startweight'),
 ('endweight', '$endweight'),
 ('log_startvalue', '$log_startvalue'),
 ('log_endvalue', '$log_endvalue'),
 ('log_startweight', '$log_startweight'),
 ('log_endweight', '$log_endweight'),
 ('_delta_weight', '$_delta_weight'),
 ('_ids', '$_ids'),
 ('_xs', '$_xs'),
 ('_ys', '$_ys')]

In [18]:
source.data['state']

0             Ohio
1             Ohio
2             Ohio
3             Ohio
4             Ohio
5             Ohio
6             Ohio
7             Ohio
8             Ohio
9          Georgia
10         Georgia
11         Georgia
12         Georgia
13         Georgia
14         Georgia
15         Georgia
16         Georgia
17         Georgia
18      New Jersey
19      New Jersey
20      New Jersey
21      New Jersey
22      New Jersey
23      New Jersey
24      New Jersey
25      New Jersey
26      New Jersey
27         Florida
28         Florida
29         Florida
          ...     
62        New York
63        New York
64        Illinois
65        Illinois
66        Illinois
67        Illinois
68        Illinois
69        Illinois
70        Illinois
71        Illinois
72        Illinois
73    Pennsylvania
74    Pennsylvania
75    Pennsylvania
76    Pennsylvania
77    Pennsylvania
78    Pennsylvania
79    Pennsylvania
80    Pennsylvania
81    Pennsylvania
82      California
83      Cali