In [32]:
import pandas as pd
import numpy as np
pslice=pd.IndexSlice
import asyncio
import numpy as np
import panel as pn
import box
from box import Box
from itertools import combinations
import plotly
from plotly import express as px
import plotly.graph_objects as go
import logging
logging.basicConfig(level=logging.INFO)
from functools import partial
import time
from copy import copy,deepcopy
from dataclasses import dataclass
import inspect
import typing
import param
pn.extension(
    'plotly', 
    admin=True,
    global_loading_spinner=True
             )

### temporary imports
import importlib
# from pympler.asizeof import asizeof
# import sidecar as SC

from genCodeModule import genCode as gc
# %run -i './genCodeModule/genCode.py'
from helper import sankey_helper as h
from helper import sankey_style as ss
# for editing
importlib.reload(h)
importlib.reload(ss)
importlib.reload(gc)

plotly.io.renderers.default = "plotly_mimetype+notebook"
from IPython.display import HTML

#### gen_filter_set

def gen_filter_set(
    timing_filter=h.filters.timing_filter.vals[0],
    user_filter=h.filters.user_filter.vals[0],
    week_filter=h.filters.week_filter.vals[:1].tolist(),
    position_filter=h.filters.position_filter.vals.to_list(),
    starter_filter=h.filters.starter_filter.vals.to_list(),
    level_filter=h.filters.level_filter.vals.to_list(),
):
    return Box(
            timing_filter=timing_filter,
            user_filter=user_filter,
            week_filter=week_filter,
            position_filter=position_filter,
            starter_filter=starter_filter,
            level_filter=level_filter
    )

def apply_filter_set(fs=None,df_levels=h.df_levels):
    # call it df_l
    if fs is None: fs = gen_filter_set()
    return \
    df_levels[
        (df_levels.timing==fs.timing_filter) &\
        (df_levels.rid==fs.user_filter) &\
        (df_levels.week.isin(fs.week_filter)) &\
        (df_levels.primary_pos.isin(fs.position_filter)) &\
        (df_levels.starter_bn.isin(fs.starter_filter))
    ].reset_index(drop=True)


def get_links_from_levels(levels):
    links = [levels[idx:idx+2] for idx in range(len(levels)-1)]
    return links

def inner_gbapply_df_ll(col):
    # pid, slot, source_id, target_id, source_level, target_level, pts, link_label_body_row
    match col.name:
        case 'pid': return col.values[0]
        case 'slot': return '|'.join(col.values)
        case 'source_id'|'target_id': return '|'.join(col.drop_duplicates().values)
        case 'source_level'|'target_level': return col.values[0]
        case 'pts': return col.sum()
        case 'link_label_body_row':return '<br>'.join(col)

def get_df_ll(df_l,levels):
    df_ll=None
    if len(levels)>2:
        return pd.concat([get_df_ll(df_l,link) for link in get_links_from_levels(levels)],ignore_index=True)
    df_ll =\
    df_l.assign(
        source_id=df_l[levels[0]],
        target_id=df_l[levels[1]],
        source_level=levels[0],
        target_level=levels[1],
    ).groupby(
        ['pid','slot','source_id','target_id','source_level','target_level'],
        as_index=False,
    )[['pts']].sum().query('pts>=.05')
    df_ll['link_label_body_row']=df_ll.apply(lambda row:h.get_link_label_body_row(row.slot,row.pts),axis=1)
    df_ll=\
    df_ll.groupby(['pid','source_id','target_id'],as_index=False,group_keys=False,sort=False,)[df_ll.columns].apply(
        lambda df:df.sort_values('pts',ascending=False).apply(inner_gbapply_df_ll),
    )
    df_ll['link_label_header']=df_ll.pid.map(h.get_label_header)
    return df_ll

def get_pid_node_order(df_l):
    # maybe: df_l.query('pts>=.05') ???
    pid_node_order=\
    df_l.sort_values('pts',ascending=False).drop_duplicates(subset='pid',keep='first')\
        .sort_values(['slot','pts'],key=lambda col: col.map(list(h.label_maps.slot.keys()).index) if col.name=='slot' else col,
                     ascending=[True,False],
                     ignore_index=True
                    ).pid
    return pid_node_order
    
