In [5]:
#embedding_visualization.ipynb
import os
import torch
import numpy as np
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA

In [6]:
def load_embeddings(epoch, folder="node_embeddings"):
    """
    Loads saved embeddings, SK_ID_CURR values, and cluster labels for a given epoch.

    Args:
        epoch (int): Epoch number used during saving
        folder (str): Folder where intermediate .pt and .npy files are stored

    Returns:
        torch.Tensor: node embeddings of shape (N, D)
        np.ndarray: corresponding SK_ID_CURR values of shape (N,)
        np.ndarray: corresponding cluster labels of shape (N,)
    """
    embedding_path = os.path.join(folder, f"val_embeddings_epoch_{epoch}.pt")
    sk_id_path = os.path.join(folder, f"val_sk_ids_epoch_{epoch}.npy")
    cluster_path = os.path.join(folder, f"val_clusters_epoch_{epoch}.npy")
    
    embeddings = torch.load(embedding_path, map_location='cpu')
    sk_ids = np.load(sk_id_path)
    clusters = np.load(cluster_path)

    return embeddings, sk_ids, clusters

In [7]:
import numpy as np

def plot_embeddings_tsne(embeddings, sk_ids=None, clusters=None, annotate=True, title="t-SNE of Node Embeddings", perplexity=30, random_state=42):
    """
    Visualizes node embeddings using t-SNE.

    Args:
        embeddings (torch.Tensor): Node embeddings [N, D]
        sk_ids (np.ndarray): SK_ID_CURR values for annotation (optional)
        clusters (np.ndarray): Cluster labels for color-coding (optional)
        annotate (bool): Whether to annotate selected points with SK_IDs
        title (str): Plot title
        perplexity (int): t-SNE perplexity
        random_state (int): t-SNE random seed
    """
    embeddings_np = embeddings.detach().cpu().numpy()
    tsne = TSNE(n_components=2, perplexity=perplexity, random_state=random_state, init='pca')
    reduced = tsne.fit_transform(embeddings_np)

    plt.figure(figsize=(10, 8))

    if clusters is not None:
        plt.scatter(reduced[:, 0], reduced[:, 1], c=clusters, cmap='tab10', alpha=0.6)
    else:
        plt.scatter(reduced[:, 0], reduced[:, 1], alpha=0.6)

    if annotate and sk_ids is not None:
        for i in range(len(sk_ids)):
            if i % max(1, len(sk_ids) // 100) == 0:
                plt.text(reduced[i, 0], reduced[i, 1], str(sk_ids[i]), fontsize=6, alpha=0.5)

    plt.title(title)
    plt.xlabel("t-SNE Dimension 1")
    plt.ylabel("t-SNE Dimension 2")
    plt.grid(True)
    plt.tight_layout()
    plt.show()