In [1]:
import transformers
from datasets import load_dataset
import pandas as pd
import numpy as np
import torch
import time
from tqdm import tqdm_notebook as tqdm

In [2]:
from beir import util, LoggingHandler
from beir.datasets.data_loader import GenericDataLoader
import logging
import pathlib, os

#### Just some code to print debug information to stdout
logging.basicConfig(format='%(asctime)s - %(message)s',
                    datefmt='%Y-%m-%d %H:%M:%S',
                    level=logging.INFO,
                    handlers=[LoggingHandler()])
#### /print debug information to stdout

In [3]:
import pathlib, os
from beir import util

dataset = "scifact"
url = "https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{}.zip".format(dataset)
out_dir = os.path.join(os.getcwd(), "datasets")
data_path = util.download_and_unzip(url, out_dir)
print("Dataset downloaded here: {}".format(data_path))

Dataset downloaded here: /home/toghrul/ada/ml/final/datasets/scifact


In [4]:
!ls datasets/scifact/

corpus.jsonl  embeddings.csv  qrels	     results.csv
cosine	      hybrid	      queries.jsonl  retrieval_results.csv


In [5]:
corpus, queries, qrels = GenericDataLoader(data_folder=data_path).load(split="test")

2024-05-23 14:24:20 - Loading Corpus...


  0%|          | 0/5183 [00:00<?, ?it/s]