def get_node_order(df_l,levels):
    if 'pid' in levels:
        return np.hstack((get_pid_node_order(df_l),h.non_pid_nodes)).tolist()
    else:
        return h.non_pid_nodes.tolist()

def get_node_labels(df_l,df_ll,levels,level=None):
    if level is None:
        return pd.concat([get_node_labels(df_l,df_ll,levels,level) for level in levels],ignore_index=True)
        
    if level=='pid':
        pid_node_hoverlabels = df_l.groupby(['pid',levels[1]],as_index=False)[['pts']].sum().sort_values(['pid','pts'],ascending=[True,False])
        pid_node_hoverlabels['node_label_body_row']=pid_node_hoverlabels.apply(lambda row: h.get_node_label_body_row(row[levels[1]],row.pts),axis=1)
        pid_node_hoverlabels=\
        pid_node_hoverlabels.groupby('pid',as_index=False,group_keys=False,sort=False).apply(lambda df:'<br>'.join(df.query('pts>=.05').node_label_body_row),
                                                                                             include_groups=False).rename(columns={'pid':'node_id',None:'node_label_body_row'})
        return pid_node_hoverlabels
    else:
        source_or_target = 'source' if level==levels[0] else 'target'
        non_pid_node_hoverlabels = df_ll.query(f'{source_or_target}_level==@level').groupby([f'{source_or_target}_id','pid'],as_index=False)[['pts']].sum()\
            .sort_values([f'{source_or_target}_id','pts'],ascending=[True,False],ignore_index=True)
        
        non_pid_node_hoverlabels['node_label_body_row']=non_pid_node_hoverlabels.apply(lambda row: h.get_node_label_body_row(row.pid,row.pts),axis=1)
        non_pid_node_hoverlabels=\
        non_pid_node_hoverlabels.groupby(f'{source_or_target}_id',as_index=False,group_keys=False,sort=False).apply(lambda df:'<br>'.join(df.query('pts>=.05').iloc[:5].node_label_body_row),
                                                                                             include_groups=False).rename(columns={f'{source_or_target}_id':'node_id',None:'node_label_body_row'})
        # add ellipse
        # non_pid_node_hoverlabels['node_label_body_row']=non_pid_node_hoverlabels.node_label_body_row+'<br>...'
        return non_pid_node_hoverlabels


def _get_df_l_gb(df_l=None,df_nf=None,levels=None,sankeyPrep=None):
    if sankeyPrep is not None: df_l,df_nf,levels=sankeyPrep.df_l,sankeyPrep.df_nf,sankeyPrep.fs.level_filter
    # TODO: DANGER: below the line: `.query('node_id in @df_nf.node_id.values')\` should probably be removed...
    return\
    df_l.melt(value_vars=levels,
              id_vars='pts',var_name='level',
              value_name='node_id').groupby(['level','node_id'],as_index=False)[['pts']].sum()\
            .query('pts>.05')\
            .query('node_id in @df_nf.node_id.values')\
            .sort_values('node_id',key=lambda col:col.map(df_nf.node_id.tolist().index),ignore_index=True)

def _get_space_size(df_nf,levels,frac_empty=.5): #.7 or higher more desirable
    level0=df_nf.groupby('level').node_id.count().sort_values(ascending=False).index[0]
    df_nf_level0=df_nf.query('level==@level0')
    total_pts = df_nf_level0.pts.sum()
    window_pts=total_pts/(1-frac_empty)
    num_spaces=df_nf_level0.shape[0]-1 
    space_size=((frac_empty*total_pts)/((1 - frac_empty)*num_spaces)) # calc:: num_spaces*space_size/(num_spaces*space_size+total_pts) = 1
    return space_size,window_pts # to turn off :: # return 0,total_pts

