# Mappers

> Mapper classes for indexing and searching.

In [None]:
# | default_exp index.mappers


In [None]:
# | export

from dreamai_ray.imports import *
from dreamai_ray.utils import *
from dreamai_ray.mapper import *
from dreamai_ray.index.utils import *
from dreamai_ray.index.df import *


In [None]:
#| hide

%load_ext autoreload
%autoreload 2
%reload_ext autoreload

In [None]:
# | export


from dreamai_ray.imports import noop


class write_index_cb(Callback):
    "A `Callback` to write the index to disk."

    def __init__(self, verbose=False) -> None:
        self.verbose = verbose

    def after_batch(self, cls, **kwargs):
        cls.index = cls.udf_kwargs["index"]
        index_folder = cls.index_folder
        os.makedirs(index_folder, exist_ok=True)
        index_path = str(Path(index_folder) / f"{cls.block_counter}.faiss")
        if self.verbose and cls.verbose:
            msg.info(f"Writing Index to {index_path}")
            msg.info(f"Index Size: {cls.index.ntotal}")
        faiss.write_index(cls.index, index_path)


class reset_index_cb(Callback):
    "A `Callback` to reset the index."

    def __init__(self, verbose=True) -> None:
        self.verbose = verbose

    def after_batch(self, cls, **kwargs):
        cls.index.reset()
        if self.verbose and cls.verbose:
            msg.info(f"Index Size Post Reset: {cls.index.ntotal}")
        cls.udf_kwargs["index"] = cls.index
        cls.udf = partial(cls.udf, **cls.udf_kwargs)


class CreateIndex(Mapper):
    """
    Creates an index from embeddings.
    """

    def __init__(
        self,
        index_dim=3,  # The dimension of the index.
        index_folder="indexes",  # The folder to write the index to.
        ems_col="embedding",  # The column to use to create the index.
        udf=df_to_index,  # The function to use to create the index.
        cbs=[write_index_cb, reset_index_cb],  # The `Callback`s to use.
        verbose=True,  # Whether to print out information.
        udf_kwargs={},  # Additional kwargs to pass to the udf.
        **kwargs,
    ):
        self.index_folder = index_folder
        self.index = create_index(index_dim)
        udf_kwargs["index"] = self.index
        udf_kwargs["ems_col"] = ems_col
        udf_kwargs["verbose"] = verbose
        self.verbose = verbose
        super().__init__(**locals_to_params(locals()))


## Usage Example

In [None]:
# | eval: false

index_dim = 768
index_folder = "indexes"
shutil.rmtree(index_folder, ignore_errors=True)
m = CreateIndex(index_dim=index_dim, index_folder=index_folder, verbose=False)



