In [2]:
from spellchecker import SpellChecker
import ir_datasets
import nltk
from nltk.corpus import stopwords, wordnet
from nltk.tokenize import word_tokenize, sent_tokenize
from nltk.stem import PorterStemmer, LancasterStemmer, WordNetLemmatizer
from collections import Counter
from num2words import num2words
import os
import string
import numpy as np
import copy
# import pandas as pd
import pickle
import re
import math
import unicodedata
import contractions
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
from typing import List

# %load_ext autotime
dataset = ir_datasets.load("beir/webis-touche2020/v2")

In [3]:


def convert_lower_case(data):
    return np.char.lower(data)


def remove_stop_words(data):
    stop_words = stopwords.words('english')
    words = word_tokenize(str(data))
    new_text = ""
    for w in words:
        if w not in stop_words and len(w) > 1:
            new_text = new_text + " " + w
    return new_text


def remove_punctuation(data):
    symbols = "!\"#$%&()*+-./:;<=>?@[\]^_`{|}~\n"
    for i in range(len(symbols)):
        data = np.char.replace(data, symbols[i], ' ')
        data = np.char.replace(data, "  ", " ")
    data = np.char.replace(data, ',', '')
    return data

 
def remove_apostrophe(data):
    return np.char.replace(data, "'", "")

def stemming(data):
    stemmer= PorterStemmer()

    tokens = word_tokenize(str(data))
    new_text = ""
    for w in tokens:
        new_text = new_text + " " + stemmer.stem(w)
    return new_text

def nltk_tag_to_wordnet_tag(nltk_tag):
    if nltk_tag.startswith('J'):
        return wordnet.ADJ
    elif nltk_tag.startswith('V'):
        return wordnet.VERB
    elif nltk_tag.startswith('N'):
        return wordnet.NOUN
    elif nltk_tag.startswith('R'):
        return wordnet.ADV
    else:
        return None

def lemmatize(sentence):
    lemmatizer = WordNetLemmatizer()
    #tokenize the sentence and find the POS tag for each token
    nltk_tagged = nltk.pos_tag(word_tokenize(str(sentence)))
    #tuple of (token, wordnet_tag)
    wordnet_tagged = map(lambda x: (x[0], nltk_tag_to_wordnet_tag(x[1])), nltk_tagged)
    lemmatized_sentence = []
    for word, tag in wordnet_tagged:
        if tag is None:
            #if there is no available tag, append the token as is
            lemmatized_sentence.append(word)
        else:
            #else use the tag to lemmatize the token
            lemmatized_sentence.append(lemmatizer.lemmatize(word, tag))
    return " ".join(lemmatized_sentence)

def convert_numbers(data):
    tokens = word_tokenize(str(data))
    new_text = ""
    for w in tokens:
        try:
            w = num2words(int(w))
        except:
            a = 0
        new_text = new_text + " " + w
    new_text = np.char.replace(new_text, "-", " ")
    return new_text

def remove_urls(data):
    cleaned_text = re.sub(r'/(https:\/\/www\.|http:\/\/www\.|https:\/\/|http:\/\/)?[a-zA-Z]{2,}(\.[a-zA-Z]{2,})(\.[a-zA-Z]{2,})?\/[a-zA-Z0-9]{2,}|((https:\/\/www\.|http:\/\/www\.|https:\/\/|http:\/\/)?[a-zA-Z]{2,}(\.[a-zA-Z]{2,})(\.[a-zA-Z]{2,})?)|(https:\/\/www\.|http:\/\/www\.|https:\/\/|http:\/\/)?[a-zA-Z0-9]{2,}\.[a-zA-Z0-9]{2,}\.[a-zA-Z0-9]{2,}(\.[a-zA-Z0-9]{2,})?/g', '', data)
    return cleaned_text;

def replace_contractions(data):
    return " ".join(contractions.fix(data))

def correct_sentence_spelling(data):
    tokens = word_tokenize(str(data))
    spell = SpellChecker()
    misspelled = spell.unknown(tokens)
    for i, token in enumerate(tokens):
        if token in misspelled:
            corrected = spell.correction(token)
            if corrected is not None:
                tokens[i] = corrected
    return " ".join( tokens)



