In [1]:
import transformers
from datasets import load_dataset
import pandas as pd
import numpy as np
import torch
import time
from tqdm import tqdm

In [2]:
from beir import util, LoggingHandler
from beir.datasets.data_loader import GenericDataLoader
import logging
import pathlib, os

#### Just some code to print debug information to stdout
logging.basicConfig(format='%(asctime)s - %(message)s',
                    datefmt='%Y-%m-%d %H:%M:%S',
                    level=logging.INFO,
                    handlers=[LoggingHandler()])
#### /print debug information to stdout

In [3]:
import pathlib, os
from beir import util

dataset = "scifact"
url = "https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{}.zip".format(dataset)
out_dir = os.path.join(os.getcwd(), "datasets")
data_path = util.download_and_unzip(url, out_dir)
print("Dataset downloaded here: {}".format(data_path))

Dataset downloaded here: /home/toghrul/ada/ml/final/datasets/scifact


In [4]:
!ls datasets/scifact/

corpus.jsonl  embeddings.csv  qrels  queries.jsonl  retrieval_results.csv


In [5]:
corpus, queries, qrels = GenericDataLoader(data_folder=data_path).load(split="test")

2024-05-23 04:42:55 - Loading Corpus...


  0%|          | 0/5183 [00:00<?, ?it/s]

