In [None]:
def plot_umap_grid(df, n_neighbors_list=[3, 5, 10], min_dist_list=[0.05, 0.15, 0.5]):
    data = StandardScaler().fit_transform(df.select_dtypes(include=[np.number]))
    records = []
    base_palette = px.colors.qualitative.Plotly
    param_grid = list(product(min_dist_list, n_neighbors_list))

    for min_dist, n_neighbors in tqdm(param_grid, desc="Generating UMAPs"):
        reducer = umap.UMAP(n_neighbors=n_neighbors, min_dist=min_dist, random_state=42, n_jobs=1)
        embedding = reducer.fit_transform(data)

        clusterer = hdbscan.HDBSCAN(min_cluster_size=5)
        labels = clusterer.fit_predict(embedding)

        for x, y, label in zip(embedding[:, 0], embedding[:, 1], labels):
            color = 'rgba(150,150,150,0.3)' if label == -1 else base_palette[label % len(base_palette)]
            records.append({
                'UMAP1': x,
                'UMAP2': y,
                'cluster': str(label),
                'color': color,
                'n_neighbors': n_neighbors,
                'min_dist': min_dist,
            })

    plot_df = pd.DataFrame.from_records(records)

    # Prepare subplots
    rows = len(min_dist_list)
    cols = len(n_neighbors_list)
    fig = make_subplots(rows=rows, cols=cols, 
                        horizontal_spacing=0.02, vertical_spacing=0.02,
                        subplot_titles=[f"{n} / {m}" for m in min_dist_list for n in n_neighbors_list])

    for i, (min_dist, n_neighbors) in enumerate(param_grid):
        sub_df = plot_df[
            (plot_df['min_dist'] == min_dist) & 
            (plot_df['n_neighbors'] == n_neighbors)
        ]
        row = min_dist_list.index(min_dist) + 1
        col = n_neighbors_list.index(n_neighbors) + 1

        for cluster in sub_df['cluster'].unique():
            cluster_df = sub_df[sub_df['cluster'] == cluster]
            fig.add_trace(
                go.Scattergl(
                    x=cluster_df['UMAP1'], 
                    y=cluster_df['UMAP2'],
                    mode='markers',
                    marker=dict(color=cluster_df['color'].iloc[0], size=3),
                    showlegend=False,
                    hoverinfo='skip'
                ),
                row=row, col=col
            )

        # Clean up axes
        fig.update_xaxes(showticklabels=False, ticks="", row=row, col=col)
        fig.update_yaxes(showticklabels=False, ticks="", row=row, col=col)

    # Layout and annotations
    fig.update_layout(
        template='simple_white',
        plot_bgcolor='white',
        margin=dict(l=100, r=100, t=80, b=80),
    )

    # Remove subplot titles (replaced with simplified axis annotations)
    fig.update_layout(annotations=[
        a.update(text='') or a for a in fig['layout']['annotations']
    ], height=500)



    # Add row (min_dist) labels on the left
    for i, min_dist in enumerate(min_dist_list):
        fig.add_annotation(
            text=f"min-dist: {min_dist}",
            x=-0.05,  # just outside the plot grid
            y=1 - (i + 0.5) / rows,
            xref="paper", yref="paper",
            showarrow=False,
            font=dict(size=12),
            xanchor="right",
            yanchor="middle",
            align="right",
            textangle=90,
        )

    # Add column (n_neighbors) labels at the bottom
    for j, n_neighbors in enumerate(n_neighbors_list):
        fig.add_annotation(
            text=f"n-neighbors: {n_neighbors}",
            x=(j + 0.5) / cols,
            y=-0.08,
            xref="paper", yref="paper",
            showarrow=False,
            font=dict(size=12),
            xanchor="center",
            yanchor="top",
        )


    fig.show(config={"toImageButtonOptions": {"format": "svg"}})
    return plot_df

In [None]:
umaps = plot_umap_grid(clean_df)