In [None]:
"""
spectral_clustering_from_wasserstein_no_outliers.py
---------------------------------------------------
Performs spectral clustering on Wasserstein distance matrices
and removes a given proportion of outliers before clustering.

Steps:
1. Load Wasserstein distance matrix
2. Compute average distance per language
3. Remove top X% outliers (most distant languages)
4. Perform spectral clustering (Gaussian kernel + KMeans)
5. Save cluster results and 2D spectral embedding plot
---------------------------------------------------
Author: Samuel Jiang (2025)
"""

from pathlib import Path
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn.cluster import KMeans
from sklearn.preprocessing import normalize
from numpy.linalg import eigh

GROUPS = {
    "objects": Path("/home/njian29/Desktop/wasserstein_matrix_objects.csv"),
    "ideologies": Path("/home/njian29/Desktop/wasserstein_matrix_ideologies.csv"),
    "sports": Path("/home/njian29/Desktop/wasserstein_matrix_sports.csv"),
}

OUT_DIR = Path("/home/njian29/Desktop/spectral_clusters_cleaned")
OUT_DIR.mkdir(parents=True, exist_ok=True)

K = 5             
SIGMA = None      
OUTLIER_PCT = 0.05

def similarity_from_dist(D: np.ndarray, sigma: float | None = None) -> np.ndarray:
    if sigma is None:
        vals = D[np.triu_indices_from(D, k=1)]
        sigma = np.median(vals[vals > 0])
        if sigma <= 0 or np.isnan(sigma):
            sigma = 1.0
    A = np.exp(-(D ** 2) / (2.0 * sigma ** 2))
    np.fill_diagonal(A, 0.0)
    return A


def normalized_graph_laplacian(A: np.ndarray) -> np.ndarray:
    d = A.sum(axis=1)
    d_inv_sqrt = np.where(d > 0, 1.0 / np.sqrt(d), 0.0)
    D_inv_sqrt = np.diag(d_inv_sqrt)
    L = np.eye(A.shape[0]) - D_inv_sqrt @ A @ D_inv_sqrt
    return L


def spectral_embed(L: np.ndarray, k: int) -> np.ndarray:
    evals, evecs = eigh(L)
    order = np.argsort(evals)
    evecs = evecs[:, order]
    nontrivial_idx = np.where(evals > 1e-9)[0]
    start = nontrivial_idx[0] if nontrivial_idx.size > 0 else 0
    U = evecs[:, start:start + k]
    U = normalize(U, norm="l2")
    return U


def spectral_clustering(D: np.ndarray, k: int, sigma: float | None = None):
    A = similarity_from_dist(D, sigma)
    L = normalized_graph_laplacian(A)
    U = spectral_embed(L, k)
    km = KMeans(n_clusters=k, n_init=20, random_state=42)
    labels = km.fit_predict(U)
    return labels, U

def remove_outliers(D: np.ndarray, langs: list[str], pct: float):
    avg_dist = D.mean(axis=1)
    n_remove = int(len(langs) * pct)
    if n_remove == 0:
        return D, langs

    cutoff = np.partition(avg_dist, -n_remove)[-n_remove]
    keep_idx = np.where(avg_dist <= cutoff)[0]
    removed_idx = np.where(avg_dist > cutoff)[0]

    print(f"[INFO] 去除 {n_remove} 个 outlier (>{pct*100:.1f}%)：")
    for i in removed_idx:
        print(f"    - {langs[i]} (avg dist = {avg_dist[i]:.3f})")

    D_clean = D[np.ix_(keep_idx, keep_idx)]
    langs_clean = [langs[i] for i in keep_idx]
    return D_clean, langs_clean

def main():
    for name, path in GROUPS.items():
        print(f"\n[INFO] 处理 {name}: {path}")
        df = pd.read_csv(path, index_col=0)
        langs = df.index.tolist()
        D = df.values
        D_clean, langs_clean = remove_outliers(D, langs, OUTLIER_PCT)
        labels, U = spectral_clustering(D_clean, k=K, sigma=SIGMA)

        out_csv = OUT_DIR / f"clusters_{name}_cleaned.csv"
        pd.DataFrame({
            "Language": langs_clean,
            "Cluster": labels
        }).sort_values("Cluster").to_csv(out_csv, index=False)
        print(f"[OK] 已保存聚类表: {out_csv}")

        plt.figure(figsize=(10, 8))
        plt.style.use("seaborn-v0_8-whitegrid")
        scatter = plt.scatter(U[:, 0], U[:, 1], c=labels, cmap="tab10", s=60, edgecolor="k", alpha=0.9)
        for i, lang in enumerate(langs_clean):
            plt.text(U[i, 0], U[i, 1], lang.replace(".txt", ""), fontsize=7, ha="center", va="center", alpha=0.8)

        plt.title(f"Spectral Clusters of {name.capitalize()} (Cleaned, k={K})", fontsize=14, pad=15)
        plt.xlabel("Spectral Dimension 1")
        plt.ylabel("Spectral Dimension 2")
        plt.colorbar(scatter, label="Cluster ID")
        plt.tight_layout()

        out_fig = OUT_DIR / f"spectral_clusters_{name}_cleaned.png"
        plt.savefig(out_fig, dpi=300)
        plt.close()
        print(f"[OK] 已保存图像: {out_fig}")


if __name__ == "__main__":
    main()