def _get_y(df,space_size):
    return \
    df.eval('pts_space=pts+@space_size')\
    .eval('pts_half=pts/2')\
    .apply(lambda col: col.cumsum()-space_size if col.name=='pts_space' else col)\
    .apply(lambda row: row.pts_space-row.pts_half,axis=1)

def _get_xy(df_l=None,df_nf=None,levels=None,sankeyPrep=None):
    if sankeyPrep is not None: df_l,df_nf,levels=sankeyPrep.df_l,sankeyPrep.df_nf,sankeyPrep.fs.level_filter

    df_lgb=_get_df_l_gb(df_l=df_l,df_nf=df_nf,levels=levels)
    
    x_map=pd.Series(np.linspace(.01,.99,len(levels)),index=levels)
    df_lgb['x']=x_map.loc[df_lgb.level].values
    
    space_size,window_pts = _get_space_size(df_lgb,levels)
    df_lgb['y']=df_lgb.groupby('level').apply(lambda df: _get_y(df,space_size),include_groups=False).reset_index(level=0,drop=True)
    # shift to bottom
    _ratio_adjust=1 #1 most desirable
    df_lgb['y']=df_lgb.groupby('level',as_index=False).apply(lambda df: df.y+_ratio_adjust*(window_pts-df.eval('y+pts/2').max()),include_groups=False).reset_index(level=0,drop=True)/window_pts
    
    df_lgb.attrs['space_size']=space_size
    df_lgb.attrs['window_pts']=window_pts
    return df_lgb


def get_df_nodes_filtered(df_l,df_ll,levels,dfsb=h.dfsb,df_nodes_all=h.df_nodes_all):
    # df_nodes_all=h.df_nodes_all # get_df_nodes_all(dfsb)
    df_nf = df_nodes_all[df_nodes_all.node_id.isin(np.unique(df_l.query('pts>=.05')[levels]))].reset_index(drop=True)
    
    node_order = get_node_order(df_l,levels)
    df_nf.sort_values('node_id',key=lambda col:col.map(node_order.index),ignore_index=True,inplace=True)
    
    node_labels = get_node_labels(df_l,df_ll,levels)
    df_nf['node_label_body']=node_labels.set_index('node_id').loc[df_nf.node_id].node_label_body_row.values
    df_nf['node_label_header']=df_nf.node_id.map(h.get_label_header)
    
    df_nf_xy = _get_xy(df_l=df_l,df_nf=df_nf,levels=levels)
    df_nf=df_nf.merge(df_nf_xy[['node_id','pts','x','y']],how='left',left_on='node_id',right_on='node_id')
    df_nf.attrs|=df_nf_xy.attrs
    return df_nf
    
def sort_df_ll(df_l,df_ll,df_nf,levels):
    def inner_sort_df_ll(col):
        match col.name:
            case 'slot':
                return col.map(lambda slot:slot.split('|')[0]).map(list(h.label_maps.slot.keys()).index)
            case 'pts':
                # neg. to avoid using ascend
                return -col
            case 'pid':
                node_order = get_pid_node_order(df_l).to_list() if levels[0]!='pid' else df_nf.node_id.tolist()
                return col.map(node_order.index)
    sort_cols=['slot','pid','pts']
    return df_ll.sort_values(sort_cols,key=inner_sort_df_ll,ignore_index=True)

def get_node_or_link_colors(node_or_link,df_l=None,df_ll=None,df_nf=None,alt=False):
    """node_or_link :: 'node' or 'link'
    - links need df_l and df_ll
    - nodes need df_nf
    """
    colorCol='color' if alt==False else 'color_alt'
    match node_or_link:
        case 'node': return h.df_colors[colorCol].loc[df_nf.primary_pos] \
                            if 'color' not in df_nf.columns else df_nf.color
        case 'link': return h.df_colors[colorCol].loc[df_l.drop_duplicates(subset='pid').set_index('pid').primary_pos.loc[df_ll.pid]] \
                            if 'color' not in df_ll.columns else df_ll.color

def mask_node_labels(df_nf):
    if 'color' in df_nf.columns:
        return df_nf.node_label.where(df_nf.color=='lawngreen','')
    else:
        return df_nf.node_label
    
