# Core

> Core functionality for indexing and searching.

In [None]:
# | default_exp index.core


In [None]:
# | export

from dreamai_ray.imports import *
from dreamai_ray.utils import *
from dreamai_ray.mapper import *
from dreamai_ray.index.utils import *
from dreamai_ray.index.df import *


In [None]:
#| hide

%load_ext autoreload
%autoreload 2
%reload_ext autoreload

In [None]:
# | export


class write_index_cb(Callback):
    "A `Callback` to write the index to disk."

    def after_batch(self, cls, *args, **kwargs):
        cls.index = cls.udf_kwargs["index"]
        index_folder = cls.index_folder
        os.makedirs(index_folder, exist_ok=True)
        index_path = str(Path(index_folder) / f"{cls.block_counter}_{cls.index.ntotal}.faiss")
        df_path = str(Path(index_folder) / f"{cls.block_counter}.csv")
        if self.verbose and cls.verbose:
            msg.info(f"Writing Index to {index_path}")
            msg.info(f"Index Size: {cls.index.ntotal}")
            msg.info(f"Writing DF to {df_path}")
        faiss.write_index(cls.index, index_path)
        kwargs["df"].reset_index(drop=True).to_csv(df_path, index=False)


class reset_index_cb(Callback):
    "A `Callback` to reset the index."

    def after_batch(self, cls, **kwargs):
        cls.index.reset()
        if self.verbose and cls.verbose:
            msg.info(f"Index Size Post Reset: {cls.index.ntotal}")
        cls.udf_kwargs["index"] = cls.index
        cls.udf = partial(cls.udf, **cls.udf_kwargs)


class IndexCreator(Mapper):
    "Creates indexes from embeddings."

    def __init__(
        self,
        index=None,
        index_dim=3,  # The dimension of the index.
        index_folder="indexes",  # The folder to write the index to.
        ems_col="embedding",  # The column to use to create the index.
        udf=df_to_index,  # The function to use to create the index.
        cbs=[write_index_cb, reset_index_cb],  # The `Callback`s to use.
        block_counter=0,  # The starting block counter.
        verbose=True,  # Whether to print out information.
        udf_verbose=False,  # Whether to print out information in the udf.
        udf_kwargs={},  # Additional kwargs to pass to the udf.
        **kwargs,
    ):
        self.index_folder = index_folder
        if index is None:
            self.index = create_index(index_dim)
        else:
            self.index = index
        udf_kwargs["index"] = self.index
        udf_kwargs["ems_col"] = ems_col
        udf_kwargs["verbose"] = udf_verbose
        self.verbose = verbose
        cbs = [block_counter_cb(block_counter)] + cbs
        super().__init__(**locals_to_params(locals()))


def create_indexes_setup(
    task_id,
    ems_folder,
    index_folder="indexes",
    local_index_folder="/media/hamza/data2/faiss_data/saved_indexes",
    ems_col="embedding",
    num_blocks=1,
    block_size=40000,
    block_counter=0,
    verbose=True,
    *args,
    **kwargs,
):
    task_folder = f"/tmp/{task_id}"
    t1 = time()
    ems_folder, _ = handle_input_path(ems_folder, local_path=task_folder)
    em_files = sorted(
        get_files(ems_folder, extensions=[".json"], make_str=True),
        key=lambda x: int(Path(x).stem.split("_")[-1]),
    )

    t2 = time()
    msg.info(
        f"Embeddings download time: {t2-t1:.2f} seconds.",
        spaced=True,
        show=verbose,
    )
    df = pd.DataFrame({ems_col: em_files})
    msg.info(f"Embeddings DF created of length: {len(df)}", spaced=True, show=verbose)

    index_folder, index_bucket = get_local_path(index_folder, local_path=task_folder)
    local_index_folder = Path(local_index_folder) / Path(index_folder).name
    if local_index_folder.exists():
        msg.info(f"Local Index Folder Exists: {local_index_folder}", spaced=True, show=verbose)
        bucket_size = len(get_files(local_index_folder, extensions=[".faiss"]))
    else:
        bucket_size = max(bucket_count(index_bucket), 0) // 2
    block_counter += bucket_size
    if num_blocks is not None:
        block_size = len(df) // num_blocks
    msg.info(f"Bucket Size: {bucket_size}", spaced=True, show=verbose)
    msg.info(f"Block Size: {block_size}", spaced=True, show=verbose)
    msg.info(f"Block Counter: {block_counter}", spaced=True, show=verbose)

    return (
        df,
        block_size,
        block_counter,
        task_folder,
        index_folder,
        index_bucket,
        local_index_folder,
    )


