### This file will extract a list of paper DOIs that belong to similar papers to a given seed paper.
- Similarity is measured using sentence embeddings. Similarity threshold can be decided manually
- Make sure to execute every cell sequentially. get_similar_papers() is the function that returns a list of relevant papers' DOIs. 

In [1]:
import torch
device = 'cuda' if torch.cuda.is_available() else 'cpu'

import json
from sentence_transformers import SentenceTransformer, util
import numpy as np
import random
random.seed(42)

  from tqdm.autonotebook import tqdm, trange


In [2]:
model = SentenceTransformer('all-mpnet-base-v2').to(device)




In [63]:
with open("seed_paper_10.json") as file:  # Take seed paper 10 as example
    refs_10 = json.load(file) 
    

In [82]:
def is_similar(target, ref, threshold):  # Use sentence embeddings to measure the similarity and return a similarity score
    target_embedding = model.encode(target, convert_to_tensor=True, device=device)
    ref_embeddings = model.encode(ref, convert_to_tensor=True, device=device)
    similarity_scores = model.similarity(target_embedding, ref_embeddings)
    return similarity_scores[0] >= threshold, similarity_scores[0]


In [39]:
def concatenate_title_abs(title, abs):  # Concatenate title with abstract in account of type mismatch
    if type(title) != str or type(abs) != str:
        if type(title) == list:
            str_title = ""
            for text in title:
                str_title += text + " "
            title = str_title
        
        if type(abs) == list:
            str_abs = ""
            for text in abs:
                str_abs += text + " "
            abs = str_abs
        
            
    return str(title) + ": " + str(abs)

In [88]:
# Return (1) a list of DOIs from similar papers to a seed paper and (2) total number of papers
# Sample usage: similar_papers_8, paper_count_8 = get_similar_papers(seed_paper_refs=refs_8)

def get_similar_papers(seed_paper_refs, similarity_threshold=0.64):  
    similar_papers = []
    paper_count = 1
    
    seed_title = seed_paper_refs[0]['metadata']['title']
    seed_abs = seed_paper_refs[0]['metadata']['abstract']
    seed_title_abs =  concatenate_title_abs(seed_title, seed_abs) 
    
    for lev_1_ref in seed_paper_refs[0]['references']:
        lev_1_title = lev_1_ref['metadata']['title']
        lev_1_abs = lev_1_ref['metadata']['abstract']
        lev_1_title_abs = concatenate_title_abs(lev_1_title, lev_1_abs)  
        similar_lev_1, score_lev_1 = is_similar(seed_title_abs, lev_1_title_abs, similarity_threshold)
        
        paper_count += 1 + len(lev_1_ref['references'])
        
        
        if similar_lev_1 and len(lev_1_ref['metadata']['doi']) != 0:
            similar_papers.append(lev_1_ref['metadata']['doi'])
            # similar_papers.append((lev_1_ref['metadata']['doi'], score_lev_1))

            for lev_2_ref in lev_1_ref['references']:
                lev_2_title = lev_2_ref['metadata']['title']
                lev_2_abs = lev_2_ref['metadata']['abstract']
                lev_2_title_abs = concatenate_title_abs(lev_2_title, lev_2_abs)
                
                target_lev_2 = seed_title_abs + " " + lev_1_title_abs
                similar_lev_2, score_lev_2 = is_similar(target_lev_2, lev_2_title_abs, similarity_threshold)
                
                
                if similar_lev_2 and len(lev_2_ref['metadata']['doi']) != 0:
                    similar_papers.append(lev_2_ref['metadata']['doi'])
                    # similar_papers.append((lev_2_ref['metadata']['doi'], score_lev_2))
    

    return similar_papers, paper_count



In [92]:
# Sample usage:
similar_papers_10, paper_count_10 = get_similar_papers(seed_paper_refs=refs_10)
print(len(similar_papers_10))
print(paper_count_10)
print(similar_papers_10[0:10])