In [1]:
import plotly.express as px
import plotly.graph_objects as go
import pandas as pd
import numpy as np

In [2]:
from src.utils import dim_red

In [3]:
df = pd.read_csv("/home/cotsios/dsit/2nd-semester/im-anal-and-proc/Img_analysis_assignment_2/extracted_features/features.csv")
target = df.pop("label")

In [20]:

def plot_3d_interactive(df_reduced: pd.DataFrame, 
                       labels: pd.Series = None,
                       color_column: str = None,
                       title: str = "3D Interactive Plot",
                       marker_size: int = 5,
                       opacity: float = 0.7,
                       width: int = 800,
                       height: int = 600,
                       **kwargs) -> go.Figure:
    """
    Create an interactive 3D plot of the top 3 dimensions from dimensionality reduction.
    
    Parameters
    ----------
    df_reduced : pd.DataFrame
        The reduced DataFrame from dim_red function with at least 3 columns.
    labels : pd.Series, optional
        Labels for coloring points. If provided, points will be colored by these labels.
    color_column : str, optional
        Name of column in df_reduced to use for coloring. Takes precedence over labels.
    title : str, optional
        Title for the plot. Default is "3D Interactive Plot".
    marker_size : int, optional
        Size of markers. Default is 5.
    opacity : float, optional
        Opacity of markers (0-1). Default is 0.7.
    width : int, optional
        Width of the figure in pixels. Default is 800.
    height : int, optional
        Height of the figure in pixels. Default is 600.
    **kwargs : keyword arguments
        Additional arguments to pass to plotly express scatter_3d.
        
    Returns
    -------
    plotly.graph_objects.Figure
        Interactive 3D plotly figure.
        
    Examples
    --------
    >>> # Basic usage with PCA results
    >>> reduced_df = dim_red(data, method='pca', n_components=3)
    >>> fig = plot_3d_interactive(reduced_df, title="PCA 3D Visualization")
    >>> fig.show()
    
    >>> # With color labels
    >>> fig = plot_3d_interactive(reduced_df, labels=target_labels, 
    ...                          title="PCA with Class Labels")
    >>> fig.show()
    """
    
    if df_reduced.shape[1] < 3:
        raise ValueError("DataFrame must have at least 3 columns for 3D plotting")
    
    # Get the first 3 columns (top 3 dimensions)
    cols = df_reduced.columns[:3]
    x_col, y_col, z_col = cols[0], cols[1], cols[2]
    
    # Prepare the data for plotting
    plot_df = df_reduced[cols].copy()
    
    # Handle coloring
    color = None
    if color_column and color_column in df_reduced.columns:
        plot_df['color'] = df_reduced[color_column]
        color = 'color'
    elif labels is not None:
        plot_df['color'] = labels
        color = 'color'
    
    # Create the 3D scatter plot
    fig = px.scatter_3d(
        plot_df,
        x=x_col,
        y=y_col,
        z=z_col,
        color=color,
        title=title,
        opacity=opacity,
        **kwargs
    )
    
    # Update marker size
    fig.update_traces(marker=dict(size=marker_size))
    
    # Update layout for better visualization
    fig.update_layout(
        scene=dict(
            xaxis_title=x_col,
            yaxis_title=y_col,
            zaxis_title=z_col,
            camera=dict(
                up=dict(x=0, y=0, z=1),
                center=dict(x=0, y=0, z=0),
                eye=dict(x=1.5, y=1.5, z=1.5)
            )
        ),
        width=width,
        height=height,
        margin=dict(r=20, b=10, l=10, t=40)
    )
    
    return fig


