Load necessary packages

In [56]:
import pandas as pd
import numpy as np
import re
import nltk
from tqdm import tqdm
from nltk.corpus import stopwords
from transformers import BertTokenizer, BertModel
import torch
import os
# no lowercasing here, as we do it ourselves later.
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased", do_lower_case=False)
nltk.download("stopwords")
stop_words = set(stopwords.words("english"))
tqdm.pandas()

[nltk_data] Downloading package stopwords to
[nltk_data]     C:\Users\vitaf\AppData\Roaming\nltk_data...
[nltk_data]   Package stopwords is already up-to-date!


Data loading and inspection

In [57]:
df = pd.read_json("../data/trec-medline.json", lines=True)

In [58]:
id_rows = df.iloc[::2].reset_index(drop=True)[["index"]]
id_rows["index"] = id_rows["index"].apply(lambda x: int(x["_id"]))
content_rows = df.iloc[1::2].reset_index(drop=True).drop(labels=["index"], axis=1)
combined_df = pd.concat([id_rows, content_rows], axis=1)

combined_df.head()

Unnamed: 0,index,AB,AD,CY,DA,DCOM,DP,EDAT,ID,IP,...,CON,CIN,RPF,RPI,SPIN,RIN,ROF,ORI,UOF,UIN
0,1,We present an evaluation of the accuracy and p...,Department of Molecular Biology and Skaggs Ins...,Netherlands,20011105.0,20020401.0,2001 Sep,2001/11/06 10:00,GM56879/GM/NIGMS,1,...,,,,,,,,,,
1,2,An analysis is presented of experimental versu...,"Department of Medical Biosciences, Medical Bio...",Netherlands,20011105.0,20020401.0,2001 Sep,2001/11/06 10:00,,1,...,,,,,,,,,,
2,3,The global fold of maltose binding protein in ...,Protein Engineering Network Center of Excellen...,Netherlands,20011105.0,20020401.0,2001 Sep,2001/11/06 10:00,,1,...,,,,,,,,,,
3,4,A general method is presented for magnetic fie...,"Molecular Structure Division, National Institu...",Netherlands,20011105.0,20020401.0,2001 Sep,2001/11/06 10:00,,1,...,,,,,,,,,,
4,5,The dependence between the anomeric carbon che...,"Department of Chemistry & Biochemistry, Univer...",Netherlands,20011105.0,20020401.0,2001 Sep,2001/11/06 10:00,,1,...,,,,,,,,,,


In [59]:
docs = combined_df[["index", "AB", "PMID"]]
docs = docs.astype({"index": int, "PMID": int})
docs.head(5)

Unnamed: 0,index,AB,PMID
0,1,We present an evaluation of the accuracy and p...,11693564
1,2,An analysis is presented of experimental versu...,11693565
2,3,The global fold of maltose binding protein in ...,11693566
3,4,A general method is presented for magnetic fie...,11693567
4,5,The dependence between the anomeric carbon che...,11693568


In [60]:
# load queries
queries = pd.DataFrame(columns=["index", "query"])

with open("../data/training-queries-simple.txt", "r") as f:
    lines = f.readlines()

data = []
for line in lines:
    x = line.strip().split("\t")
    if len(x) >= 2:  
        data.append({"index": int(x[0]), "query": x[1]})
    else:
        raise ValueError("wtf")
queries = pd.concat([queries, pd.DataFrame(data)], ignore_index=True)
queries.head(5)
print(queries.isna().sum())

index    0
query    0
dtype: int64


In [61]:
# drop missings
print(docs.isna().sum())
docs = docs.dropna()
def remove_short_strings(df, column_name):
    pattern = re.compile(r'\W+')
    filtered_df = df[df[column_name].apply(
        lambda x: isinstance(x, str) and len(pattern.sub('', x)) >= 20
    )].copy()
    filtered_df.reset_index(drop=True, inplace=True)
    
    return filtered_df
print(docs.shape[0])
docs = remove_short_strings(docs, "AB")
print(docs.shape[0])

index         0
AB       123568
PMID          0
dtype: int64
402369
401929


In [62]:
# find max words
max_words = docs['AB'].apply(lambda x: len(x.split())).max()
print(max_words)

1529


