In [4]:
import pandas as pd
import string
from nltk.tokenize import word_tokenize
from nltk.corpus import stopwords, wordnet
from nltk.stem import WordNetLemmatizer, PorterStemmer
from nltk import pos_tag
from spellchecker import SpellChecker
from sklearn.feature_extraction.text import TfidfVectorizer
from typing import List
import re
import numpy as np
import os 
import pickle
# Load dataset and corpus
import ir_datasets
from sklearn.metrics.pairwise import cosine_similarity

dataset = ir_datasets.load("antique/train")


df = pd.read_csv('collection.tsv', sep='\t', header=None, names=['doc_id', 'text'])

# Build the corpus dictionary
corpus = {}
counter=1
for index, row in df.iterrows():
    
    if  len(row['text']) >0:
        if isinstance(row['text'], str):
            corpus[row['doc_id']] = row['text']
        else: 
            corpus[row['doc_id']] = ""
    counter+=1
    if  counter>=9:
           break
    

# Convert the corpus to a list of documents and handle NaN values
documents = list(corpus.values())




In [5]:

def normalize_country_names(text):

    # List of country names and their variations
    country_names = {
    "uae": "united arab emirates", "u.a.e": "united arab emirates",
    "cn": "china", "china": "china",
    "sy": "syria", "syria": "syria",
    "usa": "united states of america", "u.s.a": "united states of america", 
    "us": "united states of america", "u.s.": "united states of america",
    "uk": "united kingdom", "u.k.": "united kingdom", "united kingdom": "united kingdom",
    "england": "united kingdom", "gb": "united kingdom", "g.b.": "united kingdom",
    "great britain": "united kingdom", "fr": "france", "france": "france",
    "de": "germany", "germany": "germany", "deutschland": "germany",
    "jp": "japan", "japan": "japan", "it": "italy", "italy": "italy",
    "itália": "italy", "es": "spain", "spain": "spain", "españa": "spain",
    "ru": "russia", "russia": "russia", "россия": "russia", "in": "india",
    "india": "india", "br": "brazil", "brazil": "brazil", "brasil": "brazil",
    "au": "australia", "australia": "australia", "ca": "canada", "canada": "canada",
    "mx": "mexico", "mexico": "mexico", "méxico": "mexico", "za": "south africa",
    "south africa": "south africa", "southafrica": "south africa", "kr": "south korea",
    "south korea": "south korea", "southkorea": "south korea", "sa": "saudi arabia",
    "saudi arabia": "saudi arabia", "ksa": "saudi arabia", "kingdom of saudi arabia": "saudi arabia",
    "tr": "turkey", "turkey": "turkey", "trkiye": "turkey", "ch": "switzerland",
    "switzerland": "switzerland", "suisse": "switzerland", "chile": "chile", 
    "pt": "portugal", "portugal": "portugal", "pl": "poland", "poland": "poland",
    "polska": "poland", "eg": "egypt", "egypt": "egypt", "egito": "egypt",
    "ng": "nigeria", "nigeria": "nigeria", "nigéria": "nigeria", "ar": "argentina",
    "argentina": "argentina", "gr": "greece", "greece": "greece", "ellada": "greece",
    "se": "sweden", "sweden": "sweden", "sverige": "sweden", "no": "norway",
    "norway": "norway", "norge": "norway", "fi": "finland", "finland": "finland",
    "suomi": "finland", "nl": "netherlands", "netherlands": "netherlands", 
    "holland": "netherlands", "vn": "vietnam", "vietnam": "vietnam", "hk": "hong kong",
    "hong kong": "hong kong", "ir": "iran", "iran": "iran", "iq": "iraq", 
    "iraq": "iraq", "ph": "philippines", "philippines": "philippines", "pk": "pakistan",
    "pakistan": "pakistan", "th": "thailand", "thailand": "thailand", "my": "malaysia",
    "malaysia": "malaysia", "id": "indonesia", "indonesia": "indonesia", 
    "bd": "bangladesh", "bangladesh": "bangladesh", "af": "afghanistan",
    "afghanistan": "afghanistan", "il": "israel", "israel": "israel", "at": "austria",
    "austria": "austria", "be": "belgium", "belgium": "belgium", "cl": "chile",
    "co": "colombia", "colombia": "colombia", "cz": "czech republic",
    "czech republic": "czech republic", "dk": "denmark", "denmark": "denmark",
    "hu": "hungary", "hungary": "hungary", "is": "iceland", "iceland": "iceland",
    "ie": "ireland", "ireland": "ireland", "ke": "kenya", "kenya": "kenya", 
    "lt": "lithuania", "lithuania": "lithuania", "lu": "luxembourg", 
    "luxembourg": "luxembourg", "mt": "malta", "malta": "malta", "ma": "morocco",
    "morocco": "morocco", "nz": "new zealand", "new zealand": "new zealand", 
    "pe": "peru", "peru": "peru", "ro": "romania", "romania": "romania", 
    "sg": "singapore", "singapore": "singapore", "sk": "slovakia", 
    "slovakia": "slovakia", "tw": "taiwan", "taiwan": "taiwan", "ua": "ukraine", 
    "ukraine": "ukraine", "ve": "venezuela", "venezuela": "venezuela"
}
    return country_names.get(text, text)

