# Interactive Weight Visualization for wGMM-UAPCA and UAPCA

This notebook provides an interactive interface for exploring how different weight configurations affect the projection of datasets.

- **GMM distributions**: Uses wGMM-UAPCA projection
- **Normal distributions**: Uses UAPCA projection

**Note**: All distributions in a dataset must be of the same type (either all GMM or all Normal).

## Imports

In [1]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import multivariate_normal
from sklearn.mixture import GaussianMixture
from ipywidgets import (
    FloatSlider, Button, Checkbox, Output, VBox, HBox, 
    HTML, Layout, Dropdown
)
from IPython.display import display, clear_output

from uadapy import Distribution
from uadapy.data import load_iris_normal, load_iris_gmm, generate_synthetic_gmm
from uadapy.dr.uapca import compute_uapca
from uadapy.dr.wgmm_uapca import create_gmm_from_components
from uadapy.distributions import MultivariateGMM
from uadapy.plotting.plots_2d import plot_contour, _get_color_palette

## Helper Functions

In [2]:
def is_gmm_distribution(distribution : Distribution) -> bool:
    """
    Check if a distribution is a GMM.
    
    Parameters
    ----------
    distribution : Distribution
        Distribution object to check.
    
    Returns
    -------
    bool
        True if GMM, False otherwise.
    """
    return isinstance(distribution.model, MultivariateGMM)


def validate_distributions(distributions: list[Distribution]) -> None:
    """
    Validate that all distributions are of the same type (all GMM or all Normal).
    
    Parameters
    ----------
    distributions : list[Distribution]
        List of Distribution objects.
    
    Raises
    ------
    ValueError
        If distributions are mixed types or contain unsupported types.
    """
    if not distributions:
        raise ValueError("Distribution list cannot be empty")
    
    # Check first distribution to determine expected type
    first_is_gmm = is_gmm_distribution(distributions[0])
    first_is_normal = distributions[0].name in ['Normal', 'multivariate_normal_frozen']
    
    if not (first_is_gmm or first_is_normal):
        raise ValueError(f"All distributions inside list must be either GMM or Normal. Found type: {distributions[0].name}.")
    
    # Check all distributions are the same type
    for d in distributions[1:]:
        is_gmm = is_gmm_distribution(d)
        is_normal = d.name in ['Normal', 'multivariate_normal_frozen']
        
        if not (is_gmm or is_normal):
            raise ValueError(f"All distributions inside list must be either GMM or Normal. Found type: {d.name}.")
    
        if is_gmm != first_is_gmm:
            raise ValueError(f"All distributions inside list must be of same type.")


def align_projection_matrix(P: np.ndarray, P_ref: np.ndarray) -> np.ndarray:
    """
    Align projection matrix P to reference projection P_ref for visual stability.
    
    Parameters
    ----------
    P : np.ndarray
        Projection matrix to be aligned (n_features x n_dims).
    P_ref : np.ndarray
        Reference projection matrix (n_features x n_dims).
    
    Returns
    -------
    P_aligned : np.ndarray
        Aligned projection matrix.
    """
    P = P.copy()
    P_ref = P_ref.copy()
    
    # Normalize columns
    P /= np.linalg.norm(P, axis=0, keepdims=True)
    P_ref /= np.linalg.norm(P_ref, axis=0, keepdims=True)

    # Ensure same direction for each eigenvector
    for j in range(P.shape[1]):
        if np.dot(P_ref[:, j], P[:, j]) < 0:
            P[:, j] *= -1

    # Ensure right-handed coordinate system (for 2D)
    if P.shape[1] == 2 and np.linalg.det(P_ref.T @ P) < 0:
        P[:, 1] *= -1

    return P


