In [1]:
import torch
import argparse
import time
import math
import os, sys
import itertools
import pickle
from tqdm import tqdm
import numpy as np
import os.path as osp

In [2]:
torch.cuda.device_count()

1

In [3]:
def loadpkl(path):
    with open(path, 'rb') as f:
        rst=pickle.load(f)
    return rst

In [4]:
path="/share/data/mei-work/kangrui/github/ssref/data/refsum-data/arxiv-aiml/full_data_no_embed.pkl"
full_data=loadpkl(path)

In [5]:
path="/share/data/mei-work/kangrui/github/ssref/result/pretrained_pair_sbert/test_result/eval_key_id2embed.pkl"
key_embedding=loadpkl(path)

In [21]:
path="/share/data/mei-work/kangrui/github/ssref/result/pretrained_pair_sbert/f1_result/eval_test_pred_1000.pkl"
pred_set=loadpkl(path)

In [22]:
path="/share/data/mei-work/kangrui/github/ssref/result/pretrained_pair_sbert/f1_result/eval_test_gold.pkl"
gold_set=loadpkl(path)

In [23]:
query_id2embedding=key_embedding
key_id2embedding=key_embedding

In [24]:
candidate_set={}
candidate_ref_set={}
for k,v in pred_set.items():
    cur=set()
    for paper in v[:200]:
        cur.add(paper)
        refset=set()
        for ref in full_data[paper]["references"]:
            cur.add(ref["paperId"])
            refset.add(ref["paperId"])
        candidate_ref_set[paper]=refset
    candidate_set[k]=list(cur)

In [36]:
def rerank_pop(candidate_set,candidate_ref_set,rerank_num=None):
    query=list(candidate_set.keys())
    rerank_candidate={}
    if rerank_num is None:
        for v in candidate_set.values():
            if rerank_num is None or rerank_num>len(v):
                rerank_num=len(v)
    pbar = tqdm(range(0,len(candidate_set),1), postfix=f"calculating f1")
    for i in pbar:
        query_paper=query[i]
        cite_frequency={}
        for candidate in candidate_set[query_paper]:
            if candidate in candidate_ref_set.keys():
                # this is possibly the anchor paper
                for ref in list(candidate_ref_set[candidate]):
                    # deal with the ref of an anchor paper
                    if ref in candidate_set[query_paper]:
                        if ref not in cite_frequency.keys():
                            cite_frequency[ref]=1
                        else:
                            cite_frequency[ref]+=1
                    else:
                        # this is not anchor paper for this query
                        break
            if candidate not in cite_frequency.keys():
                cite_frequency[candidate]=1
            else:
                cite_frequency[candidate]+=1
                
        top_k = sorted(cite_frequency.items(), key=lambda x:x[1], reverse=True)
        rerank = [k for k,v in top_k[:rerank_num]]
        rerank_candidate[query_paper]=rerank
    print("min rerank num->",rerank_num)
    return rerank_candidate
        
def rerank_og(candidate_set,query_id2embedding,key_id2embedding,rerank_num=None,device=None):
    cnt=0
    pbar = tqdm(range(0,len(candidate_set),1), postfix=f"calculating f1")
    query=list(candidate_set.keys())
    empty_embed=torch.tensor(np.zeros_like(next(iter(key_id2embedding.values()))))
    empty_embed_cnt=0
    if rerank_num is None:
        for v in candidate_set.values():
            if rerank_num is None or rerank_num>len(v):
                rerank_num=len(v)
                
    rerank_candidate={}   
    for i in pbar:
        query_paper=query[i]
        query_embedding=torch.tensor(query_id2embedding[query_paper]).unsqueeze(0)
        key_embedding=[]
        for paper in candidate_set[query_paper]:
            try:
                key_embedding.append(torch.tensor(key_id2embedding[paper]))
            except:
                key_embedding.append(empty_embed)
                empty_embed_cnt+=1
        
        key_embedding=torch.stack(key_embedding)
        query_embedding=query_embedding.to(device)
        key_embedding=key_embedding.to(device)
        pred_logits=torch.mm(query_embedding,key_embedding.T)
        top_k=torch.topk(pred_logits[0],k=rerank_num)[1]
