In [1]:
# default_exp visualize.bokeh

# Visualize Bokeh

> Use bokeh to do visualize

In [2]:
#hide
from nbdev.showdoc import *

In [3]:
#export
from data_tool.imports import *
from bokeh.palettes import Spectral, Category20, Viridis256, Cividis256, Turbo256, Set2
from bokeh.io import output_file, show
from bokeh.palettes import Spectral6, Spectral
from bokeh.plotting import figure
from bokeh.models import ColumnDataSource, LabelSet, RangeTool, HoverTool, Label, Legend
from bokeh.transform import cumsum
from bokeh.plotting import figure
from bokeh.io import output_notebook, show
from bokeh.layouts import column
from bokeh.plotting.figure import get_range, get_scale

In [4]:
#export
def count2colors(count: int):
    """
    colors from `bokeh.palettes`
    range:
        1:12    -> Spectral, 
        12:21   -> Category20, 
        21:257  -> Viridis256, 
        257:769 -> Turbo256+Cividis256+Viridis256
    """
    if count <= 0:
        raise ValueError(f"Invalid count={count}, must > 0")
    elif count <= 3:
        return Set2[3][:count]
    elif count <= 8:
        return Set2[count]
    elif count <= 20:
        return Category20[count]
    elif count <= 256:
        index = np.arange(256)
        return tuple(np.array(Viridis256)[np.isin(index, np.arange(0, 256, 256/count, dtype=np.uint8))])
    elif count <= 256 * 3:
        return tuple(np.random.choice(Turbo256 + Cividis256 + Viridis256, count))
    else:
        raise ValueError(f"Invalid count={count}, must <= 768")

In [5]:
try:
    count2colors(-1)
except ValueError as e:
    logger.error(e)

2021-08-16 08:19:04.531 | ERROR    | __main__:<module>:4 - Invalid count=-1, must > 0


In [6]:
count2colors(5)

('#66c2a5', '#fc8d62', '#8da0cb', '#e78ac3', '#a6d854')

In [7]:
count2colors(18)

('#1f77b4',
 '#aec7e8',
 '#ff7f0e',
 '#ffbb78',
 '#2ca02c',
 '#98df8a',
 '#d62728',
 '#ff9896',
 '#9467bd',
 '#c5b0d5',
 '#8c564b',
 '#c49c94',
 '#e377c2',
 '#f7b6d2',
 '#7f7f7f',
 '#c7c7c7',
 '#bcbd22',
 '#dbdb8d')

In [8]:
try:
    count2colors(1000)
except ValueError as e:
    logger.error(e)

2021-08-16 08:19:08.830 | ERROR    | __main__:<module>:4 - Invalid count=1000, must <= 768


In [9]:
#export
def bar_figure(data: dict, title="统计", with_label=True, y_log_multiple_thresh=100, eps=1e-1, toolbar_location=None, **kwargs):
    x = list(data.keys())
    y = list(data.values())
    
    source = ColumnDataSource(data=dict(x=x, y=y, color=count2colors(len(x))))

    p = figure(
        x_range=x, title=title,
        y_axis_type="log" if max(y)/(min(y)+eps) > y_log_multiple_thresh else 'linear',
        toolbar_location=toolbar_location, **kwargs)
    p.vbar(x='x', top='y', width=0.9, bottom=0.1, color='color', legend_field="x", source=source)
    p.legend.orientation = "vertical"
    p.legend.location = "top_right"
    
    if with_label:
        labels = LabelSet(
            x='x', y='y', text='y', level='glyph',
            x_offset=-10, y_offset=2, source=source, render_mode='canvas', text_font_size="9pt")
        p.add_layout(labels)
    p.xgrid.grid_line_color = None
    p.xaxis.major_label_orientation = math.pi/12
    return p

In [10]:
output_notebook()

In [11]:
test_data = {
    'A': 10000,
    'B': 1000,
    'C': 100,
    'D': 10
}
p = bar_figure(test_data, toolbar_location=None)
show(p)

In [12]:
show(p)

In [13]:
test_data = {
    'A': 100,
    'B': 88,
    'C': 99,
    'D': 10
}
p = bar_figure(test_data)
show(p)