In [6]:

# Custom tokenizer
def custom_tokenizer(text: str) -> list[str]:
    tokens = word_tokenize(text.lower())
    return tokens

def get_wordnet_pos(tag):
    tag = tag[0].upper()
    tag_dict = {"J": wordnet.ADJ, "N": wordnet.NOUN, "V": wordnet.VERB, "R": wordnet.ADV}
    return tag_dict.get(tag, wordnet.NOUN)

def correct_sentence_spelling(tokens):
    spell = SpellChecker()
    misspelled = spell.unknown(tokens)
    for i, token in enumerate(tokens):
        if token in misspelled:
            corrected = spell.correction(token)
            if corrected is not None:
                tokens[i] = corrected
    return tokens

def remove_punctuation(tokens: List[str]) -> List[str]:
    """Removes punctuation from tokens."""
    return [token.translate(str.maketrans('', '', string.punctuation)) for token in tokens]

def remove_apostrophe(tokens):
    new_tokens = []
    for token in tokens:
        new_tokens.append(str(np.char.replace(token, "'", " ")))
    return new_tokens

def remove_markers(tokens):
    new_tokens = []
    for token in tokens:
        new_tokens.append(re.sub(r'\u00AE', '', token))
    return new_tokens
def remove_links(text):
    url_pattern = re.compile(r"http[s]?://\S+|www\.\S+")
    cleaned_text = re.sub(url_pattern, "", text)
    
    return cleaned_text  
    

def remove_apostrophe(tokens: List[str]) -> List[str]:
    """Removes apostrophes from tokens."""
    return [token.replace("'", " ") for token in tokens]

def replace_under_score_with_space(tokens: List[str]) -> List[str]:
    """Replaces underscores with spaces in tokens."""
    return [re.sub(r'_', ' ', token) for token in tokens]
def preprocess_text(text: str) -> str:
    """Preprocesses the input text by tokenizing, removing punctuation, stopwords, and then stemming and lemmatizing."""
    #remove links
    text=remove_links(text)
        
    # Convert text to lowercase and tokenize
    text = text.lower()
    words = word_tokenize(text)
    
    # Remove punctuation
    words = [word.translate(str.maketrans('', '', string.punctuation)) for word in words]
    
    # Remove stopwords
    stop_words = set(stopwords.words('english'))
    words = [word for word in words if word not in stop_words]
    
    # Correct spelling
    #words = correct_sentence_spelling(words)
    
    # Further token cleaning
    words = remove_markers(words)
    words = replace_under_score_with_space(words)
    words = remove_apostrophe(words)
    words = [normalize_country_names(word) for word in words]
    # Stemming and Lemmatization
    stemmer = PorterStemmer()
    words = [stemmer.stem(word) for word in words]
    
    pos_tags = pos_tag(words)
    lemmatizer = WordNetLemmatizer()
    words = [lemmatizer.lemmatize(word, pos=get_wordnet_pos(tag)) for word, tag in pos_tags]
    
    return ' '.join(words)


