# COFFI: 

A visual analytics system for the exploration of black-box-classifiers concentrating on decision boundaries and counterfactuals.

Execute:
```
panel serve --show embedding_tool2.1.ipynb --port 8080 --dev
```

## Set parameters

In [None]:
# print events
from __future__ import print_function

css = '''
.bk.header-box {
  background: #f0f0f0;
  //border-radius: 5px;
}
  
.bk.header {
  top: 0px !important;
}

.bk.featureList {
    //background: #ffdd00;
    
}

.history-box {
  overflow: auto;
}

.bk.panel-test-box {
  border-bottom: 1px #f0f0f0 solid;
}

'''


## Load necessary libraries

In [None]:
# GUI library
import numpy as np
import pandas as pd

from sklearn.decomposition import PCA, TruncatedSVD
from sklearn.inspection import permutation_importance
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.ensemble import RandomForestClassifier
from sklearn.neural_network import MLPClassifier
from sklearn.neighbors import BallTree
from sklearn.svm import LinearSVC

import seaborn as sns
import time
from matplotlib.colors import LinearSegmentedColormap, ListedColormap

from bokeh.plotting import figure
from bokeh.palettes import Category10, PiYG, linear_palette
from bokeh.models import ColumnDataSource, TapTool, CDSView, IndexFilter, BooleanFilter, CustomJS
from bokeh.transform import linear_cmap
from bokeh.events import Event, Tap, LODEnd
from bokeh.models.widgets.tables import NumberFormatter, DataTable
from bokeh.models.widgets import Div
from bokeh.embed import file_html, json_item
from bokeh.io.export import get_screenshot_as_png

import param
import panel as pn
from panel.io import save

import json

pn.extension(raw_css=[css])


# local files
from embedding_util2_1 import load_dataset, partial_dependence, plot_horizons, \
            embedding_view, update_embedding, update_horizons, update_table,   \
            update_df, rgb_to_hex, Shifter, SVD, update_horizon_lines,         \
            non_linear_view, update_non_linear_view, colorbar_view

In [None]:
class Params():
    total_width = 1400
    total_height = 500
    bot_height = 300
    header_height = 32
    emb_width = 500
    fea_width = 160
    tbl_width = 100
    hor_width = total_width-(emb_width+fea_width+tbl_width)
    red_res = 500
    pdp_res = 50
    dot_size = 6
    
    dataset_name = 'diabetes'
    neighbors = 100
    classifier = 'RandomForest'
    
    palette = []
    cm = None
    num_colors = 8
    sat_palette = []
    sat_cm = None
    
    def __init__(self):
        sat_color_list = ['#d95f02','#1b9e77','#7570b3','#e6ab02','#66a61e','#e7298a','#a6761d','#666666'] #1
#        sat_color_list = ['#ff6023','#36c296','#5a7dcc','#e7298a','#66a61e','#e6ab02','#a6761d','#666666'] #1.1
#         sat_color_list = ['#5da5da','#faa43a','#60bd68','#f17cb0','#b2912f','#b276b2','#decf3f','#f15854'] #2
        color_list = ['#fc8d62','#66c2a5','#8da0cb','#ffd92f','#a6d854','#e78ac3','#e5c494','#b3b3b3'] #1
#         color_list = [rgb_to_hex(sns.light_palette(c, n_colors=6)[3]) for c in sat_color_list]
        num_colors = len(color_list)
        for i in range(num_colors):
            self.palette += [rgb_to_hex(c) for c in sns.light_palette(color_list[i], n_colors=6)][1:]
        self.sat_palette = sat_color_list
        self.cm = ListedColormap(colors=self.palette)
        self.sat_cm = ListedColormap(colors=self.sat_palette)
            
params = Params()

## Data Management