In [63]:
# load query results
query_res = pd.DataFrame(columns=["query_index", "doc_index", "relevant"])

with open("../data/training-qrels.txt", "r") as f:
    lines = f.readlines()

data = []
for line in lines:
    x = line.strip().split("\t")
    if len(x) >= 4:  
        data.append({"query_index": int(x[0]), "doc_index": int(x[2]), "relevant": int(x[3])})
    else:
        raise ValueError("wtf")
query_res = pd.concat([query_res, pd.DataFrame(data)], ignore_index=True)
print(query_res.head(5))
print(query_res.isna().sum())

  query_index doc_index relevant
0           1  11642719        1
1           1  11695244        1
2           1  11700040        1
3           1  11733969        1
4           1  11741909        1
query_index    0
doc_index      0
relevant       0
dtype: int64


In [64]:
# combine queries and results

filtered_df = query_res[query_res["relevant"] == 1]
grouped_df = filtered_df.groupby('query_index')['doc_index'].apply(list).reset_index()
grouped_df = grouped_df.rename(columns={'doc_index': 'relevant_docs'})
queries_training = pd.concat([queries, grouped_df], axis=1)
queries_training = queries_training.drop(columns=["query_index"])
queries_training.head(5)

Unnamed: 0,index,query,relevant_docs
0,1,"""cyclin-dependent kinase inhibitor 1A (p21, Ci...","[11642719, 11695244, 11700040, 11733969, 11741..."
1,2,"""DEAD/H (Asp-Glu-Ala-Asp/His) box polypeptide ...","[12101238, 12527917]"
2,3,ets variant gene 6 (TEL oncogene) in Homo sapiens,"[11731410, 11861293, 11861295, 12080468, 12091..."
3,4,fibroblast growth factor 7 (keratinocyte growt...,"[11937263, 11943656, 11973338, 12008951, 12016..."
4,5,"""glycine receptor, alpha 1 (startle disease/hy...","[11580237, 11781706, 11973623, 11981020, 11981..."


In [65]:
# inspect final datasetsts
queries_training.head(5)

Unnamed: 0,index,query,relevant_docs
0,1,"""cyclin-dependent kinase inhibitor 1A (p21, Ci...","[11642719, 11695244, 11700040, 11733969, 11741..."
1,2,"""DEAD/H (Asp-Glu-Ala-Asp/His) box polypeptide ...","[12101238, 12527917]"
2,3,ets variant gene 6 (TEL oncogene) in Homo sapiens,"[11731410, 11861293, 11861295, 12080468, 12091..."
3,4,fibroblast growth factor 7 (keratinocyte growt...,"[11937263, 11943656, 11973338, 12008951, 12016..."
4,5,"""glycine receptor, alpha 1 (startle disease/hy...","[11580237, 11781706, 11973623, 11981020, 11981..."


In [66]:
docs.head(5)

Unnamed: 0,index,AB,PMID
0,1,We present an evaluation of the accuracy and p...,11693564
1,2,An analysis is presented of experimental versu...,11693565
2,3,The global fold of maltose binding protein in ...,11693566
3,4,A general method is presented for magnetic fie...,11693567
4,5,The dependence between the anomeric carbon che...,11693568


In [67]:
# check if all doc ids from queries are in the dataset (after removing missings)
unique_relevant_docs = set(queries_training['relevant_docs'].explode())
existing_docs = unique_relevant_docs.intersection(docs.PMID)
missing_docs = unique_relevant_docs.difference(docs.PMID)

print(missing_docs)
print(f"Number of relevant docs: {len(unique_relevant_docs)}")
print(f"Number of existing docs in 'docs' DataFrame: {len(existing_docs)}")
print(f"Number of missing docs: {len(missing_docs)}")

{12147208, 12147209, 11861518, 11688978, 11822867, 11714840, 12027934, 11374883, 11406125, 11042116, 11717190, 11700040, 11781193, 11781706, 11564874, 11580237, 11882578, 11846485, 11642719, 11685227, 11466351, 11841916, 11752574, 11752575, 11779460, 11740559, 11727760, 12412576, 11686318, 11441070, 11809712, 11743158, 11701948, 11749055, 11842244, 11748297, 11733969, 11731410, 11741909, 11752172, 11751405, 12161015}
Number of relevant docs: 327
Number of existing docs in 'docs' DataFrame: 285
Number of missing docs: 42


In [90]:
# a smaller dataset
existing_docs = set(unique_relevant_docs).intersection(set(docs['PMID']))

# Select the rows from 'docs' where 'PMID' is in 'existing_docs'
selected_docs = docs[docs['PMID'].isin(existing_docs)]

In [68]:
# ofc there are missings relevant texts with no abstract. nice dataset:)
def filter_missing_docs(doc_list):
    return [doc for doc in doc_list if doc in existing_docs]
    
queries_training['relevant_docs'] = queries_training['relevant_docs'].apply(filter_missing_docs)
# remove quries with no docs


In [69]:
queries_training = queries_training[queries_training['relevant_docs'].apply(lambda x: len(x) > 0)]

In [None]:
# final results
queries_training = queries_training.drop(columns=["index"])
docs = docs.drop(columns=["index"])

In [71]:
def preprocess_text(text, remove_stopwords=True):
    copy = text
    copy = copy.lower()
    # remove punctuation
    copy = re.sub(r"[^\w\s]", '', copy)
    # remove double whitespaces.
    copy = re.sub(r'\s+', ' ', copy).strip()
    
    if remove_stopwords:
        copy = ' '.join(w for w in copy.split() if w not in stop_words)
    return copy

def tokenize_text(text):
    copy = tokenizer.tokenize(text)
    return copy

# copy the dataframes to not break anything when re-running some cells in the notebook.
docs_copy = docs.copy(deep=True)
queries_copy = queries_training.copy(deep=True)
print("Preprocessing docs")
docs_copy["tokens"] = docs_copy["AB"].progress_apply(preprocess_text)
print("Tokenizing docs")
docs_copy["tokens"] = docs_copy["tokens"].progress_apply(tokenize_text)
print("Preprocessing queries")
queries_copy["tokens"] = queries_copy["query"].progress_apply(preprocess_text)
print("Tokenizing queriess")
queries_copy["tokens"] = queries_copy["tokens"].progress_apply(tokenize_text)

docs_copy.head(3)


Preprocessing docs


100%|██████████| 401929/401929 [00:25<00:00, 15784.20it/s]


Tokenizing docs


100%|██████████| 401929/401929 [09:54<00:00, 675.87it/s] 


Preprocessing queries


100%|██████████| 47/47 [00:00<00:00, 46913.92it/s]


Tokenizing queriess


100%|██████████| 47/47 [00:00<00:00, 9401.13it/s]


Unnamed: 0,AB,PMID,tokens
0,We present an evaluation of the accuracy and p...,11693564,"[present, evaluation, accuracy, precision, rel..."
1,An analysis is presented of experimental versu...,11693565,"[analysis, presented, experimental, versus, ca..."
2,The global fold of maltose binding protein in ...,11693566,"[global, fold, mal, ##tose, binding, protein, ..."


In [72]:
# save the results into a df
def save_parquet(df, path):
    if not os.path.exists(path):
        df.to_parquet(path, engine="fastparquet", index=False)

path_docs = "../data/docs.parquet"
path_queries = "../data/queries.parquet"

save_parquet(docs_copy, path_docs)
save_parquet(queries_copy, path_queries)

In [46]:
# load parquet
path_docs = "../data/docs.parquet"
path_queries = "../data/queries.parquet"


docs_copy = pd.read_parquet(path_docs, engine="fastparquet")
queries_copy = pd.read_parquet(path_queries, engine="fastparquet")
print(docs_copy.shape[0])

402369


In [47]:
# now, the fun part - bert.
model = BertModel.from_pretrained('bert-base-uncased')
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

def get_embeddings(tokens, window=512, overlap=50):
    window_size = window - 2
    step = window_size - overlap
    embeddings = []
    for i in range(0, len(tokens), step):
        subset = tokens[i:i+window_size]
        subset = [tokenizer.cls_token] + subset + [tokenizer.sep_token]
        ids = tokenizer.convert_tokens_to_ids(subset)
        masks = [1] * len(ids)
        ids_tensor = torch.tensor([ids]).to(device)
        mask_tensor = torch.tensor([masks]).to(device)
        
        with torch.no_grad():
            outputs = model(ids_tensor, attention_mask=mask_tensor)
            cls = outputs.last_hidden_state[:, 0, :].squeeze().cpu().numpy() 
            embeddings.append(cls)
    return np.mean(embeddings, axis=0)
        
    

In [48]:
docs_copy["embeddings"] = docs_copy["tokens"].progress_apply(get_embeddings)

  return _methods._mean(a, axis=axis, dtype=dtype,
100%|██████████| 402369/402369 [36:44<00:00, 182.56it/s]


In [54]:

docs_embeddings = docs_copy[["PMID", "embeddings"]]
path_docs = "../data/doc_embeddings.npz"
invalid_rows = docs_copy[docs_copy['embeddings'].apply(lambda x: x.size != 768)]
invalid_pmids = set(invalid_rows['PMID'].tolist())

# Filter out rows in `docs_copy` where the PMID is in `invalid_pmids`
docs_embeddings = docs_embeddings[~docs_embeddings['PMID'].isin(invalid_pmids)].reset_index(drop=True)

ids_list = docs_embeddings['PMID'].values.astype(np.int64)

embeddings_list = docs_embeddings['embeddings'].values
embeddings_array = np.vstack(embeddings_list)
embeddings_array = embeddings_array.astype(np.float32)

np.savez_compressed(path_docs, ids=ids_list, embeddings=embeddings_array)

In [None]:
from sklearn.metrics.pairwise import cosine_similarity
# testing
queries_copy["embedding"] = queries_copy["tokens"].progress_apply(get_embeddings)

def retrieve_docs(comparison_vector, top_n=10, filepath="../data/doc_embeddings.npz"):
    data = np.load(filepath)
    vectors = data["embeddings"]
    pmids = data["ids"]
    
    comparison_vector = comparison_vector.reshape(1, -1)
    cos_similarities = cosine_similarity(comparison_vector, vectors).flatten()
    
    top_indices = cos_similarities.argsort()[-top_n:][::-1]
    #top_docs = [(pmids[index], cos_similarities[index]) for index in top_indices]
    top_docs = [pmids[index] for index in top_indices]
    
    return top_docs

result = queries_copy.copy(deep=True)

result["retrieved_docs"] = result["embedding"].progress_apply(retrieve_docs)
result.head(2)

100%|██████████| 47/47 [00:00<00:00, 92.34it/s] 
100%|██████████| 47/47 [03:56<00:00,  5.02s/it]


Unnamed: 0,query,relevant_docs,tokens,embedding,retrieved_docs
0,"""cyclin-dependent kinase inhibitor 1A (p21, Ci...","[11695244, 11751903, 11756412, 11762751, 11872...","[cy, ##cl, ##ind, ##ep, ##end, ##ent, kinase, ...","[-0.7932618, -0.1177296, -0.5675259, 0.0377362...","[12010809, 12445194, 11931770, 12027893, 11953..."
1,"""DEAD/H (Asp-Glu-Ala-Asp/His) box polypeptide ...","[12101238, 12527917]","[dead, ##h, as, ##pg, ##lu, ##ala, ##as, ##phi...","[-0.6105295, 0.026878148, -0.1451509, -0.00151...","[12509243, 12423350, 12036072, 12110569, 12435..."


In [89]:
def compute_match_percentage(row):
    retrieved = set(row['retrieved_docs'])
    relevant = set(row['relevant_docs'])
    if not relevant:
        return 0.0  # Avoid division by zero; adjust as needed
    num_matches = len(retrieved & relevant)
    percentage = (num_matches / len(relevant)) * 100
    return percentage

# Apply the function to each row and assign the result to a new column
result['match_percentage'] = result.apply(compute_match_percentage, axis=1)
print(result.head(3))

                                               query  \
0  "cyclin-dependent kinase inhibitor 1A (p21, Ci...   
1  "DEAD/H (Asp-Glu-Ala-Asp/His) box polypeptide ...   
2  ets variant gene 6 (TEL oncogene) in Homo sapiens   

                                       relevant_docs  \
0  [11695244, 11751903, 11756412, 11762751, 11872...   
1                               [12101238, 12527917]   
2  [11861293, 11861295, 12080468, 12091359, 12127...   

                                              tokens  \
0  [cy, ##cl, ##ind, ##ep, ##end, ##ent, kinase, ...   
1  [dead, ##h, as, ##pg, ##lu, ##ala, ##as, ##phi...   
2  [et, ##s, variant, gene, 6, tel, on, ##co, ##g...   

                                           embedding  \
0  [-0.7932618, -0.1177296, -0.5675259, 0.0377362...   
1  [-0.6105295, 0.026878148, -0.1451509, -0.00151...   
2  [-0.47980258, -0.14554822, -0.5667131, 0.12695...   

                                      retrieved_docs  match_percentage  
0  [12010809, 12445194, 119

In [91]:
selected_docs_copy = selected_docs.copy(deep=True)
print("Preprocessing docs")
selected_docs_copy["tokens"] = selected_docs_copy["AB"].progress_apply(preprocess_text)
print("Tokenizing docs")
selected_docs_copy["tokens"] = selected_docs_copy["tokens"].progress_apply(tokenize_text)

Preprocessing docs


100%|██████████| 285/285 [00:00<00:00, 14250.18it/s]


Tokenizing docs


100%|██████████| 285/285 [00:00<00:00, 626.37it/s]


In [92]:
selected_docs_copy["embeddings"] = selected_docs_copy["tokens"].progress_apply(get_embeddings)

100%|██████████| 285/285 [00:01<00:00, 151.69it/s]


In [None]:
selected_queries = queries_copy.copy(deep=True)
def retrieve_docs(comparison_vector, top_n=10, dataset=selected_docs_copy):
    vectors = dataset["embeddings"].values
    pmids = dataset["PMID"].values.astype(np.int64)
    vectors = np.vstack(vectors)
    vectors = vectors.astype(np.float32)

    comparison_vector = comparison_vector.reshape(1, -1)
    cos_similarities = cosine_similarity(comparison_vector, vectors).flatten()
    top_indices = cos_similarities.argsort()[-top_n:][::-1]
    #top_docs = [(pmids[index], cos_similarities[index]) for index in top_indices]
    top_docs = [pmids[index] for index in top_indices]
    
    return top_docs
selected_queries["retrieved_docs"] = selected_queries["embedding"].progress_apply(retrieve_docs)
print(result["match_percentage"].mean())

100%|██████████| 47/47 [00:00<00:00, 824.53it/s]

0.0





In [103]:
def retrieve_all_docs_df(comparison_vector, dataset):
    vectors = dataset["embeddings"].values
    pmids = dataset["PMID"].values.astype(np.int64)
    
    # Stack embeddings into a 2D NumPy array
    vectors = np.vstack(vectors).astype(np.float32)

    # Ensure the comparison vector has the correct shape and type
    comparison_vector = comparison_vector.reshape(1, -1).astype(np.float32)
    
    # Compute cosine similarities between the comparison vector and all document vectors
    cos_similarities = cosine_similarity(comparison_vector, vectors).flatten()
    
    # Create a DataFrame with 'doc_id' and 'cosine_similarity' columns
    df = pd.DataFrame({
        'doc_id': pmids,
        'cosine_similarity': cos_similarities
    })
    
    # Sort the DataFrame by 'cosine_similarity' in descending order
    df_sorted = df.sort_values(by='cosine_similarity', ascending=False).reset_index(drop=True)
    
    return df_sorted

relevant_docs = set(selected_queries["relevant_docs"].iloc[0])
df_retrieved = retrieve_all_docs_df(selected_queries["embedding"].iloc[0], selected_docs_copy)
df_retrieved['relevant?'] = df_retrieved['doc_id'].apply(lambda x: 1 if x in relevant_docs else 0)

print(df_retrieved)
df_retrieved.to_csv("thisisbad.csv")

       doc_id  cosine_similarity  relevant?
0    11905808           0.882718          0
1    12052963           0.874792          0
2    12370803           0.870549          0
3    12172548           0.868151          0
4    12429910           0.867177          1
..        ...                ...        ...
280  12039527           0.763055          0
281  12021067           0.757628          0
282  12393285           0.750171          0
283  12210105           0.743329          0
284  11906328           0.721113          0

[285 rows x 3 columns]
