Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP GSOC 2018]: Multistream API, Part 1 #2048

Closed
Closed
Show file tree
Hide file tree
Changes from 32 commits
Commits
Show all changes
51 commits
Select commit Hold shift + click to select a range
2724812
Add wikipedia parsing script
May 14, 2018
f893487
Track performance metrics in base_any2vec.py
May 14, 2018
f03d9e6
reset performance metrics in beginning of epoch
May 14, 2018
55517fd
add tracking CPU load + benchmarking script
May 15, 2018
8ae3248
Some bug fixes
May 15, 2018
29d2dba
prettify logging results in benchmark script
May 15, 2018
5e47dfa
More prettifying in benchmark script
May 15, 2018
389293f
add SUM cpu load
May 21, 2018
b1765e7
remove sent2vec from script
May 21, 2018
4d50cff
First approach to multistream, only for word2vec right now
May 21, 2018
48f498c
adapted benchmarking script to multistream
May 21, 2018
a2a6e4f
fix
May 21, 2018
b9668ee
fix bench script
May 22, 2018
2765207
Measure vocabulary building time
May 28, 2018
d110f26
fix
May 28, 2018
c9e507f
multiprocessing multistream
May 30, 2018
44bc8f8
add w2v benchmarking script
May 30, 2018
99d0fc0
multiprocessinng for scan_vocab
May 30, 2018
ffd5204
fixes
May 30, 2018
8a0badd
without progress_per at all
May 31, 2018
f21b3a2
Merge branch 'develop' into feature/gsoc-multistream-api-1
Jun 15, 2018
75cac9d
Merge branch 'feature/gsoc-multistream-api-1' of https://github.com/p…
Jun 15, 2018
2472b2b
get rid of job_producer, make batches in _worker_loop
Jun 15, 2018
4e0c103
fix
Jun 15, 2018
3dd8a64
fix
Jun 15, 2018
d389847
make cythonlinesentence. not working, but at least compiles now
Jun 20, 2018
4c1d3a6
add operator>>
Jun 21, 2018
36882a0
change ifstream to ifstream*
Jun 21, 2018
37b55f3
fastlinesentence in c++
Jun 21, 2018
97f834d
almost working version; works on large files, but one bug is to be fixed
Jun 21, 2018
944e3dc
remove batch iterator from pyx
Jun 21, 2018
0081f01
working code
Jun 22, 2018
fe66246
remove build_vocab changes
Jun 23, 2018
491a087
approaching to fully nogil cython _worker_loop
Jun 27, 2018
15e07ae
wrapper fix
Jun 27, 2018
5cad26b
one more fix
Jun 27, 2018
495c4dc
more fixes
Jun 27, 2018
8b29df8
upd
Jun 27, 2018
2119c3a
try to cythonize batch preparation
Jun 27, 2018
3506ec9
it compiles
Jun 27, 2018
62f71ee
prepare batch inside nogil section in a while loop
Jun 28, 2018
8924af5
compiles
Jun 28, 2018
53fedfa
some bugfixes
Jun 29, 2018
c679bc6
add cpu_distribution script
Jun 29, 2018
921ff38
accept CythonLineSentence into _worker_loop, not filename
Jul 4, 2018
9e4ed0e
make CythonLineSentence iterable
Jul 4, 2018
f9ea23b
fix
Jul 4, 2018
cb8bb71
python iterators without gil
Jul 5, 2018
6162b50
fix
Jul 5, 2018
c14fca1
fixes
Jul 5, 2018
440c6df
last changes
Jul 9, 2018
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
223 changes: 140 additions & 83 deletions gensim/models/base_any2vec.py

Large diffs are not rendered by default.

23 changes: 23 additions & 0 deletions gensim/models/linesentence.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
#pragma once

#include <stdexcept>
#include "linesentence.h"


FastLineSentence::FastLineSentence(const std::string& filename) : fs_(filename) { }

std::vector<std::string> FastLineSentence::ReadSentence() {
if (fs_.eof()) {
throw std::runtime_error("EOF occured in C++!");
}
std::string line, word;
std::getline(fs_, line);
std::vector<std::string> res;

std::istringstream iss(line);
while (iss >> word) {
res.push_back(word);
}

return res;
}
15 changes: 15 additions & 0 deletions gensim/models/linesentence.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
#pragma once

#include <fstream>
#include <sstream>
#include <vector>