def replace_country_symbols(data):
   
    country_symbols = {
    'AF': 'Afghanistan',
    'AX': 'Åland Islands',
    'AL': 'Albania',
    'DZ': 'Algeria',
    'AS': 'American Samoa',
    'AD': 'Andorra',
    'AO': 'Angola',
    'AI': 'Anguilla',
    'AQ': 'Antarctica',
    'AG': 'Antigua and Barbuda',
    'AR': 'Argentina',
    'AM': 'Armenia',
    'AW': 'Aruba',
    'AU': 'Australia',
    'AT': 'Austria',
    'AZ': 'Azerbaijan',
    'BS': 'Bahamas',
    'BH': 'Bahrain',
    'BD': 'Bangladesh',
    'BB': 'Barbados',
    'BY': 'Belarus',
    'BE': 'Belgium',
    'BZ': 'Belize',
    'BJ': 'Benin',
    'BM': 'Bermuda',
    'BT': 'Bhutan',
    'BO': 'Bolivia',
    'BQ': 'Bonaire, Sint Eustatius and Saba',
    'BA': 'Bosnia and Herzegovina',
    'BW': 'Botswana',
    'BV': 'Bouvet Island',
    'BR': 'Brazil',
    'IO': 'British Indian Ocean Territory',
    'BN': 'Brunei Darussalam',
    'BG': 'Bulgaria',
    'BF': 'Burkina Faso',
    'BI': 'Burundi',
    'CV': 'Cabo Verde',
    'KH': 'Cambodia',
    'CM': 'Cameroon',
    'CA': 'Canada',
    'KY': 'Cayman Islands',
    'CF': 'Central African Republic',
    'TD': 'Chad',
    'CL': 'Chile',
    'CN': 'China',
    'CX': 'Christmas Island',
    'CC': 'Cocos (Keeling) Islands',
    'CO': 'Colombia',
    'KM': 'Comoros',
    'CG': 'Congo',
    'CD': 'Congo, Democratic Republic of the',
    'CK': 'Cook Islands',
    'CR': 'Costa Rica',
    'CI': 'Côte d\'Ivoire',
    'HR': 'Croatia',
    'CU': 'Cuba',
    'CW': 'Curaçao',
    'CY': 'Cyprus',
    'CZ': 'Czech Republic',
    'DK': 'Denmark',
    'DJ': 'Djibouti',
    'DM': 'Dominica',
    'DO': 'Dominican Republic',
    'EC': 'Ecuador',
    'EG': 'Egypt',
    'SV': 'El Salvador',
    'GQ': 'Equatorial Guinea',
    'ER': 'Eritrea',
    'EE': 'Estonia',
    'SZ': 'Eswatini',
    'ET': 'Ethiopia',
    'FK': 'Falkland Islands (Malvinas)',
    'FO': 'Faroe Islands',
    'FJ': 'Fiji',
    'FI': 'Finland',
    'FR': 'France',
    'GF': 'French Guiana',
    'PF': 'French Polynesia',
    'TF': 'French Southern Territories',
    'GA': 'Gabon',
    'GM': 'Gambia',
    'GE': 'Georgia',
    'DE': 'Germany',
    'GH': 'Ghana',
    'GI': 'Gibraltar',
    'GR': 'Greece',
    'GL': 'Greenland',
    'GD': 'Grenada',
    'GP': 'Guadeloupe',
    'GU': 'Guam',
    'GT': 'Guatemala',
    'GG': 'Guernsey',
    'GN': 'Guinea',
    'GW': 'Guinea-Bissau',
    'GY': 'Guyana',
    'HT': 'Haiti',
    'HM': 'Heard Island and McDonald Islands',
    'VA': 'Holy See',
    'HN': 'Honduras',
    'HK': 'Hong Kong',
    'HU': 'Hungary',
    'IS': 'Iceland',
    'IN': 'India',
    'ID': 'Indonesia',
    'IR': 'Iran, Islamic Republic of',
    'IQ': 'Iraq',
    'IE': 'Ireland',
    'IM': 'Isle of Man',
    'IL': 'Israel',
    'IT': 'Italy',
    'JM': 'Jamaica',
    'JP': 'Japan',
    'JE': 'Jersey',
    'JO': 'Jordan',
    'KZ': 'Kazakhstan',
    'KE': 'Kenya',
    'KI': 'Kiribati',
    'KP': "Korea, Democratic People's Republic of",
    'KR': 'Korea, Republic of',
    'KW': 'Kuwait',
    'KG': 'Kyrgyzstan',
    'LA': "Lao People's Democratic Republic",
    'LV': 'Latvia',
    'LB': 'Lebanon',
    'LS': 'Lesotho',
    'LR': 'Liberia',
    'LY': 'Libya',
    'LI': 'Liechtenstein',
    'LT': 'Lithuania',
    'LU': 'Luxembourg',
    'MO': 'Macao',
    'MG': 'Madagascar',
    'MW': 'Malawi',
    'MY': 'Malaysia',
    'MV': 'Maldives',
    'ML': 'Mali',
    'MT': 'Malta',
    'MH': 'Marshall Islands',
    'MQ': 'Martinique',
    'MR': 'Mauritania',
    'MU': 'Mauritius',
    'YT': 'Mayotte',
    'MX': 'Mexico',
    'FM': 'Micronesia, Federated States of',
    'MD': 'Moldova, Republic of',
    'MC': 'Monaco',
    'MN': 'Mongolia',
    'ME': 'Montenegro',
    'MS': 'Montserrat',
    'MA': 'Morocco',
    'MZ': 'Mozambique',
    'MM': 'Myanmar',
    'NA': 'Namibia',
    'NR': 'Nauru',
    'NP': 'Nepal',
    'NL': 'Netherlands',
    'NC': 'New Caledonia',
    'NZ': 'New Zealand',
    'NI': 'Nicaragua',
    'NE': 'Niger',
    'NG': 'Nigeria',
    'NU': 'Niue',
    'NF': 'Norfolk Island',
    'MK': 'North Macedonia',
    'MP': 'Northern Mariana Islands',
    'NO': 'Norway',
    'OM': 'Oman',
    'PK': 'Pakistan',
    'PW': 'Palau',
    'PS': 'Palestine, State of',
    'PA': 'Panama',
    'PG': 'Papua New Guinea',
    'PY': 'Paraguay',
    'PE': 'Peru',
    'PH': 'Philippines',
    'PN': 'Pitcairn',
    'PL': 'Poland',
    'PT': 'Portugal',
    'PR': 'Puerto Rico',
    'QA': 'Qatar',
    'RE': 'Réunion',
    'RO': 'Romania',
    'RU': 'Russian Federation',
    'RW': 'Rwanda',
    'BL': 'Saint Barthélemy',
    'SH': 'Saint Helena, Ascension and Tristan da Cunha',
    'KN': 'Saint Kitts and Nevis',
    'LC': 'Saint Lucia',
    'MF': 'Saint Martin (French part)',
    'PM': 'Saint Pierre and Miquelon',
    'VC': 'Saint Vincent and the Grenadines',
    'WS': 'Samoa',
    'SM': 'San Marino',
    'ST': 'Sao Tome and Principe',
    'SA': 'Saudi Arabia',
    'SN': 'Senegal',
    'RS': 'Serbia',
    'SC': 'Seychelles',
    'SL': 'Sierra Leone',
    'SG': 'Singapore',
    'SX': 'Sint Maarten (Dutch part)',
    'SK': 'Slovakia',
    'SI': 'Slovenia',
    'SB': 'Solomon Islands',
    'SO': 'Somalia',
    'ZA': 'South Africa',
    'GS': 'South Georgia and the South Sandwich Islands',
    'SS': 'South Sudan',
    'ES': 'Spain',
    'LK': 'Sri Lanka',
    'SD': 'Sudan',
    'SR': 'Suriname',
    'SJ': 'Svalbard and Jan Mayen',
    'SE': 'Sweden',
    'CH': 'Switzerland',
    'SY': 'Syrian Arab Republic',
    'TW': 'Taiwan, Province of China',
    'TJ': 'Tajikistan',
    'TZ': 'Tanzania, United Republic of',
    'TH': 'Thailand',
    'TL': 'Timor-Leste',
    'TG': 'Togo',
    'TK': 'Tokelau',
    'TO': 'Tonga',
    'TT': 'Trinidad and Tobago',
    'TN': 'Tunisia',
    'TR': 'Turkey',
    'TM': 'Turkmenistan',
    'TC': 'Turks and Caicos Islands',
    'TV': 'Tuvalu',
    'UG': 'Uganda',
    'UA': 'Ukraine',
    'AE': 'United Arab Emirates',
    'GB': 'United Kingdom',
    'US': 'United States',
    'UM': 'United States Minor Outlying Islands',
    'UY': 'Uruguay',
    'UZ': 'Uzbekistan',
    'VU': 'Vanuatu',
    'VE': 'Venezuela, Bolivarian Republic of',
    'VN': 'Viet Nam',
    'VG': 'Virgin Islands, British',
    'VI': 'Virgin Islands, U.S.',
    'WF': 'Wallis and Futuna',
    'EH': 'Western Sahara',
    'YE': 'Yemen',
    'ZM': 'Zambia',
    'ZW': 'Zimbabwe'
    }
    words = word_tokenize(data)
    replaced_words = [country_symbols[word] if word in country_symbols else word for word in words]
    return " ".join(replaced_words)