def project_distributions(distributions: list[Distribution], P: np.ndarray) -> list[Distribution]:
    """
    Project distributions using the given projection matrix.
    
    Parameters
    ----------
    distributions : list[Distribution]
        List of Distribution objects (all same type).
    P : np.ndarray
        Projection matrix (n_features x 2).
    
    Returns
    -------
    list
        List of projected Distribution objects.
    """
    
    projected_dists = []
    
    # Check if GMM or Normal based on first distribution
    is_gmm = is_gmm_distribution(distributions[0])
    
    for d in distributions:
        if is_gmm:
            # GMM case
            gmm = d.model
            
            # Project all components
            projected_means = gmm.means_ @ P
            projected_covs = np.array([P.T @ cov @ P for cov in gmm.covariances_])
            
            # Create new GaussianMixture with projected parameters
            projected_gmm = create_gmm_from_components(
                projected_means, projected_covs, gmm.weights_
            )
            projected_dists.append(Distribution(MultivariateGMM(projected_gmm), name="GMM"))
        else:
            # Normal distribution case
            mean = d.mean()
            cov = d.cov()
            
            projected_mean = mean @ P
            projected_cov = P.T @ cov @ P
            
            projected_dists.append(
                Distribution(multivariate_normal(projected_mean, projected_cov), name="Normal")
            )
    
    return projected_dists


def project_with_alignment(distributions: list[Distribution], weights: np.ndarray, 
                          P_ref: np.ndarray = None) -> tuple:
    """
    Compute projection matrix, align if reference exists, and project distributions.
    
    Parameters
    ----------
    distributions : list
        List of Distribution objects (all same type).
    weights : np.ndarray
        Weights for each distribution.
    P_ref : np.ndarray, optional
        Reference projection matrix for alignment.
    
    Returns
    -------
    tuple
        (projected_distributions, projection_matrix)
    """
    # Get projection matrix using UAPCA (works for all types)
    means = np.array([d.mean() for d in distributions])
    covs = np.array([d.cov() for d in distributions])
    eigvecs, _ = compute_uapca(means, covs, weights)
    P = eigvecs[:, :2]
    
    # Align if reference exists
    if P_ref is not None:
        P = align_projection_matrix(P, P_ref)
    
    # Project distributions using the projection matrix
    projected_distributions = project_distributions(distributions, P)
    
    return projected_distributions, P

In [3]:
def normalize_weights_on_change(sliders: dict, changed_idx: int, new_value: float, 
                                n_classes: int, auto_update: Checkbox, 
                                update_plot_callback: callable, updating_flag: dict) -> None:
    """
    Normalize all slider weights so that their total remains equal to 1.0.
    """
    if not auto_update.value or updating_flag["value"]:
        return
    
    updating_flag["value"] = True

    # Get current values
    vals = np.array([sliders[i].value for i in range(n_classes)], dtype=float)
    vals[changed_idx] = float(new_value)

    # Distribute remaining weight proportionally
    remaining = 1.0 - vals[changed_idx]
    other_idx = [i for i in range(n_classes) if i != changed_idx]
    
    if other_idx:
        sum_other = vals[other_idx].sum()
        if sum_other > 0:
            vals[other_idx] *= remaining / sum_other
        else:
            vals[other_idx] = remaining / len(other_idx)

    # Update all sliders
    for i in range(n_classes):
        sliders[i].value = float(vals[i])

    updating_flag["value"] = False
    
    if auto_update.value:
        update_plot_callback()


def create_control_panel(n_classes: int, sample_based_weights: dict, 
                        update_plot_callback: callable, sliders: dict, 
                        updating_flag: dict) -> tuple:
    """
    Create control panel with preset buttons and update options.
    """
    # Widgets
    auto_update = Checkbox(
        value=True, 
        description="Auto Update Plot", 
        indent=False, 
        layout=Layout(width="180px")
    )
    
    update_button = Button(
        description="Update Plot", 
        button_style="primary", 
        layout=Layout(width="180px", height="36px")
    )
    
    equal_button = Button(
        description="Equal Weights", 
        button_style="info", 
        layout=Layout(width="180px", height="36px")
    )
    
    sample_button = Button(
        description="Sample-Based", 
        button_style="info", 
        layout=Layout(width="180px", height="36px")
    )

    # Callbacks
    def on_update_clicked(_):
        total = sum(sliders[i].value for i in range(n_classes))
        if total > 0:
            for i in range(n_classes):
                sliders[i].value /= total
        update_plot_callback()

    def on_equal_clicked(_):
        updating_flag["value"] = True
        for i in range(n_classes):
            sliders[i].value = 1.0 / n_classes
        updating_flag["value"] = False
        update_plot_callback()

    def on_sample_clicked(_):
        updating_flag["value"] = True
        weights = np.array([sample_based_weights.get(i, 1.0/n_classes) for i in range(n_classes)])
        weights /= weights.sum()
        for i in range(n_classes):
            sliders[i].value = float(weights[i])
        updating_flag["value"] = False
        update_plot_callback()

    update_button.on_click(on_update_clicked)
    equal_button.on_click(on_equal_clicked)
    sample_button.on_click(on_sample_clicked)

    # Layout
    preset_label = HTML(
        "<h3 style='color:#007acc; font-size:20px; margin:0 0 10px 5px;'>Weight Presets</h3>"
    )
    
    preset_row = VBox([
        preset_label, 
        HBox([equal_button, sample_button], layout=Layout(gap='25px'))
    ], layout=Layout(padding="5px 0"))

    update_row = HBox(
        [update_button, auto_update],
        layout=Layout(gap="40px", margin="20px 0 0 0", align_items="center")
    )

    control_panel = VBox([preset_row, update_row], layout=Layout(margin="20px"))
    
    return control_panel, auto_update

