In [1]:
import numpy as np
import pandas as pd
import networkx as nx
from sklearn.datasets import make_classification
import plotly.graph_objects as go
import plotly.express as px
from ipywidgets import widgets, interact

# Step 1: Generate a large synthetic dataset
n_samples = 1000  # Number of observations
n_features = 50   # Number of features (high-dimensional data)

X, y = make_classification(n_samples=n_samples, n_features=n_features, random_state=41)
df = pd.DataFrame(X, columns=[f'Feature_{i}' for i in range(n_features)])

# Step 2: Compute the correlation matrix (using absolute values to handle both positive and negative correlations)
correlation_matrix = df.corr().abs()

# Function to create a graph based on the correlation matrix with a given threshold
def create_graph(threshold=0.5):
    G = nx.Graph()
    
    # Add edges with weights above the threshold
    for i in range(n_features):
        for j in range(i + 1, n_features):  # Upper triangle only to avoid redundancy
            if correlation_matrix.iloc[i, j] >= threshold:
                G.add_edge(f'Feature_{i}', f'Feature_{j}', weight=correlation_matrix.iloc[i, j])
    
    return G

# Function to update and visualize the graph with the given threshold
def update_graph(threshold=0.1):
    G = create_graph(threshold)
    
    # Step 4: Compute the Maximum Spanning Tree (MST) using correlations directly as weights
    mst = nx.maximum_spanning_tree(G)
    
    # Step 5: Visualize the Maximum Spanning Tree
    pos = nx.spring_layout(mst, seed=42)

    # Extract edges and their corresponding weights (correlations)
    edge_x = []
    edge_y = []
    for edge in mst.edges():
        x0, y0 = pos[edge[0]]
        x1, y1 = pos[edge[1]]
        edge_x.append(x0)
        edge_x.append(x1)
        edge_x.append(None)
        edge_y.append(y0)
        edge_y.append(y1)
        edge_y.append(None)

    # Create edge traces for Plotly visualization
    edge_trace = go.Scatter(
        x=edge_x, y=edge_y,
        line=dict(width=0.5, color='#888'),
        hoverinfo='none',
        mode='lines')

    # Create node traces
    node_x = []
    node_y = []
    for node in mst.nodes():
        x, y = pos[node]
        node_x.append(x)
        node_y.append(y)

    node_trace = go.Scatter(
        x=node_x, y=node_y,
        mode='markers+text',
        text=list(mst.nodes),
        textposition="top center",
        hoverinfo='text',
        marker=dict(
            showscale=True,
            colorscale='YlGnBu',
            size=10,
            color=[],
            colorbar=dict(
                thickness=15,
                title='Node Connections',
                xanchor='left',
                titleside='right'
            ),
            line_width=2))

    # Add node degree (number of connections) to node trace
    node_adjacencies = []
    for node, adjacencies in mst.adjacency():
        node_adjacencies.append(len(adjacencies))
    node_trace.marker.color = node_adjacencies

    # Create the figure
    fig = go.Figure(data=[edge_trace, node_trace],
                    layout=go.Layout(
                        title=f'Maximum Spanning Tree of Feature Correlations (Threshold: {threshold})',
                        titlefont_size=16,
                        showlegend=False,
                        hovermode='closest',
                        margin=dict(b=0, l=0, r=0, t=40),
                        annotations=[dict(
                            text="MST Visualization of Feature Correlations",
                            showarrow=False,
                            xref="paper", yref="paper",
                            x=0.005, y=-0.002)],
                        xaxis=dict(showgrid=False, zeroline=False),
                        yaxis=dict(showgrid=False, zeroline=False)))

    # Show the interactive plot
    fig.show()

# Step 6: Add interactivity with a slider to adjust the correlation threshold
interact(update_graph, threshold=widgets.FloatSlider(min=0.0, max=1.0, step=0.05, value=0.5, description='Threshold:'))


interactive(children=(FloatSlider(value=0.5, description='Threshold:', max=1.0, step=0.05), Output()), _dom_cl…

<function __main__.update_graph(threshold=0.1)>