#         print(pred_logits[0][top_k[-5:]])
        rerank=[candidate_set[query_paper][idx.cpu().item()] for idx in top_k]
        rerank_candidate[query_paper]=rerank
    print("empty embedding cnt->",empty_embed_cnt)
    print("min rerank num->",rerank_num)
    return rerank_candidate
        
def rerank_sum(candidate_set,query_id2embedding,key_id2embedding,candidate_ref_set,rerank_num=None,device=None):
    cnt=0
    pbar = tqdm(range(0,len(candidate_set),1), postfix=f"calculating f1")
    query=list(candidate_set.keys())
    empty_embed=torch.tensor(np.zeros_like(next(iter(key_id2embedding.values()))))
    empty_embed_cnt=0
    if rerank_num is None:
        for v in candidate_set.values():
            if rerank_num is None or rerank_num>len(v):
                rerank_num=len(v)
                
    rerank_candidate={}   
    for i in pbar:
        query_paper=query[i]
        query_embedding=torch.tensor(query_id2embedding[query_paper]).unsqueeze(0)
        
        cited_dict={}
        for candidate in candidate_set[query_paper]:
            
            if candidate in candidate_ref_set.keys():
                # this is possibly the anchor paper
                for ref in list(candidate_ref_set[candidate]):
                    # deal with the ref of an anchor paper
                    if ref in candidate_set[query_paper]:
                        if ref not in cited_dict.keys():
                            cited_dict[ref]=[candidate]
                        else:
                            cited_dict[ref].append(candidate)
                    else:
                        # this is not anchor paper for this query
                        break
            if candidate not in cited_dict.keys():
                cited_dict[candidate]=[candidate]
            else:
                cited_dict[candidate].append(candidate)
                
                
        key_embedding=[]
        for candidate in candidate_set[query_paper]:
            candidate_embedding=[]
            for paper in cited_dict[candidate]:
                if paper in key_id2embedding:
                    candidate_embedding.append(torch.tensor(key_id2embedding[paper]))
            if len(candidate_embedding)==0:
                candidate_embedding.append(empty_embed)
                empty_embed_cnt+=1
            candidate_embedding=torch.sum(torch.stack(candidate_embedding),dim=0)
            
            key_embedding.append(candidate_embedding)
        
        
        key_embedding=torch.stack(key_embedding)
        query_embedding=query_embedding.to(device)
        key_embedding=key_embedding.to(device)
        pred_logits=torch.mm(query_embedding,key_embedding.T)
        top_k=torch.topk(pred_logits[0],k=rerank_num)[1]
#         print(pred_logits[0][top_k[-5:]])
        rerank=[candidate_set[query_paper][idx.cpu().item()] for idx in top_k]
        rerank_candidate[query_paper]=rerank
    print("empty embedding cnt->",empty_embed_cnt)
    print("min rerank num->",rerank_num)
    return rerank_candidate
    
def rerank_mean(candidate_set,query_id2embedding,key_id2embedding,candidate_ref_set,rerank_num=None,device=None):
    cnt=0
    pbar = tqdm(range(0,len(candidate_set),1), postfix=f"calculating f1")
    query=list(candidate_set.keys())
    empty_embed=torch.tensor(np.zeros_like(next(iter(key_id2embedding.values()))))
    empty_embed_cnt=0
    if rerank_num is None:
        for v in candidate_set.values():
            if rerank_num is None or rerank_num>len(v):
                rerank_num=len(v)
                
    rerank_candidate={}   
    for i in pbar:
        query_paper=query[i]
        query_embedding=torch.tensor(query_id2embedding[query_paper]).unsqueeze(0)
        
        cited_dict={}
        for candidate in candidate_set[query_paper]:
            
            if candidate in candidate_ref_set.keys():
                # this is possibly the anchor paper
                for ref in list(candidate_ref_set[candidate]):
                    # deal with the ref of an anchor paper
                    if ref in candidate_set[query_paper]:
                        if ref not in cited_dict.keys():
                            cited_dict[ref]=[candidate]
                        else:
                            cited_dict[ref].append(candidate)
                    else:
                        # this is not anchor paper for this query
                        break
            if candidate not in cited_dict.keys():
                cited_dict[candidate]=[candidate]
            else:
                cited_dict[candidate].append(candidate)
                
                
        key_embedding=[]
        for candidate in candidate_set[query_paper]:
            candidate_embedding=[]
            for paper in cited_dict[candidate]:
                if paper in key_id2embedding:
                    candidate_embedding.append(torch.tensor(key_id2embedding[paper]))
            if len(candidate_embedding)==0:
                candidate_embedding.append(empty_embed)
                empty_embed_cnt+=1
            candidate_embedding=torch.mean(torch.stack(candidate_embedding),dim=0)
            
            key_embedding.append(candidate_embedding)
        
        
        key_embedding=torch.stack(key_embedding)
        query_embedding=query_embedding.to(device)
        key_embedding=key_embedding.to(device)
        pred_logits=torch.mm(query_embedding,key_embedding.T)
        top_k=torch.topk(pred_logits[0],k=rerank_num)[1]
