# ingest.search

> full-text search engine

In [None]:
# | default_exp ingest.search

In [None]:
# | hide
from nbdev.showdoc import *

In [None]:
# | export

import json
import os
import warnings
from typing import Dict, List, Optional, Sequence
import math

from whoosh import index
from whoosh.analysis import StemmingAnalyzer
from whoosh.fields import *
from whoosh.filedb.filestore import RamStorage
from whoosh.qparser import MultifieldParser
from langchain_core.documents import Document
import uuid
from tqdm import tqdm

# ------------------------------------------------------------------------------
# IMPORTANT: Metadata fields in langchain_core.documents.Document objects
#            (i.e., the input to WSearch.index_documents) should
#            ideally match schema fields below, but this is not strictly required.
#
#            The page_content field is the only truly required field in supplied
#            Document objects. All other fields, including dynamic fields, are optional. 
# ------------------------------------------------------------------------------

DEFAULT_SCHEMA = Schema(
    page_content=TEXT(stored=True), # REQUIRED
    id=ID(stored=True, unique=True),
    source=KEYWORD(stored=True, commas=True), 
    source_search=TEXT(stored=True),
    filepath=KEYWORD(stored=True, commas=True),
    filepath_search=TEXT(stored=True),
    filename=KEYWORD(stored=True),
    ocr=BOOLEAN(stored=True),
    table=BOOLEAN(stored=True),
    markdown=BOOLEAN(stored=True),
    page=NUMERIC(stored=True),
    document_title=TEXT(stored=True),
    md5=KEYWORD(stored=True),
    mimetype=KEYWORD(stored=True),
    extension=KEYWORD(stored=True),
    filesize=NUMERIC(stored=True),
    createdate=DATETIME(stored=True),
    modifydate=DATETIME(stored=True),
    tags=KEYWORD(stored=True, commas=True),
    notes=TEXT(stored=True),
    msg=TEXT(stored=True),
    )
DEFAULT_SCHEMA.add("*_t", TEXT(stored=True), glob=True)
DEFAULT_SCHEMA.add("*_k", KEYWORD(stored=True, commas=True), glob=True)
DEFAULT_SCHEMA.add("*_b", BOOLEAN(stored=True), glob=True)
DEFAULT_SCHEMA.add("*_n", NUMERIC(stored=True), glob=True)
DEFAULT_SCHEMA.add("*_d", DATETIME(stored=True), glob=True)


def default_schema():
    schema = DEFAULT_SCHEMA
    #if "raw" not in schema.stored_names():
        #schema.add("raw", TEXT(stored=True))
    return schema


