# Huggingface Plugins

> Huggingface datasets functions and classes

In [None]:
#| default_exp plugins.huggingface

In [None]:
#| hide
from nbdev.showdoc import *
%load_ext autoreload
%autoreload 2

In [None]:
#| export
from emb_opt.imports import *
from emb_opt.schemas import Query, Item, DataSourceResponse
from emb_opt.data_source import DataSourcePlugin, DataSourceModule
from emb_opt.executor import Executor
from emb_opt.utils import build_batch_from_embeddings

import datasets
from datasets import Dataset

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
#| export

class DatasetExecutor(Executor):
    '''
    DatasetExecutor - executes function in parallel 
    using `Dataset.map`
    '''
    def __init__(self,
                 function: Callable,              # function to be wrapped
                 batched: bool,                   # if inputs should be batched
                 batch_size: int=1,               # batch size (set batch_size=0 to pass all inputs)
                 concurrency: Optional[int]=1,    # number of concurrent threads
                 map_kwargs: Optional[dict]=None  # kwargs for `Dataset.map`
                ):
        
        self.function = function
        self.batched = batched
        self.concurrency = concurrency
        self.batch_size = batch_size
        self.map_kwargs = map_kwargs if map_kwargs else {}
        
    def batch_inputs(self, inputs: List[BaseModel]):
        dataset = datasets.Dataset.from_list([i.model_dump() for i in inputs])
        return dataset
            
    def unbatch_inputs(self, dataset):
        return dataset.to_list()

    def execute(self, dataset):
        dataset = dataset.map(lambda row: self.function(row), batched=self.batched, 
                             batch_size=self.batch_size, num_proc=self.concurrency, **self.map_kwargs)
        return dataset

In [None]:
class TestInput(BaseModel):
    value: float
        
class TestOutput(BaseModel):
    result: bool

def test_function_hf(input: dict) -> dict:
    return {'result' : input['value']>0.5}

def test_function_hf_batched(input: dict) -> dict:
    return {'result' : [i>0.5 for i in input['value']]}
        
np.random.seed(42)
values = np.random.uniform(size=100).tolist()
inputs = [TestInput(value=i) for i in values]
expected_outputs = [TestOutput(result=i>0.5) for i in values]

# dataset

executor = DatasetExecutor(test_function_hf, batched=False, concurrency=None, batch_size=1)
res11 = executor(inputs)
assert [TestOutput.model_validate(i) for i in res11] == expected_outputs

executor = DatasetExecutor(test_function_hf, batched=False, concurrency=2, batch_size=1)
res12 = executor(inputs)
assert [TestOutput.model_validate(i) for i in res12] == expected_outputs

executor = DatasetExecutor(test_function_hf_batched, batched=True, concurrency=2, batch_size=5)
res13 = executor(inputs)
assert [TestOutput.model_validate(i) for i in res13] == expected_outputs

executor = DatasetExecutor(test_function_hf_batched, batched=True, concurrency=None, batch_size=5)
res14 = executor(inputs)
assert [TestOutput.model_validate(i) for i in res14] == expected_outputs

                                                                                

In [None]:
#| export 

class HugggingfaceDataPlugin(DataSourcePlugin):
    '''
    HugggingfaceDataPlugin - data plugin for working with 
    huggingface datasets library.
    
    The input `dataset` should have a faiss embedding index 
    denoted by `index_name`.
    
    The data query will run `k` nearest neighbors against the 
    dataset index based on the metric used to create the index
    
    Optionally, `item_key` denotes the column in `dataset` defining a 
    specific item, and `id_key` denotes the column defining an item's ID
    '''
    def __init__(self,
                 k: int,                       # k nearest neighbors to return
                 dataset: datasets.Dataset,    # input dataset
                 index_name: str,              # name of the faiss index in `dataset`
                 item_key: Optional[str]=None, # dataset column denoting item value
                 id_key: Optional[str]=None    # dataset column denoting item id
                ):
        
        self.k = k
        self.dataset = dataset
        self.index_name = index_name
        self.index = self.dataset.get_index(index_name)
        self.item_key = item_key
        self.id_key = id_key
        
    def __call__(self, inputs: List[Query]) -> List[DataSourceResponse]:
        queries = np.array([i.embedding for i in inputs])
        
        res = self.index.search_batch(queries, k=self.k)
        distances = res.total_scores
        indices = res.total_indices
        
        outputs = []
        for i in range(indices.shape[0]):
            items = []
            query_data = {'query_distance' : []}
            for j in range(indices.shape[1]):
                query_data['query_distance'].append(distances[i,j])
                
                dataset_index = indices[i, j]
                item_data = dict(self.dataset[int(dataset_index)])
                embedding = item_data.pop(self.index_name)
                item = item_data.pop(self.item_key) if self.item_key else None
                item_id = item_data.pop(self.id_key) if self.id_key else None
                
                item = Item(id=item_id, 
                            item=item,
                            embedding=embedding, 
                            data=item_data, 
                            score=None)
                items.append(item)
                
            result = DataSourceResponse(valid=True, data=query_data, query_results=items)
            outputs.append(result)
            
        return outputs       

In [None]:
n_vectors = 256
d_vectors = 64
k = 10
n_queries = 5

vectors = np.random.randn(n_vectors, d_vectors)
# vector_data = [{'index':np.random.randint(0,1e6), 'embedding':vectors[i]} 
#                for i in range(vectors.shape[0])]

vector_data = [{'index':str(np.random.randint(0,1e6)), 
                'other':np.random.randint(0,1e3), 
                'item':str(np.random.randint(0,1e4)),
                'embedding':vectors[i]
               } 
               for i in range(vectors.shape[0])]

dataset = Dataset.from_list(vector_data)
dataset.add_faiss_index('embedding')

data_function = HugggingfaceDataPlugin(k, dataset, 'embedding', 'item', 'index')
data_module = DataSourceModule(data_function)

batch = build_batch_from_embeddings(np.random.randn(n_queries, d_vectors))
batch2 = data_module(batch)

for q in batch2:
    for r in q:
        assert r.internal.parent_id == q.id

100%|███████████████████████████████████████████| 1/1 [00:00<00:00, 1687.17it/s]
