# Notebook: Search Engine with Python, NLTK, and Scikit-learn
# Author: [Yasser Barghouth]

## Description: This notebook demonstrates the creation of a simple search engine using a service-oriented approach.
##              It utilizes NLTK for text preprocessing and Scikit-learn for vectorization and search functionality.


# IR Search Engine

This notebook implements an Information Retrieval (IR) search engine using various datasets. It includes preprocessing, spell correction, vectorization, and a search mechanism. The application is built using Flask and SQLAlchemy.

## Components
1. Preprocessor
2. SpellCorrector
3. Vectorizer
4. SearchEngine
5. Flask API

## Datasets
- TREC-TOT
- Webis-Touche

## Usage
The search engine supports indexing documents from datasets, processing queries, and returning relevant documents based on cosine similarity.



## Import necessary libraries

In [1]:
import ir_datasets
import string
from nltk.tokenize import word_tokenize
from nltk.corpus import stopwords
from nltk.stem import WordNetLemmatizer
from nltk.corpus import wordnet
from nltk import pos_tag
from spellchecker import SpellChecker
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np
import joblib
import re
import os
from flask import Flask, request, jsonify
from flask_sqlalchemy import SQLAlchemy
from sqlalchemy import inspect

## Preprocessor Service

In [2]:
# Preprocessor Class: Tokenization, stopword removal, lemmatization
class Preprocessor:
    def __init__(self):
        self.stop_words = set(stopwords.words('english'))
        self.lemmatizer = WordNetLemmatizer()

    def get_wordnet_pos(self, tag):
        tag_dict = {
            "J": wordnet.ADJ,
            "N": wordnet.NOUN,
            "V": wordnet.VERB,
            "R": wordnet.ADV
        }
        return tag_dict.get(tag[0].upper(), wordnet.NOUN)

    def lemmatization(self, tagged_doc_text):
        return [self.lemmatizer.lemmatize(word, pos=self.get_wordnet_pos(tag)) for word, tag in tagged_doc_text]
    
    def preprocess(self, text):
        tokens = word_tokenize(text.lower())
        tokens = [token.translate(str.maketrans('', '', string.punctuation)) for token in tokens if token.translate(str.maketrans('', '', string.punctuation))]
        tokens = [word for word in tokens if word not in self.stop_words]
        tagged_tokens = pos_tag(tokens)
        tokens = self.lemmatization(tagged_tokens)
        return tokens

## SpellCorrector Service

In [3]:
# SpellCorrector Class: Corrects misspelled words in a query
class SpellCorrector:
    def __init__(self):
        self.spell_checker = SpellChecker()

    def correct_sentence_spelling(self, query):
        query_tokens = word_tokenize(query.lower())
        misspelled = self.spell_checker.unknown(query_tokens)
        corrected_tokens = [self.spell_checker.correction(token) if token in misspelled else token for token in query_tokens]
        return ' '.join(corrected_tokens)

## Vectorizer Service

In [4]:
# Vectorizer Class: TF-IDF Vectorization
class Vectorizer:
    def __init__(self):
        self.vectorizer = TfidfVectorizer()

    def fit_transform(self, documents):
        self.tfidf_matrix = self.vectorizer.fit_transform(documents)
        return self.tfidf_matrix

    def transform(self, document):
        return self.vectorizer.transform([document])

## SearchEngine Service

