# Utils


In [None]:
# | default_exp indexing.utils

In [None]:
# | export

from dreamai.imports import *
from langchain_ray.imports import *
from langchain_ray.utils import *
from langchain_ray.chains import *

In [None]:
#| hide

%load_ext autoreload
%autoreload 2
%reload_ext autoreload

In [None]:
# | export


def index_names(index_folder, index_name):
    return [
        f.stem for f in Path(index_folder).glob("*.faiss") if f.stem.startswith(index_name)
    ]


def docs_to_ems(docs, ems_model):
    return ems_model.encode([doc.page_content for doc in docs]).tolist()


def add_ems_to_docs(docs, ems_model, key="embeddings"):
    fn = partial(docs_to_ems, ems_model=ems_model)
    return add_docs_metadata(docs, fn, key)


def docs_to_faiss(
    docs,
    ems_model,
    index_folder="/media/hamza/data2/faiss_data/saved_indexes/",
    index_name="index",
):
    "Create a `FAISS` database from `Documents`."
    db = FAISS.from_documents(flatten_list(docs), ems_model)
    # Get an unused index path.
    index_path = (Path(index_folder) / index_name).with_suffix(".faiss")
    if index_path.exists():
        index_path = find_alternate_path(index_path, first_idx=1, verbose=False)
    # Save the database to the index path.
    index_name = index_path.stem
    db.save_local(index_folder, index_name)
    return docs


def search_faiss(index_folder, index_names, query, ems_model, filter=None, k=2):
    # print(f"\n\nSearching {index_names} in {index_folder} for {query}!!!\n\n")
    if path_or_str(index_folder):
        index_folder = [index_folder]
    if path_or_str(index_names):
        index_names = [index_names]
    if path_or_str(query):
        query = [query]
    if is_iter(k):
        k = int(k[0])
    q_sims = []
    for index_folder, index_name, query in zip(index_folder, index_names, query):
        # print(f"\n\nSearching {index_name} in {index_folder} for {query}\n\n")
        db = FAISS.load_local(index_folder, embeddings=ems_model, index_name=index_name)
        q_sims.append(db.similarity_search_with_score(query, filter=filter, k=k))
    # print(f"\n\nQSIMS: {len([q_sims])}, INDEXES: {len(index_names)}\n\n")
    return [[q_sims]]

In [None]:
# | hide
import nbdev

nbdev.nbdev_export()