In [None]:
class Dataset():
    name = ''
    data = None 
    datasource = ColumnDataSource()
    view_filter = IndexFilter()
    wrong_filter = IndexFilter()
    view = CDSView()
    labels = None
    features = []
    selected_features = []
    classes = None
    categories = None
    trees = []
    
    bounds = []
    point = []
    
    def update(self, name):
        self.name = name
        self.data,self.labels,self.features,self.classes,_,self.categories = load_dataset(self.name)
        self.selected_features = self.features
        
        self.data['color']      = ['#444444']*len(self.data.index)
        self.data['sat_color']  = ['#444444']*len(self.data.index)
        self.data['line_color'] = ['#444444']*len(self.data.index)
        self.data['size'] = [params.dot_size]*len(self.data.index)
        self.data['x']  = [0]*len(self.data.index)
        self.data['y']  = [0]*len(self.data.index)
        self.data['x1'] = [0]*len(self.data.index)
        self.data['y1'] = [0]*len(self.data.index)
        self.datasource.selected.indices = [0] 
        self.datasource.selected.indices = []
        self.datasource.data = self.data
        self.view_filter.indices  = list(self.data.index)
        self.wrong_filter.indices = list(self.data.index)
        self.view = CDSView(source=self.datasource, filters=[self.view_filter])
        
#         self.bounds = [dataset.data[dataset.features].quantile(0.05).to_numpy(), dataset.data[dataset.features].quantile(0.95).to_numpy()]
        self.bounds = [dataset.data[dataset.features].min().to_numpy(), dataset.data[dataset.features].max().to_numpy()]
        self.point = dataset.data[dataset.features].mean()
        print("--- data loaded", self.data.shape)
    
dataset = Dataset()
dataset.update(params.dataset_name)

class Point():
    def __init__(self):
        self.point = ColumnDataSource({})
        self.line = ColumnDataSource({})
        self.text = ColumnDataSource({})
    
    def reset(self):
        self.point.data = {}
        self.line.data = {}
        self.text.data = {}
        
cf, pt = Point(), Point()

class Predictor():
    pip = None
    
    def update_pip(self, classifier_type):
        if classifier_type == "RandomForest":
            self.pip = Pipeline([('scaler', StandardScaler()), 
                                 ('classificator', RandomForestClassifier(
                                      random_state=1, min_samples_split = 4, min_samples_leaf = 3))])
        elif classifier_type == "NeuralNetwork":
            self.pip = Pipeline([('scaler', StandardScaler()), 
                                 ('classificator', MLPClassifier(hidden_layer_sizes=(100,100,100), 
                                                                 max_iter=5000, random_state=1))])
        else: print("unkown classifier type:", classifier_type)
        
    def update_data(self, dataset, params):
        # train predictor on dataset
        #from sklearn.model_selection import train_test_split
        #X_train, X_test, y_train, y_test = train_test_split(dataset.data[dataset.features], dataset.labels, test_size=0.2, random_state=42)
        #self.pip.fit(X_train, y_train)
        #score = self.pip['classificator'].score(self.pip['scaler'].transform(X_test), y_test)
        #print('score: ', score)
        self.pip.fit(dataset.data[dataset.features], dataset.labels)
        # save predicitons in dataset
        dataset.data['prob'] = self.pip.predict_proba(dataset.data[dataset.features]).tolist()
        dataset.data['maxprob'] = np.max(np.array(dataset.data['prob'].tolist()), axis=1)
        dataset.datasource.data['prob'] = dataset.data['prob']
        dataset.datasource.data['maxprob'] = dataset.data['maxprob']
        # compute colors of dataset
        dataset.data['most_prob_class'] = np.argmax(np.array(dataset.data['prob'].tolist()), axis=1)
        colorvalues = (np.clip(dataset.data['maxprob'].to_numpy(),0.51,0.99) - 0.5) * 2 + dataset.data['most_prob_class'].to_numpy()
        dataset.data['color'] = [rgb_to_hex(params.cm(i/params.num_colors)) for i in colorvalues]
        dataset.data['sat_color'] = [params.sat_palette[i] for i in dataset.data['most_prob_class'].to_numpy()]
        dataset.data['line_color'] = dataset.data['sat_color'].copy()
        dataset.datasource.data['color'] = dataset.data['color']
        dataset.datasource.data['sat_color'] = dataset.data['sat_color']
        dataset.datasource.data['line_color'] = dataset.data['sat_color'].copy()
        
        # compute wrongly predicted indices
        dataset.wrong_filter.indices = np.where(np.not_equal(dataset.data['most_prob_class'].to_numpy(), dataset.labels))[0].tolist()
