In [None]:
import sys
sys.path.append('../src')

import numpy as np
import pandas as pd
import plotly
import plotly.express as px
import plotly.graph_objs as go

import cb_utils

In [None]:
# Load and shape dataframe
za = cb_utils.get_table('vw_ds_all_z_yr' , use_cache=True)
x = (za[['attd_pcs_auth_utlzn',
         'attd_pcs_auth_hrs',
         'attd_pcs_visit_hrs',
         'uzn_attd_pcs_q1_of_zs',
         'uzn_attd_pcs_q2_of_zs',
         'uzn_attd_pcs_q3_of_zs',
         'uzn_attd_pcs_q4_of_zs'
        ]
       ] * 2).round()/2
x['attd_pcs_auth_utlzn'] = x['attd_pcs_auth_utlzn'].round()
x['lob'] = za['lob']
x['grp'] = za['grp']
x2 = x[x[['uzn_attd_pcs_q1_of_zs','uzn_attd_pcs_q2_of_zs','uzn_attd_pcs_q3_of_zs','uzn_attd_pcs_q4_of_zs']].notnull().all(1)]

# Inputs to turn into a function
df = x2.copy()
x_label = 'attd_pcs_auth_hrs'   # scatter
y_label = 'attd_pcs_auth_utlzn' # scatter
mask = df['grp'] == 2

f = go.FigureWidget(plotly.subplots.make_subplots(rows=2))
scatter_traces = [0, 1]
parcat_traces  = [2, 3, 4]
categorical_dimensions = [f'uzn_attd_pcs_q{i}_of_zs' for i in range(1, 5)] # for parcats
colorscale = [[0, 'light gray'], [1, 'firebrick']] # for parcats

# Scatter traces
f.add_scatter(x=df.loc[mask, x_label],
              y=df.loc[mask, y_label],
              marker={"size": 5 * df.loc[mask, 'grp'],
                      "color": df.loc[mask, y_label]
                     },
              mode='markers',
              marker_colorscale=plotly.colors.sequential.Viridis,
              name='Group 2'
             )

f.add_scatter(x=df.loc[~mask, x_label],
              y=df.loc[~mask, y_label],
              marker={"size": 5 * df.loc[~mask, 'grp'],
                      "color": df.loc[~mask, y_label]
                     },
              mode='markers',
              marker_colorscale=plotly.colors.sequential.Viridis,
              name='Group 3'
             )

# Parcats #1 - need a "total" parcat because you can't stack parcats on the same subplot
dimensions = [dict(values=df[label], label=label) for label in categorical_dimensions]
color = np.zeros(len(df), dtype='uint8')

f.add_parcats(
    domain={'y': [0, 0.4]},
    dimensions=dimensions,
    line={'colorscale': colorscale,
          'cmin': 0,
          'cmax': 3,
          'color': color,
          'shape': 'hspline'}
)

# Parcats #2 - hidden initially
dimensions = [dict(values=df.loc[mask, label], label=label) for label in categorical_dimensions]
color = np.zeros(len(df.loc[mask]), dtype='uint8')

f.add_parcats(
    domain={'y': [0, 0.4]},
    dimensions=dimensions,
    line={'colorscale': colorscale,
          'cmin': 0,
          'cmax': 3,
          'color': color,
          'shape': 'hspline'},
    visible=False
)

# Parcats #3 - hidden initially
dimensions = [dict(values=df.loc[~mask, label], label=label) for label in categorical_dimensions]
color = np.zeros(len(df.loc[~mask]), dtype='uint8')

f.add_parcats(
    domain={'y': [0, 0.4]},
    dimensions=dimensions,
    line={'colorscale': colorscale,
          'cmin': 0,
          'cmax': 3,
          'color': color,
          'shape': 'hspline'},
    visible=False
)

# Chart formatting
f.update_layout(
        height=800, xaxis={'title': 'attd_pcs_auth_hrs'},
        yaxis={'title': 'Auth Utilization', 'domain': [0.6, 1]},
        dragmode='lasso', hovermode='closest')

f.update_layout(
    updatemenus=[
        dict(buttons=list([
                dict(
                    args=[
                        {"type": ["scatter", "scatter", "parcats", "parcats", "parcats"],
                         "visible": [True, True, True, False, False]
                        }
                    ],
                    label="All",
                    method="restyle"
                ),
                dict(
                    args=[
                        {"type": ["scatter", "scatter", "parcats", "parcats", "parcats"],
                          "visible": [True, False, False, True, False],
                          "showlegend": [True, True, False, False, False]
                        }
                    ],
                    label="Group 2",
                    method="restyle"
                ),
                dict(
                    args=[
                        {"type": ["scatter", "scatter", "parcats", "parcats", "parcats"],
                         "visible": [False, True, False, False, True],
                         "showlegend": [True, True, False, False, False]
                        }
                    ],
                    label="Group 3",
                    method="restyle"
                )
            ]),
            type='buttons',
            direction="left",
            pad={"r": 10, "t": 10},
            showactive=True,
            x=0.1,
            xanchor="left",
            y=1.1,
            yanchor="top"
        ),
    ]
)

# Update color callback - for selecting points with clicks and drags
def update_color(trace, points, state):
    # Update scatter selection
    f.data[0].selectedpoints = points.point_inds
    f.data[1].selectedpoints = points.point_inds

    # Update parcats colors
    new_color = np.zeros(len(df), dtype='uint8')
    new_color[points.point_inds] = 1
    print(1)
    f.data[2].line.color = new_color
    f.data[2].line
    f.data[3].line.color = new_color
    f.data[3].line
    f.data[4].line.color = new_color
    f.data[4].line
    
# Make click and drag selection work on the scatters
for sct in scatter_traces:
    f.data[sct].on_selection(update_color)

# and parcats click
for par in parcat_traces:
    f.data[par].on_click(update_color)

f