## Functionality Summary
In this notebook, we did
* Baseline experiments for TFIDF 
* built the whole vocabulary on the RFA and testing on publications belonging to the test data
* all evaluations are done on the test datasets 

In [1]:
import pickle
import os
import sys
import argparse 
import ast
import pandas as pd
import numpy as np
import logging 
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import linear_kernel

#local 
import utils_bsl as ut

In [2]:
#in the order that they will be used
parser = argparse.ArgumentParser(description = 'TFIDF for Grant recommendation')
parser.add_argument('-data_path', type = str, default = 'newdata/', 
                    help = 'complete path to the training data [default:newdata/]')
parser.add_argument('-load_pretrained', type = bool, default = False,
                    help = 'whether to load pretrained TFIDF embeddings & corpus & vectorizer [default:False]')
parser.add_argument('-load_path', type = str, default = 'evalAuto/tfidf/', 
                    help = 'path where TFIDF embeddings & corpus & vectorizers are saved [default:evalAuto/tfidf/]')
# some training parameters regardinf TFIDF 
parser.add_argument('-ngram_range', type = str, default = '(1,2)', help = 'see sklearn TFIDF params')
parser.add_argument('-min_df', type = int, default = 2, help = 'see sklearn TFIDF params')
parser.add_argument('-max_features', type = int, default = 2000, help = 'see sklearn TFIDF params')
parser.add_argument('-top', type = int, default = 10, 
                    help = 'number of recommendations to take [default:10]')
args = parser.parse_args([])

In [3]:
def main(args):
    seed_val = 1234
    ut.set_seed(seed_val) 
    
    # get logger started
    logging.basicConfig(level=logging.ERROR, filename= args.load_path + "logfile", filemode="a+",
                            format="%(asctime)-15s %(levelname)-8s %(message)s")
    logger = logging.getLogger('TFIDF for grant')
    handler = logging.FileHandler(args.load_path + "logfile")
    logger.addHandler(handler)
    logger.error('TFIDF for grant')
    
    try:
        # model 
        tfidf =  TfidfVectorizer(ngram_range = ast.literal_eval(args.ngram_range), \
                                 min_df = args.min_df,\
                                 max_features = args.max_features)
        # train, valid and test 
        rfas, pubs, mix_df, \
        train_idx, valid_idx, test_idx, \
        train_citation, valid_citation, citation, \
        train_mixed, valid_mixed, test_mixed = ut.load_data(args.data_path)
        rfa_tfidf, rfa_ids, tfidf = ut.process_rfa_corpus(df = rfas, vectorizer = tfidf, 
                                                               outpath = args.load_path, \
                                                               load_pretrained = args.load_pretrained)
        train_pubs_tfidf, train_pmids = ut.process_pub_query(idx = train_idx, mix_df = mix_df, pubs = pubs, 
                                                          vectorizer = tfidf, idx_name = 'train_', \
                                                          outpath = args.load_path, \
                                                          load_pretrained = args.load_pretrained)
        valid_pubs_tfidf, valid_pmids = ut.process_pub_query(idx = valid_idx, mix_df = mix_df, pubs = pubs, 
                                                          vectorizer = tfidf, idx_name = 'valid_', \
                                                          outpath = args.load_path, \
                                                          load_pretrained = args.load_pretrained)
        test_pubs_tfidf, test_pmids = ut.process_pub_query(idx = test_idx, mix_df = mix_df, pubs = pubs, 
                                                          vectorizer = tfidf, idx_name = 'test_', \
                                                          outpath = args.load_path, \
                                                          load_pretrained = args.load_pretrained)
        # prediction 
        train_dict = ut.sim_recommend(corpus_vecs = rfa_tfidf, corpus_ids = rfa_ids,\
                           query_vecs = train_pubs_tfidf, query_ids = train_pmids, 
                           mix_dict = train_mixed, mode= 'strict', outpath = args.load_path, query_name = 'train_', 
                           top = args.top)
        valid_dict = ut.sim_recommend(corpus_vecs = rfa_tfidf, corpus_ids = rfa_ids,\
                           query_vecs = valid_pubs_tfidf, query_ids = valid_pmids, 
                           mix_dict = valid_mixed, mode= 'strict', outpath = args.load_path, query_name = 'valid_', 
                           top = args.top)
        test_dict = ut.sim_recommend(corpus_vecs = rfa_tfidf, corpus_ids = rfa_ids,\
                           query_vecs = test_pubs_tfidf, query_ids = test_pmids, 
                           mix_dict = test_mixed, mode= 'strict', outpath = args.load_path, query_name = 'test_', 
                           top = args.top)
        # evaluation on train and test 
        logger.error('=======train statistics======')
        ut.print_metrics(citation = train_citation, similarity_dict = train_dict, logger = logger, ks = [1, 5])
        print('=========================================')
        logger.error('=======test statistics======')
        ut.print_metrics(citation = citation, similarity_dict = test_dict, logger = logger, ks = [1, 5])
        logging.shutdown()
        for handler in logger.handlers:
            if isinstance(handler, logging.FileHandler):
                handler.close()
                
    except KeyboardInterrupt:
        print(colored('--' * 70, 'green'))
        print(colored('Exiting from training early', 'green'))

In [4]:
if __name__ == "__main__":
    main(args)

MRR:
0.8648628804436893
recall@1, recall@5:
0.6871173657727649
0.8842527913131382
precision@1, precision@5:
0.7454027101009757
0.6933972050900526
MAP:
0.8648628804436893
MRR:
0.8646486758507894
recall@1, recall@5:
0.6878090205699189
0.8853745989809398
precision@1, precision@5:
0.7454991507831666
0.6932515568975278
MAP:
0.8646486758507894
