In [None]:

from sklearn.metrics.pairwise import linear_kernel, cosine_similarity
import pickle
from collections import defaultdict
import numpy as np
import scipy
import os
import time       
import os 

from emb.tf_idf import TFIDF
from utils.find_ranking_citation import CitationRanking
from utils.write_result import WriteResult
from utils.utils import sort_this
from utils.configuraiton import Rec_configuration 
from utils.eval import Metrics
from utils.data_loading import DataLoading

In [None]:
class Master:
    def __init__(self, rerank = False):
        self.rec_conf = Rec_configuration()
        self.article_title, self.article_abstract, self.geo_title, \
        self.geo_summary, self.citation_data = DataLoading().get_all_details() #we should get actual citations
        self.auto_rank = CitationRanking(self.citation_data)
        self.write_res = WriteResult()
        choose = list(self.citation_data.keys())
        self.geo_test_data = [key for key in choose if self.citation_data[key] != []] #76,064
        #self.rr = ReRankingTitle() # i definitely need to work on this as well (re-ranking title using bert)
        self.top_threshold = 10 #sounds nice
        self.rerank = rerank
        #bert related
        self.res_addr = self.rec_conf.result_address_tfidf #subject to change
        
    def process(self):
        #check paths and check vecs  
        tf_idf = TFIDF(self.rec_conf.model_path_tfidf)
        pmids, tfidf_vecs = None, None #this should be loading something

        # This is for loading the pre-trained bert models.
        if tfidf_vecs is not None and pmids is not None:
            print('load trained models and encoded publications vectors')
            '''
            
            '''
        else:
            print('Building models')
            joined_dict_article = {}
            for pmid in self.article_title: #for articles, encoding the articles
                #print('pmid')
                #print(pmid)
                if (self.article_title[pmid] + self.article_abstract[pmid]).strip() != '':
                    joined_dict_article[pmid] = self.article_title[pmid] + ' ' + self.article_abstract[pmid]
            #should be some training here, otherwise just do the easiest encodinh
            pmids, tfidf_vecs = tf_idf.train_TFIDF(joined_dict_article)
            with open(self.res_addr + 'smallbase/tfidf_vecs', 'wb') as handle:
                pickle.dump(tfidf_vecs, handle, protocol=pickle.HIGHEST_PROTOCOL)
            with open(self.res_addr + 'smallbase/pmids', 'wb') as handle:
                pickle.dump(pmids, handle, protocol=pickle.HIGHEST_PROTOCOL)
   
    
        joined_dict_geo = {} # maxium: 76,064
        # screening first for those geo_ids with pmids 
        for geo_id in self.geo_test_data:
            print('geo_id')
            print(geo_id)
            if (self.geo_title[geo_id] + self.geo_summary[geo_id]).strip() != '':
                #only get the ones with actual texts
                joined_dict_geo[geo_id] = self.geo_title[geo_id] + ' ' + self.geo_summary[geo_id]
        geo_ids, geo_tfidf_vecs = tf_idf.test_TFIDF(joined_dict_geo)
        with open(self.res_addr + 'smallbase/geo_tfidf_vecs', 'wb') as handle:
            pickle.dump(geo_tfidf_vecs, handle, protocol=pickle.HIGHEST_PROTOCOL)
        with open(self.res_addr + 'smallbase/geo_ids', 'wb') as handle:
            pickle.dump(geo_ids, handle, protocol=pickle.HIGHEST_PROTOCOL)        

In [None]:
def main():
    x = Master(rerank = False)
    x.process()


if __name__ == '__main__':
    main()

In [None]:
#read-in stored vec 
res_addr = 'results/tf-idf_plain/smallbase/'
with open(res_addr +'tfidf_vecs', 'rb') as fp:
    tfidf_vecs = pickle.load(fp)
with open(res_addr +'geo_tfidf_vecs', 'rb') as fp:
      geo_tfidf_vecs = pickle.load(fp)        
        
with open(res_addr +'pmids', 'rb') as fp:
      pmids = pickle.load(fp)
with open(res_addr +'geo_ids', 'rb') as fp:
      geo_ids = pickle.load(fp)      

In [None]:
_, _, geo_title, geo_summary, citation_data = DataLoading().get_all_details() #we should get actual citations
auto_rank = CitationRanking(citation_data)
write_res = WriteResult()  
b_size = 64
#tfidf_vecs = tfidf_vecs.toarray()

