In [1]:
import ir_datasets
from nltk.tokenize import word_tokenize
from nltk.corpus import stopwords
from nltk.stem import WordNetLemmatizer
from nltk.corpus import wordnet
from nltk import pos_tag
from spellchecker import SpellChecker
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np
import string
import re
import joblib
from flask import Flask, request, jsonify
from flask_sqlalchemy import SQLAlchemy

## Preprocessor Service

In [2]:
class Preprocessor:
    def __init__(self):
        self.stop_words = set(stopwords.words('english'))
        self.lemmatizer = WordNetLemmatizer()

    def get_wordnet_pos(self, tag):
        tag_dict = {
            "J": wordnet.ADJ,
            "N": wordnet.NOUN,
            "V": wordnet.VERB,
            "R": wordnet.ADV
        }
        return tag_dict.get(tag[0].upper(), wordnet.NOUN)

    def lemmatization(self, tagged_doc_text):
        return [self.lemmatizer.lemmatize(word, pos=self.get_wordnet_pos(tag)) for word, tag in tagged_doc_text]
    
    def preprocess(self, text):
        tokens = word_tokenize(text.lower())
        tokens = [token.translate(str.maketrans('', '', string.punctuation)) for token in tokens if token.translate(str.maketrans('', '', string.punctuation))]
        tokens = [word for word in tokens if word not in self.stop_words]
        tagged_tokens = pos_tag(tokens)
        tokens = self.lemmatization(tagged_tokens)
        return tokens

In [3]:
# Spell Corrector Service
class SpellCorrector:
    def __init__(self):
        self.spell_checker = SpellChecker()

    def correct_sentence_spelling(self, query):
        query_tokens = word_tokenize(query.lower())
        misspelled = self.spell_checker.unknown(query_tokens)
        corrected_tokens = [self.spell_checker.correction(token) if token in misspelled else token for token in query_tokens]
        return ' '.join(corrected_tokens)

In [4]:
# Vectorizer Service
class Vectorizer:
    def __init__(self):
        self.vectorizer = TfidfVectorizer()

    def fit_transform(self, documents):
        self.tfidf_matrix = self.vectorizer.fit_transform(documents)
        return self.tfidf_matrix

    def transform(self, document):
        return self.vectorizer.transform([document])


In [5]:
# SearchEngine Service
class SearchEngine:
    def __init__(self, preprocessor, spell_corrector, document_model, app):
        self.preprocessor = preprocessor
        self.spell_corrector = spell_corrector
        self.vectorizers = {}
        self.tfidf_matrices = {}
        
        with app.app_context():
            self.documents = [doc.to_dict() for doc in document_model.query.all()]
            self.elements = document_model.get_columns(True)

    def index_documents(self):
        for element in self.elements:
            try:
                processed_docs = [' '.join(self.preprocessor.preprocess(doc[element])) for doc in self.documents]
                vectorizer = Vectorizer()
                self.vectorizers[element] = vectorizer
                self.tfidf_matrices[element] = vectorizer.fit_transform(processed_docs)
            except Exception as e:
                print(f"An error occurred during vectorization of {element}:", e)
    
    def save_model(self, name):
        joblib.dump(self.vectorizers, f'{name}_vectorizers.pkl')
        joblib.dump(self.tfidf_matrices, f'{name}tfidf_matrices.pkl')

    def load_model(self, name):
        self.vectorizers = joblib.load(f'{name}_vectorizers.pkl')
        self.tfidf_matrices = joblib.load(f'{name}tfidf_matrices.pkl')

    def search(self, query, weights):
        corrected_query = self.spell_corrector.correct_sentence_spelling(query)
        query_processed = self.preprocessor.preprocess(corrected_query)
        if not query_processed:
            return [], [], corrected_query
        
        query = ' '.join(query_processed)
        scores = np.zeros(len(self.documents))

        for element, weight in zip(self.elements, weights):
            try:
                query_vector = self.vectorizers[element].transform(query)
                if query_vector.shape[1] == 0:
                    continue
                cosine_similarities = cosine_similarity(query_vector, self.tfidf_matrices[element]).flatten()
                scores += weight * cosine_similarities
            except Exception as e:
                print(f"An error occurred during searching in {element}:", e)

        ranked_indices = np.argsort(scores)[::-1]
        return ranked_indices, scores, corrected_query


