In [106]:
import numpy as np
import pandas as pd
from sklearn.cluster import MeanShift, estimate_bandwidth
from sklearn.datasets import load_iris
from bokeh.layouts import gridplot
from bokeh.models import ColumnDataSource
from bokeh.plotting import figure, show
from bokeh.transform import factor_cmap
from itertools import product
from bokeh.themes import Theme
from bokeh.io import curdoc
from bokeh.io import show
from bokeh.layouts import gridplot
from bokeh.models import (BasicTicker, ColumnDataSource, DataRange1d,
                          Grid, LassoSelectTool, LinearAxis, PanTool,
                          Plot, ResetTool, Scatter, WheelZoomTool)

In [107]:
iris_bunch = load_iris()
df = pd.DataFrame(iris_bunch.data, columns=iris_bunch.feature_names)
df['species'] = iris_bunch.target

In [108]:
X = df[df.columns[:-1]]  
bandwidth = estimate_bandwidth(X, quantile=0.2) 
mean_shift = MeanShift(bandwidth=bandwidth, bin_seeding=True)
df['cluster'] = mean_shift.fit_predict(X)

df['cluster'] = df['cluster'].astype(str)

In [109]:
dataset_source = ColumnDataSource(df)
numeric_columns = df.columns[:-2]  # Excludes the last 'species' and 'cluster' column

In [110]:
clusters = df['cluster'].unique().astype(str)
colors = ['#1f77b4', '#ff7f0e', '#2ca02c']  


curdoc().theme = 'light_minimal'

In [111]:
x_ranges = {col: None for col in numeric_columns}
y_ranges = {col: None for col in numeric_columns}

scatter_plots = []
plot_size = 250
y_max = len(numeric_columns)-1
for i, y_col in enumerate(numeric_columns):
    for j, x_col in enumerate(numeric_columns):
        # Create figure and link axis ranges
        p = figure(width=plot_size, height=plot_size, x_axis_label=x_col, y_axis_label=y_col,
                   tools="pan,wheel_zoom,box_select,lasso_select,reset")

        # Link x and y ranges
        if x_ranges[x_col] is None:
            x_ranges[x_col] = p.x_range
        else:
            p.x_range = x_ranges[x_col]

        if y_ranges[y_col] is None:
            y_ranges[y_col] = p.y_range
        else:
            p.y_range = y_ranges[y_col]

        # Add circles, color by cluster
        p.circle(source=dataset_source, x=x_col, y=y_col, fill_alpha=0.6, size=6,
                 fill_color=factor_cmap('cluster', palette=colors, factors=clusters),
                 line_color=factor_cmap('cluster', palette=colors, factors=clusters),
                 selection_color="red", 
                 nonselection_fill_alpha=0.1,  
                 nonselection_line_alpha=0.1)  
        
        if j > 0:
            p.yaxis.axis_label = ""
            p.yaxis.visible = False
        
        if i < y_max:
            p.xaxis.axis_label = ""
            p.xaxis.visible = False

        scatter_plots.append(p)

grid = gridplot(scatter_plots, ncols = len(numeric_columns))
show(grid)