def create_indexes_iter(
    ems_folder="embeddings",  # The folder containing the embeddings.
    ems_col="embedding",  # The column to use to create the index.
    index_dim=768,  # The dimension of the index.
    index_folder="indexes",  # The folder to write the index to.
    local_index_folder="/media/hamza/data2/faiss_data/saved_indexes",  # The local folder to write the index to.
    num_blocks=1,  # The number of blocks to make.
    block_size=40000,  # The size of each block. Will be ignored if `num_blocks` is not None.
    block_counter=0,  # The starting block counter.
    verbose=True,  # Whether to print out information.
    udf_verbose=False,  # Whether to print out information in the udf.
    udf_kwargs={},  # Additional kwargs to pass to the udf.
    task_id=gen_random_string(16),  # The task id to use.
    *args,
    **kwargs,
):
    t1 = time()
    (
        df,
        block_size,
        block_counter,
        task_folder,
        index_folder,
        index_bucket,
        local_index_folder,
    ) = create_indexes_setup(
        task_id=task_id,
        ems_folder=ems_folder,
        index_folder=index_folder,
        local_index_folder=local_index_folder,
        ems_col=ems_col,
        num_blocks=num_blocks,
        verbose=verbose,
    )
    t2 = time()
    msg.info(f"Setup time: {t2-t1:.2f} seconds.", spaced=True, show=verbose)

    for i in range(0, len(df), block_size):
        cls_kwargs = dict(
            index_dim=index_dim,
            index_folder=index_folder,
            ems_col=ems_col,
            block_counter=block_counter,
            verbose=verbose,
            udf_verbose=udf_verbose,
            udf_kwargs=udf_kwargs,
        )
        fn_kwargs = dict(
            task_folder=task_folder,
            index_folder=index_folder,
            index_bucket=index_bucket,
            local_index_folder=local_index_folder,
        )
        block_counter += 1
        yield dict(df=df.iloc[i : i + block_size], fn_kwargs=fn_kwargs, cls_kwargs=cls_kwargs)


def create_index_(
    data_dict,
    **kwargs,
):
    df = data_dict["df"]
    cls_kwargs = data_dict["cls_kwargs"]
    m = IndexCreator(**cls_kwargs)
    m(df)
    return data_dict["fn_kwargs"]


def create_indexes_combine(res_list, *args, **kwargs):
    task_folder = res_list[0]["task_folder"]
    index_folder = res_list[0]["index_folder"]
    index_bucket = res_list[0]["index_bucket"]
    local_index_folder = res_list[0]["local_index_folder"]
    shutil.copytree(index_folder, local_index_folder, dirs_exist_ok=True)
    bucket_up(index_folder, index_bucket, only_new=True)
    shutil.rmtree(task_folder, ignore_errors=True)
    return {"index_folder": str(index_bucket), "local_index_folder": str(local_index_folder)}