In [None]:
chunks = geo_tfidf_vecs.shape[0] // b_size +1
for step in range(chunks):
    batch = geo_tfidf_vecs[step*b_size:(step+1)*b_size]
    batch = batch.toarray()
    print(batch.shape)
    geo_ids_batch = geo_ids[step*b_size: (step+1)*b_size]
    #geo_batch, geo_ids_batch = batch
    similarity_value_dict = dict()
    similarity_scores = cosine_similarity(batch, tfidf_vecs)
    #we need to export this value
    np.save(res_addr + 'similarity_scores_batch_' + str(step), similarity_scores)#.cpu().numpy())
    #load this later
    #np.load(self.res_addr + 'base/similarity_scores')
    print(similarity_scores.shape)
    #call the re-ranking or not 
    for i, geo_id in enumerate(geo_ids_batch): #this is only from testing geo_id
        print(geo_id)
        similarity_value_dict[geo_id] =  list(similarity_scores[i])#.cpu().numpy()

    new_similarity_dict =  sort_this(geo_sim_dict = similarity_value_dict, pmid_ls = pmids)
    new_similarity_dict = dict(new_similarity_dict)
    filename =  res_addr + 'new_similarity_dict_batch_' + str(step)
    os.makedirs(os.path.dirname(filename), exist_ok=True)
    with open(res_addr + 'new_similarity_dict_batch_' + str(step), 'wb') as handle:
        pickle.dump(new_similarity_dict, handle, protocol=pickle.HIGHEST_PROTOCOL)
        
    print('export dict!!!!!!!!')
     
        #screen for geo_ids that actually has citations t
    for id, geo_id in enumerate(geo_ids_batch): 
        print(geo_id)
        #this has already stored in a pickled list of geo_haspmid.pickle
        temp_selected = new_similarity_dict[geo_id]
        filename =  res_addr + geo_id + '.txt'
        os.makedirs(os.path.dirname(filename), exist_ok=True)
        write_res.write(res_addr  + geo_id + '.txt', temp_selected)
        auto_rank.find_citations(geo_id, list(temp_selected.keys())) #keys are a list of pmids

a, b, c, d = auto_rank.get_values() #need to modify this 
print('good geo recommendations = {}, top1 hit geo recommendations = {}, bad geo recommendations = {}, '
      'geo without citations = {}'.format(a, b, c, d))

In [None]:
#import glob
step =1137
res_addr = 'results/tf-idf_plain/smallbase/'
new_similarity_dict =  {}
#print
#for file in files:
for i in range(step +1):
    with open(res_addr + "new_similarity_dict_batch_" + str(i), 'rb') as fp:
        dict_batch = pickle.load(fp)
        new_similarity_dict.update(dict_batch)
#dict.update(dict2)
#dict.update(dict3)
with open(res_addr + 'new_similarity_dict.pickle', 'wb') as handle:
    pickle.dump(new_similarity_dict, handle, protocol=pickle.HIGHEST_PROTOCOL)

In [None]:
#MRR 
print('MRR:')
print(Metrics(citation_data).calculate_mrr(new_similarity_dict)) #mrr
#recall @1 and recall @10
print('recall@1, recall@10:')
print(Metrics(citation_data).calculate_recall_at_k(new_similarity_dict, 1))
print(Metrics(citation_data).calculate_recall_at_k(new_similarity_dict, 10))
#Precision@1 and precision10
print('precision@1, precision@10:')
print(Metrics(citation_data).calculate_precision_at_k(new_similarity_dict, 1))        
print(Metrics(citation_data).calculate_precision_at_k(new_similarity_dict, 10))
#MAP@10
print(Metrics(citation_data).calculate_MAP_at_k(new_similarity_dict))

## 2. time it

In [None]:
joined_dict_geo = {}
for geo_id in geo_title:
    #print('geo_id')
    #print(geo_id)
    if (geo_title[geo_id] + geo_summary[geo_id]).strip() != '':
        #only get the ones with actual texts
        joined_dict_geo[geo_id] = geo_title[geo_id] + ' ' + geo_summary[geo_id]

In [None]:
with open( 'joined_dict_geo', 'wb') as handle:
        pickle.dump(joined_dict_geo, handle, protocol=pickle.HIGHEST_PROTOCOL)

In [None]:
tf_idf = TFIDF(Rec_configuration().model_path_tfidf)

In [None]:
timeit_ls =[]
chunks = len(joined_dict_geo) // b_size +1
for step in range(chunks):
    out = dict(list(joined_dict_geo.items())[step*b_size:(step+1)*b_size])  
    start_time = time.time()
    geo_ids, geo_tfidf_vecs = tf_idf.test_TFIDF(out)
    #batch = geo_tfidf_vecs[step*b_size:(step+1)*b_size]
    batch = geo_tfidf_vecs
    batch = batch.toarray()
    #print(batch.shape)
    #geo_ids_batch = geo_ids[step*b_size: (step+1)*b_size]
    similarity_value_dict = dict()
    similarity_scores = cosine_similarity(batch, tfidf_vecs)
    for i, geo_id in enumerate(geo_ids): #this is only from testing geo_id
        #print(geo_id)
        similarity_value_dict[geo_id] =  list(similarity_scores[i])#.cpu().numpy())

    new_similarity_dict =  sort_this(geo_sim_dict = similarity_value_dict, pmid_ls = pmids)
    end_time = time.time()
    secs = end_time - start_time
    timeit_ls.append(secs)
    if step > 4:
        break
    
print(np.mean(np.array(timeit_ls))) 
print(np.mean(np.array(timeit_ls))/b_size)