## Functionality Summary
In this notebook, we did
* Baseline experiments for doc2vec 
* built the whole vocabulary on the RFA and testing on publications belonging to the test data
* all evaluations are done on the test datasets 

In [None]:
import pickle
import os
import sys
import argparse 
import ast
import pandas as pd
import numpy as np
import logging 
#from sklearn.feature_extraction.text import TfidfVectorizer
#from sklearn.metrics.pairwise import linear_kernel
#from gensim import corpora
#from gensim.summarization import bm25
import gensim
import gensim.downloader as gensim_api
from gensim.models.doc2vec import Doc2Vec, TaggedDocument
# from gensim.models.word2vec import Word2Vec
#from gensim.test.utils import get_tmpfile

#local 
import utils_bsl as ut

In [None]:
#in the order that they will be used
parser = argparse.ArgumentParser(description = 'doc2vec for Grant recommendation')
parser.add_argument('-data_path', type = str, default = 'newdata/', 
                    help = 'complete path to the training data [default:newdata/]')
parser.add_argument('-load_pretrained', type = bool, default = True,
                    help = 'whether to load pretrained doc2vec embeddings & corpus & vectorizer [default:False]')
parser.add_argument('-load_path', type = str, default = 'evalAuto/doc2vec/', 
                    help = 'path where doc2vec embeddings & corpus & vectorizers are saved [default:evalAuto/doc2vec/]')
# some training parameters regarding word2vec
parser.add_argument('-vector_size', type = int, default = 200, 
                    help = 'document vector dimension [default:200]')
parser.add_argument('-min_count', type = int, default = 2, 
                    help = 'vector minimun count [default:2]')
parser.add_argument('-epochs', type = int, default = 100, 
                    help = 'training epochs [default:50]')
parser.add_argument('-workers', type = int, default = 4, 
                    help = 'number of workers for model [default:4]')
parser.add_argument('-top', type = int, default = 10, 
                    help = 'number of recommendations to take [default:10]')
args = parser.parse_args([])

In [None]:
def main(args):
    seed_val = 1234
    ut.set_seed(seed_val) 
    
    # get logger started
    logging.basicConfig(level=logging.ERROR, filename= args.load_path + "logfile", filemode="a+",
                            format="%(asctime)-15s %(levelname)-8s %(message)s")
    logger = logging.getLogger('doc2vec for grant')
    handler = logging.FileHandler(args.load_path + "logfile")
    logger.addHandler(handler)
    logger.error('doc2vec for grant')
    
    try:
        # train, valid and test 
        rfas, pubs, mix_df, \
        train_idx, valid_idx, test_idx, \
        train_citation, valid_citation, citation, \
        train_mixed, valid_mixed, test_mixed = ut.load_data(args.data_path)
        
        model, rfa_ids = ut.process_rfa_corpus_d2v(df = rfas, outpath = args.load_path, args = args, \
                                                               load_pretrained = args.load_pretrained)
        train_sims , train_pmids, _ = ut.process_pub_query_d2v(idx = train_idx, mix_df = mix_df, pubs = pubs, 
                                                          model = model,\
                                                          idx_name = 'train_', \
                                                          outpath = args.load_path, \
                                                          load_pretrained = args.load_pretrained)
        valid_sims, valid_pmids, _  = ut.process_pub_query_d2v(idx = valid_idx, mix_df = mix_df, pubs = pubs, 
                                                          model = model,\
                                                          idx_name = 'valid_', \
                                                          outpath = args.load_path, \
                                                          load_pretrained = args.load_pretrained)
        test_sims, test_pmids, _ = ut.process_pub_query_d2v(idx = test_idx, mix_df = mix_df, pubs = pubs, 
                                                          model = model,\
                                                          idx_name = 'test_', \
                                                          outpath = args.load_path, \
                                                          load_pretrained = args.load_pretrained)
        train_dict = ut.sim_recommend_d2v(corpus_ids = rfa_ids,\
                                       sims = train_sims, query_ids = train_pmids, 
                                       mix_dict = train_mixed, mode= 'strict',outpath = args.load_path, \
                                       query_name = 'train_', top= args.top) 
        valid_dict = ut.sim_recommend_d2v(corpus_ids = rfa_ids,\
                                       sims = valid_sims, query_ids = valid_pmids, 
                                       mix_dict = valid_mixed, mode= 'strict',outpath = args.load_path, \
                                       query_name = 'valid_', top= args.top) 
        test_dict = ut.sim_recommend_d2v(corpus_ids = rfa_ids,\
                                         sims = test_sims, query_ids = test_pmids, 
                                         mix_dict = test_mixed, mode= 'strict',outpath = args.load_path, \
                                         query_name = 'test_', top= args.top) 
        

        # evaluation on train and test 
        logger.error('=======train statistics======')
        ut.print_metrics(citation = train_citation, similarity_dict = train_dict, logger = logger, ks = [1, 5])
        print('=========================================')
        logger.error('=======test statistics======')
        ut.print_metrics(citation = citation, similarity_dict = test_dict, logger = logger, ks = [1, 5])
        logging.shutdown()
        for handler in logger.handlers:
            if isinstance(handler, logging.FileHandler):
                handler.close()
                
    except KeyboardInterrupt:
        print(colored('--' * 70, 'green'))
        print(colored('Exiting from training early', 'green'))

In [None]:
if __name__ == "__main__":
    main(args)