In [1]:
import numpy as np
from flyvec import FlyVec

In [2]:
model = FlyVec.load()

In [4]:
def sim(a,b):
    assert a.shape == b.shape
    return (np.sum(a == b) / a.shape[0])

In [5]:
def get_most_similar_word(w):
    a = model.get_sparse_embedding(w)['embedding']
    max_sim = -1
    max_sim_word = ""
    for word in model.token_vocab:
        if word != w and model.tokenizer.tokenize(word) != []:
            embedding = model.get_sparse_embedding(word)['embedding']
            similarity = sim(a, embedding)
            if similarity > max_sim:
                max_sim = similarity
                max_sim_word = word
    print(max_sim, max_sim_word)

In [7]:
def get_top_k_similar_words(w, k):
    d = dict()
    for word in model.token_vocab:
        if word != w and model.tokenizer.tokenize(word) != []:
            embedding = model.get_sparse_embedding(word)['embedding']
            similarity = sim(model.get_sparse_embedding(w)['embedding'], embedding)
            d[word] = similarity
    return sorted(d.items(), key=lambda item: item[1])[-k:]

In [8]:
get_top_k_similar_words("car", 5)

[('cab', 0.905),
 ('tractor', 0.905),
 ('vehicle', 0.91),
 ('truck', 0.91),
 ('suv', 0.91)]

In [9]:
get_top_k_similar_words("chemicals", 10)

[('acid', 0.915),
 ('algae', 0.915),
 ('fluids', 0.915),
 ('poisonous', 0.915),
 ('potassium', 0.915),
 ('synthetic', 0.92),
 ('substances', 0.93),
 ('pesticides', 0.935),
 ('toxins', 0.935),
 ('toxic', 0.945)]

In [10]:
get_top_k_similar_words("language", 10)

[('descriptions', 0.895),
 ('notation', 0.895),
 ('languages', 0.9),
 ('phrases', 0.9),
 ('metaphors', 0.9),
 ('calculus', 0.9),
 ('translation', 0.905),
 ('vocabulary', 0.905),
 ('grammar', 0.905),
 ('concepts', 0.91)]

In [11]:
get_top_k_similar_words("electric", 10)

[('detector', 0.89),
 ('coil', 0.89),
 ('diesel', 0.895),
 ('toyota', 0.895),
 ('hydrogen', 0.895),
 ('motors', 0.9),
 ('powered', 0.9),
 ('tesla', 0.9),
 ('hybrid', 0.905),
 ('solar', 0.905)]

In [12]:
get_top_k_similar_words("apple", 10)

[('smartphone', 0.905),
 ('iphone', 0.905),
 ('blackberry', 0.905),
 ('dell', 0.905),
 ('asus', 0.905),
 ('apple’s', 0.91),
 ('motorola', 0.91),
 ('apples', 0.91),
 ('htc', 0.91),
 ('samsung', 0.935)]