In [None]:
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 

import warnings
warnings.filterwarnings('ignore')

import re
import json
import sent2vec

import numpy as np
import pandas as pd
import networkx as nx
import tqdm.notebook as tqdm

from rank_bm25 import BM25Okapi

from collections import defaultdict

from nltk import word_tokenize
from spacy.lang.en.stop_words import STOP_WORDS

from sentence_splitter import SentenceSplitter 
from sentence_splitter import split_text_into_sentences

In [None]:
biosentvec = sent2vec.Sent2vecModel()
biosentvec.load_model('bio_sent_vec.file')

In [None]:
df = []
data = json.load(open('dataset/covidqa.json'))
for category in data['categories']:
    for subcategory in category['sub_categories']:
        for answer in subcategory['answers']:
            df.append({
                'natural_language_query': subcategory['nq_name'],
                'keyword_query': subcategory['kq_name'],
                'cord_id': answer['id'],
                'title': answer['title'],
                'answer': answer['exact_answer']
            })
df = pd.DataFrame(df)
display(df)

In [None]:
cord_ids = set(df['cord_id'].tolist())

In [None]:
metadata = pd.read_csv('dataset/cord19-round1/metadata.csv')
display(metadata)

In [None]:
def cosine(a, b):
    return np.sum(a * b) / ((np.sum(a ** 2) ** 0.5) * (np.sum(b ** 2) ** 0.5))

In [None]:
path = {}
body = defaultdict(list)

splitter = SentenceSplitter(language = 'en')

for ind, row in tqdm.tqdm(metadata.iterrows()):
    
    if row['cord_uid'] in cord_ids:
        
        text = ''
        
        for filename in row['sha'].split('; '):
        
            filename = filename + '.json'
            pdf_path = '/'.join(['dataset', 'cord19-round1', row['full_text_file'], 'pdf_json', filename])
            path[row['cord_uid']] = pdf_path

            file = json.load(open(pdf_path))

            for abstract_content in file['abstract']:
                text += ' ' + abstract_content['text']

            for body_content in file['body_text']:
                text += ' ' + body_content['text']
                
            for ref_content in file['ref_entries'].values():
                text += ' ' + ref_content['text']

        sentences = splitter.split(text = text)
        
        for sent in sentences:
            body[row['cord_uid']].append(sent)

In [None]:
stopwords = set(STOP_WORDS)

custom_stop_words = [
    'doi', 'preprint', 'copyright', 'peer', 'reviewed', 'org', 'https', 'et', 'al', 'author', 'figure', 
    'rights', 'reserved', 'permission', 'used', 'using', 'biorxiv', 'medrxiv', 'license', 'fig', 'fig.', 
    'al.', 'Elsevier', 'PMC', 'CZI','table'
]

stopwords |= set(custom_stop_words)

stopwords = set([word.lower() for word in stopwords])

def clean(text):
    text = text.split()
    text = ' '.join(text)
    text = text.lower()
    text = text.strip()
    return text

def remove_stopwords(text):
    filtered_text = []
    for word in word_tokenize(text):
        if word not in stopwords:
            filtered_text.append(word)
    filtered_text = ' '.join(filtered_text)
    return filtered_text

def further_preprocess(sent):
    sent = clean(sent)
    sent = re.sub('[^a-z ]+', '', sent)
    sent = clean(sent)
    sent = remove_stopwords(sent)
    sent = clean(sent)
    return sent

In [None]:
mrr = 0
pat1 = 0
rat3 = 0
n = df.shape[0]

for index, row in tqdm.tqdm(list(df.iterrows())[:n]):
    
    answer = clean(row['answer'])
    query = clean(row['natural_language_query'])
    query = further_preprocess(query)
    
    correct = set()
    for ind, sent in enumerate(body[row['cord_id']]):
        sent = clean(sent)
        if sent.find(answer) != -1:
            correct.add(ind)

    sents = []
    tokenized_sents = []
    for sent in body[row['cord_id']]:
        sent = clean(sent)
        sents.append(sent)
        sent = further_preprocess(sent)
        tokenized_sents.append(sent.split())
        
    bm25 = BM25Okapi(tokenized_sents)
    bm_scores = bm25.get_scores(query.split())
    bm_scores = np.array(bm_scores)
    bm_scores /= np.sum(bm_scores ** 2) ** 0.5
    
    query_emb = biosentvec.embed_sentences([query])[0]
    bioemb = biosentvec.embed_sentences(sents)
    bioemb_scores = []
    for i in range(len(sents)):
        if np.sum(bioemb[i]):
            bioemb_scores.append(cosine(query_emb, bioemb[i]))
        else:
            bioemb_scores.append(0)
    bioemb_scores = np.array(bioemb_scores)
    bioemb_scores /= np.sum(bioemb_scores ** 2) ** 0.5
    
    edge_count = 0
    adjList = defaultdict(set)
    for i in range(len(sents)):
        for j in range(i + 1, len(sents)):
            if np.sum(bioemb[i]) and np.sum(bioemb[j]) and cosine(bioemb[i], bioemb[j]) >= 0.1:
                if cosine(bioemb[i], query_emb) >= 0.1 and cosine(bioemb[j], query_emb) >= 0.1:
                    adjList[i].add(j)
                    adjList[j].add(i)
                    edge_count += 1
    
    G = nx.Graph()
    G.add_nodes_from(list(range(len(sents))))

    for i in adjList.keys():
        for j in adjList[i]:
            G.add_edge(i, j)

    pagerank = nx.pagerank(G)
    pr_scores = [0] * len(sents)
    for key, val in pagerank.items():
        pr_scores[key] = val
    pr_scores = np.array(pr_scores)
    pr_scores /= np.sum(pr_scores ** 2) ** 0.5
    
    tmp = []
    for i in range(len(sents)):
        tmp.append([bm_scores[i] + bioemb_scores[i] + pr_scores[i], i])
    tmp.sort(reverse = True)
    
    ranklist = []
    for entry in tmp:
        ranklist.append(entry[1])
    
    if ranklist[0] in correct:
        pat1 += 1
        
    rat3 +=  len(set(ranklist[: 3]) & correct) / len(correct)
    
    rr = []
    for i in correct:
        rr.append(1 / (ranklist.index(i) + 1))
    mrr += max(rr)

pat1 /= n
rat3 /= n
mrr /= n

print('P@1: %.4f' % round(pat1, 4))
print('R@3: %.4f' % round(rat3, 4))
print('MRR: %.4f' % round(mrr, 4))