In [2]:
import torch
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.preprocessing import normalize

import faiss  
import pandas as pd
import numpy as np
import dill as pickle
import sqlite3

In [3]:
con = sqlite3.connect("../cord.db")
df = pd.read_sql_query("SELECT title, abstract, authors, body_text FROM cord19", con)

In [6]:
a = pd.read_sql_query("SELECT * FROM cord19", con)
a.iloc[[3835, 421], :]

Unnamed: 0,cord_uid,title,doi,pubmed_id,license,abstract,publish_time,authors,journal,pdf_json_files,url,body_text
3835,s8ohts1j,The metaRbolomics Toolbox in Bioconductor and ...,10.3390/metabo9100200,31548506,cc-by,Metabolomics aims to measure and characterise ...,2019-09-23,"Stanstrup, Jan; Broeckling, Corey D.; Helmus, ...",Metabolites,document_parses/pdf_json/8c44e37bccf0c2a493e1b...,https://www.ncbi.nlm.nih.gov/pmc/articles/PMC6...,"Metabolomics aims to measure, identify and (se..."
421,5z3s5pjl,The calculation of information and organismal ...,10.1186/1745-6150-5-59,20937149,cc-by,BACKGROUND: It is difficult to measure precise...,2010-10-12,"Jiang, Yun; Xu, Cunshuan",Biol Direct,document_parses/pdf_json/6daca390ad9b7e330f843...,https://www.ncbi.nlm.nih.gov/pmc/articles/PMC2...,Organismal complexity is difficult to define a...


In [4]:
model = SentenceTransformer('monologg/biobert_v1.1_pubmed')

Some weights of the model checkpoint at /home/arnav/.cache/torch/sentence_transformers/monologg_biobert_v1.1_pubmed were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.weight', 'cls.predictions.decoder.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [5]:
if torch.cuda.is_available():
   model = model.to(torch.device("cuda"))
print(model.device)

cuda:0


In [7]:
with open('../data/title_embeddings.pickle', 'rb') as file:
    title_embeddings = pickle.load(file)
    
with open('../data/document_embeddings.pickle', 'rb') as file:
    document_embeddings = pickle.load(file)

In [159]:
num_vectors = document_embeddings.shape[0]
dimension = document_embeddings.shape[1]
num_neighbours = 100
title_weight = 0.6
document_weight = 1 - title_weight

In [172]:
title_index = faiss.IndexFlatIP(dimension)

title_index.add(normalize(title_embeddings, norm='l2'))

document_index = faiss.IndexFlatIP(dimension)
document_index.add(normalize(document_embeddings, norm='l2'))

In [173]:
query = ['what is the cause of diseases']
query_embedding = model.encode(query)
query_embedding_normalized = normalize(query_embedding, norm='l2')

In [182]:
papers_dict = {}

In [183]:
title_distances, title_indices = title_index.search(query_embedding_normalized, num_neighbours)
document_distances, document_indices = document_index.search(query_embedding_normalized, num_neighbours)

In [184]:
papers = list((set(title_indices[0]) & set(document_indices[0])) | set(title_indices[0, -5:]) | set(document_indices[0, -5:]))
print(papers)

[3585, 1922, 483, 421, 3877, 2824, 2472, 3307, 2764, 1741, 887, 19, 277, 2775, 1496, 1371, 4407, 1822]


In [207]:
for paper in papers :
    title_distance = 0
    document_distance = 0
    
    if paper in title_indices[0]:
        index = np.where(title_indices[0] == paper)
        title_distance = title_distances[0][index][0]
    else : title_distance = cosine_similarity(title_embeddings[paper].reshape(1, -1), query_embedding)[0][0]
    
    if paper in document_indices[0]:
        index = np.where(document_indices[0] == paper)
        document_distance = document_distances[0][index][0]
    else : document_distance = cosine_similarity(document_embeddings[paper].reshape(1, -1), query_embedding)[0][0]
        

    print(paper, title_distance, document_distance)
        
    papers_dict[paper] = title_weight * title_distance + document_weight * document_distance

3585 0.8229039 0.7957651
1922 0.8096517 0.77606976
483 0.69900274 0.79075366
421 0.8636091 0.79422915
3877 0.8097015 0.75759286
2824 0.82425416 0.79487664
2472 0.80035007 0.7908313
3307 0.778925 0.79065186
2764 0.79390657 0.79047155
1741 0.8097362 0.78029627
887 0.81014943 0.7928622
19 0.7935066 0.7907141
277 0.8121701 0.82344
2775 0.8253701 0.7916711
1496 0.8096403 0.777135
1371 0.8157666 0.79170644
4407 0.809805 0.7973352
1822 0.829394 0.7914342


In [208]:
papers_dict

{3585: 0.812048363685608,
 1922: 0.7962189078330993,
 483: 0.7357031106948853,
 421: 0.8358571052551269,
 3877: 0.7888580441474915,
 2824: 0.8125031471252442,
 2472: 0.7965425729751587,
 3307: 0.7836157441139222,
 2764: 0.7925325632095337,
 1741: 0.7979602217674255,
 887: 0.8032345294952392,
 19: 0.7923896074295044,
 277: 0.8166780591011047,
 2775: 0.8118904829025269,
 1496: 0.7966381788253785,
 1371: 0.806142520904541,
 4407: 0.8048170685768128,
 1822: 0.8142100811004638}