#         dataset.datasource.data['target_color'] = np.where(
#             dataset.data['most_prob_class'].to_numpy() != dataset.labels,
#             [rgb_to_hex(params.sat_cm(i/params.num_colors)) for i in (dataset.labels+0.99)], 
#             np.full(len(dataset.data),'#333333'))
        dataset.datasource.data['target_color'] = [params.sat_palette[i] for i in dataset.labels]
        
        # create nearest neighbor search structure per class
        dataset.trees = []
        for i in range(len(dataset.classes)):
            index_map = np.where(dataset.data['most_prob_class'].to_numpy() == i)[0]
            class_data = dataset.data.loc[index_map,dataset.features]
            scaled_class_data = self.pip['scaler'].transform(class_data)
            dataset.trees.append( (index_map, BallTree(scaled_class_data)) )

predictor = Predictor()
predictor.update_pip(params.classifier)
predictor.update_data(dataset, params)

# class LL():
#     svm = None
#     svd = None
#     components_ = None
#     V = np.zeros(0)
    
#     def __init__(self):
#         self.svm = LinearSVC(random_state=0)
#         self.svd = SVD()
    
#     def fit(self, X, y, sample_weight=None):
#         self.svm.fit(X, y)
#         self.svd.fit(X)
#         self.components_ = np.append(self.svm.coef_/np.linalg.norm(self.svm.coef_), [[0.01]*len(self.svm.coef_[0])], axis=0)
# #                                      [self.svd.components_[0]], axis=0)
        
#     def transform(self, X):
#         X = np.dot(X, self.components_.T)
#         return X
    
#     def inverse_transform(self, X):
#         X = np.dot(X, self.components_)
#         return X
        
    
class Embedding():
    emb = Pipeline([('shifter', Shifter()), ('scaler', StandardScaler()), ('pca', SVD())])
    zoom = "Global"
    focus = "Mean"
    nn_ind = []
    
    def compute(self, dataset):
        nn = len(dataset.datasource.selected.indices)
        if nn > 1:
            self.nn_ind = dataset.datasource.selected.indices
        else:
            self.nn_ind = dataset.data.index.tolist()
            nn = len(self.nn_ind)
               
        train_set = dataset.data.loc[self.nn_ind][dataset.selected_features]
        #tic = time.perf_counter()
        self.emb.fit(train_set)
        pp = self.emb.transform(dataset.data[dataset.selected_features])
        #toc = time.perf_counter()
        
        #print("Compute plane:",toc-tic,"s")
        
        dataset.data['x'] = pp[:,0]
        dataset.data['y'] = pp[:,1]
        dataset.datasource.data['x'] = pp[:,0]
        dataset.datasource.data['y'] = pp[:,1]
        
        shift_bounds = [[dataset.bounds[j][i] for i,f in enumerate(dataset.features) if (f in dataset.selected_features)] for j in [0,1]]
        self.emb['shifter'].set_bounds(shift_bounds) 
        if dataset.point.name != None:
            # shift plane to focus point
            point = dataset.point[dataset.selected_features]
            shift = point.values.astype('float64') - self.emb.inverse_transform(self.emb.transform([point]))[0]
            self.emb['shifter'].set_by(shift)
        
        self.focus = str(dataset.point.name) if dataset.point.name != None else "Mean"
        self.zoom = str(nn)+" samples"
            

embedding = Embedding()
embedding.compute(dataset)

## Input widgets

In [None]:
# GUI parameters & widgets
dataset_selector = pn.widgets.Select(name='', options=["iris","income","diabetes","breast","heart-failure","shuttle","robot24","robot4"], value=params.dataset_name, width=100)
predictor_selector = pn.widgets.Select(name='', value=params.classifier, options=["RandomForest","NeuralNetwork"], width=120)
neighbors = pn.widgets.TextInput(name='', placeholder=str(params.neighbors), value=str(params.neighbors), width=50)
zoom = pn.widgets.Button(name='Embed: '+embedding.zoom, width=120)
focus = pn.widgets.Button(name='Shift: '+embedding.focus, width=90)
select_features = pn.widgets.Button(name='Filter Features', width=90)
reset = pn.widgets.Button(name='Reset', width=60)
average = pn.widgets.Toggle(name='Average', width=60, value=False)
nl_select = pn.widgets.Select(name='', value='UMAP', options=['UMAP', 't-SNE'], width=180)
nl_neighbors = pn.widgets.IntSlider(name='Neighborhood size', start=5, end=200, step=5, value=20, width=180)
nl_feats = pn.widgets.Checkbox(name='use only selected features', value=False, width=180, height=30)
feature_selector = pn.widgets.CheckBoxGroup(name='', value=dataset.features, options=dataset.features, 
                                            inline=False, width=120, css_classes=["featureList"], 
                                            max_height=params.total_height-10)