def get_node_hovertemplate(main_or_hover):
    hover_template = '<b><i>%{label}</i></b><br><b>%{customdata}</b>'
    extra_section='<extra>%{value}</extra>' if main_or_hover=='main' else '<extra></extra>'
    hover_template+=extra_section
    return hover_template

def get_link_hovertemplate(main_or_hover):
    if main_or_hover=='hover': return None
    hover_template = '<b><i>%{label}</i></b><br>%{customdata}'
    extra_section='<extra>%{value}</extra>'
    hover_template+=extra_section
    return hover_template

#### get_sankey_object

def get_sankey_object(df_l,df_ll,df_nf):
    mainOrHover=df_l.attrs['main_or_hover']
    sank=go.Sankey(
        node=dict(
            label=mask_node_labels(df_nf),
            customdata=h.apply_fmt(df_nf.node_label_body),
            hovertemplate=get_node_hovertemplate(mainOrHover),
            hoverlabel=dict(
                align='right',
                font=dict(
                    color='white',
                    family=h.font_family,
                ),
                bgcolor=None if mainOrHover=='main' else 'gray',
                bordercolor=None if mainOrHover=='main' else 'lawngreen' 
            ),
            color=get_node_or_link_colors('node',df_nf=df_nf),
            x=None if 'x' not in df_nf.columns else df_nf.x,
            y=None if 'y' not in df_nf.columns else df_nf.y,
            pad=50,
        ),
        link=dict(
            source=df_ll.source_id.map(df_nf.node_id.tolist().index),
            target=df_ll.target_id.map(df_nf.node_id.tolist().index),
            value=df_ll.pts,
            label=h.apply_fmt(df_ll.link_label_header),
            customdata=h.apply_fmt(df_ll.link_label_body_row),
            hovertemplate=get_link_hovertemplate(mainOrHover),
            hoverlabel=dict(
                align='right',
                font=dict(
                    color='white',
                    family=h.font_family,
                ),
            ),
            color=get_node_or_link_colors('link',df_l=df_l,df_ll=df_ll),
        ),
        arrangement='perpendicular',
        valueformat='d',
        valuesuffix=' pts',
        domain=dict(
            y=[.05,.95],
            x=[.01,.99]
        ),
    )
    return sank

def get_sankey_figure(sankey_object):
    fig=go.Figure(
        data=[sankey_object],
        layout=dict(
            font_family=h.font_family,
            font_color='white',
            paper_bgcolor=h.bg_color,
            autosize=True,
            margin=dict(pad=0),
            template='plotly',
        ),
    )
    return fig
    
def get_plotly_pane(fig):
    pane=pn.pane.Plotly(fig,sizing_mode='stretch_both')
    return pane

@dataclass
class SankeyPrep:
    def __init__(self,fs=gen_filter_set(),df_l=None):
        self.fs = fs
        self.df_l = df_l if df_l is not None else apply_filter_set(fs)
        self.df_ll=get_df_ll(self.df_l,fs.level_filter)
        self.df_nf = get_df_nodes_filtered(self.df_l,self.df_ll,fs.level_filter)
        self.df_ll=sort_df_ll(self.df_l,self.df_ll,self.df_nf,fs.level_filter)
        
        self.df_l.attrs['main_or_hover']='main'

        self.sank = get_sankey_object(self.df_l,self.df_ll,self.df_nf)
        self.fig = get_sankey_figure(self.sank)
        self.pane = get_plotly_pane(self.fig)

