In [1]:
from sklearn.manifold import TSNE
import numpy as np
import matplotlib.pyplot as plt
from scipy.spatial import distance

In [2]:
import gensim.downloader
# Show all available models in gensim-data
print(list(gensim.downloader.info()['models'].keys()))

glove_vectors = gensim.downloader.load('glove-wiki-gigaword-300')

['fasttext-wiki-news-subwords-300', 'conceptnet-numberbatch-17-06-300', 'word2vec-ruscorpora-300', 'word2vec-google-news-300', 'glove-wiki-gigaword-50', 'glove-wiki-gigaword-100', 'glove-wiki-gigaword-200', 'glove-wiki-gigaword-300', 'glove-twitter-25', 'glove-twitter-50', 'glove-twitter-100', 'glove-twitter-200', '__testing_word2vec-matrix-synopsis']


In [5]:
# Category -> words
data = {
    'animal': ['dog', 'elephant', 'snake', 'pig', 'cow', 'fish', 'cat'],
    'meal': ['steak', 'kebab', 'pork', 'salad', 'tomato', 'onion'],
    'vehicle': ['car', 'motorcycle', 'bike', 'plane', 'skateboard', 'helicopter', 'bicycle'],
    'device': ['computer', 'keyboard', 'monitor', 'cpu', 'tv', 'phone']
}
# Words -> category
categories = {word: key for key, words in data.items() for word in words}

# Load the whole embedding matrix
embeddings_index = {}
with open('glove.6B.100d.txt', encoding="utf8") as f:
  for line in f:
    values = line.split()
    word = values[0]
    embed = np.array(values[1:], dtype=np.float32)
    embeddings_index[word] = embed
print('Loaded %s word vectors.' % len(embeddings_index))
# Embeddings for available words
data_embeddings = {key: value for key, value in embeddings_index.items() if key in categories.keys()}

# Processing the query
def process(query):
  query_embed = embeddings_index[query]
  scores = {}
  for word, embed in data_embeddings.items():
    category = categories[word]
    dist = query_embed.dot(embed)
    dist /= len(data[category])
    scores[category] = scores.get(category, 0) + dist
  return scores

# Testing
print(process('bird'))
print(process('burger'))
print(process('airplane'))
print(process('microphone'))

Loaded 400000 word vectors.
{'vehicle': 6.22577667236328, 'device': 5.250414292017618, 'animal': 18.646080017089847, 'meal': 6.7603126764297485}
{'vehicle': 1.394062876701355, 'device': 1.9254615505536397, 'animal': 7.955170495169503, 'meal': 12.270964622497559}
{'vehicle': 16.49170017242432, 'device': 8.591289440790812, 'animal': 6.5290389742170065, 'meal': -0.005848323305447906}
{'vehicle': 6.872500487736293, 'device': 13.15525778134664, 'animal': 4.632625784192767, 'meal': 1.0037638545036316}