def custom_tokenizer(text: str) -> List[str]:
    tokens = word_tokenize(text.lower())
    return tokens

def preprocess(data):
    data = replace_country_symbols(data)
    data = remove_urls(data)
    data = convert_numbers(data)
    data = convert_lower_case(data)
    data = remove_punctuation(data) 
    data = remove_apostrophe(data)
    data = remove_stop_words(data)
    return data

In [6]:
corpus = {}
for doc in dataset.docs_iter():
    corpus[doc.doc_id] =doc.title + " " + doc.text + " " +doc.stance  +" "+ doc.url
  
documents = list(corpus.values())



In [9]:
doc_ids = []
titles = []
texts = []
stances = []
urls = []

for doc in dataset.docs_iter()[:100]:

        doc_ids.append(doc.doc_id)
        titles.append(doc.title)
        texts.append(doc.text)
        stances.append(doc.stance)
        urls.append(doc.url)
 

# Define the corpus as a dictionary with doc_id as the key
corpus = {
    doc_ids[i]: {
        'title': titles[i],
        'text': texts[i],
        'stance': stances[i],
        'url': urls[i]
    } for i in range(len(doc_ids))
}

In [4]:
import pandas as pd
from sklearn.feature_extraction.text import TfidfVectorizer

class MockQuery:
    def __init__(self, query_id, text, description, narrative):
        self.query_id = query_id
        self.text = text
        self.description = description
        self.narrative = narrative