In [6]:
def clean_value(value):
    text = re.sub(r"'''''(.*?)'''''", r"\1", value)
    text = re.sub(r"'''(.*?)'''", r"\1", text)
    text = re.sub(r"''(.*?)''", r"\1", text)
    text = re.sub(r"\s+", " ", text).strip()
    return text

In [26]:
app = Flask(__name__)
app.config['SQLALCHEMY_DATABASE_URI'] = 'mysql://root:@localhost/ir_search_engine'
db = SQLAlchemy(app)

class Document(db.Model):
    id = db.Column(db.Integer, primary_key=True)
    page_title = db.Column(db.String(255))
    wikidata_classes = db.Column(db.String(255))
    text = db.Column(db.Text)
    sections = db.Column(db.Text)
    infoboxes = db.Column(db.Text)

    def to_dict(self):
        return {
            'id': self.id,
            'page_title': self.page_title,
            'wikidata_classes': self.wikidata_classes,
            'text': self.text,
            'sections': self.sections,
            'infoboxes': self.infoboxes,
        }
        
    def get_columns(self, exclude_id=False):
        columns = self.__table__.columns
        if exclude_id:
            return [column.name for column in columns if column.name != 'id']
        else:
            return [column.name for column in columns]

with app.app_context():
    db.create_all()

In [12]:
def save_document(doc_data):
    doc = Document(
        page_title=doc_data['page_title'],
        wikidata_classes=doc_data['wikidata_classes'],
        text=doc_data['text'],
        sections=doc_data['sections'],
        infoboxes=doc_data['infoboxes']
    )
    db.session.add(doc)
    db.session.commit()

In [27]:
# Main Application
if __name__ == "__main__":
    dataset = ir_datasets.load("trec-tot/2023/train")

    for i, doc in enumerate(dataset.docs_iter()):
        # if i == 1000:
        #     break
        doc_data = {
            "page_title": doc.page_title,
            "wikidata_classes": doc.wikidata_classes[0][1],
            "text": doc.text,
            "sections": "",
            "infoboxes": "",
        }
    
        for section_name, section_text in doc.sections.items():
            doc_data["sections"] += f"\n {section_text}"
    
        for infobox in doc.infoboxes:
            for key, value in infobox['params'].items():
                cleaned_value = clean_value(value)
                doc_data["infoboxes"] += f"\n {cleaned_value}"
        with app.app_context():
            save_document(doc_data)

In [14]:
document_model = Document()
preprocessor = Preprocessor()
spell_corrector = SpellCorrector()
search_engine = SearchEngine(preprocessor, spell_corrector, document_model, app)

In [None]:
search_engine.index_documents()
search_engine.save_model("trec-tot")

In [None]:
query = "film television Chldren"
weights = [0.5, 1.5, 0.3, 0.1, 0.1]
search_engine.load_model("trec-tot")
ranked_indices, scores, corrected_query = search_engine.search(query, weights)
if corrected_query != query.lower():
        print("Did you mean: " + str(corrected_query))

print("\n\nRanked Documents:\n")
for index in ranked_indices[:10]:
    print(f"Document {index} - Score: {scores[index]:.4f}")
    print(f"Title: {corpus[index]['page_title']}")
    print(f"Classes: {corpus[index]['wikidata_classes']}")
    print("-" * 50)

In [16]:
@app.route('/search', methods=['GET'])
def search():
    query = request.args.get('query', '')
    type_dataset = request.args.get('type_dataset', '')

    if not query:
        return jsonify({
            'message': 'Query parameter is required'
        }), 400
        
    if not type_dataset:
        return jsonify({
            'message': 'Type dataset parameter is required'
        }), 400

    if type_dataset == "1":
        search_engine.load_model("trec-tot")
    elif type_dataset == "2":
        search_engine.load_model("webis-touche2020")
    else:
        return jsonify({
            'message': "Type dataset not valid"
        }), 400
        
    ranked_indices, scores, corrected_query = search_engine.search(query, weights)
    
    results = []
    for index in ranked_indices[:10]:
        doc = search_engine.documents[index]
        results.append({
            'title': doc['page_title'],
            'wikidata_classes': doc['wikidata_classes'],
            'text_snippet': ' '.join(doc['text'].split()[:30]),
            'similarity_score': f'{scores[index]:.4f}'
        })

    response = {
        'corrected_query': corrected_query if corrected_query.lower() != query.lower() else None,
        'results': results
    }
    
    return jsonify(response)


In [None]:
if __name__ == '__main__':
    app.run(debug=False)
    # %tb
    