In [1]:
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

In [2]:
from sklearn.neighbors import NearestNeighbors

In [3]:
import os

In [4]:
glove_dir = "/Users/aleverentz/Desktop/glove.6B"

In [5]:
for f in os.listdir(glove_dir):
    print(f)

glove.6B.100d.txt
glove.6B.200d.txt
glove.6B.300d.txt
glove.6B.50d.txt


In [6]:
dim = 50

In [7]:
from tqdm import tqdm

In [8]:
%%time
vocab = []
glove_matrix_rows = []
with open(glove_dir + "/glove.6B.{}d.txt".format(dim)) as f:
    for line in tqdm(f):
        word, *rest = line.rstrip().split()
        vocab.append(word)
        rest = [float(x) for x in rest]
        glove_matrix_rows.append(rest)
print("Constructing matrix")
glove_matrix = np.array(glove_matrix_rows)
del glove_matrix_rows

400000it [00:09, 40998.96it/s]

Constructing matrix
CPU times: user 10 s, sys: 634 ms, total: 10.6 s
Wall time: 10.6 s





In [9]:
print(len(vocab))
print(glove_matrix.shape)

400000
(400000, 50)


In [10]:
n_obj = NearestNeighbors()

In [11]:
%%time
n_obj.fit(glove_matrix)
None

CPU times: user 3.29 s, sys: 60.1 ms, total: 3.35 s
Wall time: 3.42 s


In [12]:
word = "elegant"
word_index = vocab.index(word)

print("Nearest neighbors of '{}' based on {}-d glove embeddings:".format(word, dim))
dist, ind = n_obj.kneighbors(glove_matrix[[word_index], :])
ind = ind.flatten()
for i in range(len(ind)):
    if ind[i] != word_index:
        print(vocab[ind[i]])

Nearest neighbors of 'elegant' based on 50-d glove embeddings:
stylish
graceful
elegantly
decor


In [13]:
# Analogies demo

words = ["paris", "france", "beijing"]
indices = [vocab.index(w) for w in words]

query_vector = glove_matrix[indices[2], :] + glove_matrix[indices[1], :] - glove_matrix[indices[0], :]
nn_dist, nn_index = map(np.asscalar, n_obj.kneighbors([query_vector], n_neighbors=1))
nn_word = vocab[nn_index]

words.append(nn_word)
indices.append(nn_index)

print("According to {}-d glove embeddings,".format(dim))
print("{} : {} :: {} : {}".format(*words))
print("Or, via indices,")
print("{} : {} :: {} : {}".format(*indices))
print("Distance from query vector to its nearest neighbor: {}".format(nn_dist))

According to 50-d glove embeddings,
paris : france :: beijing : china
Or, via indices,
1035 : 387 :: 942 : 132
Distance from query vector to its nearest neighbor: 3.038782864600168
