In [None]:
import pandas as pd
import re
from os import listdir, makedirs
from os.path import isfile, join, exists
import gc
import gensim
from gensim.models.doc2vec import Doc2Vec, TaggedDocument
import logging
from constants import *
from itertools import islice

import multiprocessing
from gensim.models import Word2Vec
import logging
import argparse

from gensim.models.callbacks import CallbackAny2Vec

#parser = argparse.ArgumentParser()
#parser.add_argument("retrain")  # Number of training epochs
#parser.add_argument("cores")  # Number of training epochs
#parser.add_argument("epochs")  # Number of training epochs
#parser.add_argument("input_dir")
#parser.add_argument("output_dir")
#parser.add_argument("model_dir")
#args = parser.parse_args()

retrain = False
cores = 7
epochs = 5

print('cpu count:', multiprocessing.cpu_count())
print('worker count:', cores)

input_dir = join(SMPL_PATH, 'dewiki/cache')
out_dir = join(ETL_PATH, 'NETL/trained_models/w2v_lemma')

if not exists(out_dir):
    os.makedirs(out_dir)

model_file = join(out_dir, 'w2v_lemma')


class EpochLogger(CallbackAny2Vec):
    """Callback to log information about training"""
    def __init__(self):
        self.epoch = 1

    def on_epoch_begin(self, model):
        print("Epoch #{} start".format(self.epoch))

    def on_epoch_end(self, model):
        print("Epoch #{} end".format(self.epoch))
        self.epoch += 1


class EpochSaver(CallbackAny2Vec):
    """Callback to save model after each epoch."""
    def __init__(self, path_prefix):
        self.path_prefix = path_prefix
        self.epoch = 0

    def on_epoch_end(self, model):
        output_path = '{}_epoch{}.model'.format(self.path_prefix, self.epoch)
        model.save(output_path)
        self.epoch += 1

files = sorted([f for f in listdir(input_dir)
                if isfile(join(input_dir, f))])

class Sentences(object):
    def __iter__(self):
        for name in files[:]:
            gc.collect()
            corpus = name.split('.')[0]
            f = join(input_dir, name)
            ser = pd.read_pickle(f)
            for sent in ser:
                # the conversion of the hash_id to str is necessary since gensim trys to allocate 
                # an array for ids of size 2^64 if int values are too big.
                yield sent

sentences = Sentences()

logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s', level=logging.INFO)

epoch_saver = EpochSaver(out_file)
epoch_logger = EpochLogger()

# Model initialization
if retrain == 'True':
    print('load existing model from', model_file)
    model = Word2Vec.load(model_file)
    out_file = model_file + 'retrained{:d}epochs'.format(epochs)
else:
    print('construct new model')
    model = Word2Vec(
        size=300,
        window=5,
        min_count=20,
        workers=cores,
        sample=0.00001,
        negative=5,
        sg=1,
        #callbacks=[epoch_logger, epoch_saver],
        iter=epochs,
    )
    model.build_vocab(sentences)
    out_file = model_file

# Model Training
print('retrain {:d} epochs'.format(epochs))
model.train(
    sentences,
    total_examples=model.corpus_count,
    epochs=epochs,
    report_delay=60.0,
    callbacks=[epoch_logger, epoch_saver],
)

print('write model to', out_file)
model.save(out_file)
