# filters

> filter functions

In [None]:
#| default_exp filter

In [None]:
#| hide
from nbdev.showdoc import *
%load_ext autoreload
%autoreload 2

In [None]:
#| export
from emb_opt.imports import *
from emb_opt.utils import QueryDataset

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
#| export
class Filter():
    def __init__(self, 
                 filter_func: Callable,
                 filter_kwargs_dict: Optional[dict]=None
                ):
        self.filter_func = filter_func
        self.filter_kwargs_dict = filter_kwargs_dict if filter_kwargs_dict else {}
        
    def __call__(self, query_dataset:QueryDataset) -> QueryDataset:
        return query_dataset.filter(lambda item: self.filter_func(item), **self.filter_kwargs_dict)

In [None]:
query_vecs = np.random.randn(2, 256)

vector_database = np.random.randn(64, 256)

dists = ((query_vecs[:,None] - vector_database[None])**2).sum(-1)**0.5
nearest = dists.argsort(-1)[:, -24:]

query_results = []

for query_idx in range(query_vecs.shape[0]):
    for db_idx in nearest[query_idx]:
        result = {
            'query_idx' : query_idx,
            'db_idx' : db_idx,
            'embedding' : vector_database[db_idx],
            'distance' : dists[query_idx, db_idx],
            'data' : {'randint': np.random.randint(0,100)}
        }

        query_results.append(result)
        
query_dataset = QueryDataset.from_list(query_results)

def simple_filter(row):
    return row['data']['randint'] < 20

f = Filter(simple_filter)
filtered_dataset = f(query_dataset)
assert len(filtered_dataset) < len(query_dataset)

f = Filter(simple_filter, {'num_proc':2})
filtered_dataset = f(query_dataset)
assert len(filtered_dataset) < len(query_dataset)

def batched_filter(batch):
    randints = np.array([i['randint'] for i in batch['data']])
    return randints < 20

f = Filter(batched_filter, {'batched':True})
filtered_dataset = f(query_dataset)
assert len(filtered_dataset) < len(query_dataset)

                                                                                

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()