In [None]:
from sklearn.metrics.pairwise import linear_kernel, cosine_similarity
import pickle
from collections import defaultdict
import numpy as np
import scipy
from scipy.spatial.distance import cdist
import os
import time       
import os 
import operator 

## torch 
from sentence_transformers import SentenceTransformer, models
import torch
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler

## self defined
from utils.find_ranking_citation import CitationRanking
from utils.write_result import WriteResult
from utils.utils import sort_this
from utils.configuraiton import Rec_configuration 
from utils.eval import Metrics
from utils.data_loading import DataLoadingg

In [None]:
class Master:
    def __init__(self, rerank = False):
        self.rec_conf = Rec_configuration()
        self.article_title, self.article_abstract, self.geo_title, \
        self.geo_summary, self.citation_data = DataLoading().get_all_details() #we should get actual citations
        self.auto_rank = CitationRanking(self.citation_data)
        self.write_res = WriteResult()
        choose = list(self.citation_data.keys())
        self.geo_test_data = [key for key in choose if self.citation_data[key] != []] #76,064
        #self.rr = ReRankingTitle() # i definitely need to work on this as well (re-ranking title using bert)
        self.top_threshold = 10 #sounds nice
        self.rerank = rerank
        #bert related
        self.model_distilbert = SentenceTransformer('distilbert-base-nli-stsb-mean-tokens') #roberta 
        self.res_addr = self.rec_conf.result_address_distilbert #subject to change

        
    def re_rank_bert(self, geo_sim_dict, article_title, geo_title, pmid_ls): #pmid_ls. should be the complete pmid: similarity score column dimension
        assert np.asarray(list(geo_sim_dict.values())).shape[1] == len(pmid_ls) #the # of columns should be equal to total # of pmids
        geo_words = list(geo_title.values())#a list of words? I don't need to split for bert
        title_words = list(article_title.values())
        sim_value = linear_kernel(self.model_bert.encode(geo_words), self.model_bert.encode(title_words))
        for i, geo_id in enumerate(geo_sim_dict): #key of the dictionary is the geo-ids
            #for each geo_id, adding re-ranking values
            geo_sim_dict[geo_id] += sim_value[i, :]
        return self.sorting(geo_sim_dict, pmid_ls)
    
    def sorting(self, geo_sim_dict, pmid_ls):
          
        '''
        for each geo_id, give a list of recommendations 
        this could be for non-reranking or reranking
        '''
        similarity_dict = defaultdict()
        sim_np = np.asarray(list(geo_sim_dict.values()))
        idx_np = np.argsort(-sim_np, axis= 1) #so the big values will be in front 
        #take on # of top_threshold
        idx_np = idx_np[:,:self.top_threshold]
        sim_np_taketop = np.take_along_axis(sim_np, idx_np, axis=1) 
        for i, geo_id in enumerate(geo_sim_dict):
            pmid_selected = list(np.take(pmid_ls, idx_np[i]))
            selected = dict(zip(pmid_selected, sim_np_taketop[i]))
            similarity_dict[geo_id] = selected     
        return similarity_dict
    
    def process(self):
        #check paths and check vecs        
        pmids, distilbert_vecs = None, None #this should be loading something

        # This is for loading the pre-trained bert models.
        if distilbert_vecs is not None and pmids is not None:
            print('load trained models and encoded publications vectors')
            '''
            
            '''
        else:
            print('Building models')
            joined_dict_article = {}
            for pmid in self.article_title: #for articles, encoding the articles
                #print('pmid')
                #print(pmid)
                if (self.article_title[pmid] + self.article_abstract[pmid]).strip() != '':
                    joined_dict_article[pmid] = self.article_title[pmid] + ' ' + self.article_abstract[pmid]
            #should be some training here, otherwise just do the easiest encodinh
            pmids, distilbert_vecs = list(joined_dict_article.keys()), self.model_distilbert.encode(list(joined_dict_article.values()))
            with open(self.res_addr + 'base/distilbert_vecs', 'wb') as handle:
                pickle.dump(distilbert_vecs, handle, protocol=pickle.HIGHEST_PROTOCOL)
            with open(self.res_addr + 'base/pmids', 'wb') as handle:
                pickle.dump(pmids, handle, protocol=pickle.HIGHEST_PROTOCOL)
        joined_dict_geo = {} # maxium: 76,064
        # screening first for those geo_ids with pmids 
        for geo_id in self.geo_test_data:
            print('geo_id')
            print(geo_id)
            if (self.geo_title[geo_id] + self.geo_summary[geo_id]).strip() != '':
                #only get the ones with actual texts
                joined_dict_geo[geo_id] = self.geo_title[geo_id] + ' ' + self.geo_summary[geo_id]
        geo_ids, geo_distilbert_vecs = list(joined_dict_geo.keys()), self.model_distilbert.encode(list(joined_dict_geo.values()))
        with open(self.res_addr + 'base/geo_distilbert_vecs', 'wb') as handle:
            pickle.dump(geo_distilbert_vecs, handle, protocol=pickle.HIGHEST_PROTOCOL)
        with open(self.res_addr + 'base/geo_ids', 'wb') as handle:
            pickle.dump(geo_ids, handle, protocol=pickle.HIGHEST_PROTOCOL)        

