In [None]:
# Imports
import numpy as np
from sklearn.datasets import fetch_20newsgroups
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.neural_network import MLPClassifier
import pandas as pd

In [None]:
# Set Network Parameters
window_size = 2
dim = 100

In [None]:
# Retrieve Data
ng = fetch_20newsgroups(categories=['rec.sport.baseball']).data[0:50]

In [None]:
# Establish Integer indices for each unique term
cv = CountVectorizer()
ng_vecs = cv.fit_transform(ng)
# Store those indices here
vocab = cv.vocabulary_
# And the reverse mapping
id2word = {v:k for (k,v) in vocab.items()}
# The total unique words, aka size of vocabulary
V = len(vocab)

In [None]:
# Use CountVectorizer to turn our list of documents into a Series of lists of terms
tokenizer = cv.build_tokenizer()
tokenized_docs = pd.Series(ng).map(tokenizer).map(lambda x: [a.lower() for a in x])

In [32]:
# Generate the X data matrix and y vector for MLP
# X: A matrix of one-hot encoded vectors (dimension V) for each center word over all context windows (size 2+2+1=5)
# y: A matrix over all context windows where the outputs are the 4 "labels", aka the indices of the 4 "other" context words
X = []
y = []
# Step thru tokenized document list
for doc in tokenized_docs:
    # For each document, step thru the words
    for index, word in enumerate(doc): 
        # Skip if at the edge of a document (can handle differently)
        if index < 2 or index > (len(doc)-3):
            continue
        # Retrieve the one-hot V-dimensional input vector and add it to X
        one_hot_input = [0]*V 
        one_hot_input[vocab[word]] = 1
        # HACK: Had to do the window cooccurrences separately as MLP won't support multilabel tho it says it does
        X.append(one_hot_input)
        X.append(one_hot_input)
        X.append(one_hot_input)
        X.append(one_hot_input)
        # Retrieve the 4 outputs for the context window and add them to y
        # Same HACK here
        context1 = doc[index-2]
        y.append(vocab[context1])
        context2 = doc[index-1]
        y.append(vocab[context2])
        context3 = doc[index+1]
        y.append(vocab[context3])
        context4 = doc[index+2]
        y.append(vocab[context4])

In [33]:
# Convert to Numpy arrays
X = np.array(X)
y = np.array(y)

In [34]:
# Fit the MLP Classifier
mlp = MLPClassifier(hidden_layer_sizes=(dim,))
mlp.fit(X, y)

MLPClassifier(activation='relu', alpha=0.0001, batch_size='auto', beta_1=0.9,
       beta_2=0.999, early_stopping=False, epsilon=1e-08,
       hidden_layer_sizes=(100,), learning_rate='constant',
       learning_rate_init=0.001, max_iter=200, momentum=0.9,
       nesterovs_momentum=True, power_t=0.5, random_state=None,
       shuffle=True, solver='adam', tol=0.0001, validation_fraction=0.1,
       verbose=False, warm_start=False)

In [38]:
# Here are the word vectors!!
word_vecs = mlp.coefs_[0]

In [54]:
# Look at the closest words to a query using a K-Nearest Neighbors search with cosine
from sklearn.neighbors import NearestNeighbors
nn = NearestNeighbors(metric='cosine', algorithm='brute')
nn.fit(word_vecs)

NearestNeighbors(algorithm='brute', leaf_size=30, metric='cosine',
         metric_params=None, n_jobs=1, n_neighbors=5, p=2, radius=1.0)

In [84]:
# Look at the closest words to a query using a K-Nearest Neighbors search with cosine
def get_similar(query, n=10):
    query_index = vocab[query]
    if query_index:
        dist, index = nn.kneighbors(word_vecs[query_index, :], n_neighbors=n)
        return ([(id2word[i[0]], d[0]) for (d, i) in zip(dist.transpose(), index.transpose())])
    else:
        return "Query not in the dataset!"

In [86]:
# Try it out!
get_similar("bat", 20)



[('bat', 0.0),
 ('pivot', 0.56249454553458655),
 ('estimate', 0.57451018888023431),
 ('hand', 0.57471723085951176),
 ('speed', 0.58302086584558155),
 ('lesat', 0.5952954103121918),
 ('moved', 0.60489928762136724),
 ('remember', 0.60652218493149701),
 ('day', 0.60773438404801505),
 ('shutout', 0.61164096971210746),
 ('result', 0.61224321277217053),
 ('nicely', 0.61524505916367644),
 ('holding', 0.61683997165445881),
 ('timestamps', 0.62353061497667694),
 ('along', 0.6244347863014339),
 ('floor', 0.62646560203163504),
 ('plants', 0.62679990677947361),
 ('keep', 0.63238050219870612),
 ('qualifications', 0.63387099268602454),
 ('absolutely', 0.63415533869202678)]