def plot_3d_comparison(df_reduced_list: list, 
                      method_names: list,
                      labels: pd.Series = None,
                      width: int = None,
                      height: int = 500,
                      subplot_width: int = 300,
                      **kwargs) -> go.Figure:
    """
    Create side-by-side 3D subplots comparing different dimensionality reduction methods.
    
    Parameters
    ----------
    df_reduced_list : list of pd.DataFrame
        List of reduced DataFrames from different methods.
    method_names : list of str
        Names of the methods corresponding to each DataFrame.
    labels : pd.Series, optional
        Labels for coloring points across all subplots.
    width : int, optional
        Total width of the figure in pixels. If None, calculated as subplot_width * n_methods.
    height : int, optional
        Height of the figure in pixels. Default is 500.
    subplot_width : int, optional
        Width per subplot in pixels. Default is 300. Only used if width is None.
    **kwargs : keyword arguments
        Additional arguments to pass to the plotting function.
        
    Returns
    -------
    plotly.graph_objects.Figure
        Interactive figure with multiple 3D subplots.
        
    Examples
    --------
    >>> pca_df = dim_red(data, method='pca', n_components=3)
    >>> tsne_df = dim_red(data, method='tsne', n_components=3)
    >>> umap_df = dim_red(data, method='umap', n_components=3)
    >>> 
    >>> fig = plot_3d_comparison([pca_df, tsne_df, umap_df], 
    ...                         ['PCA', 't-SNE', 'UMAP'],
    ...                         labels=target_labels)
    >>> fig.show()
    """
    
    from plotly.subplots import make_subplots
    
    n_methods = len(df_reduced_list)
    if n_methods != len(method_names):
        raise ValueError("Number of DataFrames must match number of method names")
    
    # Calculate figure width if not specified
    if width is None:
        width = subplot_width * n_methods
    
    # Create subplots
    fig = make_subplots(
        rows=1, cols=n_methods,
        specs=[[{'type': 'scatter3d'} for _ in range(n_methods)]],
        subplot_titles=method_names,
        horizontal_spacing=0.05
    )
    
    colors = px.colors.qualitative.Set1 if labels is not None else None
    
    for i, (df_reduced, method_name) in enumerate(zip(df_reduced_list, method_names)):
        if df_reduced.shape[1] < 3:
            continue
            
        cols = df_reduced.columns[:3]
        x_col, y_col, z_col = cols[0], cols[1], cols[2]
        
        if labels is not None:
            unique_labels = labels.unique()
            for j, label in enumerate(unique_labels):
                mask = labels == label
                # Use .loc instead of .iloc for boolean indexing
                masked_data = df_reduced[mask]
                fig.add_trace(
                    go.Scatter3d(
                        x=masked_data.iloc[:, 0],
                        y=masked_data.iloc[:, 1],
                        z=masked_data.iloc[:, 2],
                        mode='markers',
                        marker=dict(
                            size=2,
                            color=colors[j % len(colors)] if colors else None,
                            opacity=0.7
                        ),
                        name=str(label),
                        showlegend=(i == 0)  # Only show legend for first subplot
                    ),
                    row=1, col=i+1
                )
        else:
            fig.add_trace(
                go.Scatter3d(
                    x=df_reduced.iloc[:, 0],
                    y=df_reduced.iloc[:, 1],
                    z=df_reduced.iloc[:, 2],
                    mode='markers',
                    marker=dict(size=5, opacity=0.7),
                    showlegend=False
                ),
                row=1, col=i+1
            )
        
        # Update axis labels for each subplot
        fig.update_layout(**{
            f'scene{i+1 if i > 0 else ""}': dict(
                xaxis_title=x_col,
                yaxis_title=y_col,
                zaxis_title=z_col
            )
        })
    
    fig.update_layout(
        height=height,
        width=width,
        title_text="Dimensionality Reduction Comparison"
    )
    
    return fig

In [5]:
pca_df = dim_red(df, method='pca', n_components=3)
tsne_df = dim_red(df, method='tsne', n_components=3)
umap_df = dim_red(df, method='umap', n_components=3)



In [21]:
fig = plot_3d_comparison([pca_df, tsne_df, umap_df], 
                        ['PCA', 't-SNE', 'UMAP'],
                        labels=target,
                        width=1800,
                        height=600,)
fig.show()