# Filter

> Filter functions and classes

In [None]:
#| default_exp filter

In [None]:
#| hide
from nbdev.showdoc import *
%load_ext autoreload
%autoreload 2

In [None]:
#| export
from emb_opt.imports import *
from emb_opt.core import Module, Executor
from emb_opt.schemas import Item, Query, Batch, FilterFunction, FilterResponse

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
#| export

class FilterModule(Module):
    def __init__(self,
                 function: FilterFunction
                ):
        super().__init__(FilterResponse, function)
        
    def gather_inputs(self, batch: Batch) -> (List[Tuple], List[Item]):
        idxs, inputs = batch.flatten_query_results()
        return (idxs, inputs)
        
    def scatter_results(self, batch: Batch, idxs: List[Tuple], results: List[FilterResponse]):
        for (q_idx, r_idx), result in zip(idxs, results):
            batch_item = batch.get_item(q_idx, r_idx)
            if result.data:
                batch_item.data.update(result.data)
                
            if not result.valid:
                batch_item.data['_internal']['remove'] = True
                batch_item.data['_internal']['remove_details'] = 'filter response invalid'

In [None]:
def build_batch():
    np.random.seed(42)
    d_emb = 128
    query = Query(embedding=np.random.randn(d_emb))
    query_results = [Item(embedding=np.random.randn(d_emb), data={'id':i}) for i in range(100)]
    failed_ids = [i.data['id'] for i in query_results if np.linalg.norm(i.embedding)>=10.5]
    query.add_query_results(query_results)
    batch = Batch(queries=[query])
    return batch, failed_ids

def test_filter(filter_module):
    batch, failed_ids = build_batch()
    batch = filter_module(batch)
    removed = batch.clean_results()
    removed_ids = [i.data['id'] for i in removed]
    assert removed_ids == failed_ids
    
def norm_filter(input: Item):
    embedding = np.array(input.embedding)
    norm = np.linalg.norm(embedding)
    return FilterResponse(valid=norm<10.5, data={'norm':norm})

def norm_filter_batched(inputs: List[Item]):
    embeddings = np.array([i.embedding for i in inputs])
    norms = np.linalg.norm(embeddings, axis=-1)
    results = [FilterResponse(valid=i<10.5, data={'norm':i}) for i in norms]
    return results

func = Executor(norm_filter, batched=False)
filter_module = FilterModule(func)
test_filter(filter_module)

filter_module = FilterModule(norm_filter_batched)
test_filter(filter_module)

func = Executor(norm_filter_batched, batched=True, batch_size=5)
filter_module = FilterModule(func)
test_filter(filter_module)