@dataclass
class SankeyHover:
    def __init__(self,sankeyPrep,hoverNode):
        self.hoverNode=hoverNode
        self.level = h.df_nodes_all.query('node_id==@hoverNode').level.squeeze()
        df_lin = sankeyPrep.df_l.query(f'{self.level}==@hoverNode')
        df_lout = sankeyPrep.df_l.query(f'{self.level}!=@hoverNode')
        
        df_llin = get_df_ll(df_lin,sankeyPrep.fs.level_filter)
        df_llout = get_df_ll(df_lout,sankeyPrep.fs.level_filter)
        
        df_llin['color']='lawngreen'
        df_llout['color']=get_node_or_link_colors('link',df_l=df_lout,df_ll=df_llout,alt=True).values
        
        df_linout = pd.concat((df_lin,df_lout),ignore_index=True)
        df_llinout = pd.concat((df_llin,df_llout),ignore_index=True)
        
        df_nfinout=get_df_nodes_filtered(df_linout,df_llinout,sankeyPrep.fs.level_filter)
        df_nfinout['color']=get_node_or_link_colors('node',df_linout,df_llinout,df_nfinout,alt=True).values
        df_nfinout['color']=df_nfinout.color.where(~df_nfinout.node_id.isin(np.unique(df_llin[['source_id','target_id']])),'lawngreen')
        
        df_llinout=sort_df_ll(df_linout,df_llinout,df_nfinout,sankeyPrep.fs.level_filter)
        
        self.df_l=df_linout
        self.df_ll=df_llinout
        self.df_nf=df_nfinout
        
        self.df_l.attrs['main_or_hover']='hover'
        
        self.sank = get_sankey_object(self.df_l,self.df_ll,self.df_nf)
        self.fig = get_sankey_figure(self.sank)
        # comment out pane when in production
        self.pane = get_plotly_pane(self.fig)

#### boxify

@gc.listicate('objs')
def boxify(objs,title=None,include_header=True,align_header='center'):
    styles={'padding':"{top_pad}px 10px 10px 10px".format(top_pad=0 if include_header==True else 10)}
    some_box=pn.Column(styles=styles)
    if include_header==True: 
        header=pn.pane.HTML(f'<h2>{title}</h2>',margin=0,align=align_header,styles=dict(padding='0px'),tags=['text','boxify'])
        some_box.append(header)
    body = pn.Row(*objs,align='center')
    some_box.append(body)
    return some_box

def get_level_box(sankeyWidget):
    lw1 = pn.widgets.CheckButtonGroup(
        options=sankeyWidget.param.levelParam.names,
        value=sankeyWidget.levelParam,
        align='center',
        button_style='outline',
        button_type='light',
        margin=0,
        styles=dict(
            padding='0px',
        ),
        tags=['lw1','level'],
    )
    lwbox = \
    boxify(
        lw1,
        title='levels',
    )
    lwbox.param.update(
        height=110,
        tags=['level_widget_box','text'],
    )
    def update_sankeyWidget(event):
        sankeyWidget.param.update(levelParam=event.new)
    lw1.param.watch(update_sankeyWidget,'value')
    return lwbox

def get_timing_box(sankeyWidget):
    tw1 = pn.widgets.RadioButtonGroup(
        options=sankeyWidget.param.timingParam.names,
        value=sankeyWidget.timingParam,
        align='center',
        button_style='outline',
        button_type='light',
        margin=0,
        styles=dict(
            padding='0px',
        ),
        tags=['tw1','timing'],
    )
    twbox = \
    boxify(
        tw1,
        title='perspective',
    )
    twbox.param.update(
        height=110,
        tags=['timing_widget_box','text'],
    )
    def update_sankeyWidget(event):
        sankeyWidget.param.update(timingParam=event.new)
    tw1.param.watch(update_sankeyWidget,'value')
    return twbox

