In [1]:
import re
import string
import pickle
from typing import List
import numpy as np
from nltk import tokenize, pos_tag, download
from nltk.corpus import stopwords, wordnet
from nltk.stem import PorterStemmer, WordNetLemmatizer
from sklearn.feature_extraction.text import TfidfVectorizer
import ir_datasets
from sklearn.metrics.pairwise import cosine_similarity
import pandas as pd
from typing import Callable
import re
import string
from typing import Callable, List
from nltk.corpus import stopwords, wordnet
from nltk.stem import PorterStemmer, WordNetLemmatizer
from nltk.tokenize import word_tokenize
from nltk import pos_tag
import numpy as np
import inflect
import contractions
class LemmatizerWithPOSTagger(WordNetLemmatizer):
    def __init__(self):
        super().__init__()

    def _get_wordnet_pos(self, tag: str) -> str:
        if tag.startswith('J'):
            return wordnet.ADJ
        elif tag.startswith('V'):
            return wordnet.VERB
        elif tag.startswith('N'):
            return wordnet.NOUN
        elif tag.startswith('R'):
            return wordnet.ADV
        else:
            return wordnet.NOUN

    def lemmatize(self, word: str, pos: str = "n") -> str:
        return super().lemmatize(word, self._get_wordnet_pos(pos))

class TextPreprocessor():

    def __init__(self, tokenizer: Callable = None) -> None:
        self.tokenizer = tokenizer

        if self.tokenizer is None:
            self.tokenizer = word_tokenize

        self.stopwords_tokens = stopwords.words('english')
        self.stemmer = PorterStemmer()
        self.lemmatizer = LemmatizerWithPOSTagger()
        self.inflect_engine = inflect.engine()

    def tokenize(self, text: str) -> List[str]:
        tokens = self.tokenizer(text)
        return tokens
    
    def to_lower(self, tokens: List[str]) -> List[str]:
        return [token.lower() for token in tokens]

    def remove_markers(self, tokens: List[str]) -> List[str]:
        return [re.sub(r'\u00AE', '', token) for token in tokens]

    def remove_punctuation(self, tokens: List[str]) -> List[str]:
        return [token.translate(str.maketrans('', '', string.punctuation)) for token in tokens]

    def replace_under_score_with_space(self, tokens: List[str]) -> List[str]:
        return [token.replace('_', ' ') for token in tokens]

    def remove_stop_words(self, tokens: List[str]) -> List[str]:
        return [token for token in tokens if token not in self.stopwords_tokens and len(token) > 1]

    def remove_apostrophe(self, tokens: List[str]) -> List[str]:
        return [token.replace("'", " ") for token in tokens]

    def stemming(self, tokens: List[str]) -> List[str]:
        return [self.stemmer.stem(token) for token in tokens]

    def normalize_abbreviations(self, tokens: List[str]) -> List[str]:
        new_tokens = []
        resolved_terms = {}
        for token in tokens:
            if len(token) >= 2:
                synsets = wordnet.synsets(token)
                if synsets:
                    resolved_term = synsets[0].lemmas()[0].name()
                    resolved_terms[token] = resolved_term

        for token in tokens:
            if token in resolved_terms:
                new_tokens.append(resolved_terms[token])
            else:
                new_tokens.append(token)
        return new_tokens

    def lemmatizing(self, tokens: List[str]) -> List[str]:
        tagged_tokens = pos_tag(tokens)
        lemmatized_tokens = [self.lemmatizer.lemmatize(token, pos) for token, pos in tagged_tokens]
        return lemmatized_tokens

    def process_hashtags_mentions(self, tokens: List[str]) -> List[str]:
        return [token for token in tokens if not token.startswith('#') and not token.startswith('@')]

    def replace_country_symbols(self, tokens: List[str]) -> List[str]:
        country_symbols = {
            'US': 'United States', 'UK': 'United Kingdom', 'IN': 'India', 'CA': 'Canada',
            'AU': 'Australia', 'DE': 'Germany', 'FR': 'France', 'ES': 'Spain', 'IT': 'Italy',
            'JP': 'Japan', 'CN': 'China', 'BR': 'Brazil', 'RU': 'Russia', 'MX': 'Mexico',
            'ZA': 'South Africa', 'KR': 'South Korea', 'AR': 'Argentina', 'SA': 'Saudi Arabia',
            'EG': 'Egypt', 'NG': 'Nigeria', 'TR': 'Turkey', 'NL': 'Netherlands', 'SE': 'Sweden',
            'CH': 'Switzerland', 'BE': 'Belgium', 'AT': 'Austria', 'DK': 'Denmark', 'FI': 'Finland',
            'NO': 'Norway', 'PL': 'Poland', 'IE': 'Ireland', 'NZ': 'New Zealand', 'SG': 'Singapore',
            'MY': 'Malaysia', 'TH': 'Thailand', 'PH': 'Philippines', 'ID': 'Indonesia', 'VN': 'Vietnam',
            'PK': 'Pakistan', 'BD': 'Bangladesh', 'IR': 'Iran', 'IQ': 'Iraq', 'IL': 'Israel', 'GR': 'Greece',
            'PT': 'Portugal', 'CZ': 'Czech Republic', 'HU': 'Hungary', 'RO': 'Romania', 'BG': 'Bulgaria',
            'HR': 'Croatia', 'SI': 'Slovenia', 'SK': 'Slovakia', 'UA': 'Ukraine', 'BY': 'Belarus', 'LT': 'Lithuania',
            'LV': 'Latvia', 'EE': 'Estonia', 'IS': 'Iceland', 'MT': 'Malta', 'CY': 'Cyprus', 'LK': 'Sri Lanka',
            'KE': 'Kenya', 'GH': 'Ghana', 'UG': 'Uganda', 'TZ': 'Tanzania', 'SN': 'Senegal', 'DZ': 'Algeria',
            'MA': 'Morocco', 'TN': 'Tunisia', 'AE': 'United Arab Emirates', 'QA': 'Qatar', 'KW': 'Kuwait',
            'OM': 'Oman', 'BH': 'Bahrain', 'LB': 'Lebanon', 'JO': 'Jordan', 'SY': 'Syria', 'YE': 'Yemen',
            'AF': 'Afghanistan', 'UZ': 'Uzbekistan', 'KZ': 'Kazakhstan', 'KG': 'Kyrgyzstan', 'TJ': 'Tajikistan',
            'TM': 'Turkmenistan', 'MN': 'Mongolia', 'KH': 'Cambodia', 'LA': 'Laos', 'MM': 'Myanmar', 'NP': 'Nepal',
            'BT': 'Bhutan', 'LK': 'Sri Lanka', 'MV': 'Maldives', 'BN': 'Brunei', 'MO': 'Macau', 'HK': 'Hong Kong',
            'TW': 'Taiwan', 'AM': 'Armenia', 'GE': 'Georgia', 'AZ': 'Azerbaijan'
        }
        return [country_symbols.get(token, token) for token in tokens]

    def replace_contractions(self, text: str) -> str:
        return contractions.fix(text)



    def preprocess(self, text: str) -> str:
        
        text_tokens = self.tokenize(text)
        
        operations = [
            self.process_hashtags_mentions,       # Step 5: Process hashtags and mentions
            self.replace_country_symbols,         # Step 6: Replace country symbols
            self.normalize_abbreviations,         # Step 2: Normalize abbreviations
            self.remove_markers,                  # Step 9: Remove markers
            self.replace_under_score_with_space ,  # Step 12: Replace underscores with spaces
            self.to_lower,                        # Step 1: Convert text to lower case
            self.remove_punctuation,              # Step 3: Remove punctuation
            self.remove_apostrophe,               # Step 4: Remove apostrophes
            self.remove_stop_words,               # Step 8: Remove stop words
            self.lemmatizing, 
        ]

        for op in operations:
            text_tokens = op(text_tokens)
        
        new_text = ' '.join(text_tokens)
        return new_text