def create_indexes(
    ems_folder="embeddings",  # The folder containing the embeddings.
    ems_col="embedding",  # The column to use to create the index.
    index_dim=768,  # The dimension of the index.
    index_folder="indexes",  # The folder to write the index to.
    local_index_folder="/media/hamza/data2/faiss_data/saved_indexes",  # The local folder to write the index to.
    num_blocks=1,  # The number of blocks to make.
    block_size=40000,  # The size of each block. Will be ignored if `num_blocks` is not None.
    block_counter=0,  # The starting block counter.
    verbose=True,  # Whether to print out information.
    udf_verbose=False,  # Whether to print out information in the udf.
    udf_kwargs={},  # Additional kwargs to pass to the udf.
    task_id=gen_random_string(16),  # The task id to use.
    *args,
    **kwargs,
):
    "Function to create indexes from embeddings."
    t_1 = time()
    try:
        t1 = time()
        iterator = create_indexes_iter(**locals_to_params(locals()))
        t2 = time()
        msg.info(f"Iterator creation time: {t2-t1:.2f} seconds.", spaced=True, show=verbose)
    except Exception as e:
        msg.fail(f"Error creating iterator: {e}", spaced=True, show=verbose)
        raise e

    try:
        t1 = time()
        create_res = [create_index_(data_dict) for data_dict in iterator]
        t2 = time()
        msg.info(f"Index creation time: {t2-t1:.2f} seconds.", spaced=True, show=verbose)
    except Exception as e:
        msg.fail(f"Error creating indexes: {e}", spaced=True, show=verbose)
        raise e

    try:
        t1 = time()
        combine_res = create_indexes_combine(create_res)
        t2 = time()
        msg.info(f"Index combination time: {t2-t1:.2f} seconds.", spaced=True, show=verbose)
    except Exception as e:
        msg.fail(f"Error combining indexes: {e}", spaced=True, show=verbose)
        raise e
    t_2 = time()
    msg.good(f"Total creation time: {t_2-t_1:.2f} seconds.", spaced=True, show=verbose)
    return combine_res


def search_indexes_setup(
    task_id,
    ems,
    index_folder,
    local_index_folder="/media/hamza/data2/faiss_data/saved_indexes",
    verbose=True,
):
    task_folder = f"/tmp/{task_id}"
    # if local_index_folder is None:
    #     index_folder, _ = handle_input_path(
    #         index_folder, local_path=local_index_folder, task_id=task_id
    #     )
    # else:
    pre_index_folder, _ = get_local_path(index_folder, local_path=local_index_folder)
    if os.path.exists(pre_index_folder):
        msg.info(f"Cached Index Folder: {pre_index_folder}", spaced=True, show=verbose)
        index_folder = pre_index_folder
    else:
        index_folder, _ = handle_input_path(
            index_folder, local_path=local_index_folder, task_id=task_id
        )
    bucket_dl(ems, task_folder)
    ems_file = get_files(task_folder, extensions=[".json"])[0]
    with open(ems_file) as f:
        ems = json.load(f)["embedding"]
    indexes = sorted(
        get_files(index_folder, extensions=[".faiss"]),
        key=lambda x: int(x.stem.split("_")[0]),
    )
    if not os.path.exists(index_folder) or len(indexes) == 0:
        raise Exception(
            f"No indexes found in '{index_folder}' folder. Please create indexes first."
        )
    indexes = sorted(
        get_files(index_folder, extensions=[".faiss"]),
        key=lambda x: int(x.stem.split("_")[0]),
    )
    return ems, indexes, index_folder, task_folder


def search_indexes_iter(
    ems,  # The embedding to search. Can be pre-loaded or a path to a json file.
    index_folder,  # The remote folder containing the indexes.
    local_index_folder="/media/hamza/data2/faiss_data/saved_indexes",  # Not required if `index_folder` is local.
    k=2,  # The number of nearest neighbors to return.
    verbose=True,  # Whether to print out information.
    task_id=gen_random_string(16),  # The task id to use.
    *args,
    **kwargs,
):
    ems, indexes, index_folder, task_folder = search_indexes_setup(
        task_id, ems, index_folder, local_index_folder, verbose
    )
    for index in indexes:
        yield {
            "qdf": pd.DataFrame({"index": [index], "embedding": [ems]}),
            "k": k,
            "index_folder": index_folder,
            "task_folder": task_folder,
            "verbose": verbose,
        }


