In [1]:
%%capture
import json
from gensim.summarization import bm25
from gensim.utils import tokenize
from tqdm.notebook import tqdm

# source : https://github.com/RaRe-Technologies/gensim/blob/develop/gensim/summarization/bm25.py

In [9]:
FULL_SET_PATH = './data/aan_full.json'
TEST_SET_PATH = './data/aan_test.json'
CONCAT_OUTPUT_PATH = './results/base_bm25_concat_aan.json'
SUM_OUTPUT_PATH = './results/base_bm25_sum_aan.json'

In [2]:
FULL_SET_PATH = './data/dblp_full.json'
TEST_SET_PATH = './data/dblp_test.json'
CONCAT_OUTPUT_PATH = './results/base_bm25_concat_dblp.json'
SUM_OUTPUT_PATH = './results/base_bm25_sum_dblp.json'

In [3]:
with open(FULL_SET_PATH) as f:
    full_set = json.load(f)

In [4]:
corpus = [list(tokenize(ref['title'] + ' ' + ref['abstract'])) for ref in full_set]
full_set_ids = [ref['id'] for ref in full_set]
model = bm25.BM25(corpus)

In [5]:
with open(TEST_SET_PATH) as f:
    test_set = json.load(f)

# concat variant

In [6]:
results = []

for input_papers in tqdm(test_set):
    result = {}
    result['input'] = [input_paper['id'] for input_paper in input_papers]
    input_paper_ids_set = set(result['input'])
    
    input_document = ''
    for input_paper in input_papers:
        input_document += input_paper['title'] + ' ' + input_paper['abstract']
    input_document = input_document.split()
    
    candidate_scores = [(paper_id, score) for paper_id, score in zip(full_set_ids, model.get_scores(input_document))]
    candidate_scores.sort(key=lambda x: x[1], reverse=True)
    filtered_candidate_scores = [cs for cs in candidate_scores if cs[0] not in input_paper_ids_set]    
    
    result['output'] = [cs[0] for cs in filtered_candidate_scores[:100]]
    results.append(result)

HBox(children=(FloatProgress(value=0.0, max=2000.0), HTML(value='')))




In [7]:
with open(CONCAT_OUTPUT_PATH, 'w') as f:
    json.dump(results, f)

# sum variant

In [8]:
results = []

for input_papers in tqdm(test_set):
    result = {}
    result['input'] = [input_paper['id'] for input_paper in input_papers]
    input_paper_ids_set = set(result['input'])
    
    partial_scores = [[] for _ in range(3)]
    
    for i, input_paper in enumerate(input_papers):
        input_document = (input_paper['title'] + ' ' + input_paper['abstract']).split()    
        partial_scores[i].append(model.get_scores(input_document))
        
    complete_scores = [s1+s2+s3 for s1,s2,s3 in zip(partial_scores[0][0], partial_scores[1][0], partial_scores[2][0])]
    candidate_scores = [(paper_id, score) for paper_id, score in zip(full_set_ids, complete_scores)]   
    candidate_scores.sort(key=lambda x: x[1], reverse=True)
    filtered_candidate_scores = [cs for cs in candidate_scores if cs[0] not in input_paper_ids_set]    
    
    result['output'] = [cs[0] for cs in filtered_candidate_scores[:100]]
    results.append(result)

HBox(children=(FloatProgress(value=0.0, max=2000.0), HTML(value='')))




In [9]:
with open(SUM_OUTPUT_PATH, 'w') as f:
    json.dump(results, f)