In [4]:
def update_visualization(distributions: list[Distribution], sliders: dict, n_classes: int, 
                        class_names: list, output: Output, P_ref_container: dict, 
                        dataset_name: str) -> None:
    """
    Update the visualization with current weights.
    """
    with output:
        clear_output(wait=True)

        # Get normalized weights
        weights = np.array([sliders[i].value for i in range(n_classes)], dtype=float)
        weights /= weights.sum()

        # Project distributions with alignment
        projected_dists, P = project_with_alignment(
            distributions, 
            weights, 
            P_ref_container.get("P")
        )
        
        # Store projection matrix for next alignment
        P_ref_container["P"] = P

        # Create plot
        fig, ax = plt.subplots(figsize=(8, 6))
        
        # Get colors
        palette = _get_color_palette(n_classes, colorblind_safe=False)
        
        # Plot using library function
        plot_contour(
            projected_dists,
            resolution=128,
            quantiles=[25, 75, 95],
            fig=fig,
            axs=ax,
            distrib_colors=palette,
            show_plot=False
        )
        
        # Add legend
        for i, name in enumerate(class_names):
            ax.plot([], [], color=palette[i], linewidth=2, label=name)
        
        # Styling
        ax.set_xlabel('Component 1', fontsize=11)
        ax.set_ylabel('Component 2', fontsize=11)
        
        # Determine projection method based on distribution type
        is_gmm = is_gmm_distribution(distributions[0])
        method_name = 'wGMM-UAPCA' if is_gmm else 'UAPCA'
        
        ax.set_title(f'{method_name} Projection - {dataset_name}', 
                    fontsize=13, fontweight='bold')
        ax.legend(loc='best', fontsize=10)
        ax.grid(True, alpha=0.3)
        
        plt.tight_layout()
        plt.show()