css_code = pn.pane.HTML(f'''
            <style>
            .bk label {{
                line-height: {params.total_height / (13*len(feature_selector.options))} !important
            }}
            </style>
            ''')
emb_header = pn.pane.Markdown("", style={
        'width': "700px", 'height': str(params.header_height)+"px", "text-align": "left"})
#for i, c in enumerate(dataset.classes):
#    emb_header.object += "![color](https://placehold.it/15/" + params.palette[i*5+4][1:] \
#        + "/000000?text=+) " + c + "  " 
    
# projection view
detail_p = embedding_view(params, dataset, embedding, predictor.pip, average.value, cf.point)

# non-linear view
nonlinear = non_linear_view(params, dataset, predictor.pip)

# importance table
table_source = ColumnDataSource()
table = DataTable(source=table_source, editable=False, index_position=None, 
                  height=params.total_height+17, width=params.tbl_width)
update_table(table, table_source, dataset, predictor.pip, embedding, params)

# PDP horizons
horizons = plot_horizons(predictor.pip.predict_proba, dataset, params, pt, cf)

def update_line_source(src, point, opposite):
    src.line.data = {**{key: [val,val] for (key,val) in zip(dataset.features, point)}, "y":[0.0,0.25]}
    src.text.data = {**{key: [val] for (key,val) in zip(dataset.features, point)}, "y":[0.25],
                     **{key+"_label": ['{:3.1f}'.format(val)] for (key,val) in zip(dataset.features, point)},
                     **{key+"_align": ['right'] if p < o else ['left'] for key, p, o in zip(dataset.features, point, opposite)},
                     **{key+"_x_offset": [-2] if p < o else [2] for key, p, o in zip(dataset.features, point, opposite)}}

update_line_source(pt, dataset.point[dataset.features], dataset.point[dataset.features])
cf.line.data = {}
# update_horizon_lines(horizons, dataset.features, dataset.point, name="point")

# history
history = pn.Row(width=params.emb_width, max_width=params.emb_width+30, 
                 height=params.bot_height, max_height=params.bot_height, 
                 css_classes=['history-box'])

# data table
df_widget = DataTable(source=dataset.datasource, autosize_mode='fit_columns', 
                      height=params.bot_height+20, width=params.total_width-params.emb_width+10,
                      view=dataset.view)
update_df(df_widget, dataset, params)

# color bar
colorbar = colorbar_view(params, dataset)

## Callbacks

In [None]:
#-----------------------------------------------------------------------------------------------
# when choosing new dataset from dropdown

def onDatasetChanged(event):
    if event.type == 'changed':
        dataset.update(event.new)
        
        # reset lines & selection
        cf.reset()

        # re-train predictor
        e = Event()
        e.type = "changed"
        onPredictorChanged(e)
        
        # update non-linear embedding
        update_non_linear_view(nonlinear.object, dataset, predictor.pip, nl_select.value, nl_neighbors.value)
        
        # update feature selector
        feature_selector.options = dataset.features
        feature_selector.value = dataset.features
        label_height = params.total_height / (13*len(feature_selector.options))
        css_code.object = f'''
            <style>
            .bk label {{
                line-height: {label_height} !important
            }}
            </style>
            '''
        
        # update table
        update_table(table, table_source, dataset, predictor.pip, embedding, params)
        
        # update header
        #emb_header.object = ""
        #for i, c in enumerate(dataset.classes):
        #    emb_header.object += "![color](https://placehold.it/15/" + params.palette[i*5+4][1:] \
        #                + "/000000?text=+) " + c + "  " 
            
        # update colorbar
        colorbar.object = colorbar_view(params, dataset).object    
        
        # deselect and select all points
        dataset.datasource.selected.indices = [i for i in range(dataset.data.shape[0])]
        dataset.datasource.selected.indices = []
        
        history.clear()
        