2024-05-23 14:24:20 - Loaded 5183 TEST Documents.
2024-05-23 14:24:20 - Doc Example: {'text': 'Alterations of the architecture of cerebral white matter in the developing human brain can affect cortical development and result in functional disabilities. A line scan diffusion-weighted magnetic resonance imaging (MRI) sequence with diffusion tensor analysis was applied to measure the apparent diffusion coefficient, to calculate relative anisotropy, and to delineate three-dimensional fiber architecture in cerebral white matter in preterm (n = 17) and full-term infants (n = 7). To assess effects of prematurity on cerebral white matter development, early gestation preterm infants (n = 10) were studied a second time at term. In the central white matter the mean apparent diffusion coefficient at 28 wk was high, 1.8 microm2/ms, and decreased toward term to 1.2 microm2/ms. In the posterior limb of the internal capsule, the mean apparent diffusion coefficients at both times were similar (1.2 vers

In [6]:
corpus_idx = list(corpus.keys())
corpus_vals = list(corpus.values())

corpus_df = pd.DataFrame(corpus_vals, index=corpus_idx)

In [7]:
corpus_df

Unnamed: 0,text,title
4983,Alterations of the architecture of cerebral wh...,Microstructural development of human newborn c...
5836,Myelodysplastic syndromes (MDS) are age-depend...,Induction of myelodysplasia by myeloid-derived...
7912,ID elements are short interspersed elements (S...,"BC1 RNA, the transcript from a master gene for..."
18670,DNA methylation plays an important role in bio...,The DNA Methylome of Human Peripheral Blood Mo...
19238,Two human Golli (for gene expressed in the oli...,The human myelin basic protein gene is include...
...,...,...
195689316,BACKGROUND The main associations of body-mass ...,Body-mass index and cause-specific mortality i...
195689757,A key aberrant biological difference between t...,Targeting metabolic remodeling in glioblastoma...
196664003,A signaling pathway transmits information from...,Signaling architectures that transmit unidirec...
198133135,AIMS Trabecular bone score (TBS) is a surrogat...,"Association between pre-diabetes, type 2 diabe..."


In [8]:
queries_df = pd.Series(queries)

In [16]:
from typing import List
import logging
from pydantic import BaseModel

# from rag import insert_document_and_embeddings, find_similar_embeddings, preprocess
from datetime import datetime
import re
from nltk import tokenize
import unicodedata
import string
import logging
import torch
from transformers import AutoTokenizer, AutoModel
from sklearn.metrics.pairwise import cosine_similarity

MAX_WORD_COUNT = 256
MAX_TOKEN_COUNT = 512


model_name = "mixedbread-ai/mxbai-embed-large-v1"

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)


def encoding(text: str) -> str:
    """
    Remove unicoded data
    """
    text = unicodedata.normalize("NFKD", text)

    return text


def remove_URL(text: str) -> str:
    """
    Remove URLs
    """
    return re.sub(r"https?://\S+|www\.\S+", "", text)


def remove_non_ascii(text: str) -> str:
    """
    Remove non-ASCII characters
    """
    return re.sub(r"[^\x00-\x7f]", r"", text)


def remove_html(text: str) -> str:
    """
    Remove the html
    """
    html = re.compile(r"<.*?>|&([a-z0-9]+|#[0-9]{1,6}|#x[0-9a-f]{1,6});")
    return re.sub(html, "", text)


def remove_punct(text: str) -> str:
    """
    Remove the punctuation
    """
    #     return re.sub(r'[]!"$%&\'()*+,./:;=#@?[\\^_`{|}~-]+', "", text)
    return text.translate(str.maketrans("", "", string.punctuation))


def preprocess(text: str) -> str:
    """
    Preprocess the text
    """

    text = encoding(text)
    text = remove_URL(text)
    text = remove_non_ascii(text)
    text = remove_html(text)
    # text = remove_punct(text)
    return text


def ingest_input(user_input):
    user_input = preprocess(user_input)
    # logging.info(f"Preprocessed user input")

    # Generate sentence tokens
    sentence_tokens = tokenize.sent_tokenize(user_input)
    model_input = []
    temp_input: str = ""
    if len(user_input.split(" ")) > MAX_WORD_COUNT:
        logging.info(
            f"Input contains more than {MAX_WORD_COUNT} words. Splitting the input into chunks"
        )

        # Split the input into chunks based on the sentence tokens
        for i, sent in enumerate(sentence_tokens):
            num_words_sent = len(sent.split(" "))

            # Check if the new chunk would exceed the maximum word count
            if len(temp_input.split(" ")) + num_words_sent > MAX_WORD_COUNT:

                # Append the chunk to the model input
                model_input.append(temp_input.strip())
                logging.info(
                    f"Number of words in the chunk: {len(temp_input.split(' '))}"
                )
                temp_input = sent
            else:
                temp_input += " " + sent

        # Append the last chunk to the model input
        model_input.append(temp_input)
    else:
        model_input = [user_input]

    return model_input


def read_pdf_doc(filepath):
    doc = fitz.open(filepath)
    text = ""
    for page_index, page in enumerate(doc):
        logging.info(f"page {page_index+1} out of {len(doc)}")
        tp = page.get_textpage()
        words = tp.extractWORDS()

        page_text = " ".join([word[4] for word in words])
        text += page_text + " "
    doc.close()
    return text


def generate_embeddings(text, device):
    inputs = tokenizer(
        text,
        return_tensors="pt",
        truncation=True,
        max_length=MAX_TOKEN_COUNT,
        padding="max_length",
    )
    with torch.no_grad():
        inputs = inputs.to(device)
        outputs = model(**inputs)

    outputs = outputs.last_hidden_state.mean(dim=1).cpu().numpy()
    # Scale the embeddings to be between 0 and 1
    outputs = (outputs - outputs.min()) / (outputs.max() - outputs.min())
    return outputs


# Embedding Generation

In [34]:
text_chunks_list = []
embeddings_list = []
doc_ids = []

batch_size = 128
batch_no = 0
text_batch = []
ids_batch = []

for idx, doc in tqdm(corpus_df.iterrows(), total=len(corpus_df), desc="Processing documents"):
    text_chunks = ingest_input(doc['text'])

    start = time.time()
    text_batch.extend(text_chunks)
    ids_batch.extend([idx] * len(text_chunks))

    if len(text_batch) == batch_size:
        embeddings = generate_embeddings(text_batch, device)
        embeddings_list.append(embeddings)
        text_chunks_list.extend(text_batch)
        doc_ids.extend(ids_batch)

        logging.info(f">>> Generated embeddings for batch {batch_no}")
        logging.info(f"Shape of embeddings: {embeddings.shape}")
        logging.info(f"Time taken for the batch: {time.time() - start}")

        text_batch = []
        ids_batch = []
        batch_no += 1
        
# Handle any remaining batches
if len(text_batch) > 0:
    embeddings = generate_embeddings(text_batch, device)
    embeddings_list.append(embeddings)
    text_chunks_list.extend(text_batch)
    doc_ids.extend(ids_batch)

    logging.info(f">>> Generated embeddings for final batch {batch_no}")
    logging.info(f"Shape of embeddings: {embeddings.shape}")
    logging.info(f"Time taken for the batch: {time.time() - start}")

embeddings_list = np.concatenate(embeddings_list, axis=0)
doc_ids = np.array(doc_ids)


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for idx, doc in tqdm(corpus_df.iterrows(), total=len(corpus_df), desc="Processing documents"):


Processing documents:   0%|          | 0/5183 [00:00<?, ?it/s]

2024-05-23 14:21:31 - Input contains more than 256 words. Splitting the input into chunks
2024-05-23 14:21:31 - Number of words in the chunk: 254
2024-05-23 14:21:31 - Input contains more than 256 words. Splitting the input into chunks
2024-05-23 14:21:31 - Number of words in the chunk: 220
2024-05-23 14:21:31 - Input contains more than 256 words. Splitting the input into chunks
2024-05-23 14:21:31 - Number of words in the chunk: 235
2024-05-23 14:21:31 - Input contains more than 256 words. Splitting the input into chunks
2024-05-23 14:21:31 - Number of words in the chunk: 245
2024-05-23 14:21:31 - Input contains more than 256 words. Splitting the input into chunks
2024-05-23 14:21:31 - Number of words in the chunk: 245
2024-05-23 14:21:31 - Input contains more than 256 words. Splitting the input into chunks
2024-05-23 14:21:31 - Number of words in the chunk: 228
2024-05-23 14:21:31 - Input contains more than 256 words. Splitting the input into chunks
2024-05-23 14:21:31 - Number of wo

KeyboardInterrupt: 

In [None]:
# Convert the embeddings and doc_ids to a DataFrame
embeddings_df = pd.DataFrame({
    'text': text_chunks_list,
    'embedding': embeddings_list.tolist(),
    'doc_id': doc_ids
})


In [None]:
embeddings_df.to_csv(os.path.join(data_path, "embeddings.csv"), index=False)

# Retrieval

In [9]:
embeddings_df = pd.read_csv(os.path.join(data_path, "embeddings.csv"))
# embeddings_df.loc[:, "embedding"] = embeddings_df["embedding"].apply(literal_eval)

In [10]:
embeddings_df["embedding"] = embeddings_df["embedding"].apply(lambda x: x[1:-1].split(', '))

In [11]:
embeddings_df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 6231 entries, 0 to 6230
Data columns (total 3 columns):
 #   Column     Non-Null Count  Dtype 
---  ------     --------------  ----- 
 0   text       6231 non-null   object
 1   embedding  6231 non-null   object
 2   doc_id     6231 non-null   int64 
dtypes: int64(1), object(2)
memory usage: 146.2+ KB


In [12]:
from sentence_transformers import SentenceTransformer, util


def sentence_similarity(text1, text2):
    embedding_1 = model.encode(text1, convert_to_tensor=True)
    embedding_2 = model.encode(text2, convert_to_tensor=True)
    o = util.pytorch_cos_sim(embedding_1, embedding_2)
    return o.item()


model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")

2024-05-23 14:24:30 - Load pretrained SentenceTransformer: sentence-transformers/all-MiniLM-L6-v2
2024-05-23 14:24:32 - Use pytorch device_name: cuda


In [13]:
text_batch = []
sentence_embeddings_list = []
batch_size = 128
batch_no = 0


for idx, row in tqdm(embeddings_df.iterrows(), total=len(embeddings_df), desc="Processing queries"):
    text_batch.append(row['text'])

    if len(text_batch) == batch_size:
        embeddings = model.encode(text_batch)
        sentence_embeddings_list.extend(embeddings.tolist())
        text_batch = []
        logging.info(f">>> Generated embeddings for batch {batch_no}")
        batch_no += 1
        
# Handle any remaining batches
if len(text_batch) > 0:
    embeddings = model.encode(text_batch)
    sentence_embeddings_list.extend(embeddings.tolist())
    text_batch = []
    
embeddings_df["sentence_embedding"] = sentence_embeddings_list

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for idx, row in tqdm(embeddings_df.iterrows(), total=len(embeddings_df), desc="Processing queries"):


Processing queries:   0%|          | 0/6231 [00:00<?, ?it/s]

Batches:   0%|          | 0/4 [00:00<?, ?it/s]

2024-05-23 14:24:33 - >>> Generated embeddings for batch 0


Batches:   0%|          | 0/4 [00:00<?, ?it/s]

2024-05-23 14:24:33 - >>> Generated embeddings for batch 1


Batches:   0%|          | 0/4 [00:00<?, ?it/s]

2024-05-23 14:24:33 - >>> Generated embeddings for batch 2


Batches:   0%|          | 0/4 [00:00<?, ?it/s]

2024-05-23 14:24:34 - >>> Generated embeddings for batch 3


Batches:   0%|          | 0/4 [00:00<?, ?it/s]

2024-05-23 14:24:34 - >>> Generated embeddings for batch 4


Batches:   0%|          | 0/4 [00:00<?, ?it/s]

2024-05-23 14:24:34 - >>> Generated embeddings for batch 5


Batches:   0%|          | 0/4 [00:00<?, ?it/s]

2024-05-23 14:24:34 - >>> Generated embeddings for batch 6


Batches:   0%|          | 0/4 [00:00<?, ?it/s]

2024-05-23 14:24:35 - >>> Generated embeddings for batch 7


Batches:   0%|          | 0/4 [00:00<?, ?it/s]

2024-05-23 14:24:35 - >>> Generated embeddings for batch 8


Batches:   0%|          | 0/4 [00:00<?, ?it/s]

2024-05-23 14:24:35 - >>> Generated embeddings for batch 9


Batches:   0%|          | 0/4 [00:00<?, ?it/s]

2024-05-23 14:24:35 - >>> Generated embeddings for batch 10


Batches:   0%|          | 0/4 [00:00<?, ?it/s]

2024-05-23 14:24:36 - >>> Generated embeddings for batch 11


Batches:   0%|          | 0/4 [00:00<?, ?it/s]

2024-05-23 14:24:36 - >>> Generated embeddings for batch 12


Batches:   0%|          | 0/4 [00:00<?, ?it/s]

2024-05-23 14:24:36 - >>> Generated embeddings for batch 13


Batches:   0%|          | 0/4 [00:00<?, ?it/s]

2024-05-23 14:24:36 - >>> Generated embeddings for batch 14


Batches:   0%|          | 0/4 [00:00<?, ?it/s]

2024-05-23 14:24:37 - >>> Generated embeddings for batch 15


Batches:   0%|          | 0/4 [00:00<?, ?it/s]

2024-05-23 14:24:37 - >>> Generated embeddings for batch 16


Batches:   0%|          | 0/4 [00:00<?, ?it/s]

2024-05-23 14:24:37 - >>> Generated embeddings for batch 17


Batches:   0%|          | 0/4 [00:00<?, ?it/s]

2024-05-23 14:24:37 - >>> Generated embeddings for batch 18


Batches:   0%|          | 0/4 [00:00<?, ?it/s]

2024-05-23 14:24:38 - >>> Generated embeddings for batch 19


Batches:   0%|          | 0/4 [00:00<?, ?it/s]

2024-05-23 14:24:38 - >>> Generated embeddings for batch 20


Batches:   0%|          | 0/4 [00:00<?, ?it/s]

2024-05-23 14:24:38 - >>> Generated embeddings for batch 21


Batches:   0%|          | 0/4 [00:00<?, ?it/s]

2024-05-23 14:24:38 - >>> Generated embeddings for batch 22


Batches:   0%|          | 0/4 [00:00<?, ?it/s]

2024-05-23 14:24:39 - >>> Generated embeddings for batch 23


Batches:   0%|          | 0/4 [00:00<?, ?it/s]

2024-05-23 14:24:39 - >>> Generated embeddings for batch 24


Batches:   0%|          | 0/4 [00:00<?, ?it/s]

2024-05-23 14:24:39 - >>> Generated embeddings for batch 25


Batches:   0%|          | 0/4 [00:00<?, ?it/s]

2024-05-23 14:24:39 - >>> Generated embeddings for batch 26


Batches:   0%|          | 0/4 [00:00<?, ?it/s]

2024-05-23 14:24:39 - >>> Generated embeddings for batch 27


Batches:   0%|          | 0/4 [00:00<?, ?it/s]

2024-05-23 14:24:40 - >>> Generated embeddings for batch 28


Batches:   0%|          | 0/4 [00:00<?, ?it/s]

2024-05-23 14:24:40 - >>> Generated embeddings for batch 29


Batches:   0%|          | 0/4 [00:00<?, ?it/s]

2024-05-23 14:24:40 - >>> Generated embeddings for batch 30


Batches:   0%|          | 0/4 [00:00<?, ?it/s]

2024-05-23 14:24:40 - >>> Generated embeddings for batch 31


Batches:   0%|          | 0/4 [00:00<?, ?it/s]

2024-05-23 14:24:41 - >>> Generated embeddings for batch 32


Batches:   0%|          | 0/4 [00:00<?, ?it/s]

2024-05-23 14:24:41 - >>> Generated embeddings for batch 33


Batches:   0%|          | 0/4 [00:00<?, ?it/s]

2024-05-23 14:24:41 - >>> Generated embeddings for batch 34


Batches:   0%|          | 0/4 [00:00<?, ?it/s]

2024-05-23 14:24:41 - >>> Generated embeddings for batch 35


Batches:   0%|          | 0/4 [00:00<?, ?it/s]

2024-05-23 14:24:42 - >>> Generated embeddings for batch 36


Batches:   0%|          | 0/4 [00:00<?, ?it/s]

2024-05-23 14:24:42 - >>> Generated embeddings for batch 37


Batches:   0%|          | 0/4 [00:00<?, ?it/s]

2024-05-23 14:24:42 - >>> Generated embeddings for batch 38


Batches:   0%|          | 0/4 [00:00<?, ?it/s]

2024-05-23 14:24:42 - >>> Generated embeddings for batch 39


Batches:   0%|          | 0/4 [00:00<?, ?it/s]

2024-05-23 14:24:43 - >>> Generated embeddings for batch 40


Batches:   0%|          | 0/4 [00:00<?, ?it/s]

2024-05-23 14:24:43 - >>> Generated embeddings for batch 41


Batches:   0%|          | 0/4 [00:00<?, ?it/s]

2024-05-23 14:24:43 - >>> Generated embeddings for batch 42


Batches:   0%|          | 0/4 [00:00<?, ?it/s]

2024-05-23 14:24:43 - >>> Generated embeddings for batch 43


Batches:   0%|          | 0/4 [00:00<?, ?it/s]

2024-05-23 14:24:44 - >>> Generated embeddings for batch 44


Batches:   0%|          | 0/4 [00:00<?, ?it/s]

2024-05-23 14:24:44 - >>> Generated embeddings for batch 45


Batches:   0%|          | 0/4 [00:00<?, ?it/s]

2024-05-23 14:24:44 - >>> Generated embeddings for batch 46


Batches:   0%|          | 0/4 [00:00<?, ?it/s]

2024-05-23 14:24:44 - >>> Generated embeddings for batch 47


Batches:   0%|          | 0/3 [00:00<?, ?it/s]

In [15]:
embeddings_df.dtypes

text                  object
embedding             object
doc_id                 int64
sentence_embedding    object
dtype: object

In [14]:
queries_df = pd.DataFrame({
    'query_id': list(queries.keys()),
    'query': list(queries.values())
})
queries_df

Unnamed: 0,query_id,query
0,1,0-dimensional biomaterials show inductive prop...
1,3,"1,000 genomes project enables mapping of genet..."
2,5,1/2000 in UK have abnormal PrP positivity.
3,13,5% of perinatal mortality is due to low birth ...
4,36,A deficiency of vitamin B12 increases blood le...
...,...,...
295,1379,Women with a higher birth weight are more like...
296,1382,aPKCz causes tumour enhancement by affecting g...
297,1385,cSMAC formation enhances weak ligand signalling.
298,1389,mTORC2 regulates intracellular cysteine levels...


In [18]:
retrieved_docs_list = []
retrieved_text_list = []
query_ids = []
batch_size = 128
k=10
query_batch = []
query_embeddings_list = []
query_ids = []
batch_no = 0

for idx, query in tqdm(queries_df.iterrows(), total=len(queries_df), desc="Processing queries"):
    query_batch.append(ingest_input(query['query'])[0])
    
    if len(query_batch) == batch_size:
        query_embeddings = generate_embeddings(query_batch, device)
        query_embeddings_list.extend(query_embeddings.tolist())
        # query_ids.extend([query['query_id']] * batch_size)
        
        query_batch = []
        logging.info(f">>> Generated embeddings for batch {batch_no}")
        logging.info(f"Shape of embeddings: {query_embeddings.shape}")
        batch_no += 1
        
# Handle any remaining queries
if len(query_batch) > 0:
    query_embeddings = generate_embeddings(query_batch, device)
    query_embeddings_list.extend(query_embeddings.tolist())
    # query_ids.extend([query['query_id']] * len(query_batch))
    
    logging.info(f">>> Generated embeddings for final batch {batch_no}")
    logging.info(f"Shape of embeddings: {query_embeddings.shape}")




Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for idx, query in tqdm(queries_df.iterrows(), total=len(queries_df), desc="Processing queries"):


Processing queries:   0%|          | 0/300 [00:00<?, ?it/s]

2024-05-23 14:25:53 - >>> Generated embeddings for batch 0
2024-05-23 14:25:53 - Shape of embeddings: (128, 1024)
2024-05-23 14:25:59 - >>> Generated embeddings for batch 1
2024-05-23 14:25:59 - Shape of embeddings: (128, 1024)
2024-05-23 14:26:01 - >>> Generated embeddings for final batch 2
2024-05-23 14:26:01 - Shape of embeddings: (44, 1024)


In [19]:
queries_df = pd.DataFrame({
    'query_id': queries_df['query_id'],
    'query': queries_df['query'],
    'query_embedding': query_embeddings_list
})
queries_df.shape

(300, 3)

In [20]:
from rank_bm25 import BM25Okapi

TEXT_CORPUS = embeddings_df['text'].values.tolist()
TOKENIZED_CORPUS = [tokenize.word_tokenize(doc) for doc in TEXT_CORPUS]

BM25_TC = BM25Okapi(TOKENIZED_CORPUS)

In [21]:
type(embeddings_df.sentence_embedding[0])

list

In [44]:

DOC_EMBEDDINGS = np.vstack(embeddings_df["embedding"].apply(lambda x: np.array(x)))


def find_similar_embeddings(
    df: pd.DataFrame,
    query: str, 
    query_embedding: List[float], 
    top_k: int=10, 
    alpha: float=0.7,
    similarity_threshold: float=0.5,
    method: str="cosine"
):
    if isinstance(query_embedding, list):
        query_embedding = np.array(query_embedding).reshape(1, -1)

    # Calculate cosine similarities
    similarities = cosine_similarity(query_embedding, DOC_EMBEDDINGS).flatten()
    
    if method == "hybrid":
        query_tokens = tokenize.word_tokenize(query)
        doc_scores = BM25_TC.get_scores(query_tokens)
        similarities = alpha * similarities + (1 - alpha) * doc_scores

    # Add similarities to DataFrame
    df["similarity"] = similarities

    # Filter based on similarity threshold
    results_df = (
        df[df["similarity"] > similarity_threshold]
        .sort_values(by="similarity", ascending=False)
        .head(top_k)
    )
    
    if method == "keyword_rerank":
        tokenized_corpus = [tokenize.word_tokenize(doc) for doc in results_df["text"].values.tolist()]
        bm25 = BM25Okapi(tokenized_corpus)
        query_tokens = tokenize.word_tokenize(query)
        
        doc_scores = bm25.get_scores(query_tokens)
        doc_scores_idx = np.argsort(doc_scores)[::-1]
        
        results_df = results_df.iloc[doc_scores_idx].reset_index(drop=True)
    elif method == "cross_encoder_rerank":
        doc_embeddings = np.vstack(results_df["embedding"].apply(lambda x: np.array(x)))
        cross_encoder_scores = cosine_similarity(query_embedding, doc_embeddings).flatten()
        
        doc_scores_idx = np.argsort(cross_encoder_scores)[::-1]
        results_df = results_df.iloc[doc_scores_idx].reset_index(drop=True)
        
        

    return (
        results_df["doc_id"].values.tolist(),
        results_df["similarity"].values.tolist(),
        results_df["text"].values.tolist(),
    )

## Hybrid Contextual + Keyword Retrieval

In [25]:
def hybrid_retrieval(df, query, query_embedding, top_k=10, alpha=0.5, similarity_threshold=0.5):
    # Tokenize the query
    tokenized_query = tokenize.word_tokenize(query)
    # query_embedding = generate_embeddings([query], device)
    if isinstance(query_embedding, list):
            query_embedding = np.array(query_embedding).reshape(1, -1)
    # Get the Cosine similarities
    cosine_similarities = cosine_similarity(query_embedding, DOC_EMBEDDINGS).flatten()

    # Get the BM25 scores
    bm25_scores = BM25_TC.get_scores(tokenized_query)
    # Normalize the BM25 scores
    bm25_scores = (bm25_scores - bm25_scores.min()) / (
        bm25_scores.max() - bm25_scores.min()
    )

    # Combine the scores
    hybrid_scores = alpha * cosine_similarities + (1 - alpha) * bm25_scores

    # Add the scores to the DataFrame
    df["hybrid_score"] = hybrid_scores
    
    # Filter based on similarity threshold
    temp_df = (
        df[df["hybrid_score"] > similarity_threshold]
        .sort_values(by="hybrid_score", ascending=False)
        .head(top_k)
    )

    return (
        temp_df["text"].values.tolist(),
        temp_df["doc_id"].values.tolist(),
        temp_df["hybrid_score"].values.tolist(),
    )

In [26]:
# for idx, row in results_df.iterrows():
#     doc_ids, sim_scores, similar_texts = find_similar_embeddings(embeddings_df, row['query_embedding'], top_k=k)
#     break

In [24]:
queries_df

1       0-dimensional biomaterials show inductive prop...
3       1,000 genomes project enables mapping of genet...
5              1/2000 in UK have abnormal PrP positivity.
13      5% of perinatal mortality is due to low birth ...
36      A deficiency of vitamin B12 increases blood le...
                              ...                        
1379    Women with a higher birth weight are more like...
1382    aPKCz causes tumour enhancement by affecting g...
1385     cSMAC formation enhances weak ligand signalling.
1389    mTORC2 regulates intracellular cysteine levels...
1395    p16INK4A accumulation is  linked to an abnorma...
Length: 300, dtype: object

## Retrieval of Relevant Documents

In [23]:
k = 50

# methods_list = ["hybrid", "cosine"]
methods_list = ["cross_encoder_rerank", "keyword_rerank"]


for method in methods_list:
    logging.info(f">>> Retrieving documents using {method} similarity <<<")
    retrieved_docs_list = []
    retrieved_text_list = []
    retrieved_sim_list = []
    query_ids = []
    for idx, row in tqdm(
        queries_df.iterrows(), total=len(queries_df), desc="Finding similar embeddings"
    ):
        doc_ids, sim_scores, similar_texts = find_similar_embeddings(
            embeddings_df,
            row["query"],
            row["query_embedding"],
            top_k=k,
            alpha=0.5,
            similarity_threshold=0.5,
            method=method,
        )
        retrieved_docs_list.append(doc_ids)
        retrieved_sim_list.append(sim_scores)
        retrieved_text_list.append(similar_texts)

    retrieval_df = pd.DataFrame(
        {
            "query_id": queries_df["query_id"],
            "query": queries_df["query"],
            f"retrieved_docs_top{k}": retrieved_docs_list,
            f"retrieved_texts_top{k}": retrieved_text_list,
            f"retrieved_sim_top{k}": retrieved_sim_list,
        }
    )

    save_path = os.path.join(data_path, method)
    if not os.path.exists(save_path):
        os.makedirs(save_path)

    retrieval_df.to_csv(os.path.join(save_path, "retrieval_results.csv"), index=False)

2024-05-23 14:26:23 - >>> Retrieving documents using cross_encoder_rerank similarity <<<


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for idx, row in tqdm(


Finding similar embeddings:   0%|          | 0/300 [00:00<?, ?it/s]

2024-05-23 14:34:30 - >>> Retrieving documents using keyword_rerank similarity <<<


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for idx, row in tqdm(


Finding similar embeddings:   0%|          | 0/300 [00:00<?, ?it/s]

In [69]:
retrieval_df = pd.read_csv(os.path.join(data_path, "retrieval_results.csv"))

retrieval_df[f"hybrid_retrieved_docs_top{k}"] = retrieved_docs_list
retrieval_df[f"hybrid_retrieved_texts_top{k}"] = retrieved_text_list
retrieval_df[f"hybrid_retrieved_sim_top{k}"] = retrieved_sim_list

In [71]:
retrieval_df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 300 entries, 0 to 299
Data columns (total 57 columns):
 #   Column                        Non-Null Count  Dtype  
---  ------                        --------------  -----  
 0   query_id                      300 non-null    int64  
 1   retrieved_docs_top10          300 non-null    object 
 2   retrieved_texts_top10         300 non-null    object 
 3   retrieved_sim_top10           300 non-null    object 
 4   relevant_docs                 300 non-null    object 
 5   accuracy_top10                300 non-null    float64
 6   precision_inner_top10         300 non-null    float64
 7   precision_ohe_top10           300 non-null    float64
 8   recall_ohe_top10              300 non-null    float64
 9   accuracy_top5                 300 non-null    float64
 10  precision_inner_top5          300 non-null    float64
 11  precision_ohe_top5            300 non-null    float64
 12  recall_ohe_top5               300 non-null    float64
 13  accur

In [72]:
retrieval_df.to_csv(os.path.join(data_path, "retrieval_results.csv"), index=False)

In [42]:
qrels_df = pd.DataFrame({
    'query_id': list(qrels.keys()),
    'relevant_docs': [list(docs.keys()) for docs in list(qrels.values())]
})

# Evaluation

In [74]:
retrieval_df = pd.read_csv(os.path.join(data_path, "retrieval_results.csv"))

In [75]:
retrieval_df['relevant_docs'] = retrieval_df['relevant_docs'].str.replace("'", "")

In [76]:
retrieval_df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 300 entries, 0 to 299
Data columns (total 57 columns):
 #   Column                        Non-Null Count  Dtype  
---  ------                        --------------  -----  
 0   query_id                      300 non-null    int64  
 1   retrieved_docs_top10          300 non-null    object 
 2   retrieved_texts_top10         300 non-null    object 
 3   retrieved_sim_top10           300 non-null    object 
 4   relevant_docs                 300 non-null    object 
 5   accuracy_top10                300 non-null    float64
 6   precision_inner_top10         300 non-null    float64
 7   precision_ohe_top10           300 non-null    float64
 8   recall_ohe_top10              300 non-null    float64
 9   accuracy_top5                 300 non-null    float64
 10  precision_inner_top5          300 non-null    float64
 11  precision_ohe_top5            300 non-null    float64
 12  recall_ohe_top5               300 non-null    float64
 13  accur

In [47]:
DOCS_IN_CORPUS = list(corpus_df.index)
DOCS_IN_CORPUS

['4983',
 '5836',
 '7912',
 '18670',
 '19238',
 '33370',
 '36474',
 '54440',
 '70115',
 '70490',
 '72159',
 '79447',
 '87758',
 '92308',
 '92499',
 '97884',
 '102662',
 '103007',
 '104130',
 '106301',
 '116792',
 '118568',
 '120626',
 '123859',
 '140874',
 '143251',
 '152245',
 '153744',
 '159469',
 '164189',
 '164985',
 '169264',
 '175735',
 '188911',
 '195352',
 '202259',
 '207972',
 '213017',
 '219475',
 '226488',
 '236204',
 '238409',
 '243694',
 '253672',
 '263364',
 '266641',
 '275294',
 '279052',
 '285794',
 '293661',
 '301838',
 '301866',
 '306006',
 '306311',
 '308862',
 '313394',
 '313403',
 '317204',
 '323030',
 '323335',
 '327319',
 '335029',
 '341324',
 '343052',
 '344240',
 '350542',
 '356218',
 '364522',
 '365896',
 '368506',
 '371289',
 '374902',
 '380526',
 '381602',
 '393001',
 '406733',
 '409280',
 '410286',
 '418246',
 '427082',
 '427865',
 '432261',
 '435529',
 '437924',
 '439670',
 '456304',
 '457630',
 '461550',
 '463309',
 '463533',
 '464511',
 '469066',
 '47062

In [63]:
corpus_df.to_csv(os.path.join(data_path, "corpus.csv"), index=False)
queries_df.to_csv(os.path.join(data_path, "queries.csv"), index=False)
qrels_df.to_csv(os.path.join(data_path, "qrels.csv"), index=False)

In [48]:
qrels

{'1': {'31715818': 1},
 '3': {'14717500': 1},
 '5': {'13734012': 1},
 '13': {'1606628': 1},
 '36': {'5152028': 1, '11705328': 1},
 '42': {'18174210': 1},
 '48': {'13734012': 1},
 '49': {'5953485': 1},
 '50': {'12580014': 1},
 '51': {'45638119': 1},
 '53': {'45638119': 1},
 '54': {'49556906': 1},
 '56': {'4709641': 1},
 '57': {'4709641': 1},
 '70': {'5956380': 1, '4414547': 1},
 '72': {'6076903': 1},
 '75': {'4387784': 1},
 '94': {'1215116': 1},
 '99': {'18810195': 1},
 '100': {'4381486': 1},
 '113': {'6157837': 1},
 '115': {'33872649': 1},
 '118': {'6372244': 1},
 '124': {'4883040': 1},
 '127': {'21598000': 1},
 '128': {'8290953': 1},
 '129': {'27768226': 1},
 '130': {'27768226': 1},
 '132': {'7975937': 1},
 '133': {'38485364': 1,
  '6969753': 1,
  '17934082': 1,
  '16280642': 1,
  '12640810': 1},
 '137': {'26016929': 1},
 '141': {'6955746': 1, '14437255': 1},
 '142': {'10582939': 1},
 '143': {'10582939': 1},
 '146': {'10582939': 1},
 '148': {'1084345': 1},
 '163': {'18872233': 1},
 '1

In [49]:
from sklearn.metrics import precision_score, recall_score

def compute_accuracy(retrieved_docs, relevant_docs, top_k=10):
    if isinstance(retrieved_docs, str):
        retrieved_docs = retrieved_docs[1:-1].split(", ")
    retrieved_docs = retrieved_docs[:top_k]
    
    if isinstance(relevant_docs, str):
        relevant_docs = relevant_docs[1:-1].split(", ")
    relevant_docs = relevant_docs[:top_k]
    
    retrieved_docs_set = set(retrieved_docs)
    relevant_docs_set = set(relevant_docs)
    logging.info(f"Retrieved documents: {retrieved_docs}")
    logging.info(f"Relevant documents: {relevant_docs}")
    common_docs_set = retrieved_docs_set.intersection(relevant_docs_set)
    logging.info(f"Common documents: {common_docs_set}")
    # Calculate accuracy as the number of common documents divided by the total number of relevant documents
    intersect_accuracy = len(common_docs_set) / len(relevant_docs_set)
    precision_inner = len(common_docs_set) / len(retrieved_docs_set) # TP / (TP + FP)
    
    # Get the index of the relevant documents in the corpus
    relevant_docs_ohe_idx = [DOCS_IN_CORPUS.index(doc) for doc in relevant_docs]
    relevant_docs_ohe = np.zeros(len(DOCS_IN_CORPUS))
    relevant_docs_ohe[relevant_docs_ohe_idx] = 1
        
    # Get the index of the retrieved documents in the corpus
    retrieved_docs_ohe_idx = [DOCS_IN_CORPUS.index(doc) for doc in retrieved_docs]
    retrieved_docs_ohe = np.zeros(len(DOCS_IN_CORPUS))
    retrieved_docs_ohe[retrieved_docs_ohe_idx] = 1
    
    # Calculate precision, recall
    precision_ohe = precision_score(relevant_docs_ohe, retrieved_docs_ohe)
    recall_ohe = recall_score(relevant_docs_ohe, retrieved_docs_ohe)
     
    
    return intersect_accuracy, precision_inner, precision_ohe, recall_ohe

In [50]:
retrieval_df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 300 entries, 0 to 299
Data columns (total 6 columns):
 #   Column                 Non-Null Count  Dtype 
---  ------                 --------------  ----- 
 0   query_id               300 non-null    object
 1   query                  300 non-null    object
 2   retrieved_docs_top50   300 non-null    object
 3   retrieved_texts_top50  300 non-null    object
 4   retrieved_sim_top50    300 non-null    object
 5   relevant_docs          300 non-null    object
dtypes: object(6)
memory usage: 14.2+ KB


In [52]:
METHODS_LIST = ["cosine", "hybrid", "cross_encoder_rerank", "keyword_rerank"]
k_list = [1, 3, 5, 10, 20]

for item in os.listdir(data_path):
    if item in METHODS_LIST:
        retrieval_df = pd.read_csv(os.path.join(data_path, item, "retrieval_results.csv"))
        retrieval_df['query_id'] = retrieval_df['query_id'].astype(str)
        retrieval_df = retrieval_df.merge(qrels_df, on="query_id")
        print(retrieval_df.columns)
        
        for k in k_list:
            retrieval_df[f"accuracy_top{k}"], retrieval_df[f"precision_inner_top{k}"], retrieval_df[f"precision_ohe_top{k}"], retrieval_df[f"recall_ohe_top{k}"] = retrieval_df.apply(
                    lambda x: compute_accuracy(x[f"retrieved_docs_top50"], x["relevant_docs"], top_k=k),
                    axis=1, result_type="expand"
                )
        retrieval_df.to_csv(os.path.join(data_path, item, f"retrieval_results.csv"), index=False)

Index(['query_id', 'query', 'retrieved_docs_top50', 'retrieved_texts_top50',
       'retrieved_sim_top50', 'relevant_docs'],
      dtype='object')
2024-05-23 15:31:22 - Retrieved documents: ['4346436']
2024-05-23 15:31:22 - Relevant documents: ['31715818']
2024-05-23 15:31:22 - Common documents: set()
2024-05-23 15:31:22 - Retrieved documents: ['4414547']
2024-05-23 15:31:22 - Relevant documents: ['14717500']
2024-05-23 15:31:22 - Common documents: set()
2024-05-23 15:31:22 - Retrieved documents: ['76415938']
2024-05-23 15:31:22 - Relevant documents: ['13734012']
2024-05-23 15:31:22 - Common documents: set()
2024-05-23 15:31:22 - Retrieved documents: ['4791384']
2024-05-23 15:31:22 - Relevant documents: ['1606628']
2024-05-23 15:31:22 - Common documents: set()
2024-05-23 15:31:22 - Retrieved documents: ['3215494']
2024-05-23 15:31:22 - Relevant documents: ['5152028']
2024-05-23 15:31:22 - Common documents: set()
2024-05-23 15:31:22 - Retrieved documents: ['18174210']
2024-05-23 15:31:2

In [54]:
retrieval_df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 300 entries, 0 to 299
Data columns (total 26 columns):
 #   Column                 Non-Null Count  Dtype 
---  ------                 --------------  ----- 
 0   query_id               300 non-null    object
 1   query                  300 non-null    object
 2   retrieved_docs_top50   300 non-null    object
 3   retrieved_texts_top50  300 non-null    object
 4   retrieved_sim_top50    300 non-null    object
 5   relevant_docs          300 non-null    object
 6   accuracy_top1          300 non-null    int64 
 7   precision_inner_top1   300 non-null    int64 
 8   precision_ohe_top1     300 non-null    int64 
 9   recall_ohe_top1        300 non-null    int64 
 10  accuracy_top3          300 non-null    int64 
 11  precision_inner_top3   300 non-null    int64 
 12  precision_ohe_top3     300 non-null    int64 
 13  recall_ohe_top3        300 non-null    int64 
 14  accuracy_top5          300 non-null    int64 
 15  precision_inner_top5   

In [59]:
def compute_metrics(retrieval_df, top_k=10, reranked=False):
    # Compute the mean accuracy, precision, and recall
    accuracy_ft = f"accuracy_top{top_k}"
    precision_inner_ft = f"precision_inner_top{top_k}"
    precision_ohe_ft = f"precision_ohe_top{top_k}"
    recall_ohe_ft = f"recall_ohe_top{top_k}"
    
    accuracy = retrieval_df[accuracy_ft].mean()
    precision_inner = retrieval_df[precision_inner_ft].mean()
    precision_ohe = retrieval_df[precision_ohe_ft].mean()
    recall_ohe = retrieval_df[recall_ohe_ft].mean()
    
    temp_df = pd.DataFrame({
        'method': [f"top{top_k}"],
        'accuracy': [accuracy],
        'precision_inner': [precision_inner],
        'precision_ohe': [precision_ohe],
        'recall_ohe': [recall_ohe]
    })
    
    return temp_df

In [60]:
METHODS_LIST = ["cosine", "hybrid", "cross_encoder_rerank", "keyword_rerank"]
k_list = [1, 3, 5, 10, 20]

for item in os.listdir(data_path):
    if item in METHODS_LIST:
        retrieval_df = pd.read_csv(os.path.join(data_path, item, "retrieval_results.csv"))
        metrics_df = pd.DataFrame({'method': [], 'accuracy': [], 'precision_inner': [], 'precision_ohe': [], 'recall_ohe': []})
        
        for k in k_list:
            temp_df = compute_metrics(retrieval_df, top_k=k)
            metrics_df = pd.concat([metrics_df, temp_df], axis=0)
        metrics_df.to_csv(os.path.join(data_path, item, "metrics.csv"), index=False)

In [61]:
retrieval_df

Unnamed: 0,query_id,query,retrieved_docs_top50,retrieved_texts_top50,retrieved_sim_top50,relevant_docs,accuracy_top1,precision_inner_top1,precision_ohe_top1,recall_ohe_top1,...,precision_ohe_top5,recall_ohe_top5,accuracy_top10,precision_inner_top10,precision_ohe_top10,recall_ohe_top10,accuracy_top20,precision_inner_top20,precision_ohe_top20,recall_ohe_top20
0,1,0-dimensional biomaterials show inductive prop...,"[4346436, 58050905, 15337254, 17388232, 758310...","['Unlike most synthetic materials, biological ...","[0.9847771606719977, 0.98458850008293, 0.98433...",['31715818'],0,1,2,3,...,2,3,0,1,2,3,0,1,2,3
1,3,"1,000 genomes project enables mapping of genet...","[4414547, 23389795, 1388704, 19058822, 3662132...","['More generally, these data provide new insig...","[0.9881532193663813, 0.98559633471383, 0.98443...",['14717500'],0,1,2,3,...,2,3,0,1,2,3,0,1,2,3
2,5,1/2000 in UK have abnormal PrP positivity.,"[76415938, 9764256, 13734012, 9813098, 1373401...",['Eighty-six percent (102 of 123) of the patie...,"[0.9821301396917128, 0.9817294730356083, 0.981...",['13734012'],0,1,2,3,...,2,3,0,1,2,3,0,1,2,3
3,13,5% of perinatal mortality is due to low birth ...,"[4791384, 27099731, 1263446, 33257464, 8529693...","[""BACKGROUND Historically, the main focus of s...","[0.9849229915319799, 0.9842265099382841, 0.984...",['1606628'],0,1,2,3,...,2,3,0,1,2,3,0,1,2,3
4,36,A deficiency of vitamin B12 increases blood le...,"[3215494, 18557974, 37424881, 18256197, 105574...",['Hyperhomocysteinemia has recently been ident...,"[0.9864817966627986, 0.986375963990189, 0.9858...","['5152028', '11705328']",0,1,2,3,...,2,3,0,1,2,3,0,1,2,3
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
295,1379,Women with a higher birth weight are more like...,"[27123743, 17450673, 37480103, 16322674, 16322...",['Breast cancer may originate in utero. We rev...,"[0.9909429245296438, 0.9900861242796535, 0.989...","['16322674', '27123743', '23557241', '17450673']",0,1,2,3,...,2,3,0,1,2,3,0,1,2,3
296,1382,aPKCz causes tumour enhancement by affecting g...,"[3831884, 27647593, 17755060, 4959368, 293661,...",['Cancer cells have metabolic dependencies tha...,"[0.9879058433332039, 0.9877483491986919, 0.987...",['17755060'],0,1,2,3,...,2,3,0,1,2,3,0,1,2,3
297,1385,cSMAC formation enhances weak ligand signalling.,"[306006, 9283422, 11195653, 10906636, 4417558,...",['T cell activation is predicated on the inter...,"[0.9878927780087177, 0.9866585879440809, 0.984...",['306006'],0,1,2,3,...,2,3,0,1,2,3,0,1,2,3
298,1389,mTORC2 regulates intracellular cysteine levels...,"[23895668, 8460275, 17195001, 1344498, 1487481...",['Mutations in cancer reprogram amino acid met...,"[0.9907232444856812, 0.9861914732832167, 0.984...",['23895668'],0,1,2,3,...,2,3,0,1,2,3,0,1,2,3


In [92]:
results_df = pd.DataFrame({
    'method': [],
    'accuracy': [],
    'precision_inner': [],
    'precision_ohe': [],
    'recall_ohe': []
})

In [93]:
k_list = [3, 5, 10, 20]
reranked_list = [False, True]

for k in k_list:
    for reranked in reranked_list:
        temp_df = compute_metrics(retrieval_df, top_k=k, reranked=reranked)
        results_df = pd.concat([results_df, temp_df], axis=0, ignore_index=True)
    

In [98]:
k_list = [3, 5, 10, 20]

for k in k_list:
    accuracy_ft = f"hybrid_accuracy_top{k}"
    precision_inner_ft = f"hybrid_precision_inner_top{k}"
    precision_ohe_ft = f"hybrid_precision_ohe_top{k}"
    recall_ohe_ft = f"hybrid_recall_ohe_top{k}"
    
    accuracy = retrieval_df[accuracy_ft].mean()
    precision_inner = retrieval_df[precision_inner_ft].mean()
    precision_ohe = retrieval_df[precision_ohe_ft].mean()
    recall_ohe = retrieval_df[recall_ohe_ft].mean()
    
    temp_df = pd.DataFrame({
        'method': [f"hybrid_top{k}"],
        'accuracy': [accuracy],
        'precision_inner': [precision_inner],
        'precision_ohe': [precision_ohe],
        'recall_ohe': [recall_ohe]
    })
    
results_df = pd.concat([results_df, temp_df], axis=0, ignore_index=True)

In [99]:
results_df

Unnamed: 0,method,accuracy,precision_inner,precision_ohe,recall_ohe
0,top3,0.685556,0.274444,0.274444,0.685556
1,top3_rerank,0.685556,0.274444,0.274444,0.685556
2,top5,0.727889,0.182556,0.182556,0.727889
3,top5_rerank,0.727889,0.182556,0.182556,0.727889
4,top10,0.811222,0.102053,0.102053,0.811222
5,top10_rerank,0.811222,0.102053,0.102053,0.811222
6,top20,0.811222,0.102053,0.102053,0.811222
7,top20_rerank,0.811222,0.102053,0.102053,0.811222
8,hybrid_top20,0.7575,0.042954,0.042954,0.7575
9,hybrid_top20,0.7575,0.042954,0.042954,0.7575


In [100]:
results_df.to_csv(os.path.join(data_path, "results.csv"), index=False)

In [39]:
retrieval_df.to_csv(os.path.join(data_path, "retrieval_results.csv"), index=False)

In [40]:
# Apply exact keyword matching using the BM25 algorithm
from rank_bm25 import BM25Okapi

In [41]:
corpus_text = [doc['text'] for doc in list(corpus.values())]
tokenized_corpus = [tokenize.word_tokenize(doc) for doc in corpus_text]

bm25 = BM25Okapi(tokenized_corpus)

In [42]:
query = retrieval_df['query'][0]
query

'0-dimensional biomaterials show inductive properties.'

In [43]:
tokenized_query = tokenize.word_tokenize(query)

doc_scores = bm25.get_scores(tokenized_query)
doc_scores.shape

(5183,)

In [44]:
eval(retrieval_df['retrieved_texts_top10'].values[0])

['Unlike most synthetic materials, biological materials often stiffen as they are deformed. This nonlinear elastic response, critical for the physiological function of some tissues, has been documented since at least the 19th century, but the molecular structure and the design principles responsible for it are unknown. Current models for this response require geometrically complex ordered structures unique to each material. In this Article we show that a much simpler molecular theory accounts for strain stiffening in a wide range of molecularly distinct biopolymer gels formed from purified cytoskeletal and extracellular proteins. This theory shows that systems of semi-flexible chains such as filamentous proteins arranged in an open crosslinked meshwork invariably stiffen at low strains without the need for a specific architecture or multiple elements with different intrinsic stiffnesses.',
 'The World Health Organisation has declared the period 2000 to 2010 the Bone and Joint Decade. T

In [1]:
from sentence_transformers import SentenceTransformer, util


def sentence_similarity(text1, text2):
    embedding_1 = model.encode(text1, convert_to_tensor=True)
    embedding_2 = model.encode(text2, convert_to_tensor=True)
    o = util.pytorch_cos_sim(embedding_1, embedding_2)
    return o.item()


model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")

In [4]:
text1 = "Unlike most synthetic materials, biological materials often stiffen as they are deformed. This nonlinear elastic response, critical for the physiological function of some tissues, has been documented since at least the 19th century, but the molecular structure and the design principles responsible for it are unknown. Current models for this response require geometrically complex ordered structures unique to each material. In this Article we show that a much simpler molecular theory accounts for strain stiffening in a wide range of molecularly distinct biopolymer gels formed from purified cytoskeletal and extracellular proteins."
text2 = "Significant efforts have been directed to understanding the factors that influence the lineage commitment of stem cells. This paper demonstrates that cell shape, independent of soluble factors, has a strong influence on the differentiation of human mesenchymal stem cells (MSCs) from bone marrow. When exposed to competing soluble differentiation signals, cells cultured in rectangles with increasing aspect ratio and in shapes with pentagonal symmetry but with different subcellular curvature-and with each occupying the same area-display different adipogenesis and osteogenesis profiles. The results reveal that geometric features that increase actomyosin contractility promote osteogenesis and are consistent with in vivo characteristics of the microenvironment of the differentiated cells."
sentence_similarity(text1, text2)

0.3817015290260315

In [86]:
def rerank_docs(query, corpus_doc_ids, corpus_docs):
    tokenized_corpus = [tokaenize.word_tokenize(doc) for doc in corpus_docs]
    bm25 = BM25Okapi(tokenized_corpus)

    tokenized_query = tokenize.word_tokenize(query)
    doc_scores = bm25.get_scores(tokenized_query)
    doc_scores_idx = np.argsort(doc_scores)[::-1]
    
    reranked_doc_ids = [corpus_doc_ids[idx] for idx in doc_scores_idx]
    reranked_docs = [corpus_docs[idx] for idx in doc_scores_idx]
    
    return reranked_doc_ids, reranked_docs

In [87]:
retrieval_df[['reranked_docs_top20', 'reranked_text_top20']] = retrieval_df.apply(lambda x: rerank_docs(x['query'], eval(x['retrieved_docs_top20']), eval(x['retrieved_texts_top20'])), axis=1, result_type="expand")

In [88]:
for k in k_list:
    retrieval_df[[f"accuracy_top{k}_rerank", f"precision_inner_top{k}_rerank", f"precision_ohe_top{k}_rerank", f"recall_ohe_top{k}_rerank"]] = retrieval_df.apply(lambda x: compute_accuracy(x[f"reranked_docs_top20"], x['relevant_docs'], top_k=k), axis=1, result_type="expand")

2024-05-23 09:52:41 - Retrieved documents: [17388232, 37437064, 11172205]
2024-05-23 09:52:41 - Relevant documents: ['31715818']
2024-05-23 09:52:41 - Common documents: set()


ValueError: 17388232 is not in list

In [100]:
k = 20
reranked = True

temp_df = compute_metrics(retrieval_df, top_k=k, reranked=reranked)
results_df = pd.concat([results_df, temp_df], axis=0, ignore_index=True)

In [101]:
results_df

Unnamed: 0,method,accuracy,precision_inner,precision_ohe,recall_ohe
0,top3,0.685556,0.274444,0.274444,0.685556
1,top5,0.727889,0.182556,0.182556,0.727889
2,top10,0.811222,0.102053,0.102053,0.811222
3,top20,0.811222,0.102053,0.102053,0.811222
4,top20_rerank,0.811222,0.102053,0.102053,0.811222