In [5]:
def create_interactive_plot(datasets: dict) -> None:
    """
    Create interactive weight visualization with support for multiple datasets.

    Parameters
    ----------
    datasets : dict
        Dictionary mapping dataset names to dataset configurations.
    """
    
    # Validate all datasets before starting
    for dataset_name, config in datasets.items():
        try:
            validate_distributions(config['distributions'])
        except ValueError as e:
            raise ValueError(f"Error in dataset '{dataset_name}': {str(e)}")
    
    # Validate datasets
    dataset_names = list(datasets.keys())
    if not dataset_names:
        raise ValueError("At least one dataset must be provided.")
    
    # State containers
    current_state = {
        'dataset_name': dataset_names[0],
        'distributions': None,
        'class_names': None,
        'sample_counts': None,
        'n_classes': 0
    }
    
    sliders_container = {'sliders': {}}
    output = Output()
    updating_flag = {"value": False}
    P_ref_container = {"P": None}
    
    # Warning label for dataset issues
    warning_label = HTML(
        value="",
        layout=Layout(margin='0 0 10px 0', width='100%', display='flex', justify_content='center')
    )
    
    # Dataset dropdown
    dataset_dropdown = Dropdown(
        options=dataset_names,
        value=dataset_names[0],
        description='Dataset:',
        style={'description_width': '80px'},
        layout=Layout(width='300px', height='36px')
    )
    
    # Placeholder widgets that will be updated
    sliders_box_container = {'widget': VBox()}
    controls_container = {'widget': VBox(), 'auto_update': None}
    header_container = {'widget': HTML()}
    
    def load_dataset(dataset_name: str):
        """Load a dataset and update state."""
        warning_msgs = []
        
        config = datasets[dataset_name]
        
        distributions = config['distributions']
        n_classes = len(distributions)
        
        # Validate this specific dataset
        validate_distributions(distributions)
        
        # Get class names
        class_names = config.get('class_names')
        if class_names is None:
            class_names = [f"Class {i}" for i in range(n_classes)]
        elif len(class_names) != n_classes:
            # Replaced print with list append
            warning_msgs.append(f"<b>Warning:</b> class_names length ({len(class_names)}) doesn't match number of distributions ({n_classes}). Using auto-generated names.")
            class_names = [f"Class {i}" for i in range(n_classes)]
        
        # Get sample counts
        sample_counts = config.get('sample_counts')
        if sample_counts is None:
            sample_based = {i: 1.0/n_classes for i in range(n_classes)}
        else:
            # Handle case where sample_counts length doesn't match n_classes
            if len(sample_counts) != n_classes:
                # Replaced print with list append
                warning_msgs.append(f"<b>Warning:</b> sample_counts length ({len(sample_counts)}) doesn't match number of distributions ({n_classes}). Using equal weights.")
                sample_based = {i: 1.0/n_classes for i in range(n_classes)}
            else:
                total = sum(sample_counts)
                sample_based = {i: sample_counts[i]/total for i in range(n_classes)}
        
        # Update the warning widget
        if warning_msgs:
            msgs_html = "<br>".join([f"<span style='color:#e67e22; font-family:monospace;'>{msg}</span>" for msg in warning_msgs])
            warning_label.value = f"<div style='text-align:center'>{msgs_html}</div>"
        else:
            warning_label.value = ""

        # Update state
        current_state['dataset_name'] = dataset_name
        current_state['distributions'] = distributions
        current_state['class_names'] = class_names
        current_state['sample_counts'] = sample_counts
        current_state['n_classes'] = n_classes
        current_state['sample_based'] = sample_based
        
        # Reset projection reference when changing datasets
        P_ref_container["P"] = None
    
    def create_sliders():
        """Create sliders for current dataset."""
        n_classes = current_state['n_classes']
        class_names = current_state['class_names']
        
        sliders = {
            i: FloatSlider(
                value=1.0 / n_classes,
                min=0.01, max=0.98, step=0.01,
                description=f"{class_names[i]}",
                continuous_update=False,
                readout_format=".2f",
                layout=Layout(width="420px", height="32px"),
                style={'description_width': '120px'}
            )
            for i in range(n_classes)
        }
        
        sliders_container['sliders'] = sliders
        return sliders
    
    def update_plot():
        """Update the visualization."""
        update_visualization(
            current_state['distributions'],
            sliders_container['sliders'],
            current_state['n_classes'],
            current_state['class_names'],
            output,
            P_ref_container,
            current_state['dataset_name']
        )
    
    def rebuild_interface():
        """Rebuild the entire interface for the current dataset."""
        # Create new sliders
        sliders = create_sliders()
        
        # Create new controls
        controls, auto_update = create_control_panel(
            current_state['n_classes'],
            current_state['sample_based'],
            update_plot,
            sliders,
            updating_flag
        )
        controls_container['widget'] = controls
        controls_container['auto_update'] = auto_update
        
        # Connect slider observers
        for i in range(current_state['n_classes']):
            sliders[i].observe(
                lambda change, i=i: normalize_weights_on_change(
                    sliders_container['sliders'], i, change["new"],
                    current_state['n_classes'],
                    controls_container['auto_update'],
                    update_plot, updating_flag
                ),
                names="value"
            )
        
        # Layout sliders
        slider_label = HTML(
            "<h3 style='color:#007acc; font-size:20px; margin: 0 0 10px 50px;'>Weight Adjustment</h3>"
        )
        
        n_classes = current_state['n_classes']
        if n_classes >= 8:
            # Two columns for many classes
            half = (n_classes + 1) // 2
            left_col = VBox([sliders[i] for i in range(half)], layout=Layout(gap="15px"))
            right_col = VBox([sliders[i] for i in range(half, n_classes)], layout=Layout(gap="15px"))
            sliders_box = VBox(
                [slider_label, HBox([left_col, right_col], layout=Layout(gap="30px"))],
                layout=Layout(margin="25px 0 0 0")
            )
        else:
            # Single column
            sliders_col = VBox([sliders[i] for i in range(n_classes)], layout=Layout(gap="15px"))
            sliders_box = VBox([slider_label, sliders_col], layout=Layout(margin="25px 0 0 0"))
        
        sliders_box_container['widget'] = sliders_box
        
        # Update header
        header = HTML(f"""
            <h2 style='font-size:30px; font-weight:800;
                        margin:10px 0 10px 0; letter-spacing:0.5px;
                        text-align:center; line-height:1.3;'>
                Interactive Weight Visualization<br>
                <span style='color:#007acc;'>{current_state['dataset_name']}</span>
            </h2>
        """)
        header_container['widget'] = header
        
        # Update right panel with dataset selector + new controls
        right_panel.children = [dataset_selector, controls_container['widget']]
        
        # Update the main layout
        main_content.children = [
            header_container['widget'],
            HBox(
                [sliders_box_container['widget'], right_panel],
                layout=Layout(justify_content="center", align_items="flex-start", width="100%")
            ),
            warning_label,
            output
        ]
        
        # Update plot
        update_plot()
    
    def on_dataset_change(change):
        """Handle dataset dropdown change."""
        load_dataset(change['new'])
        rebuild_interface()
    
    dataset_dropdown.observe(on_dataset_change, names='value')
    
    # Dataset selector widget
    dataset_selector_label = HTML(
        "<h3 style='color:#007acc; font-size:20px; margin:0 0 10px 5px;'>Dataset Selection</h3>"
    )
    dataset_selector = VBox(
        [dataset_selector_label, dataset_dropdown],
        layout=Layout(margin="20px 0 0 0")
    )
    
    # Initialize with first dataset
    load_dataset(dataset_names[0])
    sliders = create_sliders()
    
    # Create initial controls
    controls, auto_update = create_control_panel(
        current_state['n_classes'],
        current_state['sample_based'],
        update_plot,
        sliders,
        updating_flag
    )
    controls_container['widget'] = controls
    controls_container['auto_update'] = auto_update
    
    # Create right panel (will be updated but not replaced)
    right_panel = VBox(
        [dataset_selector, controls_container['widget']],
        layout=Layout(align_items="flex-start")
    )
    
    # Connect initial slider observers
    for i in range(current_state['n_classes']):
        sliders[i].observe(
            lambda change, i=i: normalize_weights_on_change(
                sliders_container['sliders'], i, change["new"],
                current_state['n_classes'],
                controls_container['auto_update'],
                update_plot, updating_flag
            ),
            names="value"
        )
    
    # Initial sliders layout
    slider_label = HTML(
        "<h3 style='color:#007acc; font-size:20px; margin: 0 0 10px 50px;'>Weight Adjustment</h3>"
    )
    
    n_classes = current_state['n_classes']
    if n_classes >= 8:
        half = (n_classes + 1) // 2
        left_col = VBox([sliders[i] for i in range(half)], layout=Layout(gap="15px"))
        right_col = VBox([sliders[i] for i in range(half, n_classes)], layout=Layout(gap="15px"))
        sliders_box = VBox(
            [slider_label, HBox([left_col, right_col], layout=Layout(gap="30px"))],
            layout=Layout(margin="25px 0 0 0")
        )
    else:
        sliders_col = VBox([sliders[i] for i in range(n_classes)], layout=Layout(gap="15px"))
        sliders_box = VBox([slider_label, sliders_col], layout=Layout(margin="25px 0 0 0"))
    
    sliders_box_container['widget'] = sliders_box
    
    # Initial header
    header = HTML(f"""
        <h2 style='font-size:30px; font-weight:800;
                    margin:10px 0 10px 0; letter-spacing:0.5px;
                    text-align:center; line-height:1.3;'>
            Interactive Weight Visualization<br>
            <span style='color:#007acc;'>{current_state['dataset_name']}</span>
        </h2>
    """)
    header_container['widget'] = header
    
    # Main content container
    main_content = VBox([
        header_container['widget'],
        HBox(
            [sliders_box_container['widget'], right_panel],
            layout=Layout(justify_content="center", align_items="flex-start", width="100%")
        ),
        warning_label,
        output
    ], layout=Layout(align_items="center", justify_content="center", width="100%"))
    
    display(main_content)
    update_plot()

