In [None]:
# File: notebooks/paper1_heuristic_analysis.ipynb
# Cell 1: Setup and W&B Data Download

import wandb
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
import aglt
from aglt.pipeline.steps import FaissGraphConstructor, PipelineData, RankSGWTRepresentationBuilder
from pathlib import Path
import networkx as nx

# Configure plotting style
sns.set_theme(style="whitegrid")
api = wandb.Api()

# --- Download the Ablation Data from W&B ---
# Replace with your W&B entity/project
runs = api.runs("YOUR_ENTITY/Jormungandr-Semantica") 

summary_list = []
for run in runs:
    # Filter for the runs from our heuristics ablation
    if run.job_type == "ablation" and run.config.get("representation") in ["wavelet", "acmw", "community", "rank"]:
        summary_list.append(
            {
                "representation": run.config.get("representation"),
                "seed": run.config.get("seed"),
                "ARI": run.summary.get("ARI"),
                "runtime": run.summary.get("runtime_seconds"),
            }
        )

results_df = pd.DataFrame(summary_list)
print("Downloaded data for {} runs.".format(len(results_df)))
results_df.head()

In [None]:
# Cell 2: Generate the Statistical Results Table

# Calculate mean and std dev for ARI
agg_df = results_df.groupby("representation")["ARI"].agg(['mean', 'std']).reset_index()

# Sort the results for clarity
rep_order = ["wavelet", "acmw", "community", "rank"]
agg_df['representation'] = pd.Categorical(agg_df['representation'], categories=rep_order, ordered=True)
agg_df = agg_df.sort_values('representation')

# Rename 'wavelet' to the baseline name
agg_df['representation'] = agg_df['representation'].replace({'wavelet': 'Isotropic SGWT (Baseline)'})

print("--- Final Results Table ---")
print(agg_df.to_string(index=False))

# --- Generate LaTeX code for the paper ---
print("\n--- LaTeX Code for Table 1 ---")
# Format mean and std into a single "mean ± std" column
agg_df['ARI (Mean ± Std)'] = agg_df.apply(lambda x: f"{x['mean']:.2f} ± {x['std']:.2f}", axis=1)
latex_df = agg_df[['representation', 'ARI (Mean ± Std)']]
latex_df = latex_df.rename(columns={'representation': 'Representation Builder'})
print(latex_df.to_latex(index=False, column_format="lc"))

In [None]:
# Cell 3: Generate the Diagnostic UMAP Visualization

# We need to manually run one of the failed builders to get its representation
print("Generating representation for a failed heuristic (RankSGWT)...")

# --- Setup data pipeline ---
config = {
    'k': 15, 'seed': 42, 'umap_dims': 2, # Need 2D for plotting
    'rank_quantile': 0.1, 'rank_enhancement': 1.5, 'rank_dampening': 0.5,
    'wavelet_scales': [5, 15, 50], 'n_eigenvectors': 200
}
DATA_DIR = Path("../data")
embeddings = np.load(DATA_DIR / "20newsgroups_embeddings.npy")
df_labels = pd.read_csv(DATA_DIR / "20newsgroups_labels.csv")
data = PipelineData(docs=[], embeddings=embeddings, labels_true=df_labels["label"].to_numpy())

# --- Run the pipeline steps ---
data = FaissGraphConstructor(config).run(data)
data = RankSGWTRepresentationBuilder(config).run(data)

# --- Perform UMAP on the bad representation ---
from umap import UMAP
reducer = UMAP(n_components=2, random_state=42, min_dist=0.0, metric='cosine')
embedding_2d = reducer.fit_transform(data.representation)

# --- Create the plot ---
plt.figure(figsize=(10, 8))
plt.scatter(embedding_2d[:, 0], embedding_2d[:, 1], c=data.labels_true, cmap='turbo', s=5)
plt.title("UMAP of RankSGWT Representation (ARI ≈ 0.17)", fontsize=16)
plt.xlabel("UMAP Component 1")
plt.ylabel("UMAP Component 2")
plt.grid(False)
plt.xticks([])
plt.yticks([])
plt.savefig("../paper/figures/failed_heuristic_visualization.png", dpi=150, bbox_inches='tight')
print("Saved diagnostic UMAP plot.")
plt.show()