In [None]:
def main():
    x = Master(rerank = False)
    x.process()


if __name__ == '__main__':
    main()

In [None]:
#read-in stored vec 
res_addr = 'results/distilbert_plain/base/'
with open(res_addr +'distilbert_vecs', 'rb') as fp:
    distilbert_vecs = pickle.load(fp)
with open(res_addr +'geo_distilbert_vecs', 'rb') as fp:
      geo_distilbert_vecs = pickle.load(fp)        
        
with open(res_addr +'pmids', 'rb') as fp:
      pmids = pickle.load(fp)
with open(res_addr +'geo_ids', 'rb') as fp:
      geo_ids = pickle.load(fp)      
        
        
distilbert_vecs = torch.tensor(np.array(distilbert_vecs)).cuda()
geo_distilbert_vecs = torch.tensor(np.array(geo_distilbert_vecs))

In [None]:
b_size = 64
geo_distilbert_vecs_data = TensorDataset(geo_distilbert_vecs)
geo_distilbert_vecs_sampler = SequentialSampler(geo_distilbert_vecs_data) #sequential here 
geo_distilbert_vecs_dataloader = DataLoader(geo_distilbert_vecs_data, sampler=geo_distilbert_vecs_sampler, batch_size= 64)

In [None]:

_, _, _, _, citation_data = DataLoading().get_all_details() #we should get actual citations
auto_rank = CitationRanking(citation_data)
write_res = WriteResult()
     

In [None]:
device  = torch.device("cuda" if torch.cuda.is_available() else "cpu")
for step, batch in enumerate(geo_distilbert_vecs_dataloader):
    #print(batch)
    batch = batch[0].to(device)
    print(batch.shape)
    geo_ids_batch = geo_ids[step*b_size: (step+1)*b_size]
    #geo_batch, geo_ids_batch = batch
    similarity_value_dict = defaultdict()
    #similarity_scores = linear_kernel(geo_bert_vecs, bert_vecs)
    similarity_scores = torch.cdist(batch, distilbert_vecs)  
        #we need to export this value
    np.save(res_addr + 'similarity_scores_batch_' + str(step), similarity_scores.cpu().numpy())
    #load this later
    #np.load(self.res_addr + 'base/similarity_scores')
    print(similarity_scores.shape)
    #call the re-ranking or not 
    for i, geo_id in enumerate(geo_ids_batch): #this is only from testing geo_id
        #print(geo_id)
        #print(geo_id)
        similarity_value_dict[geo_id] =  list(similarity_scores[i].cpu().numpy())
    new_similarity_dict =  sorting(geo_sim_dict = similarity_value_dict, pmid_ls = pmids)
    new_similarity_dict = dict(new_similarity_dict)
    filename =  res_addr + 'new_similarity_dict_batch_' + str(step)
    os.makedirs(os.path.dirname(filename), exist_ok=True)
    with open(res_addr + 'new_similarity_dict_batch_' + str(step), 'wb') as handle:
        pickle.dump(new_similarity_dict, handle, protocol=pickle.HIGHEST_PROTOCOL)
     

    for id, geo_id in enumerate(geo_ids_batch): 
        temp_selected = new_similarity_dict[geo_id]
        write_res.write(res_addr  + geo_id + '.txt', temp_selected)
        auto_rank.find_citations(geo_id, list(temp_selected.keys())) #keys are a list of pmids

a, b, c, d = auto_rank.get_values() #need to modify this 
print('good geo recommendations = {}, top1 hit geo recommendations = {}, bad geo recommendations = {}, '
      'geo without citations = {}'.format(a, b, c, d))

In [None]:
#import glob
#files = glob.glob(res_addr + "similarity_scores_batch_*")
new_similarity_dict =  {}
#print
#for file in files:
for i in range(step +1):
    with open(res_addr + "new_similarity_dict_batch_" + str(i), 'rb') as fp:
        dict_batch = pickle.load(fp)
        new_similarity_dict.update(dict_batch)
#dict.update(dict2)
#dict.update(dict3)
with open(res_addr + 'new_similarity_dict.pickle', 'wb') as handle:
    pickle.dump(new_similarity_dict, handle, protocol=pickle.HIGHEST_PROTOCOL)

In [None]:
#MRR 
print('MRR:')
print(Metrics(citation_data).calculate_mrr(new_similarity_dict)) #mrr
#recall @1 and recall @10
print('recall@1, recall@10:')
print(Metrics(citation_data).calculate_recall_at_k(new_similarity_dict, 1))
print(Metrics(citation_data).calculate_recall_at_k(new_similarity_dict, 10))
#Precision@1 and precision10
print('precision@1, precision@10:')
print(Metrics(citation_data).calculate_precision_at_k(new_similarity_dict, 1))        
print(Metrics(citation_data).calculate_precision_at_k(new_similarity_dict, 10))
#MAP@10
print(Metrics(citation_data).calculate_MAP_at_k(new_similarity_dict))