class SearchEngine:
    def __init__(self,
                index_path: Optional[str]=None, # path to folder where search index is stored
                index_name: Optional[str] = None # name of index
        ):
        """
        Initializes full-text search engine
        """
        self.index_path = index_path
        self.index_name = index_name
        if index_path and not index_name:
            raise ValueError('index_name is required if index_path is supplied')
        if index_path:
            if not index.exists_in(index_path, indexname=index_name):
                self.ix = __class__.initialize_index(index_path, index_name)
            else:
                self.ix = index.open_dir(index_path, indexname=index_name)
        else:
            warnings.warn(
                "No index_path was supplied, so an in-memory only index"
                "was created using DEFAULT_SCHEMA"
            )
            self.ix = RamStorage().create_index(default_schema())


    @classmethod
    def index_exists_in(cls, index_path: str, index_name: Optional[str] = None):
        """
        Returns True if index exists with name, *indexname*, and path, *index_path*.
        """
        return index.exists_in(index_path, indexname=index_name)

    @classmethod
    def initialize_index(
        cls, index_path: str, index_name: str, schema: Optional[Schema] = None
    ):
        """
        Initialize index

        **Args**

        - *index_path*: path to folder storing search index
        - *index_name*: name of index
        - *schema*: optional whoosh.fields.Schema object.
                    If None, DEFAULT_SCHEMA is used
        """
        schema = default_schema() if not schema else schema

        if index.exists_in(index_path, indexname=index_name):
            raise ValueError(
                f"There is already an existing index named {index_name}  with path {index_path} \n"
                + f"Delete {index_path} manually and try again."
            )
        if not os.path.exists(index_path):
            os.makedirs(index_path)
        ix = index.create_in(index_path, indexname=index_name, schema=schema)
        return ix

    def doc2dict(self, doc:Document):
        """
        Convert LangChain Document to expected format
        """
        stored_names = self.ix.schema.stored_names()
        d = {}
        for k,v in doc.metadata.items():
            suffix = None
            if k in stored_names:
                suffix = ''
            elif isinstance(v, bool):
                suffix = '_b' if not k.endswith('_b') else ''
            elif isinstance(v, str):
                if k.endswith('_date'):
                    suffix = '_d'
                else:
                    suffix = '_k'if not k.endswith('_k') else ''
            elif isinstance(v, (int, float)):
                suffix = '_n'if not k.endswith('_n') else ''
            if suffix is not None:
                d[k+suffix] = v
        d['id'] = uuid.uuid4().hex
        d['page_content' ] = doc.page_content
        #d['raw'] = json.dumps(d)
        if 'source' in d:
            d['source_search'] = d['source']
        if 'filepath' in d:
            d['filepath_search'] = d['filepath']
        return d


    def index_documents(self,
                        docs: Sequence[Document], # list of LangChain Documents
                        verbose:bool=True, # Set to False to disable progress bar
        ):
        """
        Indexes documents.
        """
        writer = self.ix.writer()
        for doc in tqdm(docs, total=len(docs), disable=not verbose):
            d = self.doc2dict(doc)
            writer.update_document(**d)
        writer.commit(optimize=True)


    def clear_index(self, confirm=True):
        """
        Clears index
        """
        shall = True
        if confirm:
            msg = (
                f"You are about to remove all documents from the search index."
                + f"(Original documents on file system will remain.) Are you sure?"
            )
            shall = input("%s (Y/n) " % msg) == "Y"
        if shall and index.exists_in(
            self.index_path, indexname=self.index_name
        ):
            ix = index.create_in(
                self.index_path,
                indexname=self.index_name,
                schema=default_schema(),
            )
            return True
        return False


    def get_index_size(self) -> int:
        """
        Gets size of index
        """
        return self.ix.doc_count_all()


    def get_all_docs(self):
        """
        Returns a generator to iterate through all indexed documents
        """
        return self.ix.searcher().documents()


    def get_doc(self, id:str):
        """
        Get an indexed record by ID
        """
        r = self.query(f'id:{id}')
        return r['hits'][0] if len(r['hits']) > 0 else None


    def query(
            self,
            q: str,
            fields: Sequence = ["page_content"],
            highlight: bool = True,
            limit:int=10,
            page:int=1,
    ) -> List[Dict]:
        """
        Queries the index

        **Args**

        - *q*: the query string
        - *fields*: a list of fields to search
        - *highlight*: If True, highlight hits
        - *limit*: results per page
        - *page*: page of hits to return
        """
        search_results = []
        with self.ix.searcher() as searcher:
            if page == 1:
                results = searcher.search(
                    MultifieldParser(fields, schema=self.ix.schema).parse(q), limit=limit)
            else:
                results = searcher.search_page(
                    MultifieldParser(fields, schema=self.ix.schema).parse(q), page, limit)
            total_hits = results.scored_length()
            if page > math.ceil(total_hits/limit):
               results = []
            for r in results:
                #d = json.loads(r["raw"])
                d = dict(r)
                if highlight:
                    for f in fields:
                        if r[f] and isinstance(r[f], str):
                            d['hl_'+f] = r.highlights(f) or r[f]

                search_results.append(d)

        return {'hits':search_results, 'total_hits':total_hits}



In [None]:
# | hide
import nbdev

nbdev.nbdev_export()