def search_index_(data_dict):
    qdf = data_dict["qdf"]
    k = data_dict["k"]
    verbose = data_dict["verbose"]
    qdf = qdf.apply(lambda x: df_index_search(x, k=k, verbose=verbose), axis=1)
    data_dict["qdf"] = qdf
    return data_dict


def search_indexes_combine(res_list, *args, **kwargs):
    r0 = res_list[0]
    k = r0["k"]
    verbose = r0["verbose"]
    index_folder = r0["index_folder"]
    task_folder = r0["task_folder"]
    qdf = pd.concat([d["qdf"] for d in res_list]).reset_index(drop=True)
    res = index_heap(qdf, k=k, verbose=verbose, with_offset=True)
    dfs = sorted(get_files(index_folder, extensions=[".csv"]), key=lambda x: int(x.stem))
    df = pd.concat([pd.read_csv(df) for df in dfs]).reset_index(drop=True)
    res["meta_data"] = df.iloc[res["ids"][0]].to_dict(orient="records")
    shutil.rmtree(task_folder)
    return res


def search_indexes(
    ems,  # The embedding to search. Can be pre-loaded or a path to a json file.
    index_folder,  # The remote folder containing the indexes.
    local_index_folder="/media/hamza/data2/faiss_data/saved_indexes",  # Not required if `index_folder` is local.
    k=2,  # The number of nearest neighbors to return.
    verbose=True,  # Whether to print out information.
    task_id=gen_random_string(16),  # The task id to use.
):
    "Function to search an embedding against indexes."

    t_1 = time()
    try:
        t1 = time()
        iterator = search_indexes_iter(**locals_to_params(locals()))
        t2 = time()
        msg.info(f"Iterator creation time: {t2-t1:.2f} seconds.", spaced=True, show=verbose)
    except Exception as e:
        msg.fail(f"Error creating iterator: {e}", spaced=True, show=verbose)
        raise e

    try:
        t1 = time()
        create_res = [search_index_(data_dict) for data_dict in iterator]
        t2 = time()
        msg.info(f"Index creation time: {t2-t1:.2f} seconds.", spaced=True, show=verbose)
    except Exception as e:
        msg.fail(f"Error creating indexes: {e}", spaced=True, show=verbose)
        raise e

    try:
        t1 = time()
        combine_res = search_indexes_combine(create_res)
        t2 = time()
        msg.info(f"Search combination time: {t2-t1:.2f} seconds.", spaced=True, show=verbose)
    except Exception as e:
        msg.fail(f"Error combining searches: {e}", spaced=True, show=verbose)
        raise e
    t_2 = time()
    msg.good(f"Total search time: {t_2-t_1:.2f} seconds.", spaced=True, show=verbose)
    return combine_res

## Usage Example

In [None]:
# | hide

# if ray.is_initialized():
#     ray.shutdown()
# ray.init()
# ray.data.DataContext.get_current().execution_options.verbose_progress = True


In [None]:
# | hide

np.random.seed(42)
data_path = Path("/media/hamza/data2/faiss_data")

In [None]:
# # | eval: false

# data_path = Path("")
# ems_folder = data_path / "ems"
# index_folder = data_path / "indexes"
# num_ems = 50
# block_size = 10
# ems_dim = 768
# random_ems(num_ems=num_ems, ems_dim=ems_dim, ems_folder=ems_folder)


In [None]:
# | eval: false

bucket = "gs://gcsfuse-talentnet-dev"

ems_folder = f"{bucket}/ems_1"
index_folder = f"{bucket}/indexes_1"
local_index_folder = "/media/hamza/data2/faiss_data/saved_indexes"
ems_dim = 768

In [None]:
# | hide
# | eval: false

bucket_del(index_folder)