def get_position_box(sankeyWidget):
    pwO=\
        pn.widgets.CheckButtonGroup(
            options=h.pos_list_O,
            value=[pos for pos in h.pos_list_O if pos in sankeyWidget.positionParam],
            orientation='vertical',
            align='start',
            button_style='outline',
            button_type='light',
            tags=['pwO','text','position'],
        )
    pwO_button = pn.widgets.Button(
        name = 'offense',
        align='center',
        tags=['pwO_button','text','position'],
    )
    pwD=pwO.clone(
        options=h.pos_list_D,
        value=[pos for pos in h.pos_list_D if pos in sankeyWidget.positionParam],
        tags=['pwD','text','position'],
    )
    pwD_button = pwO_button.clone(
        name='defense',
        tags=['pwD_button','text','position'],
    )
    col_styles={
        'border-left':'2px white solid',
    }
    pbox=\
    boxify(
        pn.Column(
            pn.Column(
                pwO_button,pwO,
                styles = col_styles,
            ),
            pn.Spacer(height=15),
            pn.Column(
                pwD_button,pwD,
                styles = col_styles,
            )
        ),
        title='player<br>position'
    )
    pbox.param.update(
        width=150,
        tags=['position_widget_box']
    )
    def update_sankeyWidget(event):
        if (len(event.new)<len(event.old)) or (len(event.new)>len(event.old)) \
            and any([v not in sankeyWidget.positionParam for v in event.new]):
            sankeyWidget.param.update(
                positionParam=[v for v in sankeyWidget.positionParam if v not in event.obj.options] + event.new)
    
    def offense_defense_button_callback(event):
        pw_this,pw_other = [[pwO,pwD],[pwD,pwO]]['pwD_button' in event.obj.tags]
        if set(pw_this.value)==set(pw_this.options): # if all on, turn all off
            sankeyWidget.param.update(
                positionParam= [v for v in sankeyWidget.positionParam if v in pw_other.options])
            pw_this.param.update(value=[])
        else: # else turn all on
            sankeyWidget.param.update(
                positionParam= [v for v in sankeyWidget.positionParam if v in pw_other.options] + pw_this.options)
            pw_this.param.update(value=pw_this.options)
                 
    pwO.param.watch(update_sankeyWidget,'value')
    pwD.param.watch(update_sankeyWidget,'value')
    pwO_button.on_click(offense_defense_button_callback)
    pwD_button.on_click(offense_defense_button_callback)
    return pbox

def get_week_box(sankeyWidget):
    ww1 =pn.widgets.IntSlider(value=sankeyWidget.weekParam[0],
                              start=sankeyWidget.param.weekParam.objects[0],
                              end=sankeyWidget.param.weekParam.objects[-1],
                              step=1,
                              show_value=False,
                              tags=['ww1','week'],
                              value_throttled=True,
                              tooltips=False,
                              design=pn.theme.Material,
                              stylesheets=[':host { --design-primary-color: lawngreen; }'],
                             )
    ww2 = ww1.clone(tags=['ww2','week'],value=sankeyWidget.weekParam[-1])
    wws = pn.pane.Str(tags=['wws','text','week'],
                      align='center',
                      margin=0,
                      styles=dict(
                          padding='0px',
                      ))
    
    wwbox = \
    boxify(pn.Column(*[wws,ww1,ww2],
                     width=120,
                     styles=dict(padding='0px 10px')
                    ),
           title='weeks'
          )
    wwbox.param.update(
        width=150,
        tags=['week_widget_box']
    )
    def update_wws(event=None):
        ww1_val = int(ww1.value)
        ww2_val = int(ww2.value)
        if (ww1_val == ww2_val) or (False if event is None else ('ww1' in event.obj.tags)):
            week_string=f'wk {ww1_val}'
        else:
            week_string = f'wks {min(ww1_val,ww2_val)} - {max(ww1_val,ww2_val)}'
        wws.param.update(object=week_string)
    
    def align_ww2(event):
        ww2.param.update(value=event.new)
    update_wws()
    
    def update_sankeyWidget(event):
        if 'ww1' in event.obj.tags:
            sankeyWidget.param.update(weekParam = [event.new])
        elif 'ww2' in event.obj.tags:
            sankeyWidget.param.update(weekParam = np.arange(min(ww1.value,event.new),max(ww1.value,event.new)+1).tolist())
    
    
    ww1.param.watch(update_wws,'value')
    ww2.param.watch(update_wws,'value')
    ww1.param.watch(align_ww2,'value_throttled')
    
    ww1.param.watch(update_sankeyWidget,'value_throttled')
    ww2.param.watch(update_sankeyWidget,'value_throttled')
    
    return wwbox