class FastLineSentence {
public:
explicit FastLineSentence(const std::string& filename);

std::vector<std::string> ReadSentence();
private:
std::ifstream fs_;
};
24 changes: 17 additions & 7 deletions gensim/models/word2vec.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@
from copy import deepcopy
from collections import defaultdict
import threading
import multiprocessing as mp
import itertools
import warnings

Expand Down Expand Up @@ -423,7 +424,7 @@ class Word2Vec(BaseWordEmbeddingsModel):

"""

def __init__(self, sentences=None, size=100, alpha=0.025, window=5, min_count=5,
def __init__(self, sentences=None, input_streams=None, size=100, alpha=0.025, window=5, min_count=5,
max_vocab_size=None, sample=1e-3, seed=1, workers=3, min_alpha=0.0001,
sg=0, hs=0, negative=5, cbow_mean=1, hashfxn=hash, iter=5, null_word=0,
trim_rule=None, sorted_vocab=1, batch_words=MAX_WORDS_IN_BATCH, compute_loss=False, callbacks=(),
Expand Down Expand Up @@ -528,9 +529,9 @@ def __init__(self, sentences=None, size=100, alpha=0.025, window=5, min_count=5,
self.trainables = Word2VecTrainables(seed=seed, vector_size=size, hashfxn=hashfxn)

super(Word2Vec, self).__init__(
sentences=sentences, workers=workers, vector_size=size, epochs=iter, callbacks=callbacks,
batch_words=batch_words, trim_rule=trim_rule, sg=sg, alpha=alpha, window=window, seed=seed,
hs=hs, negative=negative, cbow_mean=cbow_mean, min_alpha=min_alpha, compute_loss=compute_loss,
sentences=sentences, input_streams=input_streams, workers=workers, vector_size=size, epochs=iter,
callbacks=callbacks, batch_words=batch_words, trim_rule=trim_rule, sg=sg, alpha=alpha, window=window,
seed=seed, hs=hs, negative=negative, cbow_mean=cbow_mean, min_alpha=min_alpha, compute_loss=compute_loss,
fast_version=FAST_VERSION)

def _do_train_job(self, sentences, alpha, inits):
Expand All @@ -555,7 +556,7 @@ def _set_train_params(self, **kwargs):
self.compute_loss = kwargs['compute_loss']
self.running_training_loss = 0

def train(self, sentences, total_examples=None, total_words=None,
def train(self, input_streams, total_examples=None, total_words=None,
epochs=None, start_alpha=None, end_alpha=None, word_count=0,
queue_factor=2, report_delay=1.0, compute_loss=False, callbacks=()):
"""Update the model's neural weights from a sequence of sentences (can be a once-only generator stream).
Expand Down Expand Up @@ -613,7 +614,7 @@ def train(self, sentences, total_examples=None, total_words=None,
"""

return super(Word2Vec, self).train(
sentences, total_examples=total_examples, total_words=total_words,
input_streams, total_examples=total_examples, total_words=total_words,
epochs=epochs, start_alpha=start_alpha, end_alpha=end_alpha, word_count=word_count,
queue_factor=queue_factor, report_delay=report_delay, compute_loss=compute_loss, callbacks=callbacks)

Expand Down Expand Up @@ -1156,8 +1157,17 @@ def __init__(self, max_vocab_size=None, min_count=5, sample=1e-3, sorted_vocab=T
self.raw_vocab = None
self.max_final_vocab = max_final_vocab

def scan_vocab(self, sentences, progress_per=10000, trim_rule=None):
def scan_vocab(self, input_streams, progress_per=10000, trim_rule=None):
"""Do an initial scan of all words appearing in sentences."""
from itertools import chain
line_sentences = []
for st in input_streams:
if isinstance(st, string_types):
line_sentences.append(LineSentence(st))
else:
raise RuntimeError("error!!!!!!!!")
sentences = chain(*line_sentences)

logger.info("collecting all words and their counts")
sentence_no = -1
total_words = 0
Expand Down
70 changes: 69 additions & 1 deletion gensim/models/word2vec_inner.pyx
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
#!/usr/bin/env cython
# distutils: language = c++
# distutils: sources = linesentence.cpp
# cython: boundscheck=False
# cython: wraparound=False
# cython: cdivision=True
Expand All @@ -13,7 +15,11 @@ cimport numpy as np

from libc.math cimport exp
from libc.math cimport log
from libc.string cimport memset
from libc.string cimport memset, strtok
from libcpp.string cimport string
from libcpp.vector cimport vector
from libcpp cimport bool as bool_t


# scipy <= 0.15
try:
Expand Down Expand Up @@ -42,6 +48,68 @@ cdef REAL_t[EXP_TABLE_SIZE] LOG_TABLE
cdef int ONE = 1
cdef REAL_t ONEF = <REAL_t>1.0


cdef extern from "linesentence.h":
cdef cppclass FastLineSentence:
FastLineSentence(string&) except +
vector[string] ReadSentence() nogil except +


@cython.final
cdef class CythonLineSentence:
cdef FastLineSentence* _thisptr
cdef public string source
cdef public int max_sentence_length, max_words_in_batch
cdef vector[string] buf_data

def __cinit__(self, source, max_sentence_length=MAX_SENTENCE_LEN):
self._thisptr = new FastLineSentence(source)

def __init__(self, source, max_sentence_length=MAX_SENTENCE_LEN):
self.source = source
self.max_sentence_length = max_sentence_length # isn't used in this hacky prototype
self.max_words_in_batch = MAX_SENTENCE_LEN

def __dealloc__(self):
if self._thisptr != NULL:
del self._thisptr

cpdef vector[string] read_sentence(self) nogil except *:
return self._thisptr.ReadSentence()

cpdef vector[vector[string]] next_batch(self) except *:
with nogil:
return self._next_batch()

cpdef vector[vector[string]] _next_batch(self) nogil except *:
cdef:
vector[vector[string]] job_batch
vector[string] data
int batch_size = 0
int data_length = 0

# Try to read data from previous calls which was not returned
if self.buf_data.size() > 0:
data = self.buf_data
self.buf_data.clear()
else:
data = self.read_sentence()

data_length = data.size()
while batch_size + data_length <= self.max_words_in_batch:
job_batch.push_back(data)
batch_size += data_length

# TODO: if it raises an exception, we will not return a batch we read up to this moment
data = self.read_sentence()
data_length = data.size()

# Save data which doesn't fit in batch in order to return it later.
buf_data = data

return job_batch


# for when fblas.sdot returns a double
cdef REAL_t our_dot_double(const int *N, const float *X, const int *incX, const float *Y, const int *incY) nogil:
return <REAL_t>dsdot(N, X, incX, Y, incY)
Expand Down
112 changes: 112 additions & 0 deletions gensim/scripts/benchmark_any2vec_speed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
from __future__ import unicode_literals
from __future__ import print_function

import logging
import argparse
import json
import copy
# import yappi
import os
import glob

from gensim.models import base_any2vec
from gensim.models.fasttext import FastText
from gensim.models.word2vec import Word2Vec
from gensim.models.doc2vec import Doc2Vec, TaggedLineDocument
from gensim.models.word2vec import LineSentence


logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s', level=logging.INFO)

logger = logging.getLogger(__name__)


SUPPORTED_MODELS = {
'fasttext': FastText,
'word2vec': Word2Vec,
'doc2vec': Doc2Vec,
}


def print_results(model_str, results):
logger.info('----- MODEL "{}" RESULTS -----'.format(model_str).center(50))
logger.info('\t* Vocab time: {} sec.'.format(results['vocab_time']))
logger.info('\t* Total epoch time: {} sec.'.format(results['total_time']))
# logger.info('\t* Avg queue size: {} elems.'.format(results['queue_size']))
logger.info('\t* Processing speed: {} words/sec'.format(results['words_sec']))
logger.info('\t* Avg CPU loads: {}'.format(results['cpu_load']))
logger.info('\t* Sum CPU load: {}'.format(results['cpu_load_sum']))


def benchmark_model(input_streams, model, window, workers, vector_size):
if model == 'doc2vec':
kwargs = {
'input_streams': [TaggedLineDocument(inp) for inp in input_streams]
}
else:
kwargs = {
'input_streams': [inp for inp in input_streams] # hack for CythonLineSentence
}

kwargs['size'] = vector_size

if model != 'sent2vec':
kwargs['window'] = window

kwargs['workers'] = workers
kwargs['iter'] = 1

logger.info('Creating model with kwargs={}'.format(kwargs))

# Training model for 1 epoch.
# yappi.start()
SUPPORTED_MODELS[model](**kwargs)
# yappi.get_func_stats().print_all()
# yappi.get_thread_stats().print_all()

return copy.deepcopy(base_any2vec.PERFORMANCE_METRICS)


def do_benchmarks(input_streams, models_grid, vector_size, workers_grid, windows_grid, label):
full_report = {}

for model in models_grid:
for window in windows_grid:
for workers in workers_grid:
model_str = '{}-{}-window-{:02d}-workers-{:02d}-size-{}'.format(label, model, window, workers, vector_size)

logger.info('Start benchmarking {}.'.format(model_str))
results = benchmark_model(input_streams, model, window, workers, vector_size)

print_results(model_str, results)

full_report[model_str] = results

logger.info('Benchmarking completed. Here are the results:')
for model_str in sorted(full_report.keys()):
print_results(model_str, full_report[model_str])

fout_name = '{}-report.json'.format(label)
with open(fout_name, 'w') as fout:
json.dump(full_report, fout)

logger.info('Saved metrics report to {}.'.format(fout_name))


if __name__ == '__main__':
parser = argparse.ArgumentParser(description='GSOC Multistream-API: evaluate performance '
'metrics for any2vec models')
parser.add_argument('--input', type=str, help='Input file or regexp if `multistream` mode is on.')
parser.add_argument('--models-grid', nargs='+', type=str, default=SUPPORTED_MODELS.keys())
parser.add_argument('--size', type=int, default=300)
parser.add_argument('--workers-grid', nargs='+', type=int, default=[1, 4, 8, 10, 12, 14])
parser.add_argument('--windows-grid', nargs='+', type=int, default=[10])
parser.add_argument('--label', type=str, default='untitled')

args = parser.parse_args()

input_ = os.path.expanduser(args.input)
input_streams = glob.glob(input_)
logger.info('Glob found {} input streams. List: {}'.format(len(input_streams), input_streams))

do_benchmarks(input_streams, args.models_grid, args.size, args.workers_grid, args.windows_grid, args.label)
38 changes: 38 additions & 0 deletions gensim/scripts/benchmark_w2v_vocab.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from __future__ import unicode_literals
from __future__ import print_function

import logging
import argparse
# import yappi
import os
import glob

from gensim.models import base_any2vec
from gensim.models.word2vec import Word2Vec, LineSentence


logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s', level=logging.INFO)

logger = logging.getLogger(__name__)


if __name__ == '__main__':
parser = argparse.ArgumentParser(description='GSOC Multistream-API: evaluate vocab performance '
'for word2vec')
parser.add_argument('--input', type=str, help='Input file or regexp for multistream.')
parser.add_argument('--size', type=int, default=300)
parser.add_argument('--workers-grid', nargs='+', type=int, default=[1, 2, 3, 4, 5, 8, 10, 12, 14])
parser.add_argument('--label', type=str, default='untitled')

args = parser.parse_args()

input_ = os.path.expanduser(args.input)
input_streams = glob.glob(input_)
logger.info('Glob found {} input streams. List: {}'.format(len(input_streams), input_streams))

input_streams = [LineSentence(_) for _ in input_streams]
for workers in args.workers_grid:
model = Word2Vec()
model.build_vocab(input_streams, workers=workers)
logger.info('Workers = {}\tVocab time = {:.2f} secs'.format(workers,
base_any2vec.PERFORMANCE_METRICS['vocab_time']))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can simply measure like

start = time.time()
model.build_vocab(..)
vocab_time = time.time() - start

why you need some "internal" stuff?

2 changes: 1 addition & 1 deletion wikipedia_to_txt.py → gensim/scripts/wikipedia_to_txt.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,4 @@

i += 1

fout.close()
fout.close()
25 changes: 25 additions & 0 deletions gensim/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1709,6 +1709,31 @@ def prune_vocab(vocab, min_reduce, trim_rule=None):
return result


def merge_dicts(dict1, dict2):
"""Merge `dict1` of (word, freq1) and `dict2` of (word, freq2) into `dict1` of (word, freq1+freq2).

Parameters
----------
dict1 : dict
First dictionary.
dict2 : dict
Second dictionary.

Returns
-------
result : dict
Merged dictionary with sum of frequencies as values.

"""
for word, freq in dict2.iteritems():
if word in dict1:
dict1[word] += freq
else:
dict1[word] = freq

return dict1


def qsize(queue):
"""Get the (approximate) queue size where available.

Expand Down
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,8 @@ def finalize_options(self):

ext_modules=[
Extension('gensim.models.word2vec_inner',
sources=['./gensim/models/word2vec_inner.c'],
sources=['./gensim/models/word2vec_inner.cpp', './gensim/models/linesentence.cpp'],
language="c++",
include_dirs=[model_dir]),
Extension('gensim.models.doc2vec_inner',
sources=['./gensim/models/doc2vec_inner.c'],
Expand Down