In [7]:
# Vectorizer setup
vectorizer = TfidfVectorizer(tokenizer=custom_tokenizer, preprocessor=preprocess_text)
tfidf_matrix = vectorizer.fit_transform(documents)
tfidf_model = vectorizer
print("TF-IDF DataFrame created successfully.")



TF-IDF DataFrame created successfully.


In [8]:

# Save and load functions for TF-IDF data
def save_file(file_location: str, content):
    if os.path.exists(file_location):
        os.remove(file_location)
    with open(file_location, 'wb') as handle:
        pickle.dump(content, handle, protocol=pickle.HIGHEST_PROTOCOL)

def load_file(file_location: str):
    with open(file_location, 'rb') as handle:
        content = pickle.load(handle)
    return content

def save_tfidf_data(tfidf_matrix, tfidf_model):
    save_file("tfidf_matrix.pickle", tfidf_matrix)
    save_file("tfidf_model.pickle", tfidf_model)


save_tfidf_data(tfidf_matrix, tfidf_model)

In [9]:
def process_query(query: str, tfidf_model, tfidf_matrix):
    query_tfidf = tfidf_model.transform([query])
    cosine_similarities = cosine_similarity(query_tfidf, tfidf_matrix).flatten()
    ranked_doc_indices = cosine_similarities.argsort()[::-1]
    return ranked_doc_indices, cosine_similarities

tfidf_matrix = load_file("tfidf_matrix.pickle")
tfidf_model = load_file("tfidf_model.pickle")

def getRetrievedQueries(query: str, k=10):
    preprocessed_query = preprocess_text(query)
    ranked_indices, _ = process_query(preprocessed_query, tfidf_model, tfidf_matrix)
    idsList = []
    for idx in ranked_indices[:k]:
        doc_id = list(corpus.keys())[idx]
        idsList.append(doc_id)
    return idsList

def calculate_recall_precision(query_id):
    relevant_docs = []
    for qrel in dataset.qrels_iter():
        if qrel[0] == query_id and qrel[2] > 0:
            relevant_docs.append(qrel[1])

    retrieved_docs = []
    for query in dataset.queries_iter():
        if query[0] == query_id:
            retrieved_docs = getRetrievedQueries(query[1])
            break  

    y_true = [1 if doc_id in relevant_docs else 0 for doc_id in retrieved_docs]
    true_positives = sum(y_true)
    recall_at_10 = true_positives / len(relevant_docs) if relevant_docs else 0
    precision_at_10 = true_positives / 10
    print(f"Query ID: {query_id}, Recall@10: {recall_at_10}")
    print(f"Query ID: {query_id}, Precision@10: {precision_at_10}")    
    return recall_at_10
queries_ids = {}
for qrel in dataset.qrels_iter():
    queries_ids.update({qrel[0]: ''})
    
for query_id in list(queries_ids.keys()):
    calculate_recall_precision(query_id)

[INFO] Please confirm you agree to the authors' data usage agreement found at <https://ciir.cs.umass.edu/downloads/Antique/readme.txt>
[INFO] [starting] https://ciir.cs.umass.edu/downloads/Antique/antique-test.qrel
[INFO] [finished] https://ciir.cs.umass.edu/downloads/Antique/antique-test.qrel: [00:05] [150kB] [26.2kB/s]
[INFO] [starting] https://ciir.cs.umass.edu/downloads/Antique/antique-test-queries.txt   
[INFO] [finished] https://ciir.cs.umass.edu/downloads/Antique/antique-test-queries.txt: [00:00] [11.4kB] [32.7kB/s]
                                                                                                 