class TfidfEngine:
    def __init__(self, text_preprocessor):
        self.text_preprocessor = text_preprocessor
        self.tfidf_matrix = None
        self.tfidf_model = None
        self.document_id_mapping = {}

    def train_model(self, documents):
        document_texts = [doc['text'] for doc in documents]
        vectorizer = TfidfVectorizer( tokenizer=self.text_preprocessor.tokenizer ,preprocessor=self.text_preprocessor.preprocess)
        tfidf_matrix = vectorizer.fit_transform(document_texts)
        self.tfidf_matrix = tfidf_matrix
        self.tfidf_model = vectorizer
        self.save_model(documents)

    def save_model(self, documents):
        with open('tfidf_model_antique.pickle', 'wb') as f_model:
            pickle.dump(self.tfidf_model, f_model)
        with open('tfidf_matrix_antique.pickle', 'wb') as f_matrix:
            pickle.dump(self.tfidf_matrix, f_matrix)
        with open('document_id_mapping_antique.pickle', 'wb') as f_mapping:
            pickle.dump({doc['id']: doc['text'] for doc in documents}, f_mapping)

    def load_model(self):
        with open('tfidf_model_antique.pickle', 'rb') as f_model:
            self.tfidf_model = pickle.load(f_model)
        with open('tfidf_matrix_antique.pickle', 'rb') as f_matrix:
            self.tfidf_matrix = pickle.load(f_matrix)
        with open('document_id_mapping_antique.pickle', 'rb') as f_mapping:
            self.document_id_mapping = pickle.load(f_mapping)

    def query(self, query_text):
        preprocessed_query = self.text_preprocessor.preprocess(query_text)
        query_vector = self.tfidf_model.transform([preprocessed_query])
        return query_vector
    
    def rank_documents(self, query_vector):
        cosine_similarities = cosine_similarity(query_vector, self.tfidf_matrix).flatten()
        ranked_indices = np.argsort(-cosine_similarities)
        return ranked_indices, cosine_similarities

    def get_results(self, query_text):
        query_vector = self.query(query_text)
        ranked_indices, similarities = self.rank_documents(query_vector)
        result_ids = []
        for idx in ranked_indices[:10]:  # Top 10 results
            if similarities[idx] >= 0.35:
                result_ids.append(list(self.document_id_mapping.keys())[idx])
        unordered_results = [{'_id': doc_id, 'text': self.document_id_mapping[doc_id]} for doc_id in result_ids]
        return unordered_results


