In [None]:
%run ./../various/_epoch-callback.ipynb

In [None]:
import os
from pathlib import Path
from tqdm import tqdm
from gensim.models.doc2vec import Doc2Vec, TaggedDocument

In [None]:
D2V_MODEL_DIR = f'{PROC_DATA_DIR}/d2v'
D2V_MODEL_PATH = f'{D2V_MODEL_DIR}/d2v_dataset.model'

PV_DIM = 300
MODEL_ITERS = 30

In [None]:
if RANDOM_SEED is not None:
    os.environ['PYTHONHASHSEED'] = str(RANDOM_SEED)

In [None]:
corpus = [TaggedDocument(doc, [i]) for i, doc in enumerate(df[proc_doc_col])]
workers = 1 if RANDOM_SEED is not None else 3

if not Path(D2V_MODEL_PATH).is_file():
    logger.info("Training Doc2Vec model...")
    with tqdm(total=MODEL_ITERS, disable=SILENT) as pbar:
        d2v_model = Doc2Vec(min_count=1, epochs=MODEL_ITERS, vector_size=PV_DIM, 
                            workers=workers, seed=RANDOM_SEED)
        #wf = {dictionary[idx]: dictionary.dfs[idx] for idx in dictionary}
        #d2v_model.build_vocab_from_freq(wf, corpus_count=len(corpus))
        d2v_model.build_vocab(corpus)

        pbar_updater = GensimEpochCallback(end_func=pbar.update)
        loss_tracker = GensimEpochCallback(end_func=d2v_model.get_latest_training_loss)

        d2v_model.train(corpus, total_examples=d2v_model.corpus_count, 
                        epochs=d2v_model.epochs, compute_loss=True,
                        callbacks=[pbar_updater, loss_tracker])

    logger.debug("Logging training loss for each iteration")
    losses = loss_tracker.end_results
    prev_loss = 0
    for i, loss in enumerate(losses):
        logger.debug(f"{i + 1}: {loss - prev_loss}")
        prev_loss = loss

    Path(D2V_MODEL_DIR).mkdir(parents=True, exist_ok=True)
    logger.info("Storing Doc2Vec model to disk...")
    d2v_model.save(D2V_MODEL_PATH)
else:
    d2v_model = Doc2Vec.load(D2V_MODEL_PATH)