In [6]:
import ir_datasets
from nltk.corpus import stopwords, wordnet
from nltk.stem import PorterStemmer, WordNetLemmatizer
from nltk.tokenize import word_tokenize
from nltk import pos_tag
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
import string
import pandas as pd
import pickle
import os
from spellchecker import SpellChecker
import re
from typing import List

# Load the dataset
dataset = ir_datasets.load("clinicaltrials/2021/trec-ct-2021")

# Create a corpus from the dataset
corpus = {}

counter = 0

for doc in dataset.docs_iter():
    if counter<3000:
        corpus[doc.doc_id]= doc.title+" " + doc.summary+" " +doc.detailed_description+ " "+doc.eligibility
        counter+=1
    else:
        break
        
documents = list(corpus.values())

def custom_tokenizer(text: str) -> List[str]:
    """Tokenizes and lowercases the text."""
    tokens = word_tokenize(text.lower())
    return tokens

def get_wordnet_pos(tag):
    # """Converts POS tag to a format that WordNetLemmatizer can understand."""
    tag = tag[0].upper()
    tag_dict = {"J": wordnet.ADJ,
                "N": wordnet.NOUN,
                "V": wordnet.VERB,
                "R": wordnet.ADV}
    return tag_dict.get(tag, wordnet.NOUN)

# def correct_sentence_spelling(tokens):
#     # """Corrects spelling of tokens."""
#     spell = SpellChecker()
#     misspelled = spell.unknown(tokens)
#     for i, token in enumerate(tokens):
#         if token in misspelled:
#             corrected = spell.correction(token)
#             if corrected is not None:
#                 tokens[i] = corrected
#     return tokens
def remove_markers(tokens: List[str]) -> List[str]:
    """Removes specific markers from tokens."""
    return [re.sub(r'\u00AE', '', token) for token in tokens]

def remove_punctuation(tokens: List[str]) -> List[str]:
    """Removes punctuation from tokens."""
    return [token.translate(str.maketrans('', '', string.punctuation)) for token in tokens]

def replace_under_score_with_space(tokens: List[str]) -> List[str]:
    """Replaces underscores with spaces in tokens."""
    return [re.sub(r'_', ' ', token) for token in tokens]

def remove_apostrophe(tokens: List[str]) -> List[str]:
    """Removes apostrophes from tokens."""
    return [token.replace("'", " ") for token in tokens]

def preprocess_text(text: str) -> str:
    """Preprocesses the input text by tokenizing, removing punctuation, stopwords, and then stemming and lemmatizing."""
    # Convert text to lowercase and tokenize
    text = text.lower()
    words = word_tokenize(text)
    
    # Remove punctuation
    words = [word.translate(str.maketrans('', '', string.punctuation)) for word in words]
    
    # Remove stopwords
    stop_words = set(stopwords.words('english'))
    words = [word for word in words if word not in stop_words]
    
    # Correct spelling
    # words = correct_sentence_spelling(words)
    
    # Further token cleaning
    words = remove_markers(words)
    words = replace_under_score_with_space(words)
    words = remove_apostrophe(words)
    
    # Stemming and Lemmatization
    stemmer = PorterStemmer()
    words = [stemmer.stem(word) for word in words]
    
    pos_tags = pos_tag(words)
    lemmatizer = WordNetLemmatizer()
    words = [lemmatizer.lemmatize(word, pos=get_wordnet_pos(tag)) for word, tag in pos_tags]
    
    return ' '.join(words)



In [7]:

# Vectorizer setup
vectorizer = TfidfVectorizer(tokenizer=custom_tokenizer, preprocessor=preprocess_text)
tfidf_matrix = vectorizer.fit_transform(documents)
tfidf_model = vectorizer
#df = pd.DataFrame.sparse.from_spmatrix(tfidf_matrix, columns=vectorizer.get_feature_names_out(), index=corpus.keys())






In [8]:

# Save and load functions for TF-IDF data
def save_file(file_location: str, content):
    if os.path.exists(file_location):
        os.remove(file_location)
    with open(file_location, 'wb') as handle:
        pickle.dump(content, handle, protocol=pickle.HIGHEST_PROTOCOL)

def load_file(file_location: str):
    with open(file_location, 'rb') as handle:
        content = pickle.load(handle)
    return content

def save_tfidf_data(tfidf_matrix, tfidf_model):
    save_file("tfidf_matrix.pickle", tfidf_matrix)
    save_file("tfidf_model.pickle", tfidf_model)

save_tfidf_data(tfidf_matrix, tfidf_model)


In [9]:

def process_query(query: str, tfidf_model, tfidf_matrix):
    query_tfidf = tfidf_model.transform([query])
    cosine_similarities = cosine_similarity(query_tfidf, tfidf_matrix).flatten()
    ranked_doc_indices = cosine_similarities.argsort()[::-1]
    return ranked_doc_indices, cosine_similarities

tfidf_matrix = load_file("tfidf_matrix.pickle")
tfidf_model = load_file("tfidf_model.pickle")

def getRetrievedQueries(query: str, k=10):
    preprocessed_query = preprocess_text(query)
    ranked_indices, _ = process_query(preprocessed_query, tfidf_model, tfidf_matrix)
    idsList = []
    for idx in ranked_indices[:k]:
        doc_id = list(corpus.keys())[idx]
        idsList.append(doc_id)
    return idsList

def calculate_recall_precision(query_id):
    relevant_docs = []
    for qrel in dataset.qrels_iter():
        if qrel[0] == query_id and qrel[2] > 0:
            relevant_docs.append(qrel[1])

    retrieved_docs = []
    for query in dataset.queries_iter():
        if query[0] == query_id:
            retrieved_docs = getRetrievedQueries(query[1])
            break  

    y_true = [1 if doc_id in relevant_docs else 0 for doc_id in retrieved_docs]
    true_positives = sum(y_true)
    recall_at_10 = true_positives / len(relevant_docs) if relevant_docs else 0
    precision_at_10 = true_positives / 10
    print(f"Query ID: {query_id}, Recall@10: {recall_at_10}")
    print(f"Query ID: {query_id}, Precision@10: {precision_at_10}")    
    return recall_at_10

queries_ids = {qrel[0]: '' for qrel in dataset.qrels_iter()}

for query_id in list(queries_ids.keys()):
    calculate_recall_precision(query_id)


Query ID: 1, Recall@10: 0.047337278106508875
Query ID: 1, Precision@10: 0.8
Query ID: 2, Recall@10: 0.0
Query ID: 2, Precision@10: 0.0
Query ID: 3, Recall@10: 0.011904761904761904
Query ID: 3, Precision@10: 0.1
Query ID: 4, Recall@10: 0.0
Query ID: 4, Precision@10: 0.0
Query ID: 5, Recall@10: 0.004975124378109453
Query ID: 5, Precision@10: 0.1
Query ID: 6, Recall@10: 0.0
Query ID: 6, Precision@10: 0.0
Query ID: 7, Recall@10: 0.0
Query ID: 7, Precision@10: 0.0
Query ID: 8, Recall@10: 0.0
Query ID: 8, Precision@10: 0.0
Query ID: 9, Recall@10: 0.009345794392523364
Query ID: 9, Precision@10: 0.2
Query ID: 10, Recall@10: 0.02127659574468085
Query ID: 10, Precision@10: 0.1
Query ID: 11, Recall@10: 0.015873015873015872
Query ID: 11, Precision@10: 0.2
Query ID: 12, Recall@10: 0.013422818791946308
Query ID: 12, Precision@10: 0.2
Query ID: 13, Recall@10: 0.0
Query ID: 13, Precision@10: 0.0
Query ID: 14, Recall@10: 0.0
Query ID: 14, Precision@10: 0.0
Query ID: 15, Recall@10: 0.0038461538461538464

In [10]:

def calculate_MAP(query_id):
    relevant_docs = []
    for qrel in dataset.qrels_iter():
        if qrel[0] == query_id and qrel[2] > 0:
            relevant_docs.append(qrel[1])

    retrieved_docs = []
    for query in dataset.queries_iter():
        if query[0] == query_id:
            retrieved_docs = getRetrievedQueries(query[1])
            break

    pk_sum = 0
    total_relevant = 0
    for i in range(1, 11):
        relevant_ret = 0
        for j in range(i):
            if j < len(retrieved_docs) and retrieved_docs[j] in relevant_docs:
                relevant_ret += 1
        p_at_k = (relevant_ret / i) * (1 if i - 1 < len(retrieved_docs) and retrieved_docs[i - 1] in relevant_docs else 0)
        pk_sum += p_at_k
        if i - 1 < len(retrieved_docs) and retrieved_docs[i - 1] in relevant_docs:
            total_relevant += 1

    return 0 if total_relevant == 0 else pk_sum / total_relevant

queries_ids = {qrel[0]: '' for qrel in dataset.qrels_iter()}

map_sum = 0
for query_id in list(queries_ids.keys()):
    map_sum += calculate_MAP(query_id)

print(f"Mean Average Precision (MAP@10): {map_sum / len(queries_ids)}")


Mean Average Precision (MAP@10): 0.2962722222222222