def calculate_MAP(query_id, tfidf_engine, dataset):
    relevant_docs = [qrel.doc_id for qrel in dataset.qrels_iter() if qrel.query_id == query_id]
    
    ordered_results = []
    for query in dataset.queries_iter():
        if query.query_id == query_id:
            ordered_results = tfidf_engine.get_results(query.text)
            break

    pk_sum = 0
    total_relevant = 0
    for i in range(1, 11):
        relevant_ret = 0
        for j in range(i):
            if j < len(ordered_results) and ordered_results[j]['_id'] in relevant_docs:
                relevant_ret += 1
        p_at_k = (relevant_ret / i) * (1 if i-1 < len(ordered_results) and ordered_results[i-1]['_id'] in relevant_docs else 0)
        pk_sum += p_at_k
        if i-1 < len(ordered_results) and ordered_results[i-1]['_id'] in relevant_docs:
            total_relevant += 1

    return 0 if total_relevant == 0 else pk_sum / total_relevant
def calculate_precision_at_10(query_id, tfidf_engine, dataset):
    relevant_docs = [qrel.doc_id for qrel in dataset.qrels_iter() if qrel.query_id == query_id]
    
    ordered_results = []
    for query in dataset.queries_iter():
        if query.query_id == query_id:
            ordered_results = tfidf_engine.get_results(query.text)
            break

    retrieved_docs = [result['_id'] for result in ordered_results[:10]]
    relevant_retrieved_docs = [doc for doc in retrieved_docs if doc in relevant_docs]

    precision_at_10 = len(relevant_retrieved_docs) / 10
    return precision_at_10

def calculate_recall(query_id, tfidf_engine, dataset):
    relevant_docs = [qrel.doc_id for qrel in dataset.qrels_iter() if qrel.query_id == query_id]
    
    ordered_results = []
    for query in dataset.queries_iter():
        if query.query_id == query_id:
            ordered_results = tfidf_engine.get_results(query.text)
            break

    retrieved_docs = [result['_id'] for result in ordered_results]
    relevant_retrieved_docs = [doc for doc in retrieved_docs if doc in relevant_docs]

    recall = len(relevant_retrieved_docs) / len(relevant_docs) if relevant_docs else 0
    return recall

# # Initialize TextPreprocessor
text_preprocessor = TextPreprocessor()

# Load documents
documents = []
dataset = ir_datasets.load("antique/test")
documents = [{'id': doc.doc_id, 'text': doc.text} for doc in dataset.docs_iter()]

# # Initialize TfidfEngine with the TextPreprocessor
tfidf_engine = TfidfEngine(text_preprocessor)

# Train the TF-IDF model (uncomment to train and save the model)
tfidf_engine.train_model(documents)

# Load the trained model
tfidf_engine.load_model()

# Calculate MAP, precision@10, and recall for all queries in the dataset
dataset = ir_datasets.load("antique/test")

queries_ids = {qrel.query_id for qrel in dataset.qrels_iter()}
map_sum = 0
precision_at_10_sum = 0
recall_sum = 0

for query_id in queries_ids:
    map_sum += calculate_MAP(query_id, tfidf_engine, dataset)
    precision_at_10_sum += calculate_precision_at_10(query_id, tfidf_engine, dataset)
    recall_sum += calculate_recall(query_id, tfidf_engine, dataset)

mean_average_precision = map_sum / len(queries_ids)
mean_precision_at_10 = precision_at_10_sum / len(queries_ids)
mean_recall = recall_sum / len(queries_ids)

print(f"Mean Average Precision (MAP): {mean_average_precision}")
print(f"Mean Precision@10: {mean_precision_at_10}")
print(f"Mean Recall: {mean_recall}")



Mean Average Precision (MAP): 0.6567764518770468
Mean Precision@10: 0.3434999999999999
Mean Recall: 0.1128125178448454