# Creating corpus_query dictionary
corpus_query = {}
for doc in dataset.queries_iter():
    corpus_query[doc.query_id] = doc.text + " " + doc.description + " " + doc.narrative

# Ensure the documents_query is a list of strings
documents_query = list(corpus_query.values())

# Vectorizer setup
vectorizer_query = TfidfVectorizer(tokenizer=custom_tokenizer, preprocessor=preprocess)
tfidf_matrix_query = vectorizer_query.fit_transform(documents_query)

df = pd.DataFrame.sparse.from_spmatrix(
    tfidf_matrix_query, 
    columns=vectorizer_query.get_feature_names_out(), 
    index=corpus_query.keys()
)

tfidf_model_query = vectorizer_query





In [8]:
def process_query(query: str, tfidf_model, tfidf_matrix):
    query_tfidf = tfidf_model.transform([query])
    cosine_similarities = cosine_similarity(query_tfidf, tfidf_matrix).flatten()
    ranked_doc_indices = cosine_similarities.argsort()[::-1]
    return ranked_doc_indices, cosine_similarities

def getRetrievedQueries(query: str, k=10):
    preprocessed_query = preprocess(query)
    ranked_indices, _ = process_query(preprocessed_query, tfidf_model_query, tfidf_matrix_query)
    idsList = []
    for idx in ranked_indices[:k]:
        doc_id = list(corpus_query.keys())[idx]
        idsList.append(doc_id)
    return idsList


query="should"
suggestion_query_id=getRetrievedQueries(query, k=5)

soso={}
for doc in dataset.queries_iter():
    soso[doc.query_id] = doc.text 
# Retrieve the query texts for the given suggestion_query_id list
retrieved_queries = [soso[query_id] for query_id in suggestion_query_id if query_id in soso]
for query_id, query_text in zip(suggestion_query_id, retrieved_queries):
    print(query_text) 
    


Should everyone get a universal basic income?
Does lowering the federal corporate income tax rate create jobs?
Is a two-state solution an acceptable solution to the Israeli-Palestinian conflict?
Is human activity primarily responsible for global climate change?
Is drinking milk healthy for humans?


In [51]:

for doc in dataset.queries_iter():
    print(doc)