#         print(pred_logits[0][top_k[-5:]])
        rerank=[candidate_set[query_paper][idx.cpu().item()] for idx in top_k]
        rerank_candidate[query_paper]=rerank
    print("empty embedding cnt->",empty_embed_cnt)
    print("min rerank num->",rerank_num)
    return rerank_candidate
        
        


In [26]:
device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [27]:
rerank_candidate=rerank_mean(candidate_set,query_id2embedding,key_id2embedding,candidate_ref_set,device=device)

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 944/944 [11:24<00:00,  1.38it/s, calculating f1]

empty embedding cnt-> 0
min rerank num-> 1891





In [32]:
sys.path.append("/share/data/mei-work/kangrui/github/ssref")
from utils.test_utils import *
from utils.utils import *
def show_rerank_rst(rerank_candidate,gold_set):
    pool_length=len(next(iter(rerank_candidate.values())))
    i=2
    while 1:
        all_pred_topN=[]
        all_gold=[]
        for query_paper,paperlist in rerank_candidate.items():
            topN=paperlist[:i]
            all_pred_topN.append(topN)
            all_gold.append(gold_set[query_paper])
        all_pred_topN=np.array(all_pred_topN)
        metric = get_precision_recall_f1(all_pred_topN, all_gold)
        print(f"top{i}:\n{metric}")
        if i>pool_length:
            break
        i*=2

In [33]:
rerank_candidate=rerank_mean(candidate_set,query_id2embedding,key_id2embedding,candidate_ref_set,device=device)
show_rerank_rst(rerank_candidate,gold_set)