In [5]:
# SearchEngine Class: Main search engine logic
class SearchEngine:
    def __init__(self, preprocessor, spell_corrector, document_service, model_service):
        self.preprocessor = preprocessor
        self.spell_corrector = spell_corrector
        self.document_service = document_service
        self.model_service = model_service
        self.vectorizers = {}
        self.tfidf_matrices = {}
        self.documents = []
        self.elements = []

    def index_documents(self, dataset_name):
        self.documents = self.document_service.get_documents(dataset_name)
        self.elements = self.document_service.get_columns(dataset_name)
        for element in self.elements:
            try:
                processed_docs = [' '.join(self.preprocessor.preprocess(doc[element])) for doc in self.documents]
                vectorizer = Vectorizer()
                self.vectorizers[element] = vectorizer
                self.tfidf_matrices[element] = vectorizer.fit_transform(processed_docs)
            except Exception as e:
                print(f"An error occurred during vectorization of {element}:", e)
    
    def save_model(self, name):
        try:
            self.model_service.save_vectorizers(self.vectorizers, name)
            self.model_service.save_tfidf_matrices(self.tfidf_matrices, name)
            return jsonify({'message': f'{name} vectorizers and tf-idf matrixes saved successfully'}), 200
        
        except ValueError as e:
            return jsonify({'message': str(e)}), 400    

    def load_model(self, name):
        try:
            self.vectorizers = self.model_service.load_vectorizers(name)
            self.tfidf_matrices = self.model_service.load_tfidf_matrices(name)
            return jsonify({'message': f'{name} vectorizers and tf-idf matrices loaded successfully'}), 200
        except ValueError as e:
            return jsonify({'message': str(e)}), 400

    
    def search(self, query, dataset_name, weights=None):
        if not self.vectorizers or not self.tfidf_matrices:
            return [], [], query, "Model not loaded"
        
        if weights is None:
            if dataset_name == "trec-tot":
                weights = [0.5, 1.5, 0.3, 0.1, 0.1]
            elif dataset_name == "webis-touche":
                weights = [0.5, 0.3]

        corrected_query = self.spell_corrector.correct_sentence_spelling(query)
        query_processed = self.preprocessor.preprocess(corrected_query)
        if not query_processed:
            return [], [], corrected_query, "Query could not be processed"
        
        query = ' '.join(query_processed)
        
        document_count = self.document_service.get_document_count(dataset_name)        
        scores = np.zeros(len(document_count))

        self.elements = self.document_service.get_columns(dataset_name)
        for element, weight in zip(self.elements, weights):
            try:
                query_vector = self.vectorizers[element].transform(query)
                if query_vector.shape[1] == 0:
                    continue
                cosine_similarities = cosine_similarity(query_vector, self.tfidf_matrices[element]).flatten()
                scores += weight * cosine_similarities
            except Exception as e:
                print(f"An error occurred during searching in {element}:", e)

        ranked_indices = np.argsort(scores)[::-1]
        return ranked_indices, scores, corrected_query, None

## Database Service

### TrecTot model

In [6]:
# Initialize SQLAlchemy instance
db = SQLAlchemy()

# Database Model for datasets
class TrecTotDocument(db.Model):
    id = db.Column(db.Integer, primary_key=True)
    page_title = db.Column(db.String(255))
    wikidata_classes = db.Column(db.String(255))
    text = db.Column(db.Text)
    sections = db.Column(db.Text)
    infoboxes = db.Column(db.Text)

    def to_dict(self):
        return {
            'id': self.id,
            'page_title': self.page_title,
            'wikidata_classes': self.wikidata_classes,
            'text': self.text,
            'sections': self.sections,
            'infoboxes': self.infoboxes,
        }
        
    def get_columns(self, exclude_id):
        columns = self.__table__.columns
        if exclude_id:
            return [column.name for column in columns if column.name != 'id']
        else:
            return [column.name for column in columns]

### WebisTouche model

In [None]:
# Database Model for datasets
class WebisToucheDocument(db.Model):
    id = db.Column(db.Integer, primary_key=True)
    title = db.Column(db.String(255))
    text = db.Column(db.Text)

    def to_dict(self):
        return {
            'id': self.id,
            'title': self.title,
            'text': self.text,
        }
        
    def get_columns(self, exclude_id):
        columns = self.__table__.columns
        if exclude_id:
            return [column.name for column in columns if column.name != 'id']
        else:
            return [column.name for column in columns]


### Document Service

