# Data Source

> Data Source functions and classes

In [None]:
#| default_exp data_source

In [None]:
#| hide
from nbdev.showdoc import *
%load_ext autoreload
%autoreload 2

In [None]:
#| export
from emb_opt.imports import *
from emb_opt.utils import build_batch_from_embeddings
from emb_opt.module import Module
from emb_opt.schemas import Item, Query, Batch, DataSourceFunction, DataSourceResponse

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
#| export

class DataSourceModule(Module):
    def __init__(self, function: DataSourceFunction):
        super().__init__(DataSourceResponse, function)
        
    def gather_inputs(self, batch: Batch) -> (List[Tuple], List[Query]):
        idxs, inputs = batch.flatten_queries()
        return (idxs, inputs)
    
    def scatter_results(self, batch: Batch, idxs: List[Tuple], results: List[DataSourceResponse]):
        for (q_idx, r_idx), result in zip(idxs, results):
            batch_item = batch.get_item(q_idx, r_idx)
            if result.data:
                batch_item.data.update(result.data)
                
            if not result.valid:
                batch_item.update_internal(removed=True, removal_reason='invalid query')
                
            elif len(result.query_results)==0:
                batch_item.update_internal(removed=True, removal_reason='query returned no results')
                
            else:
                batch_item.add_query_results(result.query_results)

In [None]:
def build_batch():
    embeddings = [[0.1], [0.2], [0.3]]
    batch = build_batch_from_embeddings(embeddings)
    return batch

def data_source_test(queries: List[Query]) -> List[DataSourceResponse]:
    results = []
    for i, query in enumerate(queries):
        if i==0:
            response = DataSourceResponse(valid=False, data={'test':'test false response'},
                                         query_results=[Item.from_minimal(item='', embedding=[0.1])])
        elif i==1:
            response = DataSourceResponse(valid=True, data={'test':'test empty response'},
                                         query_results=[])
        elif i==2:
            response = DataSourceResponse(valid=True, data={'test':'test normal response'},
                                         query_results=[Item.from_minimal(item='1', embedding=[0.1]), 
                                                       Item.from_minimal(item='2', embedding=[0.2])])
        results.append(response)
    return results

batch = build_batch()
data_module = DataSourceModule(data_source_test)
batch2 = data_module(batch)
assert [i.internal.removed for i in batch2] == [True, True, False]

for q in batch2:
    for r in q:
        assert r.internal.parent_id == q.internal.id

In [None]:
#| export

class DataSourcePlugin():
    def __call__(self, inputs: List[Query]) -> List[DataSourceResponse]:
        pass

In [None]:
#| export

class NumpyDataPlugin(DataSourcePlugin):
    def __init__(self,
                 k: int,
                 item_embeddings: np.ndarray,
                 item_list: Optional[List[str]]=None,
                 item_data: Optional[List[Dict]]=None,
                 distance_metric: str='euclidean'
                ):
        
        self.k = k
        self.item_embeddings = item_embeddings
        self.item_list = item_list
        self.item_data = item_data
        self.distance_metric = distance_metric
        
    def __call__(self, inputs: List[Query]) -> List[DataSourceResponse]:
        
        queries = np.array([i.embedding for i in inputs])
        distances = cdist(queries, self.item_embeddings, metric=self.distance_metric)
        topk = distances.argsort(-1)[:, :self.k]
        
        outputs = []
        for i in range(len(inputs)):
            items = []
            query_data = {'query_distance' : []}
            for j in topk[i]:
                query_data['query_distance'].append(distances[i,j])
                
                data = dict(self.item_data[j]) if self.item_data else {}
                item_value = self.item_list[j] if self.item_list else None
                
                item = Item(embedding=self.item_embeddings[j], data=data, score=None, item=item_value)
                items.append(item)
                
            result = DataSourceResponse(valid=True, data=query_data, query_results=items)
            outputs.append(result)
            
        return outputs

In [None]:
n_vectors = 256
d_vectors = 64
k = 10
n_queries = 5

vectors = np.random.randn(n_vectors, d_vectors)
vector_data = [{'index':np.random.randint(0,1e6)} for i in range(vectors.shape[0])]
item_values = [str(i['index']) for i in vector_data]

data_function = NumpyDataPlugin(k, vectors, item_values, vector_data, distance_metric='cosine')
data_module = DataSourceModule(data_function)

batch = build_batch_from_embeddings(np.random.randn(n_queries, d_vectors))
batch2 = data_module(batch)

for q in batch2:
    for r in q:
        assert r.internal.parent_id == q.internal.id

In [None]:
#| export 

class HugggingfaceDataPlugin(DataSourcePlugin):
    def __init__(self,
                 k: int,
                 dataset: datasets.Dataset,
                 index_name: str,
                 item_name: Optional[str]=None
                ):
        
        self.k = k
        self.dataset = dataset
        self.index_name = index_name
        self.index = self.dataset.get_index(index_name)
        self.item_name = item_name
        
    def __call__(self, inputs: List[Query]) -> List[DataSourceResponse]:
        queries = np.array([i.embedding for i in inputs])
        
        res = self.index.search_batch(queries, k=self.k)
        distances = res.total_scores
        indices = res.total_indices
        
        outputs = []
        for i in range(indices.shape[0]):
            items = []
            query_data = {'query_distance' : []}
            for j in range(indices.shape[1]):
                query_data['query_distance'].append(distances[i,j])
                
                dataset_index = indices[i, j]
                item_data = dict(self.dataset[int(dataset_index)])
                embedding = item_data.pop(self.index_name)
                item = item_data.pop(self.item_name) if self.item_name else None
                
                item = Item(embedding=embedding, data=item_data, item=item, score=None)
                items.append(item)
                
            result = DataSourceResponse(valid=True, data=query_data, query_results=items)
            outputs.append(result)
            
        return outputs       

In [None]:
n_vectors = 256
d_vectors = 64
k = 10
n_queries = 5

vectors = np.random.randn(n_vectors, d_vectors)
vector_data = [{'index':np.random.randint(0,1e6), 'embedding':vectors[i]} 
               for i in range(vectors.shape[0])]

dataset = Dataset.from_list(vector_data)
dataset.add_faiss_index('embedding')

data_function = HugggingfaceDataPlugin(k, dataset, 'embedding', 'index')
data_module = DataSourceModule(data_function)

batch = build_batch_from_embeddings(np.random.randn(n_queries, d_vectors))
batch2 = data_module(batch)

for q in batch2:
    for r in q:
        assert r.internal.parent_id == q.internal.id

100%|███████████████████████████████████████████| 1/1 [00:00<00:00, 1957.21it/s]
