In [None]:
# Imports
from os import makedirs
from os.path import join
import joblib
import numpy as np
rng_seed = 399
np.random.seed(rng_seed)
from matplotlib import pyplot as plt
import seaborn as sns
sns.set_theme()
from tqdm.auto import tqdm
import pandas as pd
import gudhi as gd
from gudhi.wasserstein import wasserstein_distance

import plotly.offline as pyo
pyo.init_notebook_mode()

# Directory constants
topological_data_analysis_data_dir = "data"
root_code_dir = ".."
output_dir = join(root_code_dir, "output")
word2vec_training_dir = join(output_dir, "word2vec_training")
word2vec_cluster_analysis_dir = join(output_dir, "word2vec_cluster_analysis")

# Extend sys path for importing custom Python files
import sys
sys.path.append(root_code_dir)

from utils import get_model_checkpoint_filepaths, pairwise_cosine_distances, words_to_vectors
from word_embeddings.word2vec import load_model_training_output
from vis_utils import plot_word_vectors
from topological_data_analysis.tda_utils import plot_persistence_diagram

# Prepare data

In [None]:
# Load output from training word2vec
w2v_training_output = load_model_training_output(
    model_training_output_dir=join(word2vec_training_dir, "word2vec_enwiki_sept_2020_word2phrase"),
    model_name="word2vec",
    dataset_name="enwiki",
)
last_embedding_weights = w2v_training_output["last_embedding_weights"]
words = w2v_training_output["words"]
word_to_int = w2v_training_output["word_to_int"]

In [None]:
# Restrict vocabulary size for analysis
vocab_size = 1000
vocabulary = list(range(vocab_size))

# Topolocial polysemy

In [None]:
def punctured_neighbourhood(
    target_word: str,
    word_to_int: dict,
    word_embeddings: np.ndarray,
    word_embeddings_pairwise_dists: np.ndarray,
    neighbourhood_size: int,
) -> np.ndarray:
    """
    TODO: Docs
    """
    # Find neighbouring words (excluding the target word itself)
    target_word_idx = word_to_int[target_word]
    neighbourhood_distances = word_embeddings_pairwise_dists[target_word_idx]
    neighbourhood_sorted_indices = np.argsort(neighbourhood_distances)[1:neighbourhood_size + 1]
    neighbouring_word_embeddings = word_embeddings[neighbourhood_sorted_indices]
    return neighbouring_word_embeddings

In [None]:
def tps(
    target_word: str,
    word_embeddings: np.ndarray,
    words_vocabulary: list,
    word_to_int: dict,
    neighbourhood_size: int
) -> None:
    """
    TODO: Docs
    
    Parameters
    ----------
    word_embeddings : np.ndarray
        Word embeddings
    words_vocabulary : list
        List of either words (str) or word integer representations (int), signalizing
        what part of the vocabulary we want to use.
    """
    # Create word vectors from given words/vocabulary
    word_vectors = words_to_vectors(
        words_vocabulary=words_vocabulary,
        word_to_int=word_to_int,
        word_embeddings=word_embeddings,
    )
    
    # Compute pairwise distances between each word vector
    pairwise_word_vector_distances = pairwise_cosine_distances(word_vectors)
    
    # Normalize all word vectors to have L2-norm
    word_vectors_norm = word_vectors / np.linalg.norm(word_vectors)
    
    # Compute punctured neighbourhood
    target_word_punctured_neighbourhood = punctured_neighbourhood(
        target_word=target_word,
        word_to_int=word_to_int,
        word_embeddings=word_vectors_norm,
        word_embeddings_pairwise_dists=pairwise_word_vector_distances,
        neighbourhood_size=neighbourhood_size
    )
    
    # Project word vectors in punctured neighbourhood to the unit sphere
    target_word_punctured_neighbourhood_sphere = np.zeros(target_word_punctured_neighbourhood.shape)
    target_word_vector_w = word_vectors_norm[word_to_int[target_word]]
    for i, v in enumerate(target_word_punctured_neighbourhood):
        w_v_diff = v - target_word_vector_w
        target_word_punctured_neighbourhood_sphere[i] = w_v_diff / np.linalg.norm(w_v_diff)

    # TODO: Compute the degree zero persistence diagram of punctured neighbourhood (projected to the unit sphere)
    return target_word_punctured_neighbourhood_sphere

In [None]:
target_word_punctured_neighbourhood_sphere = tps(
    word_embeddings=last_embedding_weights,
    words_vocabulary=vocabulary,
    word_to_int=word_to_int,
    target_word="summer",
    neighbourhood_size=5
)

In [None]:
rips_complex = gd.RipsComplex(points=target_word_punctured_neighbourhood_sphere)

In [None]:
simplex_tree = rips_complex.create_simplex_tree(max_dimension=0)

In [None]:
barcodes = simplex_tree.persistence()
gd.plot_persistence_diagram(barcodes)

In [None]:
zero_degree_diagram_points = np.array([[birth, death] for _, (birth, death) in barcodes])
zero_degree_diagram_points

In [None]:
empty_degree_diagram_points = np.zeros(zero_degree_diagram_points.shape)
empty_degree_diagram_points

In [None]:
wasserstein_distance(X=zero_degree_diagram_points, Y=empty_degree_diagram_points)