## Running the Interactive Visualization

#### Using Datasets from UADAPy

In [6]:
# Create datasets dictionary
datasets = {
    'Iris GMM': {
        'distributions': load_iris_gmm(n_components=2, random_state=0),
        'class_names': ['setosa', 'versicolor', 'virginica'],
        'sample_counts': [50, 50, 50]
    },
    'Iris Normal': {
        'distributions': load_iris_normal(),
        'class_names': ['setosa', 'versicolor', 'virginica'],
        'sample_counts': [50, 50, 50]
    },
    'Synthetic GMM': {
        'distributions': generate_synthetic_gmm(n_classes=3, n_dims=5, random_state=0),
        'class_names': ['Component 1', 'Component 2', 'Component 3'],
        'sample_counts': [100, 100, 100]
    }
}

# Run the interactive visualization
create_interactive_plot(datasets)

VBox(children=(HTML(value="\n        <h2 style='font-size:30px; font-weight:800;\n                    margin:1…

#### Using own Datasets

In [7]:
# Generate Normally distributed data with varying sample sizes
np.random.seed(0)
synthetic_normal = []
sample_counts_normal = [50, 100, 150, 500]
n_classes_normal = len(sample_counts_normal)
n_dims_normal = 5

for i, n_samples in enumerate(sample_counts_normal):
    # Generate random data for each class
    mean = np.random.randn(n_dims_normal) * 3 + i * 4
    A = np.random.randn(n_dims_normal, n_dims_normal)
    cov = A @ A.T + np.eye(n_dims_normal) * 0.5
    
    # Create samples and fit a normal distribution
    data = np.random.multivariate_normal(mean, cov, n_samples)
    synthetic_normal.append(Distribution(data, name="Normal"))


# Generate GMM data with varying sample sizes and exactly 2 components per class
np.random.seed(0)
synthetic_gmm = []
sample_counts_gmm = [80, 10, 100, 200, 120]
n_classes_gmm = len(sample_counts_gmm)
n_dims_gmm = 5

for i in range(n_classes_gmm):
    # Generate data for each class
    class_data = []
    n_first = np.random.randint(1, sample_counts_gmm[i])
    counts = [n_first, sample_counts_gmm[i] - n_first]
    
    for comp in range(2):
        mean = np.random.randn(n_dims_gmm) * 2 + i * 3
        A = np.random.randn(n_dims_gmm, n_dims_gmm)
        cov = A @ A.T + np.eye(n_dims_gmm) * 0.3
        samples = np.random.multivariate_normal(mean, cov, counts[comp])
        class_data.append(samples)
    
    class_data = np.vstack(class_data)
    
    # Fit GMM with 2 components
    gmm = GaussianMixture(n_components=2, covariance_type='full', random_state=0)
    gmm.fit(class_data)
    
    synthetic_gmm.append(Distribution(MultivariateGMM(gmm), name="GMM"))


# Create datasets dictionary
datasets = {
    'Synthetic Normal': {
        'distributions': synthetic_normal,
        'class_names': ['Class 1', 'Class 2', 'Class 3', 'Class 4'],
        'sample_counts': sample_counts_normal
    },
    'Synthetic GMM': {
        'distributions': synthetic_gmm,
        'class_names': ['Class 1', 'Class 2', 'Class 3', 'Class 4', 'Class 5'],
        'sample_counts': sample_counts_gmm
    }
}

# Run the interactive visualization
create_interactive_plot(datasets)

VBox(children=(HTML(value="\n        <h2 style='font-size:30px; font-weight:800;\n                    margin:1…