# Doc2vec (DBOW) on IMDB dataset

Steps:
* Tokenize punctuations as if they are their own words
* Determine the longest review's word count, then pad other reviews so that they are all as long as the longest review

In [1]:
import glob
import re
import sys
import gensim
import logging
from text_tokenizer import tokenize
from bs4 import BeautifulSoup
from gensim.models import Doc2Vec
from gensim.models.doc2vec import LabeledSentence, TaggedDocument

Using gpu device 0: GeForce GTX 1060 6GB (CNMeM is disabled, cuDNN 5105)


In [2]:
# Easily changable settings
# We will only train from training/unlabeled set
text_corpus_files = ['aclImdb/train/pos/*.txt', 'aclImdb/train/neg/*.txt', 'aclImdb/train/unsup/*.txt']
word_vector_dims = 100
model_save_as = "word2vec/d2v-imdb-dbow-{0}d.modelxx".format(word_vector_dims)

In [3]:
processed_texts = []
file_names = []
file_count = 0
for folder_files in text_corpus_files:
    for text_file in glob.glob(folder_files):
        with(open(text_file, 'r')) as f:
            processed_texts.append(tokenize(f.read()))
            file_names.append(text_file)
            file_count += 1
            if file_count % 100 == 0:
                sys.stdout.write('\rLoading text file {0:d}'.format(file_count))
                sys.stdout.flush()
                
max_processed_text_len = len(max(processed_texts, key=len))
print('\nLongest text list: {0:d}'.format(max_processed_text_len))
# for i, text_list in enumerate(processed_texts):
#     processed_texts[i] = pad_text_list(text_list, pad_width=max_processed_text_len)
#     if (i + 1) % 1000 == 0:
#         sys.stdout.write('\rPadding text list {0:d}'.format(i+1))
#         sys.stdout.flush()

Loading text file 75000
Longest text list: 2773


In [4]:
class LabeledReview(object):
    def __init__(self, docs_list, labels_list):
        self.docs_list = docs_list
        self.labels_list = labels_list
        
    def __iter__(self):
        for idx, doc in enumerate(self.docs_list):
            yield TaggedDocument(words=doc, tags=[self.labels_list[idx]])

In [5]:
%timeit
it = LabeledReview(processed_texts, file_names)

model = Doc2Vec(size=word_vector_dims, window=8, min_count=1, workers=4, alpha=0.025, min_alpha=0.025, dm=0)
model.build_vocab(it)

# TIMER
import time
start_time = time.time()
# END TIMER

for epoch in range(10):
    print("Beginning epoch {0:d}".format(epoch+1))
    model.train(it)
    model.alpha -= 0.002
    model.min_alpha = model.alpha
    model.train(it)
    
print("--- %s seconds ---" % (time.time() - start_time))

Beginning epoch 1
Beginning epoch 2
Beginning epoch 3
Beginning epoch 4
Beginning epoch 5
Beginning epoch 6
Beginning epoch 7
Beginning epoch 8
Beginning epoch 9
Beginning epoch 10
--- 1243.59263802 seconds ---


In [8]:
model.save(model_save_as)

# Test loading from file

In [9]:
# test_model = gensim.models.Doc2Vec.load_word2vec_format('word2vec/d2v-padded.bin', binary=True)
test_model = Doc2Vec.load(model_save_as)

In [10]:
test_model.most_similar('robot')

[(u'electri', 0.44531017541885376),
 (u'guttridge', 0.4025043249130249),
 (u'municipalians', 0.392833411693573),
 (u'unreconstructed', 0.39157259464263916),
 (u'bumpkins', 0.3887981176376343),
 (u'tiedied', 0.3848254084587097),
 (u'murad', 0.3837975263595581),
 (u'dobro', 0.38057032227516174),
 (u'verhopven', 0.3724980652332306),
 (u'operable', 0.37155359983444214)]

In [13]:
def infer_vector(text):
    test_model.infer_vector(tokenize(text))

In [14]:
infer_vector('Apple decides to kill ornage')

In [15]:
test_model.infer_vector('Can also accept string but who knows?')

array([-0.17231174, -0.37082675,  0.28180525,  0.03994866,  0.20616618,
       -0.0208779 ,  0.29593143, -0.11548571, -0.05137589, -0.30965227,
        0.31568965, -0.54743063,  0.50441384,  0.39466131, -0.23167916,
       -0.51313728, -0.23646614, -0.37890327, -0.41118169,  0.07404428,
        0.37155417,  0.15627038, -0.00802261, -0.41750944, -0.17853852,
        0.43775645, -0.52300262,  0.14638248,  0.07199747,  0.59722763,
        0.31002092, -0.29298946, -0.1716281 , -0.27529186, -0.10400043,
       -0.27337196,  0.15527037, -0.1680723 , -0.36725488,  0.22694018,
        0.17879798, -0.06368355,  0.44016534, -0.52956432, -0.1154929 ,
       -0.41443637, -0.0083477 , -0.02827242, -0.19422367,  0.10291863,
       -0.2656284 , -0.5024243 ,  0.09802187,  0.36961812,  0.22301677,
        0.26114619, -0.54922807,  0.1528462 , -0.47032312, -0.14636959,
        0.04742416,  0.49284756,  0.0812405 ,  0.18190728, -0.27619174,
       -0.15752394,  0.15506928,  0.64426523,  0.24821791,  0.99