# Data Source

> Data Source functions and classes

In [None]:
#| default_exp data_source

In [None]:
#| hide
from nbdev.showdoc import *
%load_ext autoreload
%autoreload 2

In [None]:
#| export
from emb_opt.imports import *
from emb_opt.core import Module, build_batch_from_embeddings
from emb_opt.schemas import Item, Query, Batch, DataSourceFunction, DataSourceResponse

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
#| export

class DataSourceModule(Module):
    def __init__(self,
                 function: DataSourceFunction,
                ):
        super().__init__(DataSourceResponse, function)
        
    def gather_inputs(self, batch: Batch) -> (List[Tuple], List[Query]):
        idxs, inputs = batch.flatten_queries()
        return (idxs, inputs)
    
    def scatter_results(self, batch: Batch, idxs: List[Tuple], results: List[DataSourceResponse]):
        for (q_idx, r_idx), result in zip(idxs, results):
            batch_item = batch.get_item(q_idx, r_idx)
            if result.data:
                batch_item.data.update(result.data)

            if result.valid:
                if result.query_results:
                    batch_item.add_query_results(result.query_results)

                else:
                    batch_item.data['_internal']['remove'] = True
                    batch_item.data['_internal']['remove_details'] = 'query returned no results'

            else:
                batch_item.data['_internal']['remove'] = True
                batch_item.data['_internal']['remove_details'] = 'query response invalid'

In [None]:
#| export

class NumpyPlugin():
    def __init__(self, embeddings, k, embedding_data=None, distance_metric='euclidean'):
        self.embeddings = embeddings
        self.distance_metric = distance_metric
        self.k = k
        self.embedding_data = embedding_data

    def __call__(self, inputs: List[Query]) -> List[DataSourceResponse]:
        queries = np.array([i.embedding for i in inputs])
        
        distances = cdist(queries, self.embeddings, metric=self.distance_metric)
        topk = distances.argsort(-1)[:, :self.k]
        
        outputs = []
        for i in range(len(inputs)):
            items = []
            for j in topk[i]:
                item_data = dict(self.embedding_data[j]) if self.embedding_data else None
                item = Item(embedding=self.embeddings[j], data=item_data)
                item.data['_internal']['query_distance'] = distances[i,j]
                items.append(item)
            result = DataSourceResponse(valid=True, data=None, query_results=items)
            outputs.append(result)
            
        return outputs

In [None]:
vectors = np.random.randn(128, 256)
vector_data = [{'index':np.random.randint(0,1e6)} for i in range(vectors.shape[0])]
data_source = NumpyPlugin(vectors, 5, embedding_data=vector_data, distance_metric='cosine')
data_plugin = DataSourceModule(data_source)

In [None]:
query_vecs = np.random.randn(5, 256)
batch = build_batch_from_embeddings(query_vecs)

In [None]:
[len(i) for i in batch]

[0, 0, 0, 0, 0]

In [None]:
batch = data_plugin(batch)

In [None]:
[len(i) for i in batch]

[5, 5, 5, 5, 5]

In [None]:
batch[0][0].data

{'index': 946302,
 '_internal': {'id': '332f5a6f-4c5b-11ee-b64b-7b1d5a84b1d4',
  'query_distance': 0.865745217252099,
  'parent': '332f5a6a-4c5b-11ee-b64b-7b1d5a84b1d4',
  'collection_index': 0}}

In [None]:
#| export

class HugggingfaceDataPlugin():
    def __init__(self, dataset, index_name, k):
        self.dataset = dataset
        self.index_name = index_name
        self.k = k
        self.index = self.dataset.get_index(index_name)
        
    def __call__(self, inputs: List[Query]) -> List[DataSourceResponse]:
        queries = np.array([i.embedding for i in inputs])
        
        res = self.index.search_batch(queries, k=self.k)

        distances = res.total_scores
        indices = res.total_indices
        
        outputs = []
        for i in range(indices.shape[0]):
            items = []
            for j in range(indices.shape[1]):
                db_idx = indices[i, j]

                data_dict = dict(self.dataset[int(db_idx)])
                embedding = data_dict.pop(self.index_name)

                item = Item(embedding=embedding, data=data_dict)
                item.data['_internal']['query_distance'] = distances[i, j]
                items.append(item)

            result = DataSourceResponse(valid=True, data=None, query_results=items)
            outputs.append(result)
            
        return outputs

In [None]:
dataset = Dataset.from_list([{'embedding':vectors[i], 'index':vector_data[i]['index']} 
                             for i in range(vectors.shape[0])])
dataset.add_faiss_index('embedding')

100%|███████████████████████████████████████████| 1/1 [00:00<00:00, 1659.14it/s]


Dataset({
    features: ['embedding', 'index'],
    num_rows: 128
})

In [None]:
data_source = HugggingfaceDataPlugin(dataset, 'embedding', 5)

In [None]:
data_plugin = DataSourceModule(data_source)

In [None]:
batch2 = build_batch_from_embeddings(query_vecs)

In [None]:
[len(i) for i in batch2]

[0, 0, 0, 0, 0]

In [None]:
batch2 = data_plugin(batch2)

In [None]:
[len(i) for i in batch2]

[5, 5, 5, 5, 5]

In [None]:
batch2[0][0].data

{'index': 222512,
 '_internal': {'id': '3829e9a9-4c5b-11ee-b64b-7b1d5a84b1d4',
  'query_distance': 428.08813,
  'parent': '3829e9a4-4c5b-11ee-b64b-7b1d5a84b1d4',
  'collection_index': 0}}

In [None]:
batch[0][0].data

{'index': 576779,
 '_internal': {'id': '300b4b89-4c2c-11ee-b64b-7b1d5a84b1d4',
  'query_distance': 0.8676547269554138,
  'parent': '300b4b84-4c2c-11ee-b64b-7b1d5a84b1d4',
  'collection_index': 0}}