In [None]:
from sklearn.metrics.pairwise import linear_kernel, cosine_similarity
import pickle
from collections import defaultdict
import numpy as np
import scipy
import os
import time       
import os 

from emb.doc2vec import Doc2Vec
from utils.find_ranking_citation import CitationRanking
from utils.write_result import WriteResult
from utils.utils import sort_this
from utils.configuraiton import Rec_configuration 
from utils.eval import Metrics
from utils.data_loading import DataLoading

In [None]:
class Master:
    def __init__(self, rerank = False):
        self.rec_conf = Rec_configuration()
        self.article_title, self.article_abstract, self.geo_title, \
        self.geo_summary, self.citation_data = DataLoading().get_all_details() #we should get actual citations
        self.auto_rank = CitationRanking(self.citation_data)
        self.write_res = WriteResult()
        choose = list(self.citation_data.keys())
        self.geo_test_data = [key for key in choose if self.citation_data[key] != []] #76,064
        #self.rr = ReRankingTitle() # i definitely need to work on this as well (re-ranking title using bert)
        self.top_threshold = 10 #sounds nice
        self.rerank = rerank
        #bert related
        #self.model_distilbert = SentenceTransformer('distilbert-base-nli-stsb-mean-tokens') #roberta 
        self.model_addr = self.rec_conf.model_path_doc2vec #subject to change
        self.res_addr = self.rec_conf.result_address_doc2vec #subject to change
        

    def process(self):
        #check paths and check vecs  
        d2v = Doc2Vec(self.model_addr)
        pmids, d2v_vecs = None, None #this should be loading something

        # This is for loading the pre-trained bert models.
        if d2v_vecs is not None and pmids is not None:
            print('load trained models and encoded publications vectors')
            '''
            
            '''
        else:
            print('Building models')
            joined_dict_article = {}
            for pmid in self.article_title: #for articles, encoding the articles
                if (self.article_title[pmid] + self.article_abstract[pmid]).strip() != '':
                    joined_dict_article[pmid] = self.article_title[pmid] + ' ' + self.article_abstract[pmid]
            #should be some training here, otherwise just do the easiest encodinh
            d2v.training(joined_dict_article)
    
        joined_dict_geo = {} # maxium: 76,064
        # screening first for those geo_ids with pmids 
        for geo_id in self.geo_test_data:
            print('geo_id')
            print(geo_id)
            if (self.geo_title[geo_id] + self.geo_summary[geo_id]).strip() != '':
                #only get the ones with actual texts
                joined_dict_geo[geo_id] = self.geo_title[geo_id] + ' ' + self.geo_summary[geo_id]
                
        #geo_w2v_vecs_dict = w2v.create_vectors(joined_dict_geo)
        with open(self.res_addr + 'smallbase/joined_dict_geo', 'wb') as handle:
            pickle.dump(joined_dict_geo, handle, protocol=pickle.HIGHEST_PROTOCOL)
    

In [None]:
def main():
    x = Master(rerank = False)
    x.process()


if __name__ == '__main__':
    main()

In [None]:
#read-in stored vec 
res_addr = 'results/doc2vec_plain/smallbase/'
model_addr = 'resources/doc2vec/'
with open(res_addr +'joined_dict_geo', 'rb') as fp:
    joined_dict_geo = pickle.load(fp)

In [None]:
#do numpy only 
_, _, _, _, citation_data = DataLoading().get_all_details() #we should get actual citations
auto_rank = CitationRanking(citation_data)
write_res = WriteResult()  
#b_size = 64
#tfidf_vecs = tfidf_vecs.toarray()

In [None]:
d2v = Doc2Vec(model_addr)
d2v.load_model()

In [None]:
geo_ids = list(joined_dict_geo.keys())
#geo_w2c_vecs  = list(geo_w2v_vecs_dict.values()) 

In [None]:
final_similarity_dict = {}
for step, (k,v) in enumerate(joined_dict_geo.items()):
    query_doc = v #list of the string
    query_vec = d2v.train_new_vec(query_doc)
    new_similarity_dict = d2v.similar_vec(query_vec)
    final_similarity_dict[k] = new_similarity_dict
    print(k)
    write_res.write(res_addr  + k + '.txt', new_similarity_dict)
    auto_rank.find_citations(k, list(new_similarity_dict.keys())) #keys are a list of pmids
a, b, c, d = auto_rank.get_values() #need to modify this 
print('good geo recommendations = {}, top1 hit geo recommendations = {}, bad geo recommendations = {}, '
      'geo without citations = {}'.format(a, b, c, d))
#we need many more than just MRR
filename =  res_addr + 'new_similarity_dict.pickle'
os.makedirs(os.path.dirname(filename), exist_ok=True)
with open(res_addr + 'new_similarity_dict.pickle', 'wb') as handle:
    pickle.dump(final_similarity_dict, handle, protocol=pickle.HIGHEST_PROTOCOL)    
    

In [None]:
with open(res_addr +'new_similarity_dict.pickle', 'rb') as fp:
    new_similarity_dict = pickle.load(fp)

In [None]:
#MRR 
print('MRR:')
print(Metrics(citation_data).calculate_mrr(new_similarity_dict)) #mrr
#recall @1 and recall @10
print('recall@1, recall@10:')
print(Metrics(citation_data).calculate_recall_at_k(new_similarity_dict, 1))
print(Metrics(citation_data).calculate_recall_at_k(new_similarity_dict, 10))
#Precision@1 and precision10
print('precision@1, precision@10:')
print(Metrics(citation_data).calculate_precision_at_k(new_similarity_dict, 1))        
print(Metrics(citation_data).calculate_precision_at_k(new_similarity_dict, 10))
#MAP@10
print(Metrics(citation_data).calculate_MAP_at_k(new_similarity_dict))