In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
from scipy.spatial import procrustes
import seaborn as sns

# === Load SPoSE model embeddings and transpose ===
W1 = np.loadtxt('sparse_embed_epoch0500_1b.txt').T  # (n_items, dim)
W2 = np.loadtxt('sparse_embed_epoch0500_12b.txt').T

# === Load human embeddings CSV ===
df_human = pd.read_csv("../result/embeddings/human/embeddings_human_out.csv")
human_items = df_human["item"].astype(str).values
human_embeds = df_human["embedding"].apply(lambda x: np.fromstring(x, sep=" "))
W_human = np.vstack(human_embeds.values)  # (n_items, dim)

# === Build index-aligned item lists (assumes order in W1 and W2 matches item list)
n_items = W_human.shape[0]
W1 = W1[:n_items]
W2 = W2[:n_items]

# === Ensure all have same number of dims
min_dim = min(W1.shape[1], W2.shape[1], W_human.shape[1])
W1 = W1[:, :min_dim]
W2 = W2[:, :min_dim]
W_human = W_human[:, :min_dim]

# === Clean invalid rows
def clean_embeddings(W):
    mask = np.all(np.isfinite(W), axis=1)
    return W[mask], np.sum(~mask)

W1_clean, _ = clean_embeddings(W1)
W2_clean, _ = clean_embeddings(W2)
W_human_clean, _ = clean_embeddings(W_human)

# === Trim to same number of rows
n = min(W1_clean.shape[0], W2_clean.shape[0], W_human_clean.shape[0])
W1_clean = W1_clean[:n]
W2_clean = W2_clean[:n]
W_human_clean = W_human_clean[:n]

# === Apply PCA for 2D projection
W1_pca = PCA(n_components=2).fit_transform(W1_clean)
W2_pca = PCA(n_components=2).fit_transform(W2_clean)
W_human_pca = PCA(n_components=2).fit_transform(W_human_clean)

# === Plot 3-panel PCA visualization
fig, axes = plt.subplots(1, 3, figsize=(18, 6), sharex=True, sharey=True)

axes[0].scatter(W1_pca[:, 0], W1_pca[:, 1], alpha=0.7, edgecolors='k', s=30)
axes[0].set_title("SPoSE 1b")
axes[0].set_xlabel("PC1")
axes[0].set_ylabel("PC2")
axes[0].grid(True)

axes[1].scatter(W2_pca[:, 0], W2_pca[:, 1], alpha=0.7, edgecolors='k', s=30)
axes[1].set_title("SPoSE 12b")
axes[1].set_xlabel("PC1")
axes[1].grid(True)

axes[2].scatter(W_human_pca[:, 0], W_human_pca[:, 1], alpha=0.7, edgecolors='k', s=30)
axes[2].set_title("SPoSE Human")
axes[2].set_xlabel("PC1")
axes[2].grid(True)

plt.suptitle("2D PCA Projection of SPoSE Embeddings (1b vs 12b vs Human)")
plt.tight_layout()
plt.show()

# === Compute pairwise Procrustes r²
def procrustes_r2(X, Y):
    mtx1, mtx2, disparity = procrustes(X, Y)
    return 1 - disparity

embeddings = {
    "1b": W1_clean,
    "12b": W2_clean,
    "human": W_human_clean
}

labels = list(embeddings.keys())
r2_matrix = np.zeros((3, 3))

for i in range(3):
    for j in range(3):
        r2_matrix[i, j] = procrustes_r2(embeddings[labels[i]], embeddings[labels[j]])

# === Plot heatmap
plt.figure(figsize=(6, 5))
sns.heatmap(r2_matrix, annot=True, fmt=".3f", xticklabels=labels, yticklabels=labels, cmap="viridis", square=True)
plt.title("Procrustes r² Between Embedding Spaces")
plt.tight_layout()
plt.show()
