In [10]:
import json
import csv
import numpy as np
from sklearn.decomposition import PCA
import plotly.graph_objs as go
from matplotlib import colormaps
from scipy.spatial import ConvexHull
import os
from src import PROJECT_ROOT, DATA_DIR
from src.utils.io import load_json

In [11]:
# === Load Clustering CSV ===
def load_clustering(file_path):
    clustering = {}  # method -> node_id -> cluster_label
    with open(file_path, newline='', encoding='utf-8') as f:
        reader = csv.DictReader(f)
        cluster_names = reader.fieldnames[1:]
        for row in reader:
            method = row["method"]
            clustering[method] = {}
            for cluster_name in cluster_names:
                ids = row[cluster_name].split(",") if row[cluster_name] else []
                for node_id in ids:
                    node_id = node_id.strip()
                    if node_id:
                        clustering[method][node_id] = cluster_name
    return clustering


In [12]:
# === Convex Hull Generation ===
def cluster_hull_trace(xs, ys, cluster_name, color):
    if len(xs) < 3:
        return None  # Not enough points for a hull
    points = np.vstack((xs, ys)).T
    hull = ConvexHull(points)
    polygon = points[hull.vertices]
    return go.Scatter(
        x=polygon[:, 0].tolist() + [polygon[0, 0]],
        y=polygon[:, 1].tolist() + [polygon[0, 1]],
        fill='toself',
        mode='lines',
        name=f"{cluster_name} boundary",
        line=dict(color=color, width=1),
        fillcolor=color.replace("rgb", "rgba").replace(")", ",0.15)"),
        hoverinfo='skip',
        showlegend=False,
        visible=False
    )

In [16]:
# === File Paths ===

embedding_path = os.path.join(DATA_DIR, "emb", "NeurIPS", "2024", "20250602_2031", "emb.json")
metadata_path = os.path.join(DATA_DIR, "unified_text", "NeurIPS", "NeurIPS_2024.json")
clustering_csv = os.path.join(PROJECT_ROOT, "result", "NeurIPS", "2024", "20250602_2158", "clusters.csv")

embedding_dict = load_json(embedding_path)
metadata_list = load_json(metadata_path)
id_to_meta = {d['id']: d for d in metadata_list}

clustering_dict = load_clustering(clustering_csv)
node_ids = list(embedding_dict.keys())
embeddings = np.array([embedding_dict[nid] for nid in node_ids])


In [19]:
# === PCA Projection ===
proj = PCA(n_components=2).fit_transform(embeddings)
id_to_proj = {nid: proj[i] for i, nid in enumerate(node_ids)}

# === Baseline Coloring ===
baseline_clusters = clustering_dict.get("Baseline", {})
unique_baseline_labels = sorted(set(baseline_clusters.get(nid, "Unassigned") for nid in node_ids))
cmap = colormaps.get_cmap("Set2").resampled(len(unique_baseline_labels))
label_to_color = {
    label: f"rgb({r},{g},{b})"
    for i, label in enumerate(unique_baseline_labels)
    for r, g, b in [tuple(int(x * 255) for x in cmap(i)[:3])]
}

# === Plotly Traces per Method ===
data = []
buttons = []
method_names = list(clustering_dict.keys())

for method in method_names:
    cluster_map = clustering_dict[method]

    x_vals, y_vals, hover_texts, colors = [], [], [], []
    cluster_xys = {}  # cluster_label -> list of (x, y)

    for nid in node_ids:
        x, y = id_to_proj[nid]
        x_vals.append(x)
        y_vals.append(y)
        baseline_label = baseline_clusters.get(nid, "Unassigned")
        colors.append(label_to_color.get(baseline_label, "gray"))

        cluster_label = cluster_map.get(nid, "Unassigned")
        cluster_xys.setdefault(cluster_label, []).append((x, y))

        meta = id_to_meta.get(nid, {})
        hover = f"""
        <b>{meta.get('title', '')}</b><br>
        <i>{meta.get('authors', '')}</i><br><br>
        <div style='max-height:100px; max-width:200px; overflow-y:auto; font-size:0.9em; line-height:1.2em;'>
        {meta.get('abstract', '')}
        </div>
        """
        hover_texts.append(hover)

    trace = go.Scatter(
        x=x_vals,
        y=y_vals,
        mode='markers',
        marker=dict(size=8, color=colors, line=dict(width=0.5, color='DarkSlateGrey')),
        text=hover_texts,
        hoverinfo='text',
        name=method,
        visible=(method == "Baseline")
    )
    data.append(trace)

    # Add cluster hulls
    for cluster_label, points in cluster_xys.items():
        xs, ys = zip(*points)
        hull_trace = cluster_hull_trace(xs, ys, cluster_label, color="gray")
        if hull_trace:
            hull_trace.visible = (method == "Baseline")
            data.append(hull_trace)

    # Dropdown button
    visible_flags = [(t.name == method) for t in data]
    buttons.append(dict(
        label=method,
        method="update",
        args=[{"visible": visible_flags}, {"title": f"Clustering Visualization — {method}"}]
    ))

# === Layout ===
layout = go.Layout(
    title="Clustering Visualization — Baseline",
    updatemenus=[dict(
        buttons=buttons,
        direction="down",
        showactive=True,
        x=0.1,
        xanchor="left",
        y=1.15,
        yanchor="top"
    )],
    xaxis=dict(title="PCA 1", showgrid=False, zeroline=False),
    yaxis=dict(title="PCA 2", showgrid=False, zeroline=False),
    hovermode='closest',
    hoverlabel=dict(
        bgcolor="white",
        font_size=12,
        font_family="Arial",
        align="left"
    ),
    height=700
)

fig = go.Figure(data=data, layout=layout)
fig.write_html("interactive_clustering.html")
