In [1]:
import numpy as np
from sklearn.datasets import fetch_20newsgroups
from collections import Counter
import heapq
import nltk
from nltk.corpus import stopwords
import itertools
%matplotlib inline

In [2]:
import gensim

In [3]:
nltk.download('stopwords')

[nltk_data] Downloading package stopwords to /Users/Roman/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!


True

In [4]:
english_stopwords = set(stopwords.words('english'))

In [5]:
dataset = fetch_20newsgroups(
    remove=('headers', 'footers', 'quotes')
)

No handlers could be found for logger "sklearn.datasets.twenty_newsgroups"


In [6]:
pured_documents = []
for i, doc in enumerate(dataset.data):
    tokens = gensim.utils.lemmatize(doc)
    document = []
    for token in tokens:
        word = token.split('/')[0]
        if word not in english_stopwords:
            document.append(word)
    pured_documents.append(document)    
    if i % 500 == 0:
        print 'Processed: ', i, 'documents from', len(dataset.data)

Processed:  0 documents from 11314
Processed:  500 documents from 11314
Processed:  1000 documents from 11314
Processed:  1500 documents from 11314
Processed:  2000 documents from 11314
Processed:  2500 documents from 11314
Processed:  3000 documents from 11314
Processed:  3500 documents from 11314
Processed:  4000 documents from 11314
Processed:  4500 documents from 11314
Processed:  5000 documents from 11314
Processed:  5500 documents from 11314
Processed:  6000 documents from 11314
Processed:  6500 documents from 11314
Processed:  7000 documents from 11314
Processed:  7500 documents from 11314
Processed:  8000 documents from 11314
Processed:  8500 documents from 11314
Processed:  9000 documents from 11314
Processed:  9500 documents from 11314
Processed:  10000 documents from 11314
Processed:  10500 documents from 11314
Processed:  11000 documents from 11314


In [7]:
text = nltk.Text(list(itertools.chain.from_iterable(pured_documents)))

In [8]:
class NltkHatPlayer(object):
    def __init__(self):
        pass
    
    def guess(self, words):
        candidates = Counter()
        for word in words:
            for w in self.explain(word)[:5]:
                candidates[w] += 1
        for word in words:
            if word in candidates:
                del candidates[word]
        return [x for x, _ in candidates.most_common(len(candidates))]
    
    def explain(self, word):
        return text._word_context_index.similar_words(word)
    
    def guess_result(self, position):
        pass
    
    def explanation_result(self, position):
        pass

In [45]:
tmp_nltk = NltkHatPlayer()
tmp_nltk.guess(['bible'] )

['god', 'say', 'damn', 'fair', 'existence']

In [46]:
tmp_nltk = GensimHatPlayer()
tmp_nltk.guess(['bible'] )

['scripture',
 'interpretation',
 'passage',
 'atheist',
 'verse',
 'biblical',
 'contradiction',
 'christian',
 'atheism',
 'matthew']

In [28]:
tmp_nltk.explain('hat')

['bible',
 'christian',
 'face',
 'joseph',
 'shirt',
 'afford',
 'price',
 'topic',
 'range',
 'disappoint',
 'require']

In [9]:
from gensim.models import Word2Vec

In [10]:
model = Word2Vec(pured_documents, size=100, window=5, min_count=5, workers=4)

In [11]:
class GensimHatPlayer(object):
    def __init__(self):
        pass
    
    def guess(self, words):
        return [x for x, _ in model.most_similar(positive=words)]
    
    def explain(self, word):
        return [x for x, _ in model.most_similar(positive=[word])]
    
    def guess_result(self, position):
        pass
    
    def explanation_result(self, position):
        pass