In [9]:
# DocumentService class: Interacts with the database to manage documents
class DocumentService:
    def __init__(self, db):
        self.db = db

    def save_document(self, doc_data, dataset_name):
        if dataset_name == "trec-tot":
            doc = TrecTotDocument(
                page_title=doc_data['page_title'],
                wikidata_classes=doc_data['wikidata_classes'],
                text=doc_data['text'],
                sections=doc_data['sections'],
                infoboxes=doc_data['infoboxes']
            )
        elif dataset_name == "webis-touche":
            doc = WebisToucheDocument(
                title=doc_data['title'],
                text=doc_data['text']
            )
        self.db.session.add(doc)
        self.db.session.commit()

    def get_documents(self, dataset_name):
        if dataset_name == "trec-tot":
            return TrecTotDocument.query.limit(100).all()
        elif dataset_name == "webis-touche":
            return WebisToucheDocument.query.limit(100).all()

    def get_columns(self, dataset_name, exclude_id):
        if dataset_name == "trec-tot":
            return TrecTotDocument.get_columns(exclude_id)
        elif dataset_name == "webis-touche":
            return WebisToucheDocument.get_columns(exclude_id)
        
    def check_data_exists(self, dataset_name):
        table = self.db.metadata.tables.get(dataset_name)
        if table is not None:
            count = self.db.session.query(table).count()
            return count > 0
        
    def check_table_exists(self, dataset_name):
        inspector = inspect(self.db.engine)
        return dataset_name in inspector.get_table_names()
    
    def get_documents_by_indices(self, indices, dataset_name):
        if dataset_name == "trec-tot":
            return TrecTotDocument.query.filter(TrecTotDocument.id.in_(indices)).all()
        elif dataset_name == "webis-touche":
            return WebisToucheDocument.query.filter(WebisToucheDocument.id.in_(indices)).all()
        
    def get_document_count(self, dataset_name):
        if dataset_name == "trec-tot":
            return TrecTotDocument.query.count()
        elif dataset_name == "webis-touche":
            return WebisToucheDocument.query.count()

def clean_value(value):
    text = re.sub(r"'''''(.*?)'''''", r"\1", value)
    text = re.sub(r"'''(.*?)'''", r"\1", text)
    text = re.sub(r"''(.*?)''", r"\1", text)
    text = re.sub(r"\s+", " ", text).strip()
    return text

## Model Service

In [10]:
# ModelService class: Manages saving/loading of models
class ModelService:
    def __init__(self, model_directory):
        self.model_directory = model_directory
    def save_vectorizers(self, vectorizers, name):
        try:
            joblib.dump(vectorizers, f'{self.model_directory}/{name}_vectorizers.pkl')
        except ValueError as e:
            return f'{name}_vectorizers.pkl can\'t saved'

    def save_tfidf_matrices(self, tfidf_matrices, name):
        try:
            joblib.dump(tfidf_matrices, f'{self.model_directory}/{name}_tfidf_matrices.pkl')
        except ValueError as e:
            return f'{name}_vectorizers.pkl can\'t saved'

    def load_vectorizers(self, name):
        try:
            return joblib.load(f'{self.model_directory}/{name}_vectorizers.pkl')
        except ValueError as e:
            return f'{name}_vectorizers.pkl can\'t loaded' 

    def load_tfidf_matrices(self, name):
        try:
            return joblib.load(f'{self.model_directory}/{name}_tfidf_matrices.pkl')
        except ValueError as e:
            return f'{name}_tfidf_matrices.pkl can\'t loaded' 


## Data Loading Service

In [13]:

def load_dataset(dataset_name):
    if dataset_name == "trec-tot":
        dataset = ir_datasets.load("trec-tot/2023/train")
    elif dataset_name == "webis-touche":
        dataset = ir_datasets.load("beir/webis-touche2020")
    else:
        raise ValueError('Invalid dataset name')

    with app.app_context():
        if not document_service.check_table_exists(dataset_name):
            db.create_all()   # Create tables if they don't exist

        if not document_service.check_data_exists(dataset_name):
            for i, doc in enumerate(dataset.docs_iter()):
                if dataset_name == "trec-tot":
                    doc_data = {
                        "page_title": doc.page_title,
                        "wikidata_classes": doc.wikidata_classes[0][1],
                        "text": doc.text,
                        "sections": "",
                        "infoboxes": "",
                    }

                    for section_name, section_text in doc.sections.items():
                        doc_data["sections"] += f"\n {section_text}"

                    for infobox in doc.infoboxes:
                        for key, value in infobox['params'].items():
                            cleaned_value = clean_value(value)
                            doc_data["infoboxes"] += f"\n {cleaned_value}"
                elif dataset_name == "webis-touche":
                    doc_data = {
                        "title": doc.title,
                        "text": doc.text,
                    }
                    
                document_service.save_document(doc_data, dataset_name)
    
        else:
                raise ValueError(f'{dataset_name} dataset already exists')


