# Prune

> Prune functions and classes

In [None]:
#| default_exp prune

In [None]:
#| hide
from nbdev.showdoc import *
%load_ext autoreload
%autoreload 2

In [None]:
#| export
from emb_opt.imports import *
from emb_opt.module import Module
from emb_opt.schemas import Item, Query, Batch, PruneFunction, PruneResponse

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
#| export

class PruneModule(Module):
    def __init__(self,
                 function: PruneFunction,
                ):
        super().__init__(PruneResponse, function)
        
    def gather_inputs(self, batch: Batch) -> (List[Tuple], List[Query]):
        idxs, inputs = batch.flatten_queries()
        return (idxs, inputs)
    
    def scatter_results(self, batch: Batch, idxs: List[Tuple], results: List[PruneResponse]):
        for (q_idx, r_idx), result in zip(idxs, results):
            batch_item = batch.get_item(q_idx, r_idx)
            if result.data:
                batch_item.data.update(result.data)

            if not result.valid:
                batch_item.update_internal(removed=True, removal_reason='prune response invalid')

In [None]:
batch = Batch(queries=[
                        Query(embedding=[0.1]),
                        Query(embedding=[0.2]),
                        Query(embedding=[0.3]),
                    ])

def prune_func(queries):
    return [PruneResponse(valid=i.embedding[0]>=0.2) for i in queries]

prune_module = PruneModule(prune_func)

batch = prune_module(batch)

assert [i.internal.removed for i in batch] == [True, False, False]

In [None]:
#| export

class PrunePlugin():
    def __call__(self, inputs: List[Query]) -> List[PruneResponse]:
        pass

In [None]:
#| export

class TopKGlobalPrune():
    def __init__(self,
                 k: int,
                 agg: str='mean'
                ):
        self.k = k
        self.agg = agg
        assert self.agg in ['mean', 'max']
        
    def prune_queries(self, queries: List[Query]) -> List[PruneResponse]:
        scores = []
        for query in queries:
            result_scores = np.array([i.score for i in query.valid_results()])
            if self.agg=='mean':
                result_scores = result_scores.mean()
            elif self.agg == 'max':
                result_scores = result_scores.max()
            scores.append(result_scores)
            
        scores = np.array(scores)
        topk_idxs = set(scores.argsort()[::-1][:self.k])
        
        outputs = [PruneResponse(valid=(i in topk_idxs), data={f'{self.agg}_score':scores[i]})
                  for i in range(len(queries))]
        
        return outputs
    
    def __call__(self, queries: List[Query]) -> List[PruneResponse]:
        outputs = self.prune_queries(queries)
            
        return outputs

In [None]:
q1 = Query(embedding=[0.1])
q1.update_internal(collection_id=0)
q1.add_query_results([
    Item(embedding=[0.11], score=-10),
    Item(embedding=[0.12], score=6),
])

q2 = Query(embedding=[0.2])
q2.update_internal(collection_id=0)
q2.add_query_results([
    Item(embedding=[0.21], score=4),
    Item(embedding=[0.22], score=5),
])

q3 = Query(embedding=[0.3])
q3.update_internal(collection_id=1)
q3.add_query_results([
    Item(embedding=[0.31], score=7),
    Item(embedding=[0.32], score=8),
])

queries = [q1, q2, q3]

prune_func = TopKGlobalPrune(k=1, agg='mean')

assert [i.valid for i in prune_func(queries)] == [False, False, True]

In [None]:
class TopKPruneLocal(TopKGlobalPrune):
    def __call__(self, queries: List[Query]) -> List[PruneResponse]:
        query_groups = defaultdict(list)
        idx_groups = defaultdict(list)
        
        outputs = [None for i in queries]
        
        for i, query in enumerate(queries):
            collection_id = query.internal.collection_id
            query_groups[collection_id].append(query)
            idx_groups[collection_id].append(i)
            
        for collection_id, query_list in query_groups.items():
            prune_results = self.prune_queries(query_list)
            scatter_idxs = idx_groups[collection_id]
            
            for i, result in enumerate(prune_results):
                outputs[scatter_idxs[i]] = result
                    
        return outputs

In [None]:
q1 = Query(embedding=[0.1])
q1.update_internal(collection_id=0)
q1.add_query_results([
    Item(embedding=[0.11], score=-10),
    Item(embedding=[0.12], score=6),
])

q2 = Query(embedding=[0.2])
q2.update_internal(collection_id=0)
q2.add_query_results([
    Item(embedding=[0.21], score=4),
    Item(embedding=[0.22], score=5),
])

q3 = Query(embedding=[0.3])
q3.update_internal(collection_id=1)
q3.add_query_results([
    Item(embedding=[0.31], score=7),
    Item(embedding=[0.32], score=8),
])

queries = [q1, q2, q3]

prune_func = TopKPruneLocal(k=1, agg='max')

assert [i.valid for i in prune_func(queries)] == [True, False, True]

prune_func = TopKPruneLocal(k=1, agg='mean')

assert [i.valid for i in prune_func(queries)] == [False, True, True]