# Fauss Plugins

> Faiss functions and classes

In [None]:
#| default_exp plugins.faiss

In [None]:
#| hide
from nbdev.showdoc import *
%load_ext autoreload
%autoreload 2

In [None]:
#| export
from emb_opt.imports import *
from emb_opt.schemas import Query, Item, DataSourceResponse
from emb_opt.data_source import DataSourcePlugin, DataSourceModule
from emb_opt.utils import build_batch_from_embeddings

try:
    import faiss
except:
    warnings.warn('Failed to import faiss library - check package install')

### Faiss Data Plugin

The `FaissDataPlugin` integrates with a [faiss](https://github.com/facebookresearch/faiss) index.

`search_params` can be any params object compatible with faiss `search`

In [None]:
#| export

class FaissDataPlugin(DataSourcePlugin):
    '''
    FaissDataPlugin - data plugin for working with 
    a faiss vector index
    
    The data query will run `k` nearest neighbors against 
    `faiss_index`
    
    Optionally, `item_data` can be provided as a list of dicts, where 
    `item_data[i]` corresponds to the data for embedding `i` in the 
    faiss index
    
    If `item_data` is provided `item_data[i]['item_key']` defines the 
    specific value for item `i`
    
    `search_params` are optional kwargs sent to 
    `faiss.SearchParameters`
    '''
    def __init__(self, 
                 k: int,                                               # k nearest neighbors to return
                 faiss_index: faiss.Index,                             # faiss index
                 item_data: Optional[List[Dict]]=None,                 # Optional dict of item data
                 item_key: Optional[str]=None,                         # Optional key for item value (should be in `item_data` dict)
                 search_params: Optional[faiss.SearchParameters]=None  # faiss search params
                ):
        
        self.k = k
        self.faiss_index = faiss_index
        self.item_data = item_data
        self.item_key = item_key
        self.search_params = search_params
        
    def __call__(self, inputs: List[Query]) -> List[DataSourceResponse]:
        query_vectors = np.array([i.embedding for i in inputs])
        
        distances, indices = self.faiss_index.search(query_vectors, self.k, params=self.search_params)
        
        outputs = []
        for query_idx in range(indices.shape[0]):
            result_embeddings = self.faiss_index.reconstruct_batch(indices[query_idx])
            items = []
            query_data = {'query_distance' : []}
            
            for result_idx in range(indices.shape[1]):
                item_id = indices[query_idx, result_idx]
                item_embedding = result_embeddings[result_idx]
                query_distance = distances[query_idx, result_idx]
                
                if item_id != -1:
                    item_data = None
                    item_value = None
                    
                    if self.item_data:
                        item_data = dict(self.item_data[item_id])
                        if self.item_key:
                            item_value = item_data.pop(self.item_key)
                            
                    item = Item(id=item_id,
                                item=item_value,
                                embedding=item_embedding,
                                data=item_data,
                                score=None
                               )
                    items.append(item)
                    query_data['query_distance'].append(query_distance)
                    
            result = DataSourceResponse(valid=bool(items), data=query_data, query_results=items)
            outputs.append(result)
        return outputs

In [None]:
n_vectors = 256
d_vectors = 64
k = 10
n_queries = 5

vectors = np.random.randn(n_vectors, d_vectors)

vector_data = [{'other':np.random.randint(0,1e3), 
                'item':str(np.random.randint(0,1e4))} 
               for i in range(vectors.shape[0])]

index = faiss.IndexFlatL2(d_vectors)
index.add(vectors)

data_function = FaissDataPlugin(5, index, vector_data, 'item')
data_module = DataSourceModule(data_function)

batch = build_batch_from_embeddings(np.random.randn(n_queries, d_vectors))
batch2 = data_module(batch)