# Prune

> Prune functions and classes

In [None]:
#| default_exp prune

In [None]:
#| hide
from nbdev.showdoc import *
%load_ext autoreload
%autoreload 2

In [None]:
#| export
from emb_opt.imports import *
from emb_opt.core import Module, build_batch_from_embeddings
from emb_opt.schemas import Item, Query, Batch, PruneFunction, PruneResponse

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
#| export

class PruneModule(Module):
    def __init__(self,
                 function: PruneFunction,
                ):
        super().__init__(PruneResponse, function)
        
    def gather_inputs(self, batch: Batch) -> (List[Tuple], List[Query]):
        idxs, inputs = batch.flatten_queries()
        return (idxs, inputs)
    
    def scatter_results(self, batch: Batch, idxs: List[Tuple], results: List[PruneResponse]):
        for (q_idx, r_idx), result in zip(idxs, results):
            batch_item = batch.get_item(q_idx, r_idx)
            if result.data:
                batch_item.data.update(result.data)

            if not result.valid:
                batch_item.data['_internal']['remove'] = True
                batch_item.data['_internal']['remove_details'] = 'prune result invalid'

In [None]:
batch = Batch(queries=[
                        Query(embedding=[0.1]),
                        Query(embedding=[0.2]),
                        Query(embedding=[0.3]),
                    ])

def prune_func(queries):
    return [PruneResponse(valid=i.embedding[0]>=0.2) for i in queries]

prune_module = PruneModule(prune_func)

batch = prune_module(batch)
assert [i.data['_internal'].get('remove', False) for i in batch] == [True, False, False]