BeirToucheQuery(query_id='1', text='Should teachers get tenure?', description="A user has heard that some countries do give teachers tenure and others don't. Interested in the reasoning for or against tenure, the user searches for positive and negative arguments. The situation of school teachers vs. university professors is of interest.", narrative="Highly relevant arguments make a clear statement about tenure for teachers in schools or universities. Relevant arguments consider tenure more generally, not specifically for teachers, or, instead of talking about tenure, consider the situation of teachers' financial independence.")
BeirToucheQuery(query_id='2', text='Is vaping with e-cigarettes safe?', description='When considering to switch from smoking to vaping, a user wonders to what extent vaping is safer and what new risks may be involved. Compared to smoking, where the risks are clear, vaping is marketed to have only benefits, and this raises doubts.', narrative='Highly relevant arg

In [10]:
# Preprocess the corpus
preprocessed_titles = [preprocess(title) for title in titles]
preprocessed_texts = [preprocess(text) for text in texts]
preprocessed_stances = [preprocess(description) for description in stances]
preprocessed_urls = [preprocess(url) for url in urls]

# Save preprocessed data to a file
preprocessed_corpus = {
    'titles': preprocessed_titles,
    'texts': preprocessed_texts,
    'stances': preprocessed_stances,
    'urls': preprocessed_urls
}

# Combine all texts into a single list for fitting the vectorizer
all_texts = preprocessed_corpus['titles'] + preprocessed_corpus['texts'] + preprocessed_corpus['stances'] + preprocessed_corpus['urls']

# Fit the vectorizer on the combined corpus
vectorizer = TfidfVectorizer(tokenizer=custom_tokenizer)
vectorizer.fit(all_texts)

# Transform each column separately
tfidf_titles = vectorizer.transform(preprocessed_corpus['titles'])
tfidf_texts = vectorizer.transform(preprocessed_corpus['texts'])
tfidf_stances = vectorizer.transform(preprocessed_corpus['stances'])
tfidf_urls = vectorizer.transform(preprocessed_corpus['urls'])

# Define weights for each column
weights = {
    'title': 4,
    'text': 5,
    'stance': 3,
    'url': 2
}

# Combine vectors with weights
tfidf_matrix = (
    weights['title'] * tfidf_titles +
    weights['text'] * tfidf_texts +
    weights['stance'] * tfidf_stances +
    weights['url'] * tfidf_urls
)/( weights['title'] + weights['text'] + weights['stance'] +  weights['url'] )
tfidf_model = vectorizer



In [63]:
# # Vectorizer setup
# vectorizer = TfidfVectorizer(tokenizer=custom_tokenizer, preprocessor=preprocess)
# tfidf_matrix = vectorizer.fit_transform(documents)
# df = pd.DataFrame.sparse.from_spmatrix(tfidf_matrix, columns=vectorizer.get_feature_names_out(), index=corpus.keys())
# tfidf_model = vectorizer


# Save and load functions for TF-IDF data
def save_file(file_location: str, content):
    if os.path.exists(file_location):
        os.remove(file_location)
    with open(file_location, 'wb') as handle:
        pickle.dump(content, handle, protocol=pickle.HIGHEST_PROTOCOL)



def load_file(file_location: str):
    with open(file_location, 'rb') as handle:
        content = pickle.load(handle)
    return content




def save_tfidf_data(tfidf_matrix, tfidf_model):
    save_file(os.path.join("D:\ir-search-engine\storage", f"beir_tfidf_matrix.pickle"), tfidf_matrix)
    save_file(os.path.join("D:\ir-search-engine\storage", f"beir_tfidf_model.pickle"), tfidf_model)


save_tfidf_data(tfidf_matrix, tfidf_model)



def process_query(query: str, tfidf_model, tfidf_matrix):
    query_tfidf = tfidf_model.transform([query])
    cosine_similarities = cosine_similarity(query_tfidf, tfidf_matrix).flatten()
    ranked_doc_indices = cosine_similarities.argsort()[::-1]
    return ranked_doc_indices, cosine_similarities



tfidf_matrix = load_file("D:\ir-search-engine\storage\\beir_tfidf_matrix.pickle")
tfidf_model = load_file("D:\ir-search-engine\storage\\beir_tfidf_model.pickle")





def getRetrievedQueries(query: str, k=10):
    preprocessed_query = preprocess(query)
    ranked_indices, _ = process_query(preprocessed_query, tfidf_model, tfidf_matrix)
    idsList = []
    for idx in ranked_indices[:k]:
        doc_id = list(corpus.keys())[idx]
        idsList.append(doc_id)
    return idsList




def calculate_recall_precision(query_id):
    relevant_docs = []
    for qrel in dataset.qrels_iter():
        if qrel[0] == query_id and qrel[2] > 0:
            relevant_docs.append(qrel[1])

    retrieved_docs = []
    for query in dataset.queries_iter():
        if query[0] == query_id:
            retrieved_docs = getRetrievedQueries(query[1])
            break  

    y_true = [1 if doc_id in relevant_docs else 0 for doc_id in retrieved_docs]
    true_positives = sum(y_true)
    recall_at_10 = true_positives / len(relevant_docs) if relevant_docs else 0
    precision_at_10 = true_positives / 10
    print(f"Query ID:  {query_id}, Recall@10: {recall_at_10}")
    print(f"Query ID: {query_id}, Precision@10: {precision_at_10}")    
    return recall_at_10



queries_ids = {qrel[0]: '' for qrel in dataset.qrels_iter()}



for query_id in list(queries_ids.keys()):
    calculate_recall_precision(query_id)



def calculate_MAP(query_id):
    relevant_docs = []
    for qrel in dataset.qrels_iter():
        if qrel[0] == query_id and qrel[2] > 0:
            relevant_docs.append(qrel[1])

    retrieved_docs = []
    for query in dataset.queries_iter():
        if query[0] == query_id:
            retrieved_docs = getRetrievedQueries(query[1])
            break

    pk_sum = 0

    total_relevant = 0

    for i in range(1, 11):
        relevant_ret = 0

        for j in range(i):
            if j < len(retrieved_docs) and retrieved_docs[j] in relevant_docs:
                relevant_ret += 1
        p_at_k = (relevant_ret / i) * (1 if i - 1 < len(retrieved_docs) and retrieved_docs[i - 1] in relevant_docs else 0)

        pk_sum += p_at_k

        if i - 1 < len(retrieved_docs) and retrieved_docs[i - 1] in relevant_docs:
            total_relevant += 1

    return 0 if total_relevant == 0 else pk_sum / total_relevant


queries_ids = {qrel[0]: '' for qrel in dataset.qrels_iter()}

map_sum = 0

for query_id in list(queries_ids.keys()):
    map_sum += calculate_MAP(query_id)



print(f"Mean Average Precision (MAP@10): {map_sum / len(queries_ids)}")



Query ID:  1, Recall@10: 0.5454545454545454
Query ID: 1, Precision@10: 0.6
Query ID:  2, Recall@10: 0.0
Query ID: 2, Precision@10: 0.0
Query ID:  3, Recall@10: 0.0
Query ID: 3, Precision@10: 0.0
Query ID:  4, Recall@10: 0.15
Query ID: 4, Precision@10: 0.3
Query ID:  5, Recall@10: 0.0625
Query ID: 5, Precision@10: 0.1
Query ID:  6, Recall@10: 0.07692307692307693
Query ID: 6, Precision@10: 0.1
Query ID:  7, Recall@10: 0.043478260869565216
Query ID: 7, Precision@10: 0.1
Query ID:  8, Recall@10: 0.0
Query ID: 8, Precision@10: 0.0
Query ID:  9, Recall@10: 0.0625
Query ID: 9, Precision@10: 0.2
Query ID:  10, Recall@10: 0.0
Query ID: 10, Precision@10: 0.0
Query ID:  11, Recall@10: 0.13636363636363635
Query ID: 11, Precision@10: 0.3
Query ID:  12, Recall@10: 0.23076923076923078
Query ID: 12, Precision@10: 0.3
Query ID:  13, Recall@10: 0.1111111111111111
Query ID: 13, Precision@10: 0.2
Query ID:  14, Recall@10: 0.10526315789473684
Query ID: 14, Precision@10: 0.2
Query ID:  15, Recall@10: 0.0740

In [64]:
def calculate_MRR(query_id):
    relevant_docs = []
    for qrel in dataset.qrels_iter():
        if qrel[0] == query_id and qrel[2] > 0:
            relevant_docs.append(qrel[1])

    retrieved_docs = []
    for query in dataset.queries_iter():
        if query[0] == query_id:
            retrieved_docs = getRetrievedQueries(query[1])
            break

    for i in range(1, 11):
        if retrieved_docs[i-1] in relevant_docs:
            return 1 / i
      

    return 0


queries_list = list(queries_ids.keys())
mrr_sum = 0
for query_id in queries_list:
    mrr_sum += calculate_MRR(query_id)
print(f"Mean Reciprocal Rank (MRR): {(1 / len(queries_list)) * mrr_sum}")


Mean Reciprocal Rank (MRR): 0.47358276643990926
