In [122]:
import pandas as pd
from sklearn.cluster import MeanShift, estimate_bandwidth
from sklearn.datasets import load_iris
from bokeh.plotting import figure
from bokeh.transform import factor_cmap
from bokeh.io import curdoc
from bokeh.io import show
from bokeh.layouts import gridplot, column, row
from bokeh.models import (Slider, ColumnDataSource, Select,)

## Prototype in this file, will be modularized in other files/classes

## Data

In [123]:
iris_bunch = load_iris()
df = pd.DataFrame(iris_bunch.data, columns=iris_bunch.feature_names)
df['species'] = iris_bunch.target

In [132]:
df.head()

Unnamed: 0,sepal length (cm),sepal width (cm),petal length (cm),petal width (cm),species,cluster
0,5.1,3.5,1.4,0.2,0,1
1,4.9,3.0,1.4,0.2,0,1
2,4.7,3.2,1.3,0.2,0,1
3,4.6,3.1,1.5,0.2,0,1
4,5.0,3.6,1.4,0.2,0,1


In [124]:
X = df[df.columns[:-1]]  
bandwidth = estimate_bandwidth(X, quantile=0.2) 
mean_shift = MeanShift(bandwidth=bandwidth, bin_seeding=True)
df['cluster'] = mean_shift.fit_predict(X)

df['cluster'] = df['cluster'].astype(str)

In [125]:
dataset_source = ColumnDataSource(df)
numeric_columns = df.columns[:-2]  # Excludes the last 'species' and 'cluster' column
clusters = df['cluster'].unique().astype(str)

## Frontend

In [126]:
colors = ['#1f77b4', '#ff7f0e', '#2ca02c']
curdoc().theme = 'light_minimal'

In [127]:
def create_splom(source):

    x_ranges = {col: None for col in numeric_columns}
    y_ranges = {col: None for col in numeric_columns}
    
    scatter_plots = []
    plot_size = 250
    y_max = len(numeric_columns)-1
    for i, y_col in enumerate(numeric_columns):
        for j, x_col in enumerate(numeric_columns):
            # Create figure and link axis ranges
            p = figure(width=plot_size, height=plot_size, x_axis_label=x_col, y_axis_label=y_col,
                       tools="pan,wheel_zoom,box_select,lasso_select,reset")
    
            # Link x and y ranges
            if x_ranges[x_col] is None:
                x_ranges[x_col] = p.x_range
            else:
                p.x_range = x_ranges[x_col]
    
            if y_ranges[y_col] is None:
                y_ranges[y_col] = p.y_range
            else:
                p.y_range = y_ranges[y_col]
    
            # Add circles, color by cluster
            p.scatter(source=source, x=x_col, y=y_col, fill_alpha=0.6, size=6,
                     fill_color=factor_cmap('cluster', palette=colors, factors=clusters),
                     line_color=factor_cmap('cluster', palette=colors, factors=clusters),
                     selection_color="red", 
                     nonselection_fill_alpha=0.1,  
                     nonselection_line_alpha=0.1)  
            
            if j > 0:
                p.yaxis.axis_label = ""
                p.yaxis.visible = False
            
            if i < y_max:
                p.xaxis.axis_label = ""
                p.xaxis.visible = False
    
            scatter_plots.append(p)
    
    grid = gridplot(scatter_plots, ncols = len(numeric_columns))
    return grid

In [128]:
scatterplot_matrix = create_splom(dataset_source)

#Widgets for bandwith selection
select = Select(title="Bandwidth Option", value="Estimate", options=["Estimate", "Custom"])
slider = Slider(start=0.1, end=2.0, value=1.0, step=0.1, title="Custom Bandwidth", visible=False)

In [129]:
def update_clusters(attr, old, new):
    bandwidth_option = select.value
    if bandwidth_option == "Estimate":
        bandwidth = estimate_bandwidth(df[numeric_columns], quantile=0.2)
    else:
        bandwidth = slider.value
    
    # Perform clustering
    mean_shift = MeanShift(bandwidth=bandwidth, bin_seeding=True)
    df['cluster'] = mean_shift.fit_predict(df[numeric_columns]).astype(str)  # Convert to string for factor_cmap

    # Update ColumnDataSource
    dataset_source.data = dict(ColumnDataSource(df).data)
    
    # Re-create scatterplot matrix with updated clusters
    layout.children[1] = create_splom(dataset_source)

In [130]:
def select_bandwidth(attr, old, new):
    if select.value == "Custom":
        slider.visible = True
    else:
        slider.visible = False
    update_clusters(attr, old, new)

In [131]:
select.on_change('value', select_bandwidth)
slider.on_change('value', update_clusters)

layout = column(select, slider, scatterplot_matrix)

curdoc().add_root(layout)