dataset_selector.param.watch(onDatasetChanged, 'value')

#-----------------------------------------------------------------------------------------------
# when choosing new predictor from dropdown

def onPredictorChanged(event):
    if event.type == 'changed':
        # update predictor
        predictor.update_pip(predictor_selector.value)
        predictor.update_data(dataset, params)
        
        # update selection table
        update_df(df_widget, dataset, params)
        
        # project data
        embedding.compute(dataset)
        
        # update embedding view
        update_embedding(detail_p.object, params, dataset, embedding, predictor.pip, average.value)
        
        # update table
        update_table(table, table_source, dataset, predictor.pip, embedding, params)
        
        # update horizons
        horizons.objects = plot_horizons(predictor.pip.predict_proba, dataset, params, pt, cf).objects
        cf_temp = [cf.point.data[f][0] for f in dataset.features] if cf.point.data != {} else dataset.point[dataset.features]
        update_line_source(pt, dataset.point[dataset.features], cf_temp)
        
        update_horizons(horizons, dataset, params)
        
predictor_selector.param.watch(onPredictorChanged, 'value')

#-----------------------------------------------------------------------------------------------
# tapping in the embedding view

def onTap(event):
    # compute tapped point through inverse
    cf_pt = embedding.emb.inverse_transform([[event.x, event.y]])[0]
    for i, f in enumerate(dataset.features):
        if f not in dataset.selected_features:
            cf_pt = np.insert(cf_pt, i, dataset.point[f])
    # update datasource accordingly
    update_line_source(cf, cf_pt, dataset.point[dataset.features])
    cf.point.data   = {key: [val] for (key,val) in zip(dataset.features+["x","y"], np.append(cf_pt,[event.x,event.y]))}
    
    # tapped on data point
    if len(dataset.datasource.selected.indices) != 0:
        p = dataset.datasource.selected.indices[0]
        dataset.point = dataset.data.loc[p].T
        # highlight tapped point in size & border (take first one in selection)
        new_size = [params.dot_size]*len(dataset.data)
        new_size[p] = 13
        dataset.datasource.data['size'] = new_size
        new_line = dataset.data['sat_color']
        new_line[p] = 'black'
        dataset.datasource.data['line_color'] = new_line
    # update horizon lines
    cf_temp = [cf.point.data[f][0] for f in dataset.features] if cf.point.data != {} else dataset.point[dataset.features]
    update_line_source(pt, dataset.point[dataset.features], cf_temp)
    
detail_p.object.on_event(Tap, onTap)
nonlinear.object.on_event(Tap, onTap)

#-----------------------------------------------------------------------------------------------
# Selecting one point in the embedding

def onSelectionChange(attr, old, new):
    print('selection changed', 'old', old, 'new', new, 'indices', dataset.datasource.selected.indices)
    if len(new) == 0 and len(old) != 1: 
        dataset.datasource.selected.indices = old
    
dataset.datasource.selected.on_change('indices', onSelectionChange)


#-----------------------------------------------------------------------------------------------
# change the feature selection

def onFeatureSelectionChanged(event):
    if len(event.new) < 2 or len(event.old) < 1: return
    add_history()
    
    dataset.selected_features = [x for _,x in sorted([(dataset.features.index(sf),sf) for sf in event.new if sf in dataset.features], key=lambda tup: tup[0])]
    
    # recompute embedding
    embedding.compute(dataset)
    cf.reset()
        
    # update embedding view
    update_embedding(detail_p.object, params, dataset, embedding, predictor.pip, average.value)
    
    # update pdp
    update_horizons(horizons, dataset, params)
    
    # update table
    update_table(table, table_source, dataset, predictor.pip, embedding, params, full=False)
    
    # update non-linear if necessary
    if nl_feats.value : change_non_linear()
    
feature_selector.param.watch(onFeatureSelectionChanged, 'value')


#-----------------------------------------------------------------------------------------------
# change the feature selection

def onSelectImportantFeatures(event=None):
    important_indices = np.where(table_source.data["Permutation Importance"] > 100.0/len(dataset.features))[0]
    feature_selector.value = np.array(dataset.features)[important_indices].tolist()

select_features.on_click(onSelectImportantFeatures)


#------------------------------------------------------------------------------------------------
# add current embedding & parameters to history