def get_user_box(sankeyWidget):
    uw1 = pn.widgets.MenuButton(
        name=list(sankeyWidget.param.userParam.names)[sankeyWidget.fs.user_filter-1],
        items = list({k:str(v) for k,v in sankeyWidget.param.userParam.names.items()}.items()),
                            width=120,
                            styles=dict(
                                padding='0px 10px 10px 10px',
                                margin='0px 0px',
                            ),
                            align='center',
                            tags=['uw1','text','user'],
                            # design=pn.theme.Material,
                            # stylesheets=[':host { --design-primary-color: lawngreen; --design-secondary-color: white;}'],
                           )
    uwbox=\
        boxify(uw1,
               title='user'
              )
    uwbox.param.update(
        width=150,
        height=110,
        tags=['user_widget_box'],
    )
    
    def update_sankeyWidget(event):
        sankeyWidget.param.update(userParam=int(event.new))

    def update_name(event):
        uw1.param.update(name=h.filters.user_filter.fmt[int(event.new)-1])
    uw1.param.watch(update_sankeyWidget,'clicked')
    uw1.param.watch(update_name,'clicked')
    # uw1.on_click(update_name)
    return uwbox

def get_starter_box(sankeyWidget):
    sw1 = pn.widgets.CheckButtonGroup(
        options=sankeyWidget.param.starterParam.names,
        value=sankeyWidget.starterParam,
        orientation='vertical',
        align='center',
        button_style='outline',
        button_type='light',
        margin=0,
        styles=dict(
            padding='0px',
        ),
        tags=['sw1','text','starter'],
    )
    swbox = \
    boxify(
        sw1,
        title='starter<br>/ bn',
    )
    swbox.param.update(
        width=150,
        tags=['starter_widget_box'],
    )
    
    def update_sankeyWidget(event):
        sankeyWidget.param.update(starterParam = event.new)
    sw1.param.watch(update_sankeyWidget,'value')
    return swbox

def get_boxes(sankeyWidget):
    boxes=Box(
    timingBox=get_timing_box(sankeyWidget),
    levelBox=get_level_box(sankeyWidget),
    positionBox=get_position_box(sankeyWidget),
    weekBox=get_week_box(sankeyWidget),
    userBox=get_user_box(sankeyWidget),
    starterBox=get_starter_box(sankeyWidget),
    )
    return boxes

#### SankeyWidget