In [57]:
class MixHatPlayer(object):
    def __init__(self):
        ##w = [0.5, 0.5]
        Nltk = NltkHatPlayer()
        Gensim = GensimHatPlayer()
        
        
    
    def guess(self, words):
        
        Nltk = NltkHatPlayer()
        Gensim = GensimHatPlayer()
        
        pred_Nltk = Nltk.guess(words)
        pred_Gensim = Gensim.guess(words)
         
        ans = []
        for word_nltk, word_gensim in zip(pred_Nltk[:5], pred_Gensim[:5]):
            ans.append(word_nltk)
            ans.append(word_gensim)
        return ans
        
    
    def explain(self, word):

        Nltk = NltkHatPlayer()
        Gensim = GensimHatPlayer()
        
        pred_Nltk = Nltk.explain(word)
        pred_Gensim = Gensim.explain(word)
         
        ans = []
        min_len = min(len(pred_Nltk), len(pred_Gensim))
        
        for i, nltk_word, gensim_word in zip(range(min_len), pred_Nltk, pred_Gensim):
            ans.append(nltk_word)
            ans.append(gensim_word)
        return ans
            
    
    def guess_result(self, position):
        pass
    
    def explanation_result(self, position):
        pass

In [12]:
words_universe = model.vocab.keys()
words_universe_set = set(model.vocab.keys())

In [13]:
def calc_score(player1, player2, word):
    explanation = [x for x in player1.explain(word) if x in words_universe_set][:10]
    exp_score = 0.
    if word in explanation:
        exp_score -= 1.0
    guess_score = 0.
    if len(explanation) > 0:
        for pref in xrange(10):
            guess = player2.guess(explanation[:pref + 1])[:10]
            try:
                pos = guess.index(word)
                player1.explanation_result(pos)
                player2.guess_result(pos)
                guess_score += 0.9 ** pos
                exp_score += 0.9 ** pos
            except ValueError:
                player1.explanation_result(None)
                player2.guess_result(None)
    return exp_score, guess_score


def play(player1, player2, rounds, seed=42):
    random_gen = np.random.RandomState(seed)
    player1_exp_score = 0.
    player1_guess_score = 0.
    player2_exp_score = 0.
    player2_guess_score = 0.
    for _ in xrange(rounds):
        word = random_gen.choice(words_universe)
        player1_exp, player2_guess = calc_score(player1, player2, word)
        player2_exp, player1_guess = calc_score(player2, player1, word)
        player1_explanation = player1.explain(word)
        player2_explanation = player1.explain(word)
        
        player1_exp_score += player1_exp
        player2_exp_score += player2_exp
        player1_guess_score += player1_guess
        player2_guess_score += player2_guess
        
    return (
        (player1_exp_score, player1_guess_score, player1_guess_score + player1_exp_score),
        (player2_exp_score, player2_guess_score, player2_guess_score + player2_exp_score)
    )

In [14]:
%%time
play(GensimHatPlayer(), GensimHatPlayer(), 1000)

CPU times: user 38.6 s, sys: 239 ms, total: 38.9 s
Wall time: 20 s


((7837.411001571994, 7837.411001571994, 15674.822003143989),
 (7837.411001571994, 7837.411001571994, 15674.822003143989))

In [18]:
text.similar('sdas')

No matches


In [19]:
%%time
play(NltkHatPlayer(), GensimHatPlayer(), 1000)

CPU times: user 33.8 s, sys: 216 ms, total: 34 s
Wall time: 17.1 s


((132.04689594800004, 161.49928565300002, 293.54618160100006),
 (161.49928565300002, 132.04689594800004, 293.54618160100006))

In [21]:
%%time
play(NltkHatPlayer(), NltkHatPlayer(), 1000)

CPU times: user 3min 11s, sys: 3.37 s, total: 3min 15s
Wall time: 3min 13s


((2144.5650743320007, 2144.5650743320007, 4289.130148664001),
 (2144.5650743320007, 2144.5650743320007, 4289.130148664001))

In [58]:
play(MixHatPlayer(), GensimHatPlayer(), 1000)

((814.5562650130004, 6344.341959767, 7158.89822478),
 (6344.341959767, 814.5562650130004, 7158.89822478))

In [60]:
play(MixHatPlayer(), NltkHatPlayer(), 1000)

((1652.9520942740003, 1688.9459514399991, 3341.8980457139996),
 (1688.9459514399991, 1652.9520942740003, 3341.8980457139996))