In [1]:
import torch
import torchtext.vocab as vocab

In [2]:
glove = vocab.GloVe(name='6B', dim=100)
print(f'There are {len(glove.itos)} words in the GloVe embeddings')

.vector_cache\glove.6B.zip: 862MB [07:18, 1.97MB/s]                               
100%|█████████▉| 399999/400000 [00:30<00:00, 13061.55it/s]


There are 400000 words in the GloVe embeddings


In [3]:
glove.vectors.shape

torch.Size([400000, 100])

In [4]:
glove.itos[:15] # index to word

['the',
 ',',
 '.',
 'of',
 'to',
 'and',
 'in',
 'a',
 '"',
 "'s",
 'for',
 '-',
 'that',
 'on',
 'is']

In [5]:
glove.stoi['the'] # word to index

0

In [6]:
def get_vector(embeddings, word):
    assert word in embeddings.stoi, f'{word} not in vocab!'
    return embeddings.vectors[embeddings.stoi[word]]

In [8]:
get_vector(glove, 'shirt')

tensor([ 0.1130,  0.5180, -0.6156, -1.0676,  0.0733,  1.3043,  0.4839,  0.2787,
        -0.7293,  0.6816,  0.0643,  0.1023,  0.5013,  0.2144,  0.5447,  0.5980,
         0.5183, -0.4394,  0.9571, -1.2551, -0.0096,  0.2459,  0.2231, -0.4758,
         0.7532,  1.1250, -0.4245, -0.9636, -0.0158, -1.1530,  0.4089,  0.4540,
        -0.0423, -0.2053, -0.0396,  0.6109, -0.6322, -0.1387, -0.0756,  0.3693,
         0.6952, -0.5582,  0.9985, -0.3462, -0.9853, -0.3707, -0.1073,  1.0643,
         0.1181, -0.5740, -0.2617,  0.0103, -0.1591,  0.6767, -0.4428, -1.5283,
        -0.7698,  0.0069,  1.1274,  0.4353, -0.0525,  0.4147, -0.9185,  0.5398,
         0.2016,  0.0181,  0.6732,  0.0315, -0.1037,  0.2568,  0.5192,  0.2458,
        -0.1006, -0.8492,  0.2326,  0.9118,  0.5317,  0.3360,  0.0194,  0.0987,
         0.0801, -1.1883,  0.2410, -0.6147, -0.2439, -0.9996, -0.9850,  0.1256,
         0.0151,  0.9764, -0.0357, -0.1035,  0.6610, -0.5702, -0.1392,  0.0119,
        -0.6207,  0.3472, -0.2152, -0.13

In [9]:
def closest(embeddings, vector, n=6):
    '''
    Find 6 closest words to a given vector.
    '''
    distances = []

    for neighbor in embeddings.itos:
        distances.append((neighbor, torch.dist(vector, get_vector(embeddings, neighbor))))
    
    return sorted(distances, key=lambda x: x[1])[:n]

In [10]:
closest(glove, get_vector(glove, 'shirt'))

[('shirt', tensor(0.)),
 ('shirts', tensor(3.2476)),
 ('jacket', tensor(3.3803)),
 ('jeans', tensor(3.8244)),
 ('pants', tensor(3.8720)),
 ('sweater', tensor(4.0266))]

In [12]:
def print_tuples(tuples):

    for t in tuples:
        print(f'({t[1]:.4f}) {t[0]}')

In [14]:
print_tuples(closest(glove, get_vector(glove, 'stupendous'), n=6))

(0.0000) stupendous
(2.5795) marvellous
(2.7539) frightful
(2.8506) stupefying
(2.8561) awe-inspiring
(2.9179) mind-blowing


In [15]:
def analogy(embeddings, w1, w2, w3, n=6):

    print(f'\n[{w1} : {w2} :: {w3} : ?]')

    closest_words = closest(
        embeddings,
        get_vector(embeddings, w2) - get_vector(embeddings, w1) + get_vector(embeddings, w3),
        n + 3
    )

    closest_words = [x for x in closest_words if x[0] not in [w1, w2, w3]][:n]

    return closest_words

In [16]:
print_tuples(analogy(glove, 'moon', 'night', 'sun'))


[moon : night :: sun : ?]
(5.7069) morning
(5.7276) afternoon
(5.8023) evening
(6.1410) hours
(6.2797) saturday
(6.3056) sunday


In [17]:
print_tuples(analogy(glove, 'king', 'queen', 'man'))


[king : queen :: man : ?]
(4.0811) woman
(4.6916) girl
(5.2703) she
(5.2788) teenager
(5.3084) boy
(5.3352) mother


In [18]:
print_tuples(analogy(glove, 'fly', 'bird', 'swim'))


[fly : bird :: swim : ?]
(5.9754) swimming
(6.2409) shark
(6.4822) dolphin
(6.5421) whale
(6.6276) cat
(6.6457) gorilla


In [19]:
print_tuples(analogy(glove, 'bird', 'fly', 'fish'))


[bird : fly :: fish : ?]
(6.0675) sail
(6.2088) catch
(6.2194) bound
(6.3329) safely
(6.3517) eat
(6.3662) loaded


In [20]:
import gc
gc.collect()

110