top2:
{'precision': 0.1228813559322034, 'recall': 0.0060353798126951096, 'f1': 0.011505653640150763, 'true_pos': 232, 'false_pos': 1656, 'false_neg': 38208, 'num_sample': 944}
top4:
{'precision': 0.10699152542372882, 'recall': 0.010509885535900104, 'f1': 0.01913966268713284, 'true_pos': 404, 'false_pos': 3372, 'false_neg': 38036, 'num_sample': 944}
top8:
{'precision': 0.08990995762711865, 'recall': 0.01766389177939646, 'f1': 0.029526874238998083, 'true_pos': 679, 'false_pos': 6873, 'false_neg': 37761, 'num_sample': 944}
top16:
{'precision': 0.07660222457627118, 'recall': 0.03009885535900104, 'f1': 0.04321679366502316, 'true_pos': 1157, 'false_pos': 13947, 'false_neg': 37283, 'num_sample': 944}
top32:
{'precision': 0.06620762711864407, 'recall': 0.05202913631633715, 'f1': 0.058268267101736396, 'true_pos': 2000, 'false_pos': 28208, 'false_neg': 36440, 'num_sample': 944}
top64:
{'precision': 0.05950410487288135, 'recall': 0.09352237252861602, 'f1': 0.07273205470583473, 'true_pos': 3595, '

In [34]:
rerank_candidate=rerank_sum(candidate_set,query_id2embedding,key_id2embedding,candidate_ref_set,device=device)
show_rerank_rst(rerank_candidate,gold_set)

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 944/944 [10:04<00:00,  1.56it/s, calculating f1]


empty embedding cnt-> 0
min rerank num-> 1891
top2:
{'precision': 0.21716101694915255, 'recall': 0.010665972944849115, 'f1': 0.02033326720888712, 'true_pos': 410, 'false_pos': 1478, 'false_neg': 38030, 'num_sample': 944}
top4:
{'precision': 0.2640360169491525, 'recall': 0.025936524453694067, 'f1': 0.047233276482850105, 'true_pos': 997, 'false_pos': 2779, 'false_neg': 37443, 'num_sample': 944}
top8:
{'precision': 0.2680084745762712, 'recall': 0.05265348595213319, 'f1': 0.08801530700991476, 'true_pos': 2024, 'false_pos': 5528, 'false_neg': 36416, 'num_sample': 944}
top16:
{'precision': 0.23232256355932204, 'recall': 0.09128511966701353, 'f1': 0.13106977439115491, 'true_pos': 3509, 'false_pos': 11595, 'false_neg': 34931, 'num_sample': 944}
top32:
{'precision': 0.1828654661016949, 'recall': 0.1437044745057232, 'f1': 0.16093695373499592, 'true_pos': 5524, 'false_pos': 24684, 'false_neg': 32916, 'num_sample': 944}
top64:
{'precision': 0.13423596398305085, 'recall': 0.21097814776274715, 'f1':

In [37]:
rerank_candidate=rerank_pop(candidate_set,candidate_ref_set,rerank_num=None)
show_rerank_rst(rerank_candidate,gold_set)

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 944/944 [06:32<00:00,  2.40it/s, calculating f1]


min rerank num-> 1891
top2:
{'precision': 0.18273305084745764, 'recall': 0.008975026014568158, 'f1': 0.01710970045625868, 'true_pos': 345, 'false_pos': 1543, 'false_neg': 38095, 'num_sample': 944}
top4:
{'precision': 0.24152542372881355, 'recall': 0.02372528616024974, 'f1': 0.04320636725412166, 'true_pos': 912, 'false_pos': 2864, 'false_neg': 37528, 'num_sample': 944}
top8:
{'precision': 0.2551641949152542, 'recall': 0.050130072840790844, 'f1': 0.08379718211862933, 'true_pos': 1927, 'false_pos': 5625, 'false_neg': 36513, 'num_sample': 944}
top16:
{'precision': 0.2232521186440678, 'recall': 0.08772112382934444, 'f1': 0.12595248767368894, 'true_pos': 3372, 'false_pos': 11732, 'false_neg': 35068, 'num_sample': 944}
top32:
{'precision': 0.1771054025423729, 'recall': 0.13917793964620187, 'f1': 0.15586761449714484, 'true_pos': 5350, 'false_pos': 24858, 'false_neg': 33090, 'num_sample': 944}
top64:
{'precision': 0.1299159163135593, 'recall': 0.20418834547346515, 'f1': 0.15879663348709233, 'tr

In [38]:
rerank_candidate=rerank_og(candidate_set,query_id2embedding,key_id2embedding,device=device)
show_rerank_rst(rerank_candidate,gold_set)

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 944/944 [00:53<00:00, 17.49it/s, calculating f1]


empty embedding cnt-> 576703
min rerank num-> 1891
top2:
{'precision': 0.2854872881355932, 'recall': 0.014021852237252861, 'f1': 0.026730807379488193, 'true_pos': 539, 'false_pos': 1349, 'false_neg': 37901, 'num_sample': 944}
top4:
{'precision': 0.23622881355932204, 'recall': 0.02320499479708637, 'f1': 0.0422588592003032, 'true_pos': 892, 'false_pos': 2884, 'false_neg': 37548, 'num_sample': 944}
top8:
{'precision': 0.1906779661016949, 'recall': 0.037460978147762745, 'f1': 0.06261958601495912, 'true_pos': 1440, 'false_pos': 6112, 'false_neg': 37000, 'num_sample': 944}
top16:
{'precision': 0.1428760593220339, 'recall': 0.056139438085327786, 'f1': 0.08060660391453758, 'true_pos': 2158, 'false_pos': 12946, 'false_neg': 36282, 'num_sample': 944}
top32:
{'precision': 0.10391287076271187, 'recall': 0.08165972944849116, 'f1': 0.09145204521617527, 'true_pos': 3139, 'false_pos': 27069, 'false_neg': 35301, 'num_sample': 944}
top64:
{'precision': 0.07363943326271187, 'recall': 0.11573881373569199,