[38;5;4mℹ BLOCK COUNTER: 0[0m



In [None]:
# | eval: false

np.random.seed(42)
num_ems = 100
block_size = 25
ems = [np.random.random((1, index_dim))[0].tolist() for i in range(num_ems)]
df = pd.DataFrame({"embedding": ems})

In [None]:
# | eval: false

for i in range(0, num_ems, block_size):
    df_block = df.iloc[i : i + block_size]
    m(df_block)



[38;5;4mℹ DF BATCH SIZE: 25[0m


[38;5;4mℹ BLOCK COUNTER: 1[0m


[38;5;4mℹ DF BATCH SIZE: 25[0m


[38;5;4mℹ BLOCK COUNTER: 2[0m


[38;5;4mℹ DF BATCH SIZE: 25[0m


[38;5;4mℹ BLOCK COUNTER: 3[0m


[38;5;4mℹ DF BATCH SIZE: 25[0m


[38;5;4mℹ BLOCK COUNTER: 4[0m



In [None]:
# | eval: false

qems = ems[60]
indexes = sorted(get_files(index_folder), key=lambda x: int(x.stem.split(".")[0]))
qdf = pd.DataFrame(
    {
        "index": indexes,
        "embedding": [qems] * len(indexes),
    }
)

qdf = qdf.apply(lambda x: df_index_search(x, k=2, verbose=False), axis=1)
qdf[:5]


Unnamed: 0,index,embedding,index_size,distances,ids
0,indexes/1.faiss,"[0.8567253358051695, 0.9884348472741084, 0.12180743223842172, 0.6510805836569036, 0.0024137544009764644, 0.04902948617458236, 0.2850010200860553, 0.8464040844196341, 0.6525009014377374, 0.44582093197232686, 0.4385984736023928, 0.46605357435312667, 0.3110331078010038, 0.8777145176255143, 0.5678691311997973, 0.48394996352585284, 0.46507660315710386, 0.9358264668979426, 0.7236619282196818, 0.14481455471658566, 0.4369701699432651, 0.2883722104145233, 0.9868686577301626, 0.4153957063131808, 0.06787010302978236, 0.41121200146280434, 0.3213502806059353, 0.5995826470782478, 0.22607817522902984, 0....",25,"[[114.96378, 121.50348]]","[[9, 21]]"
1,indexes/2.faiss,"[0.8567253358051695, 0.9884348472741084, 0.12180743223842172, 0.6510805836569036, 0.0024137544009764644, 0.04902948617458236, 0.2850010200860553, 0.8464040844196341, 0.6525009014377374, 0.44582093197232686, 0.4385984736023928, 0.46605357435312667, 0.3110331078010038, 0.8777145176255143, 0.5678691311997973, 0.48394996352585284, 0.46507660315710386, 0.9358264668979426, 0.7236619282196818, 0.14481455471658566, 0.4369701699432651, 0.2883722104145233, 0.9868686577301626, 0.4153957063131808, 0.06787010302978236, 0.41121200146280434, 0.3213502806059353, 0.5995826470782478, 0.22607817522902984, 0....",25,"[[121.56019, 123.57507]]","[[17, 7]]"
2,indexes/3.faiss,"[0.8567253358051695, 0.9884348472741084, 0.12180743223842172, 0.6510805836569036, 0.0024137544009764644, 0.04902948617458236, 0.2850010200860553, 0.8464040844196341, 0.6525009014377374, 0.44582093197232686, 0.4385984736023928, 0.46605357435312667, 0.3110331078010038, 0.8777145176255143, 0.5678691311997973, 0.48394996352585284, 0.46507660315710386, 0.9358264668979426, 0.7236619282196818, 0.14481455471658566, 0.4369701699432651, 0.2883722104145233, 0.9868686577301626, 0.4153957063131808, 0.06787010302978236, 0.41121200146280434, 0.3213502806059353, 0.5995826470782478, 0.22607817522902984, 0....",25,"[[0.0, 120.62366]]","[[10, 14]]"
3,indexes/4.faiss,"[0.8567253358051695, 0.9884348472741084, 0.12180743223842172, 0.6510805836569036, 0.0024137544009764644, 0.04902948617458236, 0.2850010200860553, 0.8464040844196341, 0.6525009014377374, 0.44582093197232686, 0.4385984736023928, 0.46605357435312667, 0.3110331078010038, 0.8777145176255143, 0.5678691311997973, 0.48394996352585284, 0.46507660315710386, 0.9358264668979426, 0.7236619282196818, 0.14481455471658566, 0.4369701699432651, 0.2883722104145233, 0.9868686577301626, 0.4153957063131808, 0.06787010302978236, 0.41121200146280434, 0.3213502806059353, 0.5995826470782478, 0.22607817522902984, 0....",25,"[[124.59711, 125.72464]]","[[13, 15]]"


In [None]:
# | eval: false

res = index_heap(qdf, k=2, verbose=False)

print(f'\n\nFinal Results:\n\tDistances: {res["distances"]}\n\tIDs: {res["ids"]}')




Final Results:
	Distances: [[  0.      114.96378]]
	IDs: [[60  9]]


In [None]:
# | hide

import nbdev

nbdev.nbdev_export()