In [None]:
# pip install umap-learn numpy scipy scikit-learn pandas matplotlib seaborn plotly
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.datasets import load_digits, fetch_openml
from sklearn.preprocessing import StandardScaler
import umap
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from mpl_toolkits.mplot3d import Axes3D
import warnings
warnings.filterwarnings('ignore')

# Set random seed for reproducibility
np.random.seed(42)

In [None]:
def plot_umap_visualization_2d_plotly(X, y, n_neighbors=15, min_dist=0.1, random_state=42,
                                    color_scheme='tab20', point_size=10, opacity=0.8,
                                    width=900, height=700, return_fig=True,
                                    title='2D UMAP Projection', legend_title='Device Type',
                                    show_count_in_legend=True, show_plot=True,
                                    hover_data=None, marker_symbol='circle',
                                    grid=True, template='plotly_white'):
    """
    Create interactive 2D UMAP visualization with Plotly.

    Parameters:
    -----------
    X : array-like
        Feature matrix with shape (n_samples, n_features)
    y : array-like
        Target labels/categories with shape (n_samples,)
    n_neighbors : int, default=15
        UMAP parameter: number of neighbors to consider
    min_dist : float, default=0.1
        UMAP parameter: minimum distance between points in low dimensional space
    random_state : int, default=42
        Random state for reproducibility
    color_scheme : str, default='tab20'
        Color scheme to use. Options: 'tab20', 'tab20b', 'tab20c', 'Set3',
        'Dark2', 'Accent', 'plasma', 'viridis', 'rainbow'
    point_size : int, default=10
        Size of scatter points
    opacity : float, default=0.8
        Transparency of scatter points (0 to 1)
    width : int, default=900
        Width of the plot in pixels
    height : int, default=700
        Height of the plot in pixels
    return_fig : bool, default=True
        Whether to return the figure object
    title : str, default='2D UMAP Projection'
        Title for the UMAP plot. Set to None to hide.
    legend_title : str, default='Device Type'
        Title for the legend
    show_count_in_legend : bool, default=True
        Whether to show sample counts in the legend labels
    show_plot : bool, default=True
        Whether to display the plot
    hover_data : dict or None, default=None
        Additional data to show on hover (dict mapping column names to arrays)
    marker_symbol : str, default='circle'
        Marker symbol for the scatter points
    grid : bool, default=True
        Whether to show grid lines
    template : str, default='plotly_white'
        Plotly template to use. Options: 'plotly', 'plotly_white', 'plotly_dark',
        'ggplot2', 'seaborn', 'simple_white', 'none'

    Returns:
    --------
    fig : plotly.graph_objects.Figure, optional
        Plotly figure object (only if return_fig=True)
    """
    import numpy as np
    import pandas as pd
    import plotly.graph_objects as go
    import umap
    from sklearn.preprocessing import StandardScaler
    import matplotlib.pyplot as plt  # Only for color mapping

    # Check if inputs have matching lengths
    if len(y) != X.shape[0]:
        raise ValueError(f"Length of y ({len(y)}) does not match number of samples in X ({X.shape[0]})")

    # Check if there are enough samples for each category
    unique_devices = np.unique(y)
    print(f"Found {len(unique_devices)} unique categories in y")
    for device in unique_devices:
        count = np.sum(y == device)
        print(f"Category '{device}': {count} samples")
        if count < 2:
            print(f"  Warning: Category '{device}' has fewer than 2 samples, which may cause issues with UMAP")

    # 1. Scale/normalize the features
    scaler = StandardScaler()
    X_scaled = scaler.fit_transform(X)
    print(f"X_scaled type: {type(X_scaled)}, shape: {X_scaled.shape}")

    # 2. Create color mapping based on the selected scheme
    if color_scheme in ['viridis', 'plasma', 'inferno', 'magma', 'cividis']:
        # Use plotly's built-in continuous color scales for these
        use_plotly_colors = True
        discrete_colors = None
    else:
        # Use matplotlib's color maps for more categorical options
        use_plotly_colors = False
        cmap = plt.cm.get_cmap(color_scheme)
        if cmap.N < len(unique_devices):
            print(f"Warning: Color scheme '{color_scheme}' has fewer colors ({cmap.N}) than categories ({len(unique_devices)})")
            # Combine multiple color schemes if needed
            if len(unique_devices) > 20:
                colors = []
                for scheme in ['tab20', 'tab20b', 'tab20c', 'Set3']:
                    scheme_colors = plt.cm.get_cmap(scheme)(np.linspace(0, 1, plt.cm.get_cmap(scheme).N))
                    colors.extend(scheme_colors)
                colors = colors[:len(unique_devices)]
            else:
                colors = cmap(np.linspace(0, 1, len(unique_devices)))
        else:
            colors = cmap(np.linspace(0, 1, min(cmap.N, len(unique_devices))))

        # Convert colors to hex format for plotly
        discrete_colors = {}
        for i, device in enumerate(unique_devices):
            rgba = colors[i]
            hex_color = f'#{int(rgba[0]*255):02x}{int(rgba[1]*255):02x}{int(rgba[2]*255):02x}'
            discrete_colors[device] = hex_color

    # 3. Apply 2D UMAP
    reducer_2d = umap.UMAP(n_neighbors=n_neighbors,
                         min_dist=min_dist,
                         n_components=2,
                         random_state=random_state)
    embedding_2d = reducer_2d.fit_transform(X_scaled)

    # 4. Create DataFrame for plotting
    umap_df_2d = pd.DataFrame({
        'UMAP1': embedding_2d[:, 0],
        'UMAP2': embedding_2d[:, 1],
        'device': y
    })

    # Add sample counts to device names if requested
    device_labels = {}
    for device in unique_devices:
        count = np.sum(umap_df_2d['device'] == device)
        if show_count_in_legend:
            device_labels[device] = f"{device} ({count})"
        else:
            device_labels[device] = f"{device}"

    # Add hover data if provided
    if hover_data is not None:
        for col_name, data in hover_data.items():
            if len(data) == X.shape[0]:
                umap_df_2d[col_name] = data
            else:
                print(f"Warning: Hover data '{col_name}' has incompatible length and will be ignored")

    # 5. Create visualization
    fig = go.Figure()

    # Add traces for each device category
    for device in unique_devices:
        mask = umap_df_2d['device'] == device
        subset = umap_df_2d[mask]

        if len(subset) == 0:
            continue

        hover_text = [device_labels[device]] * len(subset)

        # Add custom hover info if available
        if hover_data is not None:
            hover_cols = [col for col in hover_data.keys() if col in subset.columns]
            if hover_cols:
                hover_text = [
                    f"{device_labels[device]}<br>" + "<br>".join([f"{col}: {row[col]}" for col in hover_cols])
                    for _, row in subset.iterrows()
                ]

        fig.add_trace(go.Scatter(
            x=subset['UMAP1'],
            y=subset['UMAP2'],
            mode='markers',
            marker=dict(
                size=point_size,
                color=discrete_colors[device] if not use_plotly_colors else None,
                symbol=marker_symbol,
                opacity=opacity,
                line=dict(width=0)
            ),
            name=device_labels[device],
            text=hover_text,
            hoverinfo='text'
        ))

    # Update layout
    fig.update_layout(
        title=dict(
            text=title if title else "",
            font=dict(size=18)
        ),
        width=width,
        height=height,
        xaxis=dict(
            title='UMAP1',
            showgrid=grid,
            zeroline=grid
        ),
        yaxis=dict(
            title='UMAP2',
            showgrid=grid,
            zeroline=grid,
            scaleanchor="x",  # Make sure the aspect ratio is equal
            scaleratio=1
        ),
        legend=dict(
            title=dict(
                text=legend_title if legend_title else "",
                font=dict(size=14)
            )
        ),
        template=template,
        margin=dict(l=10, r=10, b=10, t=50 if title else 10)
    )

    # Show plot if requested
    if show_plot:
        fig.show()

    # Return figure if requested
    if return_fig:
        return fig