## Flask App

In [14]:
# Flask app setup
app = Flask(__name__)
app.config['SQLALCHEMY_DATABASE_URI'] = 'mysql://root:@localhost/ir_search_engine'
db.init_app(app)

# Instantiate services
document_service = DocumentService(db)
preprocessor = Preprocessor()
spell_corrector = SpellCorrector()
model_service = ModelService(model_directory='models')
search_engine = SearchEngine(preprocessor, spell_corrector, document_service, model_service)

@app.route('/load_index_save_dataset', methods=['POST'])
def load_index_save_dataset_route():
     # Check if request has JSON data
    if not request.is_json:
        return jsonify({'message': 'Request must be JSON'}), 400

    data = request.get_json()
    
    dataset_name = request.json.get('dataset_name', '').strip()
    if not dataset_name:
        return jsonify({'message': 'Dataset name is required'}), 400
    try:
        load_dataset(dataset_name)
        # search_engine.index_documents(dataset_name)
        # search_engine.save_model(dataset_name)
        return jsonify({'message': f'{dataset_name} dataset loaded and indexed successfully'}), 200
    except ValueError as e:
        return jsonify({'message': str(e)}), 400

@app.route('/search', methods=['GET'])
def search():
    query = request.args.get('query', '').strip()
    type_dataset = request.args.get('type_dataset', '').strip()
    weights = request.args.getlist('weights', type=float)
    page = request.args.get('page', default=1, type=int)
    limit = request.args.get('limit', default=10, type=int)
    
    if not query:
        return jsonify({
            'message': 'Query parameter is required'
        }), 400
        
    if not type_dataset:
        return jsonify({
            'message': 'Type dataset parameter is required'
        }), 400

    if type_dataset not in ['1', '2']:
        return jsonify({'message': 'Type dataset not valid'}), 400
    
    if type_dataset == "1":
        dataset_name = "trec-tot"
    elif type_dataset == "2":
        dataset_name = "webis-touche"
    
    if not document_service.check_table_exists(dataset_name) or not document_service.check_data_exists(dataset_name):
        return jsonify({
            'message': f'{dataset_name} Dataset is not loaded. Please load the dataset first.'
        }), 400

    model_loaded = search_engine.load_model(dataset_name)
    if not model_loaded:
        return jsonify({'message': 'Model is not loaded. Please index the documents first.'}), 400

    ranked_indices, scores, corrected_query, error = search_engine.search(query, dataset_name, weights if weights else None)
    if error:
        return jsonify({'message': error}), 400
    
    results = []
    
    start_index = (page - 1) * limit
    end_index = start_index + limit
    top_indices = ranked_indices[start_index:end_index]
    docs = document_service.get_documents_by_indices(top_indices, dataset_name)
    
    if dataset_name == "trec-tot":
        for index in top_indices:
            doc = docs[index].to_dict()
            results.append({
                'title': doc['page_title'],
                'wikidata_classes': doc['wikidata_classes'],
                'text_snippet': ' '.join(doc['text'].split()[:30]),
                'similarity_score': f'{scores[index]:.4f}'
            })
    else:
        for index in ranked_indices[:10]:
            results.append({
                'title': doc['title'],
                'text_snippet': ' '.join(doc['text'].split()[:30]),
                'similarity_score': f'{scores[index]:.4f}'
            })
            

    response = {
        'corrected_query': corrected_query if corrected_query.lower() != query.lower() else None,
        'results': results
    }
    
    return jsonify(response)

if __name__ == '__main__':
    app.run(debug=False)

 * Serving Flask app '__main__'
 * Debug mode: off


 * Running on http://127.0.0.1:5000
Press CTRL+C to quit
127.0.0.1 - - [01/Jun/2024 14:16:04] "POST /load_index_save_dataset HTTP/1.1" 200 -