[38;5;4mℹ Deleting gs://gcsfuse-talentnet-dev/indexes_1.[0m



CommandException: 1 files/objects could not be removed.


In [None]:
# | eval: false

create_res = create_indexes(
    ems_folder=ems_folder,
    index_folder=index_folder,
    local_index_folder=local_index_folder,
    index_dim=ems_dim,
    verbose=True,
    num_blocks=1,
)

create_res


[38;5;4mℹ Iterator creation time: 0.00 seconds.[0m


[38;5;4mℹ Downloading gs://gcsfuse-talentnet-dev/ems_1 to
/tmp/229c58383dd54f61/ems_1.[0m



Copying gs://gcsfuse-talentnet-dev/ems_1/resumes-4e2cdbeb-1e20-45ff-bded-a0a510350167_1.json...
Copying gs://gcsfuse-talentnet-dev/ems_1/resumes-4e2cdbeb-1e20-45ff-bded-a0a510350167_10.json...
Copying gs://gcsfuse-talentnet-dev/ems_1/resumes-4e2cdbeb-1e20-45ff-bded-a0a510350167_11.json...
Copying gs://gcsfuse-talentnet-dev/ems_1/resumes-4e2cdbeb-1e20-45ff-bded-a0a510350167_12.json...
Copying gs://gcsfuse-talentnet-dev/ems_1/resumes-4e2cdbeb-1e20-45ff-bded-a0a510350167_13.json...
Copying gs://gcsfuse-talentnet-dev/ems_1/resumes-4e2cdbeb-1e20-45ff-bded-a0a510350167_15.json...
Copying gs://gcsfuse-talentnet-dev/ems_1/resumes-4e2cdbeb-1e20-45ff-bded-a0a510350167_14.json...
Copying gs://gcsfuse-talentnet-dev/ems_1/resumes-4e2cdbeb-1e20-45ff-bded-a0a510350167_16.json...
Copying gs://gcsfuse-talentnet-dev/ems_1/resumes-4e2cdbeb-1e20-45ff-bded-a0a510350167_2.json...
Copying gs://gcsfuse-talentnet-dev/ems_1/resumes-4e2cdbeb-1e20-45ff-bded-a0a510350167_4.json...
Copying gs://gcsfuse-talentnet-de


[38;5;4mℹ Embeddings download time: 2.11 seconds.[0m


[38;5;4mℹ Embeddings DF created of length: 16[0m


[38;5;4mℹ Local Index Folder Exists:
/media/hamza/data2/faiss_data/saved_indexes/indexes_1[0m


[38;5;4mℹ Bucket Size: 4[0m


[38;5;4mℹ Block Size: 16[0m


[38;5;4mℹ Block Counter: 4[0m


[38;5;4mℹ Setup time: 2.11 seconds.[0m


[38;5;4mℹ DF BATCH SIZE: 16[0m


[38;5;4mℹ Index creation time: 2.12 seconds.[0m


[38;5;4mℹ Uploading /tmp/229c58383dd54f61/indexes_1 to
gs://gcsfuse-talentnet-dev/indexes_1.[0m



Copying file:///tmp/229c58383dd54f61/indexes_1/5_16.faiss [Content-Type=application/octet-stream]...
Copying file:///tmp/229c58383dd54f61/indexes_1/5.csv [Content-Type=text/csv]... 
/ [1/2 files][ 49.3 KiB/ 49.3 KiB]  99% Done                                    


[38;5;4mℹ Index combination time: 2.16 seconds.[0m


[38;5;2m✔ Total creation time: 4.28 seconds.[0m



- [2/2 files][ 49.3 KiB/ 49.3 KiB] 100% Done                                    
Operation completed over 2 objects/49.3 KiB.                                     


{'index_folder': 'gs://gcsfuse-talentnet-dev/indexes_1',
 'local_index_folder': Path('/media/hamza/data2/faiss_data/saved_indexes/indexes_1')}

In [None]:
# | eval: false

qems = f"{ems_folder}/resumes-4e2cdbeb-1e20-45ff-bded-a0a510350167_10.json"
res = search_indexes(
    qems,
    index_folder=index_folder,
    local_index_folder=local_index_folder,
    k=5,
    verbose=True,
)
print(f'\n\nFinal Results:\n\tDistances: {res["distances"]}\n\tIDs: {res["ids"]}')
print("\tMeta Data:")
for m in res["meta_data"]:
    print(f"\t\t{m}")



[38;5;4mℹ Iterator creation time: 0.00 seconds.[0m


[38;5;4mℹ Cached Index Folder:
/media/hamza/data2/faiss_data/saved_indexes/indexes_1[0m


[38;5;4mℹ Downloading
gs://gcsfuse-talentnet-dev/ems_1/resumes-4e2cdbeb-1e20-45ff-bded-a0a510350167_10.json
to /tmp/d8f359817d4c4f1e.[0m



Copying gs://gcsfuse-talentnet-dev/ems_1/resumes-4e2cdbeb-1e20-45ff-bded-a0a510350167_10.json...
/ [0/1 files][    0.0 B/ 16.7 KiB]   0% Done                                    

[38;5;4mℹ Index Col:
/media/hamza/data2/faiss_data/saved_indexes/indexes_1/1_4.faiss[0m
[38;5;4mℹ Index Size: 4[0m
[38;5;4mℹ Ems Shape: (1, 768)[0m
[38;5;2m✔ IDs: [[ 1  3  0  2 -1]], Distances: [[1.1570426e+00 1.2090571e+00
1.2641366e+00 1.5258880e+00 3.4028235e+38]][0m
[38;5;4mℹ Index Col:
/media/hamza/data2/faiss_data/saved_indexes/indexes_1/2_4.faiss[0m
[38;5;4mℹ Index Size: 4[0m
[38;5;4mℹ Ems Shape: (1, 768)[0m
[38;5;2m✔ IDs: [[ 3  0  1  2 -1]], Distances: [[1.2113628e+00 1.2323084e+00
1.2558432e+00 1.2952166e+00 3.4028235e+38]][0m
[38;5;4mℹ Index Col:
/media/hamza/data2/faiss_data/saved_indexes/indexes_1/3_4.faiss[0m
[38;5;4mℹ Index Size: 4[0m
[38;5;4mℹ Ems Shape: (1, 768)[0m
[38;5;2m✔ IDs: [[ 1  3  0  2 -1]], Distances: [[0.0000000e+00 9.2401117e-01
1.0372934e+00 1.2498804e+00 3.4028235e+38]][0m
[38;5;4mℹ Index Col:
/media/hamza/data2/faiss_data/saved_indexes/indexes_1/4_4.faiss[0m
[38;5;4mℹ Index Size: 4[0m
[38;5;4mℹ Ems Shape: (1, 768)[0m
[38;5;2m

/ [1/1 files][ 16.7 KiB/ 16.7 KiB] 100% Done                                    
Operation completed over 1 objects/16.7 KiB.                                     


In [None]:
# # | eval: false

# indexes_iter = create_indexes_iter(
#     ems_folder=ems_folder,
#     index_folder=index_folder,
#     index_dim=ems_dim,
#     verbose=True,
#     block_size=4,
# )

# res = [create_index_(d) for d in indexes_iter]

# indexes_up(res)


In [None]:
# #| eval: false

# qems = f"{ems_folder}/resumes-4e2cdbeb-1e20-45ff-bded-a0a510350167_10.json"
# data_dict = dict(ems=qems, index_folder=index_folder, k=5, verbose=True)
# search_iter = search_indexes_iter(data_dict=data_dict)
# res = [search_index_(d) for d in search_iter]
# res = combine_searches(res, data_dict)

# print(f'\n\nFinal Results:\n\tDistances: {res["distances"]}\n\tIDs: {res["ids"]}')
# print("\tMeta Data:")
# for m in res["meta_data"]:
#     print(f"\t\t{m}")

In [None]:
# | hide
# | eval: false

shutil.rmtree(local_index_folder, ignore_errors=True)


In [None]:
# | hide

import nbdev

nbdev.nbdev_export()