class SankeyWidget(param.Parameterized):
    # prep :: name of object w/ df_l,..., main pane
    # shov :: name of transient object with hover pane
    fs=gen_filter_set() # for defaults
    timingParam=param.Selector(
        default=fs.timing_filter,
        objects=h.filters.timing_filter.options_dict,
    )
    userParam=param.Selector(
        default=fs.user_filter,
        objects=h.filters.user_filter.options_dict,
    )
    weekParam=param.ListSelector(
        default=fs.week_filter,
        objects=h.filters.week_filter.vals.tolist(),
    )
    positionParam=param.ListSelector(
        default=fs.position_filter,
        objects=h.filters.position_filter.vals.to_list(),
    )
    starterParam=param.ListSelector(
        default=fs.starter_filter,
        objects=h.filters.starter_filter.options_dict,
    )
    levelParam=param.ListSelector(
        default=fs.level_filter,
        objects=h.filters.level_filter.options_dict,
    )    

    def __call__(self):
        self.fs=gen_filter_set(**self._filter2param())
        self.prep=SankeyPrep(gen_filter_set(**self._filter2param()))
        self.spane=self.prep.pane.clone()
        add_hover_callback(self)
        self.mainViewPane=self.mainView()
        return self
        
    def _filter2param(self):
        filterStems=['timing','user','week','position','starter','level']
        return {f'{stem}_filter':getattr(self,f'{stem}Param') for stem in filterStems}
        
    @pn.depends('timingParam','userParam','weekParam','positionParam','starterParam','levelParam')
    def view(self):
        self.fs=gen_filter_set(**self._filter2param())
        self.prep=SankeyPrep(self.fs)
        self.spane.param.update(object=self.prep.fig)
        return self.spane
        
    def hoverView(self,node='QB'):
        self.shov=SankeyHover(self.prep,node)
        return self.shov

    def mainView(self):
        self.boxes = get_boxes(self)
        self.resetButton = self._get_resetButton()
        mainViewPane = ss.load_template(self)
        ss.style_pane(mainViewPane)
        return mainViewPane

    @gc.listicate('tags')
    def _selectWidget(self,tags):
        _widgets=pn.Row(*self.boxes.values())
        return _widgets.select(lambda obj: np.all([tag in obj.tags for tag in tags]))
    def _alignWidgets(self,fs):
        self._selectWidget('user')[0].param.update(clicked=f'{fs.user_filter:d}')
        self._selectWidget('timing')[0].param.update(value=fs.timing_filter)
        self._selectWidget(['week','ww1'])[0].param.update(value=fs.week_filter[0])
        self._selectWidget(['week','ww2'])[0].param.update(value=fs.week_filter[-1])
        self.param.update(weekParam=fs.week_filter)
        self._selectWidget(['position','pwO'])[0].param.update(value=[p for p in fs.position_filter if p in h.pos_list_O])
        self._selectWidget(['position','pwD'])[0].param.update(value=[p for p in fs.position_filter if p in h.pos_list_D])
        self._selectWidget('starter')[0].param.update(value=fs.starter_filter)
        self._selectWidget('level')[0].param.update(value=fs.level_filter)
    def reset(self,event=None,
              fs=gen_filter_set(
                  position_filter=h.pos_list_O,
                  starter_filter=['starter'],
                  timing_filter='post',
              )):
        self._alignWidgets(fs)

    def _get_resetButton(self):
        resetButton = pn.widgets.Button(icon='rotate-2',icon_size='2em',align='center',margin=30)
        resetButton.on_click(self.reset)
        return resetButton
    

def add_hover_callback(sankeyWidget):
    sankeyWidget.fig_state='main'
    # @pn.io.with_lock
    async def hover_callback(event):
        # print(event)
        # event.new is None ==> exit-hover ||| 'group' not in ... ==> hovering on link
        if (event.new is None) or ('group' not in event.new['points'][0].keys()):
            if sankeyWidget.fig_state=='hover':
                await asyncio.sleep(.5) # delay un-hovering .5
                sankeyWidget.spane.param.update(object=sankeyWidget.prep.fig) # grab stored main fig
                sankeyWidget.fig_state='main' # reset fig_state tracker
            else:
                pass
        else:
            await asyncio.sleep(.75) # .75
            if sankeyWidget.spane.hover_data!=event.new:
                return # didn't wait long enough
            else:# get fig_hover...
                hover_node_id = sankeyWidget.prep.df_nf.iloc[event.new['points'][0]['pointNumber']].node_id
                sankeyWidget.spane.object = sankeyWidget.hoverView(hover_node_id).fig
                sankeyWidget.fig_state='hover'
    sankeyWidget.spane.param.watch(hover_callback,'hover_data')

def xableWidgets(sankeyWidget,en_or_dis=None):
    match en_or_dis:
        case None: return xableWidgets(sankeyWidget,'en' if sankeyWidget.resetButton.disabled else 'dis')
        case 'dis':
            for w in sankeyWidget.mainViewPane.select(pn.widgets.Widget): w.param.update(disabled=True)
        case 'en':
            for w in sankeyWidget.mainViewPane.select(pn.widgets.Widget): w.param.update(disabled=False)
        case _: raise

try:del sw
except:pass
fs_default=\
gen_filter_set(
    timing_filter='post',
    position_filter=h.pos_list_O,
    starter_filter=['starter'],
    user_filter=6, #FatBoyE
    week_filter=[5],
)
pn.extension('plotly','tabulator') # throttled=True
ss.load_stylesheet()
sw=SankeyWidget(**h.get_filter2param(fs=fs_default))()
# pn.io.notebook.load_notebook()
# sw.mainView().servable()

In [33]:
# xableWidgets(sw)
sw.mainViewPane

# testing