2024-05-23 04:42:55 - Loaded 5183 TEST Documents.
2024-05-23 04:42:55 - Doc Example: {'text': 'Alterations of the architecture of cerebral white matter in the developing human brain can affect cortical development and result in functional disabilities. A line scan diffusion-weighted magnetic resonance imaging (MRI) sequence with diffusion tensor analysis was applied to measure the apparent diffusion coefficient, to calculate relative anisotropy, and to delineate three-dimensional fiber architecture in cerebral white matter in preterm (n = 17) and full-term infants (n = 7). To assess effects of prematurity on cerebral white matter development, early gestation preterm infants (n = 10) were studied a second time at term. In the central white matter the mean apparent diffusion coefficient at 28 wk was high, 1.8 microm2/ms, and decreased toward term to 1.2 microm2/ms. In the posterior limb of the internal capsule, the mean apparent diffusion coefficients at both times were similar (1.2 vers

In [6]:
corpus_idx = list(corpus.keys())
corpus_vals = list(corpus.values())

corpus_df = pd.DataFrame(corpus_vals, index=corpus_idx)

In [7]:
corpus_df

Unnamed: 0,text,title
4983,Alterations of the architecture of cerebral wh...,Microstructural development of human newborn c...
5836,Myelodysplastic syndromes (MDS) are age-depend...,Induction of myelodysplasia by myeloid-derived...
7912,ID elements are short interspersed elements (S...,"BC1 RNA, the transcript from a master gene for..."
18670,DNA methylation plays an important role in bio...,The DNA Methylome of Human Peripheral Blood Mo...
19238,Two human Golli (for gene expressed in the oli...,The human myelin basic protein gene is include...
...,...,...
195689316,BACKGROUND The main associations of body-mass ...,Body-mass index and cause-specific mortality i...
195689757,A key aberrant biological difference between t...,Targeting metabolic remodeling in glioblastoma...
196664003,A signaling pathway transmits information from...,Signaling architectures that transmit unidirec...
198133135,AIMS Trabecular bone score (TBS) is a surrogat...,"Association between pre-diabetes, type 2 diabe..."


In [8]:
queries_df = pd.Series(queries)

In [9]:
from typing import List
import logging
from pydantic import BaseModel

# from rag import insert_document_and_embeddings, find_similar_embeddings, preprocess
from datetime import datetime
import re
from nltk import tokenize
import unicodedata
import string
import logging
import torch
from transformers import AutoTokenizer, AutoModel
from sklearn.metrics.pairwise import cosine_similarity

MAX_WORD_COUNT = 256
MAX_TOKEN_COUNT = 512


model_name = "mixedbread-ai/mxbai-embed-large-v1"

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)


def encoding(text: str) -> str:
    """
    Remove unicoded data
    """
    text = unicodedata.normalize("NFKD", text)

    return text


def remove_URL(text: str) -> str:
    """
    Remove URLs
    """
    return re.sub(r"https?://\S+|www\.\S+", "", text)


def remove_non_ascii(text: str) -> str:
    """
    Remove non-ASCII characters
    """
    return re.sub(r"[^\x00-\x7f]", r"", text)


def remove_html(text: str) -> str:
    """
    Remove the html
    """
    html = re.compile(r"<.*?>|&([a-z0-9]+|#[0-9]{1,6}|#x[0-9a-f]{1,6});")
    return re.sub(html, "", text)


def remove_punct(text: str) -> str:
    """
    Remove the punctuation
    """
    #     return re.sub(r'[]!"$%&\'()*+,./:;=#@?[\\^_`{|}~-]+', "", text)
    return text.translate(str.maketrans("", "", string.punctuation))


def preprocess(text: str) -> str:
    """
    Preprocess the text
    """

    text = encoding(text)
    text = remove_URL(text)
    text = remove_non_ascii(text)
    text = remove_html(text)
    # text = remove_punct(text)
    return text


def ingest_input(user_input):
    user_input = preprocess(user_input)
    logging.info(f"Preprocessed user input")

    # Generate sentence tokens
    sentence_tokens = tokenize.sent_tokenize(user_input)
    model_input = []
    temp_input: str = ""
    if len(user_input.split(" ")) > MAX_WORD_COUNT:
        logging.info(
            f"Input contains more than {MAX_WORD_COUNT} words. Splitting the input into chunks"
        )

        # Split the input into chunks based on the sentence tokens
        for i, sent in enumerate(sentence_tokens):
            num_words_sent = len(sent.split(" "))

            # Check if the new chunk would exceed the maximum word count
            if len(temp_input.split(" ")) + num_words_sent > MAX_WORD_COUNT:

                # Append the chunk to the model input
                model_input.append(temp_input.strip())
                logging.info(
                    f"Number of words in the chunk: {len(temp_input.split(' '))}"
                )
                temp_input = sent
            else:
                temp_input += " " + sent

        # Append the last chunk to the model input
        model_input.append(temp_input)
    else:
        model_input = [user_input]

    return model_input


def read_pdf_doc(filepath):
    doc = fitz.open(filepath)
    text = ""
    for page_index, page in enumerate(doc):
        logging.info(f"page {page_index+1} out of {len(doc)}")
        tp = page.get_textpage()
        words = tp.extractWORDS()

        page_text = " ".join([word[4] for word in words])
        text += page_text + " "
    doc.close()
    return text


def generate_embeddings(text, device):
    inputs = tokenizer(
        text,
        return_tensors="pt",
        truncation=True,
        max_length=MAX_TOKEN_COUNT,
        padding="max_length",
    )
    with torch.no_grad():
        inputs = inputs.to(device)
        outputs = model(**inputs)

    outputs = outputs.last_hidden_state.mean(dim=1).cpu().numpy()
    # Scale the embeddings to be between 0 and 1
    outputs = (outputs - outputs.min()) / (outputs.max() - outputs.min())
    return outputs


# Embedding Generation

In [2]:
text_chunks_list = []
embeddings_list = []
doc_ids = []

batch_size = 128
batch_no = 0
text_batch = []
ids_batch = []

for idx, doc in tqdm(corpus_df.iterrows(), total=len(corpus_df), desc="Processing documents"):
    text_chunks = ingest_input(doc['text'])

    start = time.time()
    text_batch.extend(text_chunks)
    ids_batch.extend([idx] * len(text_chunks))

    if len(text_batch) == batch_size:
        embeddings = generate_embeddings(text_batch, device)
        embeddings_list.append(embeddings)
        text_chunks_list.extend(text_batch)
        doc_ids.extend(ids_batch)

        logging.info(f">>> Generated embeddings for batch {batch_no}")
        logging.info(f"Shape of embeddings: {embeddings.shape}")
        logging.info(f"Time taken for the batch: {time.time() - start}")

        text_batch = []
        ids_batch = []
        batch_no += 1
        
# Handle any remaining batches
if len(text_batch) > 0:
    embeddings = generate_embeddings(text_batch, device)
    embeddings_list.append(embeddings)
    text_chunks_list.extend(text_batch)
    doc_ids.extend(ids_batch)

    logging.info(f">>> Generated embeddings for final batch {batch_no}")
    logging.info(f"Shape of embeddings: {embeddings.shape}")
    logging.info(f"Time taken for the batch: {time.time() - start}")

embeddings_list = np.concatenate(embeddings_list, axis=0)
doc_ids = np.array(doc_ids)


NameError: name 'tqdm' is not defined

In [None]:
# Convert the embeddings and doc_ids to a DataFrame
embeddings_df = pd.DataFrame({
    'text': text_chunks_list,
    'embedding': embeddings_list.tolist(),
    'doc_id': doc_ids
})


In [None]:
embeddings_df.to_csv(os.path.join(data_path, "embeddings.csv"), index=False)

# Retrieval

In [10]:
from ast import literal_eval

In [11]:
embeddings_df = pd.read_csv(os.path.join(data_path, "embeddings.csv"))
# embeddings_df.loc[:, "embedding"] = embeddings_df["embedding"].apply(literal_eval)

In [12]:
embeddings_df["embedding"] = embeddings_df["embedding"].apply(lambda x: x[1:-1].split(', '))

In [13]:
embeddings_df.dtypes

text         object
embedding    object
doc_id        int64
dtype: object

In [34]:
qrels

{'1': {'31715818': 1},
 '3': {'14717500': 1},
 '5': {'13734012': 1},
 '13': {'1606628': 1},
 '36': {'5152028': 1, '11705328': 1},
 '42': {'18174210': 1},
 '48': {'13734012': 1},
 '49': {'5953485': 1},
 '50': {'12580014': 1},
 '51': {'45638119': 1},
 '53': {'45638119': 1},
 '54': {'49556906': 1},
 '56': {'4709641': 1},
 '57': {'4709641': 1},
 '70': {'5956380': 1, '4414547': 1},
 '72': {'6076903': 1},
 '75': {'4387784': 1},
 '94': {'1215116': 1},
 '99': {'18810195': 1},
 '100': {'4381486': 1},
 '113': {'6157837': 1},
 '115': {'33872649': 1},
 '118': {'6372244': 1},
 '124': {'4883040': 1},
 '127': {'21598000': 1},
 '128': {'8290953': 1},
 '129': {'27768226': 1},
 '130': {'27768226': 1},
 '132': {'7975937': 1},
 '133': {'38485364': 1,
  '6969753': 1,
  '17934082': 1,
  '16280642': 1,
  '12640810': 1},
 '137': {'26016929': 1},
 '141': {'6955746': 1, '14437255': 1},
 '142': {'10582939': 1},
 '143': {'10582939': 1},
 '146': {'10582939': 1},
 '148': {'1084345': 1},
 '163': {'18872233': 1},
 '1

In [35]:
queries_df = pd.DataFrame({
    'query_id': list(queries.keys()),
    'query': list(queries.values())
})


In [36]:
queries_df

Unnamed: 0,query_id,query
0,1,0-dimensional biomaterials show inductive prop...
1,3,"1,000 genomes project enables mapping of genet..."
2,5,1/2000 in UK have abnormal PrP positivity.
3,13,5% of perinatal mortality is due to low birth ...
4,36,A deficiency of vitamin B12 increases blood le...
...,...,...
295,1379,Women with a higher birth weight are more like...
296,1382,aPKCz causes tumour enhancement by affecting g...
297,1385,cSMAC formation enhances weak ligand signalling.
298,1389,mTORC2 regulates intracellular cysteine levels...


In [40]:
retrieved_docs_list = []
retrieved_text_list = []
query_ids = []
batch_size = 128
k=10
query_batch = []
query_embeddings_list = []
query_ids = []
batch_no = 0

for idx, query in tqdm(queries_df.iterrows(), total=len(queries_df), desc="Processing queries"):
    query_batch.append(ingest_input(query['query'])[0])
    
    if len(query_batch) == batch_size:
        query_embeddings = generate_embeddings(query_batch, device)
        query_embeddings_list.extend(query_embeddings.tolist())
        # query_ids.extend([query['query_id']] * batch_size)
        
        query_batch = []
        logging.info(f">>> Generated embeddings for batch {batch_no}")
        logging.info(f"Shape of embeddings: {query_embeddings.shape}")
        batch_no += 1
        
# Handle any remaining queries
if len(query_batch) > 0:
    query_embeddings = generate_embeddings(query_batch, device)
    query_embeddings_list.extend(query_embeddings.tolist())
    # query_ids.extend([query['query_id']] * len(query_batch))
    
    logging.info(f">>> Generated embeddings for final batch {batch_no}")
    logging.info(f"Shape of embeddings: {query_embeddings.shape}")


# Create a DataFrame to save the results
results_df = pd.DataFrame({
    'query_id': query_ids,
    f'retrieved_docs_top{k}': retrieved_docs_list,
    f'retrieved_texts_top{k}': retrieved_text_list
    # 'retrieved_texts': retrieved_text_list
})

                                                                     

                                                                     

2024-05-23 05:01:31 - Preprocessed user input
2024-05-23 05:01:31 - Preprocessed user input
2024-05-23 05:01:31 - Preprocessed user input
2024-05-23 05:01:31 - Preprocessed user input
2024-05-23 05:01:31 - Preprocessed user input
2024-05-23 05:01:31 - Preprocessed user input
2024-05-23 05:01:31 - Preprocessed user input
2024-05-23 05:01:31 - Preprocessed user input
2024-05-23 05:01:31 - Preprocessed user input
2024-05-23 05:01:31 - Preprocessed user input
2024-05-23 05:01:31 - Preprocessed user input
2024-05-23 05:01:31 - Preprocessed user input
2024-05-23 05:01:31 - Preprocessed user input
2024-05-23 05:01:31 - Preprocessed user input
2024-05-23 05:01:31 - Preprocessed user input
2024-05-23 05:01:31 - Preprocessed user input
2024-05-23 05:01:31 - Preprocessed user input
2024-05-23 05:01:31 - Preprocessed user input
2024-05-23 05:01:31 - Preprocessed user input
2024-05-23 05:01:31 - Preprocessed user input
2024-05-23 05:01:31 - Preprocessed user input
2024-05-23 05:01:31 - Preprocessed

                                                                     

2024-05-23 05:01:31 - Preprocessed user input
2024-05-23 05:01:31 - Preprocessed user input
2024-05-23 05:01:31 - Preprocessed user input
2024-05-23 05:01:31 - Preprocessed user input
2024-05-23 05:01:31 - Preprocessed user input
2024-05-23 05:01:31 - Preprocessed user input
2024-05-23 05:01:31 - Preprocessed user input
2024-05-23 05:01:31 - Preprocessed user input
2024-05-23 05:01:31 - Preprocessed user input
2024-05-23 05:01:31 - Preprocessed user input
2024-05-23 05:01:31 - Preprocessed user input
2024-05-23 05:01:31 - Preprocessed user input
2024-05-23 05:01:31 - Preprocessed user input
2024-05-23 05:01:31 - Preprocessed user input
2024-05-23 05:01:31 - Preprocessed user input
2024-05-23 05:01:31 - Preprocessed user input
2024-05-23 05:01:31 - Preprocessed user input
2024-05-23 05:01:31 - Preprocessed user input
2024-05-23 05:01:31 - Preprocessed user input
2024-05-23 05:01:31 - Preprocessed user input
2024-05-23 05:01:31 - Preprocessed user input
2024-05-23 05:01:31 - Preprocessed

                                                                      

2024-05-23 05:01:31 - Preprocessed user input
2024-05-23 05:01:31 - Preprocessed user input
2024-05-23 05:01:31 - Preprocessed user input
2024-05-23 05:01:31 - Preprocessed user input
2024-05-23 05:01:31 - Preprocessed user input
2024-05-23 05:01:31 - Preprocessed user input
2024-05-23 05:01:31 - Preprocessed user input
2024-05-23 05:01:31 - Preprocessed user input
2024-05-23 05:01:31 - Preprocessed user input
2024-05-23 05:01:31 - Preprocessed user input
2024-05-23 05:01:31 - Preprocessed user input
2024-05-23 05:01:31 - Preprocessed user input
2024-05-23 05:01:31 - Preprocessed user input
2024-05-23 05:01:31 - Preprocessed user input
2024-05-23 05:01:31 - Preprocessed user input
2024-05-23 05:01:31 - Preprocessed user input
2024-05-23 05:01:31 - Preprocessed user input
2024-05-23 05:01:31 - Preprocessed user input
2024-05-23 05:01:31 - Preprocessed user input
2024-05-23 05:01:31 - Preprocessed user input
2024-05-23 05:01:31 - Preprocessed user input
2024-05-23 05:01:31 - Preprocessed

Processing queries:  41%|████      | 123/300 [00:00<00:00, 201.61it/s]

2024-05-23 05:01:31 - Preprocessed user input
2024-05-23 05:01:31 - Preprocessed user input
2024-05-23 05:01:31 - Preprocessed user input
2024-05-23 05:01:31 - Preprocessed user input
2024-05-23 05:01:31 - Preprocessed user input
2024-05-23 05:01:31 - Preprocessed user input
2024-05-23 05:01:31 - Preprocessed user input


Processing queries:  48%|████▊     | 144/300 [00:10<00:23,  6.51it/s] 

2024-05-23 05:01:41 - >>> Generated embeddings for batch 0
2024-05-23 05:01:41 - Shape of embeddings: (128, 1024)
2024-05-23 05:01:41 - Preprocessed user input
2024-05-23 05:01:41 - Preprocessed user input
2024-05-23 05:01:41 - Preprocessed user input
2024-05-23 05:01:41 - Preprocessed user input
2024-05-23 05:01:41 - Preprocessed user input
2024-05-23 05:01:41 - Preprocessed user input
2024-05-23 05:01:41 - Preprocessed user input
2024-05-23 05:01:41 - Preprocessed user input
2024-05-23 05:01:41 - Preprocessed user input
2024-05-23 05:01:41 - Preprocessed user input
2024-05-23 05:01:41 - Preprocessed user input
2024-05-23 05:01:41 - Preprocessed user input
2024-05-23 05:01:41 - Preprocessed user input
2024-05-23 05:01:41 - Preprocessed user input
2024-05-23 05:01:41 - Preprocessed user input
2024-05-23 05:01:41 - Preprocessed user input
2024-05-23 05:01:41 - Preprocessed user input
2024-05-23 05:01:41 - Preprocessed user input
2024-05-23 05:01:41 - Preprocessed user input
2024-05-23 0

Processing queries:  57%|█████▋    | 171/300 [00:10<00:11, 10.78it/s]

2024-05-23 05:01:41 - Preprocessed user input
2024-05-23 05:01:41 - Preprocessed user input
2024-05-23 05:01:41 - Preprocessed user input
2024-05-23 05:01:41 - Preprocessed user input
2024-05-23 05:01:41 - Preprocessed user input
2024-05-23 05:01:41 - Preprocessed user input
2024-05-23 05:01:41 - Preprocessed user input
2024-05-23 05:01:41 - Preprocessed user input
2024-05-23 05:01:41 - Preprocessed user input
2024-05-23 05:01:41 - Preprocessed user input
2024-05-23 05:01:41 - Preprocessed user input
2024-05-23 05:01:41 - Preprocessed user input
2024-05-23 05:01:41 - Preprocessed user input
2024-05-23 05:01:41 - Preprocessed user input
2024-05-23 05:01:41 - Preprocessed user input
2024-05-23 05:01:41 - Preprocessed user input
2024-05-23 05:01:41 - Preprocessed user input
2024-05-23 05:01:41 - Preprocessed user input
2024-05-23 05:01:41 - Preprocessed user input
2024-05-23 05:01:41 - Preprocessed user input
2024-05-23 05:01:41 - Preprocessed user input
2024-05-23 05:01:41 - Preprocessed

Processing queries:  70%|███████   | 211/300 [00:10<00:03, 23.07it/s]

2024-05-23 05:01:41 - Preprocessed user input
2024-05-23 05:01:41 - Preprocessed user input
2024-05-23 05:01:41 - Preprocessed user input
2024-05-23 05:01:41 - Preprocessed user input
2024-05-23 05:01:41 - Preprocessed user input
2024-05-23 05:01:41 - Preprocessed user input
2024-05-23 05:01:41 - Preprocessed user input
2024-05-23 05:01:41 - Preprocessed user input
2024-05-23 05:01:41 - Preprocessed user input
2024-05-23 05:01:41 - Preprocessed user input
2024-05-23 05:01:41 - Preprocessed user input
2024-05-23 05:01:41 - Preprocessed user input
2024-05-23 05:01:41 - Preprocessed user input
2024-05-23 05:01:41 - Preprocessed user input
2024-05-23 05:01:41 - Preprocessed user input
2024-05-23 05:01:41 - Preprocessed user input
2024-05-23 05:01:41 - Preprocessed user input
2024-05-23 05:01:41 - Preprocessed user input
2024-05-23 05:01:41 - Preprocessed user input
2024-05-23 05:01:41 - Preprocessed user input
2024-05-23 05:01:41 - Preprocessed user input
2024-05-23 05:01:41 - Preprocessed

Processing queries:  78%|███████▊  | 234/300 [00:10<00:01, 33.67it/s]

2024-05-23 05:01:41 - Preprocessed user input
2024-05-23 05:01:41 - Preprocessed user input
2024-05-23 05:01:41 - Preprocessed user input
2024-05-23 05:01:41 - Preprocessed user input
2024-05-23 05:01:41 - Preprocessed user input
2024-05-23 05:01:41 - Preprocessed user input
2024-05-23 05:01:41 - Preprocessed user input
2024-05-23 05:01:41 - Preprocessed user input
2024-05-23 05:01:41 - Preprocessed user input
2024-05-23 05:01:41 - Preprocessed user input
2024-05-23 05:01:41 - Preprocessed user input
2024-05-23 05:01:41 - Preprocessed user input
2024-05-23 05:01:41 - Preprocessed user input
2024-05-23 05:01:41 - Preprocessed user input
2024-05-23 05:01:41 - Preprocessed user input
2024-05-23 05:01:41 - Preprocessed user input
2024-05-23 05:01:41 - Preprocessed user input
2024-05-23 05:01:41 - Preprocessed user input
2024-05-23 05:01:41 - Preprocessed user input
2024-05-23 05:01:41 - Preprocessed user input
2024-05-23 05:01:41 - Preprocessed user input
2024-05-23 05:01:41 - Preprocessed

                                                                     

2024-05-23 05:01:49 - >>> Generated embeddings for batch 1
2024-05-23 05:01:49 - Shape of embeddings: (128, 1024)
2024-05-23 05:01:49 - Preprocessed user input
2024-05-23 05:01:49 - Preprocessed user input
2024-05-23 05:01:49 - Preprocessed user input
2024-05-23 05:01:49 - Preprocessed user input
2024-05-23 05:01:49 - Preprocessed user input
2024-05-23 05:01:49 - Preprocessed user input
2024-05-23 05:01:50 - Preprocessed user input
2024-05-23 05:01:50 - Preprocessed user input
2024-05-23 05:01:50 - Preprocessed user input
2024-05-23 05:01:50 - Preprocessed user input
2024-05-23 05:01:50 - Preprocessed user input
2024-05-23 05:01:50 - Preprocessed user input
2024-05-23 05:01:50 - Preprocessed user input
2024-05-23 05:01:50 - Preprocessed user input
2024-05-23 05:01:50 - Preprocessed user input
2024-05-23 05:01:50 - Preprocessed user input
2024-05-23 05:01:50 - Preprocessed user input
2024-05-23 05:01:50 - Preprocessed user input
2024-05-23 05:01:50 - Preprocessed user input
2024-05-23 0

Processing queries: 100%|██████████| 300/300 [00:19<00:00, 15.70it/s]


2024-05-23 05:01:50 - Preprocessed user input
2024-05-23 05:01:50 - Preprocessed user input
2024-05-23 05:01:50 - Preprocessed user input
2024-05-23 05:01:50 - Preprocessed user input
2024-05-23 05:01:50 - Preprocessed user input
2024-05-23 05:01:50 - Preprocessed user input
2024-05-23 05:01:50 - Preprocessed user input
2024-05-23 05:01:50 - Preprocessed user input
2024-05-23 05:01:50 - Preprocessed user input
2024-05-23 05:01:50 - Preprocessed user input
2024-05-23 05:01:50 - Preprocessed user input
2024-05-23 05:01:50 - Preprocessed user input
2024-05-23 05:01:50 - Preprocessed user input
2024-05-23 05:01:50 - Preprocessed user input
2024-05-23 05:01:50 - Preprocessed user input
2024-05-23 05:01:50 - Preprocessed user input
2024-05-23 05:01:50 - Preprocessed user input
2024-05-23 05:01:50 - Preprocessed user input
2024-05-23 05:01:50 - Preprocessed user input
2024-05-23 05:01:52 - >>> Generated embeddings for final batch 2
2024-05-23 05:01:52 - Shape of embeddings: (44, 1024)


In [41]:
results_df = pd.DataFrame({
    'query_id': queries_df['query_id'],
    'query_embedding': query_embeddings_list
})
results_df.dtypes

query_id           object
query_embedding    object
dtype: object

In [45]:
DOC_EMBEDDINGS = np.vstack(embeddings_df['embedding'].apply(lambda x: np.array(x)))

def find_similar_embeddings(
    df, query_embedding, top_k=10, similarity_threshold=0.5
):
    # query_embedding = generate_embeddings([query_text], device=device)
    if isinstance(query_embedding, list):
        query_embedding = np.array(query_embedding).reshape(1, -1)
        # query_embedding = query_embedding.flatten().reshape(1, -1)


    # Calculate cosine similarities
    similarities = cosine_similarity(query_embedding, DOC_EMBEDDINGS).flatten()

    # Add similarities to DataFrame
    df['similarity'] = similarities

    # Filter based on similarity threshold
    results_df = df[df['similarity'] > similarity_threshold].sort_values(by='similarity', ascending=False).head(top_k)


    return results_df['doc_id'].values.tolist(), results_df['similarity'].values.tolist(), results_df['text'].values.tolist()


In [46]:
for idx, row in tqdm(results_df.iterrows(), total=len(results_df), desc="Finding similar embeddings"):
    doc_ids, sim_scores, similar_texts = find_similar_embeddings(embeddings_df, row['query_embedding'], top_k=k)
    # retrieved_docs_list.append(doc_ids)
    # retrieved_sim_list.append(sim_scores)
    # retrieved_text_list.append(similar_texts)
    break


Finding similar embeddings:   0%|          | 0/300 [00:04<?, ?it/s]


In [None]:
from tqdm import tqdm_notebook as tqdm

In [48]:
retrieved_docs_list = []
retrieved_text_list = []
retrieved_sim_list = []
query_ids = []
k=20

for idx, row in tqdm(results_df.iterrows(), total=len(results_df), desc="Finding similar embeddings"):
    doc_ids, sim_scores, similar_texts = find_similar_embeddings(embeddings_df, row['query_embedding'], top_k=k)
    retrieved_docs_list.append(doc_ids)
    retrieved_sim_list.append(sim_scores)
    retrieved_text_list.append(similar_texts)


# Create a DataFrame to save the results
if os.path.exists(os.path.join(data_path, "retrieval_results.csv")):
    retrieval_df = pd.read_csv(os.path.join(data_path, "retrieval_results.csv"))
    
    retrieval_df[f'retrieved_docs_top{k}'] = retrieved_docs_list
    retrieval_df[f'retrieved_texts_top{k}'] = retrieved_text_list
    retrieval_df[f'retrieved_sim_top{k}'] = retrieved_sim_list
else:
    
    retrieval_df = pd.DataFrame({
        'query_id': queries_df['query_id'],
        f'retrieved_docs_top{k}': retrieved_docs_list,
        f'retrieved_texts_top{k}': retrieved_text_list,
        f'retrieved_sim_top{k}': retrieved_sim_list
    })

Finding similar embeddings:  26%|██▋       | 79/300 [05:13<15:09,  4.11s/it]

: 

: 

In [76]:
retrieval_df.to_csv(os.path.join(data_path, "retrieval_results.csv"), index=False)

In [60]:
retrieved_docs_list

[[4346436,
  58050905,
  15337254,
  17388232,
  7583104,
  327319,
  8290953,
  9629682,
  11172205,
  19855358],
 [4414547,
  23389795,
  1388704,
  19058822,
  3662132,
  10145528,
  27408104,
  13914198,
  14717500,
  3444507],
 [76415938,
  9764256,
  13734012,
  9813098,
  13734012,
  9764256,
  10300000,
  3413083,
  9650982,
  9764256],
 [4791384,
  27099731,
  1263446,
  33257464,
  8529693,
  1263446,
  27099731,
  49429882,
  356218,
  4791384],
 [3215494,
  18557974,
  37424881,
  18256197,
  10557471,
  21636085,
  9813098,
  16252863,
  11705328,
  12130200],
 [18174210,
  27889071,
  24042919,
  8083310,
  12240507,
  1145473,
  27162821,
  10128893,
  10526279,
  4449524],
 [13734012,
  11349166,
  27063470,
  13734012,
  18617259,
  253672,
  9813098,
  27063470,
  11349166,
  3716075],
 [5953485,
  5702790,
  23913146,
  6828370,
  7029990,
  86602746,
  84784389,
  22362025,
  2619579,
  2251426],
 [12580014,
  18488986,
  25738896,
  15405204,
  6945691,
  36950726,

In [61]:
retrieval_df.shape

(300, 4)

In [68]:
qrels_df = pd.DataFrame({
    'query_id': list(qrels.keys()),
    'relevant_docs_': [list(docs.keys()) for docs in list(qrels.values())]
})

# Evaluation

In [14]:
retrieval_df = pd.read_csv(os.path.join(data_path, "retrieval_results.csv"))

In [19]:
retrieval_df['relevant_docs'] = retrieval_df['relevant_docs'].str.replace("'", "")

In [20]:
DOCS_IN_CORPUS = list(corpus_df.index)
DOCS_IN_CORPUS

['4983',
 '5836',
 '7912',
 '18670',
 '19238',
 '33370',
 '36474',
 '54440',
 '70115',
 '70490',
 '72159',
 '79447',
 '87758',
 '92308',
 '92499',
 '97884',
 '102662',
 '103007',
 '104130',
 '106301',
 '116792',
 '118568',
 '120626',
 '123859',
 '140874',
 '143251',
 '152245',
 '153744',
 '159469',
 '164189',
 '164985',
 '169264',
 '175735',
 '188911',
 '195352',
 '202259',
 '207972',
 '213017',
 '219475',
 '226488',
 '236204',
 '238409',
 '243694',
 '253672',
 '263364',
 '266641',
 '275294',
 '279052',
 '285794',
 '293661',
 '301838',
 '301866',
 '306006',
 '306311',
 '308862',
 '313394',
 '313403',
 '317204',
 '323030',
 '323335',
 '327319',
 '335029',
 '341324',
 '343052',
 '344240',
 '350542',
 '356218',
 '364522',
 '365896',
 '368506',
 '371289',
 '374902',
 '380526',
 '381602',
 '393001',
 '406733',
 '409280',
 '410286',
 '418246',
 '427082',
 '427865',
 '432261',
 '435529',
 '437924',
 '439670',
 '456304',
 '457630',
 '461550',
 '463309',
 '463533',
 '464511',
 '469066',
 '47062

In [21]:
from sklearn.metrics import precision_score, recall_score

def compute_accuracy(retrieved_docs, relevant_docs, top_k=10):
    retrieved_docs = retrieved_docs[1:-1].split(", ")[:top_k]
    relevant_docs = relevant_docs[1:-1].split(", ")[:top_k]
    retrieved_docs_set = set(retrieved_docs)
    relevant_docs_set = set(relevant_docs)
    logging.info(f"Retrieved documents: {retrieved_docs}")
    logging.info(f"Relevant documents: {relevant_docs}")
    common_docs_set = retrieved_docs_set.intersection(relevant_docs_set)
    logging.info(f"Common documents: {common_docs_set}")
    # Calculate accuracy as the number of common documents divided by the total number of relevant documents
    intersect_accuracy = len(common_docs_set) / len(relevant_docs_set)
    precision_inner = len(common_docs_set) / len(retrieved_docs_set) # TP / (TP + FP)
    
    # Get the index of the relevant documents in the corpus
    relevant_docs_ohe_idx = [DOCS_IN_CORPUS.index(doc) for doc in relevant_docs]
    relevant_docs_ohe = np.zeros(len(DOCS_IN_CORPUS))
    relevant_docs_ohe[relevant_docs_ohe_idx] = 1
        
    # Get the index of the retrieved documents in the corpus
    retrieved_docs_ohe_idx = [DOCS_IN_CORPUS.index(doc) for doc in retrieved_docs]
    retrieved_docs_ohe = np.zeros(len(DOCS_IN_CORPUS))
    retrieved_docs_ohe[retrieved_docs_ohe_idx] = 1
    
    # Calculate precision, recall
    precision_ohe = precision_score(relevant_docs_ohe, retrieved_docs_ohe)
    recall_ohe = recall_score(relevant_docs_ohe, retrieved_docs_ohe)
     
    
    return intersect_accuracy, precision_inner, precision_ohe, recall_ohe

In [30]:
k=3
retrieval_df[[f"accuracy_top{k}", f"precision_inner_top{k}", f"precision_ohe_top{k}", f"recall_ohe_top{k}"]] = retrieval_df.apply(lambda x: compute_accuracy(x[f"retrieved_docs_top10"], x['relevant_docs'], top_k=k), axis=1, result_type="expand")

2024-05-23 04:48:29 - Retrieved documents: ['4346436', '58050905', '15337254']
2024-05-23 04:48:29 - Relevant documents: ['31715818']
2024-05-23 04:48:29 - Common documents: set()
2024-05-23 04:48:29 - Retrieved documents: ['4414547', '23389795', '1388704']
2024-05-23 04:48:29 - Relevant documents: ['14717500']
2024-05-23 04:48:29 - Common documents: set()
2024-05-23 04:48:29 - Retrieved documents: ['76415938', '9764256', '13734012']
2024-05-23 04:48:29 - Relevant documents: ['13734012']
2024-05-23 04:48:29 - Common documents: {'13734012'}
2024-05-23 04:48:29 - Retrieved documents: ['4791384', '27099731', '1263446']
2024-05-23 04:48:29 - Relevant documents: ['1606628']
2024-05-23 04:48:29 - Common documents: set()
2024-05-23 04:48:29 - Retrieved documents: ['3215494', '18557974', '37424881']
2024-05-23 04:48:29 - Relevant documents: ['5152028', '11705328']
2024-05-23 04:48:29 - Common documents: set()
2024-05-23 04:48:29 - Retrieved documents: ['18174210', '27889071', '24042919']
2024-

In [25]:
# Merge the retrieval results with the qrels
# retrieval_df = retrieval_df.merge(qrels_df, how='inner', on='query_id')
# retrieval_df

In [23]:
retrieval_df.head()

Unnamed: 0,query_id,retrieved_docs_top10,retrieved_texts_top10,retrieved_sim_top10,relevant_docs,accuracy_top10,precision_inner_top10,precision_ohe_top10,recall_ohe_top10
0,1,"[4346436, 58050905, 15337254, 17388232, 758310...","['Unlike most synthetic materials, biological ...","[0.9847771606719984, 0.98458850008293, 0.98433...",[31715818],0.0,0.0,0.0,0.0
1,3,"[4414547, 23389795, 1388704, 19058822, 3662132...","['More generally, these data provide new insig...","[0.9881532193663816, 0.9855963347138302, 0.984...",[14717500],1.0,0.1,0.1,1.0
2,5,"[76415938, 9764256, 13734012, 9813098, 1373401...",['Eighty-six percent (102 of 123) of the patie...,"[0.982130139691713, 0.981729473035609, 0.98163...",[13734012],1.0,0.142857,0.142857,1.0
3,13,"[4791384, 27099731, 1263446, 33257464, 8529693...","[""BACKGROUND Historically, the main focus of s...","[0.9849229915319802, 0.9842265099382834, 0.984...",[1606628],0.0,0.0,0.0,0.0
4,36,"[3215494, 18557974, 37424881, 18256197, 105574...",['Hyperhomocysteinemia has recently been ident...,"[0.9864817966627987, 0.9863759639901892, 0.985...","[5152028, 11705328]",0.5,0.1,0.1,0.5


In [25]:
def compute_metrics(retrieval_df, top_k=10):
    # Compute the mean accuracy, precision, and recall
    accuracy = retrieval_df[f"accuracy_top{top_k}"].mean()
    precision_inner = retrieval_df[f"precision_inner_top{top_k}"].mean()
    precision_ohe = retrieval_df[f"precision_ohe_top{top_k}"].mean()
    recall_ohe = retrieval_df[f"recall_ohe_top{top_k}"].mean()
    
    return accuracy, precision_inner, precision_ohe, recall_ohe

In [32]:
results_df = pd.DataFrame({
    'accuracy': [],
    'precision_inner': [],
    'precision_ohe': [],
    'recall_ohe': []
})

In [28]:
# Results for the top-10 retrieved documents
accuracy, precision_inner, precision_ohe, recall_ohe = compute_metrics(retrieval_df, top_k=10)
print(f"Accuracy @top-10: {accuracy}")
print(f"Precision @top-10 (Inner): {precision_inner}")
print(f"Precision @top-10 (One-hot): {precision_ohe}")
print(f"Recall @top-10 (One-hot): {recall_ohe}")

Accuracy @top-10: 0.8112222222222223
Precision @top-10 (Inner): 0.10205291005291005
Precision @top-10 (One-hot): 0.10205291005291005
Recall @top-10 (One-hot): 0.8112222222222223


In [29]:
accuracy, precision_inner, precision_ohe, recall_ohe = compute_metrics(retrieval_df, top_k=5)
print(f"Accuracy @top-5: {accuracy}")
print(f"Precision @top-5 (Inner): {precision_inner}")
print(f"Precision @top-5 (One-hot): {precision_ohe}")
print(f"Recall @top-5 (One-hot): {recall_ohe}")

Accuracy @top-5: 0.7278888888888889
Precision @top-5 (Inner): 0.18255555555555555
Precision @top-5 (One-hot): 0.18255555555555555
Recall @top-5 (One-hot): 0.7278888888888889


In [31]:
accuracy, precision_inner, precision_ohe, recall_ohe = compute_metrics(retrieval_df, top_k=3)
print(f"Accuracy @top-3: {accuracy}")
print(f"Precision @top-3 (Inner): {precision_inner}")
print(f"Precision @top-3 (One-hot): {precision_ohe}")
print(f"Recall @top-3 (One-hot): {recall_ohe}")

Accuracy @top-3: 0.6855555555555555
Precision @top-3 (Inner): 0.2744444444444444
Precision @top-3 (One-hot): 0.2744444444444444
Recall @top-3 (One-hot): 0.6855555555555555


In [39]:
retrieval_df.to_csv(os.path.join(data_path, "retrieval_results.csv"), index=False)

In [None]:
# Apply exact keyword matching using the BM25 algorithm
from rank_bm25 import BM25Okapi