def add_history(flip_avg=False):
    return
    im = get_screenshot_as_png(detail_p.object)
    history.insert(0, pn.Column(pn.pane.PNG(im, width=int(params.total_height/2.5)),
                                pn.pane.Markdown("Embed: "+embedding.zoom \
                                                 +"<br />Shift: "+embedding.focus))) \
#                                                  +"<br />Plain: "+("Averaged" if average.value != flip_avg else "Exact"))))
    
#-----------------------------------------------------------------------------------------------
# Shift base of embedding to last selected point
def onShift(event=None, add_hist=True):
    if add_hist: add_history()
    
    # project data
    embedding.compute(dataset)
    zoom.name = "Embed: "+embedding.zoom
    focus.name = "Shift: "+embedding.focus
    cf.reset()
    
    # update embedding view
    update_embedding(detail_p.object, params, dataset, embedding, predictor.pip, average.value)
    
    # update table
    update_table(table, table_source, dataset, predictor.pip, embedding, params, full=False)

    # update horizons
    update_line_source(pt, dataset.point[dataset.features], dataset.point[dataset.features])
    horizons.objects = plot_horizons(predictor.pip.predict_proba, dataset, params, pt, cf).objects
    
focus.on_click(onShift)


#-----------------------------------------------------------------------------------------------
# Rotate embedding according to PCA of nearest neighbors (same class + other class) of last selected point

def zoom_view(event):
    add_history()
    
    # set embedding anchor
    selection = dataset.datasource.selected.indices
    # zoom without selection -> reset zoom / zoom all
    if len(selection) is 0: 
        dataset.point = dataset.data[dataset.features].mean()
    # zoom with point selected -> zoom to point and its neighbors
    elif len(selection) < 3:
        on_select_nn()
        # focus selected point (take first one in selection)
        p = dataset.datasource.selected.indices[0]
        dataset.point = dataset.data.loc[p].T
        # highlight tapped point in size & border 
        new_size = [params.dot_size]*len(dataset.data)
        new_size[p] = 13
        dataset.datasource.data['size'] = new_size
        new_line = dataset.data['sat_color']
        new_line[p] = 'black'
        dataset.datasource.data['line_color'] = new_line
    # zoom with group of points selected without point focused -> zoom average of group
    elif dataset.point.name is None:
        dataset.point = dataset.data.loc[selection][dataset.features].mean()
    # else zoom to focused point with regard to selected group
        
    # update horizons
    update_line_source(pt, dataset.point[dataset.features], dataset.point[dataset.features])
    horizons.objects = plot_horizons(predictor.pip.predict_proba, dataset, params, pt, cf).objects
    
    # project data
    embedding.compute(dataset)
    zoom.name = "Embed: "+embedding.zoom
    focus.name = "Shift: "+embedding.focus
    cf.reset()
    
    # update embedding view
    update_embedding(detail_p.object, params, dataset, embedding, predictor.pip, average.value)
    
    # update table
    update_table(table, table_source, dataset, predictor.pip, embedding, params)
    
zoom.on_click(zoom_view)

#-----------------------------------------------------------------------------------------------
# Switch between average and exact embedding
def switch_average(event):
    add_history(flip_avg=True)
    
    # update embedding view
    update_embedding(detail_p.object, params, dataset, embedding.emb, predictor.pip, event.new)

    # update horizons
#     horizons.objects = plot_horizons(predictor.pip.predict_proba, dataset, params).objects
#     update_horizon_lines(horizons, dataset.features, dataset.point, name="point")
    
average.param.watch(switch_average, 'value')

#-----------------------------------------------------------------------------------------------
# Reset embedding to default (Global + Mean)

def reset_view(event):
    add_history()
    if len(dataset.datasource.selected.indices) > 1:
        dataset.datasource.selected.indices = [dataset.datasource.selected.indices[0]] 
    dataset.datasource.selected.indices = []
    dataset.point = dataset.data[dataset.features].mean()
    dataset.datasource.data['size'] = [params.dot_size]*len(dataset.data)
    onShift(add_hist=False)
    
reset.on_click(reset_view)

def change_non_linear(event=None):
    update_non_linear_view(nonlinear.object, dataset, predictor.pip, nl_select.value, nl_neighbors.value, nl_feats.value)
    
