In [None]:
import os
import torch
import numpy as np
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA

In [None]:
def load_embeddings(epoch, folder="node_embeddings"):
    """
    Loads saved embeddings, SK_ID_CURR values, and cluster labels for a given epoch.

    Args:
        epoch (int): Epoch number used during saving
        folder (str): Folder where intermediate .pt and .npy files are stored

    Returns:
        torch.Tensor: node embeddings of shape (N, D)
        np.ndarray: corresponding SK_ID_CURR values of shape (N,)
        np.ndarray: corresponding cluster labels of shape (N,)
    """
    embedding_path = os.path.join(folder, f"val_embeddings_epoch_{epoch}.pt")
    sk_id_path = os.path.join(folder, f"val_sk_ids_epoch_{epoch}.npy")
    cluster_path = os.path.join(folder, f"val_clusters_epoch_{epoch}.npy")
    
    embeddings = torch.load(embedding_path, map_location='cpu')
    sk_ids = np.load(sk_id_path)
    clusters = np.load(cluster_path)

    return embeddings, sk_ids, clusters