In [1]:
import colorsys
import gensim
from gensim.models.callbacks import LossLogger, LossSetter
from IPython.display import HTML
import itertools
import logging
import matplotlib.pyplot as plt
from matplotlib import animation, rc, cm
from matplotlib.collections import LineCollection
import numpy as np
from numpy.linalg import norm
from numpy import dot
from scipy.spatial.distance import pdist, squareform
from sklearn.manifold import TSNE
import os
import pandas as pd
import plotly.plotly as py
import plotly.graph_objs as go
from plotly.offline import download_plotlyjs, init_notebook_mode, plot, iplot
import sys

init_notebook_mode(connected=True)
logging.basicConfig(level=logging.WARN)

ROOT = "/Users/alext/Documents/Master/Thesis/"

In [2]:
# model_fn = os.path.join(ROOT, "models/word2vec_baseline/w2v_levy_sg_5_2_A025_a0001_n5_w5_c25000_cosine_OPTsgd")
model_fn = os.path.join(ROOT, "models/geometric_emb/w2v_levy_nll_3_2_A05_a001_n5_w5_c5000_poincare_OPTwfullrsgd_burnin1")
model = gensim.models.Word2Vec.load(model_fn)
emnlp_wv = model.wv

In [3]:
def load_minkowski_vectors(fname):
    """
    Loads a minkowski word vectors text file, returns DataFrame.
    """
    syn0 = pd.read_csv(fname, header=None, sep=' ',
                       na_values=None, keep_default_na=False # these two are needed since otherwise Pandas maps "null" and "nan" to np.nan!
                       ).set_index(0)
    syn0.index = syn0.index.map(lambda x: str(x))
    return syn0

def lorentz2poincare(X):
    return X[:, :-1] / (X[:, -1][:, None] + 1)

In [4]:
vectors_pd = load_minkowski_vectors(os.path.join(ROOT, "emnlp_paper_code/outputs/output_0.05_1ep_3D.csv"))
vectors = np.array(
    [vectors_pd.loc[word].values for word in emnlp_wv.index2word]
)

In [5]:
emnlp_wv.vectors = lorentz2poincare(vectors)
print(vectors.shape, emnlp_wv.vectors.shape, vectors[0], emnlp_wv.vectors[0])

(15984, 3) (15984, 2) [ -8324.29893807 -14884.99431709  17054.53047229] [-0.48807036 -0.87273711]


In [6]:
# Get vocabulary of the Google word analogy benchmark + labels for each section
def read_word_dict():
    filename = os.path.join(ROOT, "data/google_analogy_vocab_labeled.txt")

    with open(filename, "r") as f:
        lines = [line.strip().split(" ") for line in f.readlines()]
    result_dict = {}
    for l, w in lines:
        if l not in result_dict:
            result_dict[l] = []
        result_dict[l].append(w)
    return result_dict

def HSVToRGB(h, s, v):
    (r, g, b) = colorsys.hsv_to_rgb(h, s, v)
    return (int(255*r), int(255*g), int(255*b))

def get_colors(word_dict):
    labels = word_dict.keys()
    huePartition = 1.0 / (len(labels) + 1)
    return dict(zip(labels, [HSVToRGB(huePartition * value, 1.0, 1.0) for value in range(0, len(labels))]))


def plot_embeddings(wv, word_dict, ratio_words=0.1):
    label_whitelist = [
        'capital-world', 'days', 'currency', 'seasons','family', 'city-in-state', 'capital-common-countries', 'digits',
#         'gram6-nationality-adjective', 'gram9-plural-verbs', 'gram2-opposite', 'gram5-present-participle', 
#         'gram1-adjective-to-adverb', 'gram8-plural', 'gram4-superlative', 'gram7-past-tense', 'gram3-comparative'
    ]
    colors = get_colors(labeled_word_dict)
    embeddings = np.array([wv.word_vec(w) for w in wv.index2entity])
    traces = []
    
    for label, words in word_dict.items():
        if label not in label_whitelist:
            continue
        idxs = [wv.vocab[w].index for w in filter(lambda w: w in wv.vocab, words)]


        idxs = idxs[:int(len(idxs) * ratio_words)]

        traces.append(
            go.Scatter(
                x=embeddings[idxs, 0],
                y=embeddings[idxs, 1],
                text=[wv.index2word[idx] for idx in idxs],
                textposition='top right',
                name=label,
                mode="markers+text",
                marker=dict(color="rgb"+str(colors[label]))))

    return iplot(dict(data=traces))

labeled_word_dict = read_word_dict()
len(labeled_word_dict)

17

In [7]:
plot_embeddings(emnlp_wv, labeled_word_dict, ratio_words=1)