In [14]:
#export
def line_figure_datetime(
        data_list: list, fig_width=800, fig_height=500, x_dtype=np.datetime64,
        select_title:str=None, draw_circle=False, tooltips_metadata:list=None,
        x_name='x', y_name='y', legend_labels:list=None, tools='save,reset,pan', **kwargs
    ):
    if isinstance(data_list, dict): data_list = [data_list]
    if isinstance(tooltips_metadata, dict): tooltips_metadata = [tooltips_metadata]
    if not (isinstance(legend_labels, list) and len(legend_labels)) == len(data_list): legend_labels = None
    if not (isinstance(tooltips_metadata, list) and len(tooltips_metadata)) == len(data_list):
        tooltips_metadata = [{}]*len(data_list)
    p = figure(
        plot_height=int(0.65*fig_height), plot_width=fig_width,
        x_axis_type="datetime", x_axis_location="above",
        background_fill_color="#efefef", x_range=None,
        tools=tools, **kwargs
    )
    print(p.tools)
    colors = count2colors(len(data_list))
    p.yaxis.axis_label = y_name

    select_title = select_title if select_title is not None else f"拖动选择 {x_name} 区间"
    select = figure(
        title=select_title, tools=tools,
        plot_height=int(0.35*fig_height), plot_width=fig_width, y_range=p.y_range,
        x_axis_type="datetime", y_axis_type=None, background_fill_color="#efefef"
    )

    max_x = []
    for i, data in enumerate(data_list):
        if not isinstance(data, dict):
            logger.error(f"invalid data_list[{i}] data format, must be dict")
        x = [str(o) for o in data.keys()]
        if x_dtype: x = list(np.array(x, dtype=x_dtype))
        y = list(data.values())
        metadata = {}
        if isinstance(tooltips_metadata[i], dict):
            for k, v in tooltips_metadata[i].items():
                if isinstance(v, list) and len(v) == len(x):
                    metadata[k] = v
                else:
                    logger.warning(f"tooltips_metadata->{k} values invalid, type={type(v)} length={len(v)}")
        tooltips = [(x_name, "@{"f"{x_name}""}{%F}"),(y_name, "@{"f"{y_name}""}")]
        for k in metadata: tooltips.append((k, "@{"f"{k}""}"))
        column_data = {
            x_name: x,
            y_name: y
        }
        source = ColumnDataSource(data=dict(**column_data, **metadata))
        legend_kwargs = {} if legend_labels is None else {'legend_label': legend_labels[i]}
        p.line(x_name, y_name, source=source, line_color=colors[i], **legend_kwargs)
        if draw_circle: p.circle(x_name, y_name, fill_color='white', size=6, source=source, line_color=colors[i], **legend_kwargs)
        if len(x) > len(max_x): max_x = x
        p.add_tools(HoverTool(tooltips=tooltips, formatters={"@{"f"{x_name}""}": 'datetime'}))

        select.line(x_name, y_name, source=source, line_color=colors[i])
    p.x_range = get_range((max_x[len(max_x)//2-min(len(max_x)//8, 300)], max_x[len(max_x)//2+min(len(max_x)//8, 300)]))
    p.x_scale = get_scale(p.x_range, 'datetime')

    range_tool = RangeTool(x_range=p.x_range)
    range_tool.overlay.fill_color = "navy"
    range_tool.overlay.fill_alpha = 0.2

    select.ygrid.grid_line_color = None
    select.add_tools(range_tool)
    select.toolbar.active_multi = range_tool
    return column(p, select)

In [15]:
#hide
"""
# old version
def line_figure_datetime(
        data: dict, fig_width=800, fig_height=500, x_dtype=np.datetime64,
        select_title:str=None, draw_circle=False, tooltips_metadata:dict=None,
        x_name='x', y_name='y', toolbar_location=None, **kwargs
    ):
    x = [str(o) for o in data.keys()]
    if x_dtype: x = list(np.array(x, dtype=x_dtype))
    y = list(data.values())
    metadata = {}
    if isinstance(tooltips_metadata, dict):
        for k, v in tooltips_metadata.items():
            if isinstance(v, list) and len(v) == len(x):
                metadata[k] = v
            else:
                logger.warning(f"tooltips_metadata->{k} values invalid, type={type(v)} length={len(v)}")
    tooltips = [(x_name, "@{"f"{x_name}""}{%F}"),(y_name, "@{"f"{y_name}""}")]
    for k in metadata: tooltips.append((k, "@{"f"{k}""}"))
    column_data = {
        x_name: x,
        y_name: y
    }
    source = ColumnDataSource(data=dict(**column_data, **metadata))
    p = figure(
        plot_height=int(0.65*fig_height), plot_width=fig_width, tools="xpan",
        x_axis_type="datetime", x_axis_location="above",
        background_fill_color="#efefef", x_range=(x[len(x)//2-min(len(x)//8, 300)], x[len(x)//2+min(len(x)//8, 300)]),
        toolbar_location=toolbar_location, **kwargs
    )

    p.line(x_name, y_name, source=source)
    if draw_circle: p.circle(x_name, y_name, fill_color='white', size=6, source=source)
    p.yaxis.axis_label = y_name
    p.add_tools(HoverTool(tooltips=tooltips, formatters={"@{"f"{x_name}""}": 'datetime'}, mode='vline'))
    # if isinstance(legend_list, list) and len(legend_list) == (len(column_data)+len(metadata)):
    #     legend_x = list(column_data.keys()) + list(metadata.keys())
    #     for l_x, l_y in zip(legend_x, legend_list):
    #         p.circle(0, 0, size=0.00000001, color= "#ffffff", legend_label=f"{l_x}: {l_y}")
    #     p.legend.label_text_font_style = "italic"
    #     # p.legend.label_text_color = "navy"
    #     p.legend.orientation = "vertical"
    #     p.legend.location = "top_right"
    #     p.legend.label_text_font_size = "8pt"
    select_title = select_title if select_title is not None else f"拖动选择 {x_name} 区间"
    select = figure(
        title=select_title, toolbar_location=toolbar_location,
        plot_height=int(0.35*fig_height), plot_width=fig_width, y_range=p.y_range,
        x_axis_type="datetime", y_axis_type=None, background_fill_color="#efefef"
    )

    range_tool = RangeTool(x_range=p.x_range)
    range_tool.overlay.fill_color = "navy"
    range_tool.overlay.fill_alpha = 0.2

    select.line(x_name, y_name, source=source)
    select.ygrid.grid_line_color = None
    select.add_tools(range_tool)
    select.toolbar.active_multi = range_tool
    return column(p, select)
"""



In [16]:
from bokeh.sampledata.stocks import AAPL

test_data = { x:y for x,y in zip(AAPL['date'][:300], AAPL['close'][:300])}
p = line_figure_datetime(test_data, title="趋势图", draw_circle=True)
show(p)

[SaveTool(id='1306', ...), ResetTool(id='1307', ...), PanTool(id='1308', ...)]


In [17]:
AAPL.keys()


dict_keys(['date', 'open', 'high', 'low', 'close', 'volume', 'adj_close'])

In [18]:
tooltips_metadata = {'体量': AAPL['volume'][:300]}
p = line_figure_datetime(test_data, title="趋势图", draw_circle=True,
                         tooltips_metadata=tooltips_metadata, x_name='时间',
                         y_name='价格')
show(p)

[SaveTool(id='1628', ...), ResetTool(id='1629', ...), PanTool(id='1630', ...)]


In [19]:
#export
def get_count_data_datetime(date_df, count_column, resample_mode='d', extra_meta_column:str=None):
    resample_date_data = date_df.resample(resample_mode).sum()[count_column].to_dict()
    count_data = {}
    extra_output_data = {k: [] for k in date_df[extra_meta_column].value_counts().index} if isinstance(extra_meta_column, str) else None
    for i, k in enumerate(resample_date_data):
        if k not in date_df.index:
            continue
        count_data[k] = resample_date_data[k]
        if isinstance(extra_output_data, dict):
            day_df = date_df.loc[k]
            for extra_k in extra_output_data: extra_output_data[extra_k].append(0)
            try:
                if hasattr(day_df, 'iterrows'):
                    for j, row in day_df.iterrows():
                        extra_output_data[row[extra_meta_column]][-1] += row[count_column]
                else:
                    extra_output_data[day_df[extra_meta_column]][-1] += day_df[count_column]
            except Exception as e:
                logger.error(day_df)
                raise e
                break
    return (count_data, extra_output_data) if extra_output_data is not None else count_data

In [20]:
import pandas as pd
date_df = pd.DataFrame(AAPL)
date_df['date'] = pd.to_datetime(date_df['date'])
date_df = date_df.set_index('date')
date_df.head()

Unnamed: 0_level_0,open,high,low,close,volume,adj_close
date,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
2000-03-01,118.56,132.06,118.5,130.31,38478000,31.68
2000-03-02,127.0,127.94,120.69,122.0,11136800,29.66
2000-03-03,124.87,128.23,120.0,128.0,11565200,31.12
2000-03-06,126.0,129.13,125.0,125.69,7520000,30.56
2000-03-07,126.44,127.44,121.12,122.87,9767600,29.87


In [21]:
len(get_count_data_datetime(date_df, count_column='volume'))

3270

In [22]:
p = line_figure_datetime(get_count_data_datetime(date_df, count_column='volume'), title="趋势图")
show(p)

[SaveTool(id='1988', ...), ResetTool(id='1989', ...), PanTool(id='1990', ...)]
