# Filter

> Filter functions and classes

In [None]:
#| default_exp filter

In [None]:
#| hide
from nbdev.showdoc import *
%load_ext autoreload
%autoreload 2

In [None]:
#| export
from emb_opt.imports import *
from emb_opt.module import Module
from emb_opt.schemas import Item, Query, Batch, FilterFunction, FilterResponse

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
#| export

class FilterModule(Module):
    def __init__(self,
                 function: FilterFunction
                ):
        super().__init__(FilterResponse, function)
        
    def gather_inputs(self, batch: Batch) -> (List[Tuple], List[Item]):
        idxs, inputs = batch.flatten_query_results()
        return (idxs, inputs)
        
    def scatter_results(self, batch: Batch, idxs: List[Tuple], results: List[FilterResponse]):
        for (q_idx, r_idx), result in zip(idxs, results):
            batch_item = batch.get_item(q_idx, r_idx)
            if result.data:
                batch_item.data.update(result.data)
                
            if not result.valid:
                batch_item.update_internal(removed=True, removal_reason='filter response invalid')
                
        for query in batch:
            query.update_internal()

In [None]:
def build_batch(cutoff=10.5):
    d_emb = 128
    n_emb = 100
    np.random.seed(42)
    
    embeddings = np.random.randn(n_emb+1, d_emb)
    query = Query(embedding=embeddings[-1])
    results = [Item(embedding=embeddings[i], data={'id':i}) for i in range(n_emb)]
    query.add_query_results(results)
    batch = Batch(queries=[query])
    expected_failures = [i.data['id'] for i in results if np.linalg.norm(i.embedding)>=cutoff]
    return batch, expected_failures

class NormFilter():
    def __init__(self, cutoff=10.5):
        self.cutoff = cutoff
        
    def __call__(self, inputs: List[Item]) -> List[FilterResponse]:
        
        embeddings = np.array([i.embedding for i in inputs])
        norms = np.linalg.norm(embeddings, axis=-1)
        results = [FilterResponse(valid=i<self.cutoff, data={'norm':i}) for i in norms]
        return results
    
filter_func = NormFilter()
filter_module = FilterModule(filter_func)

batch, fails = build_batch()
batch2 = filter_module(batch)

assert len(batch2.flatten_query_results(skip_removed=True)[1]) == len(batch[0])-len(fails)

for i in range(len(batch[0])):
    result = batch[0][i]
    if i in fails:
        assert result.internal.removed
    else:
        assert not result.internal.removed
        
    assert result.internal.parent_id == batch[0].internal.id
    
batch, fails = build_batch()
filter_func = NormFilter(cutoff=-1)
filter_module = FilterModule(filter_func)
batch, fails = build_batch(cutoff=-1)
batch2 = filter_module(batch)

assert batch2[0].internal.removed

In [None]:
#| export

class FilterPlugin():
    def __call__(self, inputs: List[Query]) -> List[FilterResponse]:
        pass