nl_select.param.watch(change_non_linear, 'value')
nl_neighbors.param.watch(change_non_linear, 'value')
nl_feats.param.watch(change_non_linear, 'value')

def resample_emb(event):
    plot_range = [detail_p.object.x_range.start, detail_p.object.x_range.end,
                  detail_p.object.y_range.start, detail_p.object.y_range.end]
    update_embedding(detail_p.object, params, dataset, embedding, predictor.pip, average.value, plot_range)
    
detail_p.object.on_event(LODEnd, resample_emb)

def on_select_nn(event=None):
    # get current focused point
    p = dataset.point
    num_nn = int(neighbors.value)
    # find nearest neighbors distributed: 1/2 own class, 1/2 other classes
    own_class = np.argmax(p['prob']) if p.name != None else np.argmax(predictor.pip.predict_proba([p[dataset.features]])[0])
    own_nn = np.array([])
    other_dist = np.full((len(dataset.classes), num_nn//2), np.inf)
    other_nn   = np.full((len(dataset.classes), num_nn//2), -1)
    for i in range(len(dataset.classes)):
        dist, ind = dataset.trees[i][1].query(
            predictor.pip['scaler'].transform([p[dataset.features]]), k=min(num_nn//2,len(dataset.trees[i][0])))
        if i == own_class:
            own_nn = np.append(own_nn, dataset.trees[i][0][ind])
        else:
            other_dist[i,:dist.shape[1]] = dist[0]
            other_nn[i,:dist.shape[1]] = dataset.trees[i][0][ind[0]]
    other_nn = other_nn[np.unravel_index(np.argsort(other_dist, axis=None), other_dist.shape)]
    other_nn = np.delete(other_nn, np.where(other_nn == -1))
    other_nn = other_nn[:min(num_nn//2, len(other_nn))]
    nn = np.append(own_nn, other_nn).astype(int).tolist()
    dataset.datasource.selected.indices = nn


filter_callback = CustomJS(args=dict(src=dataset.datasource, filter=dataset.view_filter, nn=embedding.nn_ind), code='''
    if (src.selected.indices.length > 0) {
        filter.indices = src.selected.indices
    } else {
        filter.indices = nn
    }
    src.change.emit()
''')

dataset.datasource.selected.js_on_change('indices', filter_callback)

click table
https://stackoverflow.com/questions/55403853/how-to-get-a-list-of-bokeh-widget-events-and-attributes-which-can-be-used-to-tr

## Set up the GUI

In [None]:
coffi = pn.Column(pn.Row(pn.pane.Markdown('''## CoFFi''', style={'height':'31px'}, css_classes=['header']), 
                         pn.Spacer(width=50),
                         pn.pane.Markdown("#### Dataset", style={'height':'31px'}), 
                         dataset_selector,
                         pn.pane.Markdown("#### Classifier", style={'height':'31px'}),
                         predictor_selector,
                         pn.pane.Markdown("#### Neighbors", style={'height':'31px'}),
                         neighbors,
                         emb_header, 
                         reset,
                         css_classes=['header-box']), 
                  pn.Row(pn.Column(pn.Row(pn.pane.Markdown("#### Embedding View", style={'height':str(params.header_height)+'px'}),
                                          pn.Spacer(width=2), focus, zoom, pn.Spacer(width=3), select_features), detail_p),
#                          pn.Column(pn.pane.Markdown("#### Legend", style={'height':str(params.header_height)+'px'}), colorbar),
                         pn.Column(pn.Spacer(height=7), table),
                         pn.Column(pn.pane.Markdown("#### Features", style={'width': str(params.fea_width)+"px", 'height': str(params.header_height)+"px"}),
                                    feature_selector),
                         pn.Column(pn.pane.Markdown("#### Partial Dependence", style={'width': str(params.hor_width)+"px", 'height': str(params.header_height)+"px"}),
                                    horizons),
                         css_classes=['panel-test-box']),
                  pn.Row(pn.Tabs(('Topology View', pn.Row(pn.Column(nl_select, nl_neighbors, nl_feats, pn.Spacer(height=5), colorbar), nonlinear)), 
                                 ('Inspection History', history)), df_widget, css_code))

In [None]:
coffi.servable("COFFI")