Query ID: 1964316, Recall@10: 0.0
Query ID: 1964316, Precision@10: 0.0
Query ID: 2418598, Recall@10: 0.0
Query ID: 2418598, Precision@10: 0.0
Query ID: 1167882, Recall@10: 0.0
Query ID: 1167882, Precision@10: 0.0
Query ID: 1880028, Recall@10: 0.0
Query ID: 1880028, Precision@10: 0.0
Query ID: 2192891, Recall@10: 0.0
Query ID: 2192891, Precision@10: 0.0
Query ID: 949154, Recall@10: 0.0
Query ID: 949154, Precision@10: 0.0
Query ID: 1844896, Recall@10: 0.0
Query ID: 1844896, Precision@10: 0.0
Query ID: 2634143, Recall@10: 0.0
Query ID: 2634143, Precision@10: 0.0
Query ID: 2382487, Recall@10: 0.0
Query ID: 2382487, Precision@10: 0.0
Query ID: 229303, Recall@10: 0.0
Query ID: 229303, Precision@10: 0.0
Query ID: 1015624, Recall@10: 0.0
Query ID: 1015624, Precision@10: 0.0
Query ID: 2785579, Recall@10: 0.0
Query ID: 2785579, Precision@10: 0.0
Query ID: 4003223, Recall@10: 0.0
Query ID: 4003223, Precision@10: 0.0
Query ID: 481173, Recall@10: 0.0
Query ID: 481173, Precision@10: 0.0
Query ID: 33

Query ID: 4012558, Recall@10: 0.0
Query ID: 4012558, Precision@10: 0.0
Query ID: 3301173, Recall@10: 0.0
Query ID: 3301173, Precision@10: 0.0
Query ID: 654124, Recall@10: 0.0
Query ID: 654124, Precision@10: 0.0
Query ID: 3278654, Recall@10: 0.0
Query ID: 3278654, Precision@10: 0.0
Query ID: 2528767, Recall@10: 0.0
Query ID: 2528767, Precision@10: 0.0
Query ID: 1977054, Recall@10: 0.0
Query ID: 1977054, Precision@10: 0.0
Query ID: 1623623, Recall@10: 0.0
Query ID: 1623623, Precision@10: 0.0
Query ID: 1290612, Recall@10: 0.0
Query ID: 1290612, Precision@10: 0.0
Query ID: 3990512, Recall@10: 0.0
Query ID: 3990512, Precision@10: 0.0
Query ID: 204963, Recall@10: 0.0
Query ID: 204963, Precision@10: 0.0
Query ID: 2892478, Recall@10: 0.0
Query ID: 2892478, Precision@10: 0.0
Query ID: 3507491, Recall@10: 0.0
Query ID: 3507491, Precision@10: 0.0
Query ID: 953489, Recall@10: 0.0
Query ID: 953489, Precision@10: 0.0
Query ID: 1152934, Recall@10: 0.0
Query ID: 1152934, Precision@10: 0.0
Query ID: 28

In [10]:

def calculate_MAP(query_id):
    relevant_docs = []
    for qrel in dataset.qrels_iter():
        if qrel[0] == query_id and qrel[2] > 0:
            relevant_docs.append(qrel[1])

    retrieved_docs = []
    for query in dataset.queries_iter():
        if query[0] == query_id:
            retrieved_docs = getRetrievedQueries(query[1])
            break

    pk_sum = 0
    total_relevant = 0
    for i in range(1, 11):
        relevant_ret = 0
        for j in range(i):
            if j < len(retrieved_docs) and retrieved_docs[j] in relevant_docs:
                relevant_ret += 1
        p_at_k = (relevant_ret / i) * (1 if i - 1 < len(retrieved_docs) and retrieved_docs[i - 1] in relevant_docs else 0)
        pk_sum += p_at_k
        if i - 1 < len(retrieved_docs) and retrieved_docs[i - 1] in relevant_docs:
            total_relevant += 1

    return 0 if total_relevant == 0 else pk_sum / total_relevant

queries_ids = {qrel[0]: '' for qrel in dataset.qrels_iter()}

map_sum = 0
for query_id in list(queries_ids.keys()):
    map_sum += calculate_MAP(query_id)

print(f"Mean Average Precision (MAP@10): {map_sum / len(queries_ids)}")


Mean Average Precision (MAP@10): 0.0
