# Prune

> Prune functions and classes

In [None]:
#| default_exp prune

In [None]:
#| hide
from nbdev.showdoc import *
%load_ext autoreload
%autoreload 2

In [None]:
#| export
from emb_opt.imports import *
from emb_opt.module import Module
from emb_opt.schemas import Item, Query, Batch, PruneFunction, PruneResponse

The Prune step optionally removes queries prior to the update step. A Prune step allows for control over the total number of queries in the scenario where the Update step generates multiple output queries for each input.

The prune step is formalized by the `PruneFunction` schema, which maps inputs `List[Query]` to outputs `List[PruneResponse]`.

The `PruneModule` manages execution of a `PruneFunction`. The `PruneModule` gathers valid items, sends them to the `PruneFunction`, and processes the results.

In [None]:
#| export

class PruneModule(Module):
    def __init__(self,
                 function: PruneFunction # prune function
                ):
        super().__init__(PruneResponse, function)
        
    def gather_inputs(self, batch: Batch) -> (List[Tuple], List[Query]):
        idxs, inputs = batch.flatten_queries()
        return (idxs, inputs)
    
    def scatter_results(self, batch: Batch, idxs: List[Tuple], results: List[PruneResponse]):
        for (q_idx, r_idx), result in zip(idxs, results):
            batch_item = batch.get_item(q_idx, r_idx)
            if result.data:
                batch_item.data.update(result.data)

            if not result.valid:
                batch_item.update_internal(removed=True, removal_reason='prune response invalid')

In [None]:
batch = Batch(queries=[
                        Query.from_minimal(embedding=[0.1]),
                        Query.from_minimal(embedding=[0.2]),
                        Query.from_minimal(embedding=[0.3]),
                    ])

def prune_func(queries):
    return [PruneResponse(valid=i.embedding[0]>=0.2, data=None) for i in queries]

prune_module = PruneModule(prune_func)

batch = prune_module(batch)

assert [i.internal.removed for i in batch] == [True, False, False]

In [None]:
#| export

class PrunePlugin():
    '''
    PrunePlugin - documentation for plugin functions to `PruneFunction`
    
    A valid `PruneFunction` is any function that maps `List[Query]` to 
    `List[PruneResponse]`. The inputs will be given as `Query` objects. 
    The outputs can be either a list of `PruneResponse` objects or a list of 
    valid json dictionaries that match the `PruneResponse` schema
    
    The Prune step is called after scoring, so each result `Item` in the 
    input queries will have a score assigned
    
    Item schema:
    
    `{
        'id' : Optional[Union[str, int]]
        'item' : Optional[Any],
        'embedding' : List[float],
        'score' : float,
        'data' : Optional[Dict],
    }`
    
    
    Query schema:
    
    `{
        'item' : Optional[Any],
        'embedding' : List[float],
        'data' : Optional[Dict],
        'query_results': List[Item]
    }`
    
    Input schema:
    
    `List[Query]`
    
    PruneResponse schema:
    
    `{
        'valid' : bool,
        'data' : Optional[Dict],
    }`
    
    Output schema:
    
    `List[PruneResponse]`
    
    '''
    def __call__(self, inputs: List[Query]) -> List[PruneResponse]:
        pass

In [None]:
#| export

class TopKPrune():
    '''
    TopKPrune - keeps the top `k` best queries in each group 
    by aggregated score
    
    queries are first grouped by `group_by`
    * if `group_by=None`, all queries are considered the same group (global pruning)
    * if `group_by='parent_id'`, queries are grouped by parent query id
    * if `group_by='collection_id', queries are grouped by collection id
    
    queries are then assigned a score based on aggregating query result scores
    * if `score_agg='mean'`, each `Query` is scored by the average score of all `Item` results
    * if `score_agg='max'`, each `Query` is scored by the max scoring `Item` result
    '''
    def __init__(self,
                 k: int,
                 score_agg: str='mean', # ['mean', 'max']
                 group_by: Optional[str]='collection_id' # [None, 'collection_id', 'parent_id']
                ):
        self.k = k
        self.score_agg = score_agg
        self.group_by = group_by
        
        assert self.score_agg in ['mean', 'max']
        assert self.group_by in [None, 'collection_id', 'parent_id']
        
    def agg_scores(self, query: Query):
        result_scores = np.array([i.score for i in query.valid_results()])
        if self.score_agg == 'mean':
            result_scores = result_scores.mean()
        elif self.score_agg == 'max':
            result_scores = result_scores.max()
        return result_scores
    
    def get_group_key(self, query: Query):
        key = None
        if self.group_by == 'collection_id':
            key = query.internal.collection_id
        elif self.group_by == 'parent_id':
            key = query.internal.parent_id
        return key
        
    def prune_queries(self, queries: List[Query]) -> List[PruneResponse]:
        scores = []
        for query in queries:
            scores.append(self.agg_scores(query))
            
        scores = np.array(scores)
        topk_idxs = set(scores.argsort()[::-1][:self.k])
        
        outputs = [PruneResponse(valid=(i in topk_idxs), data={f'{self.score_agg}_score':scores[i]})
                  for i in range(len(queries))]
        
        return outputs
    
    def group_queries(self, queries: List[Query]) -> (dict, dict):
        query_groups = defaultdict(list)
        idx_groups = defaultdict(list)
        
        for query_idx, query in enumerate(queries):
            key = self.get_group_key(query)
            query_groups[key].append(query)
            idx_groups[key].append(query_idx)
            
        return idx_groups, query_groups
    
    def scatter_queries(self, 
                        idx_groups: dict, 
                        query_groups: dict, 
                        total_queries: int)  -> List[PruneResponse]:
        
        outputs = [None for i in range(total_queries)]
        
        for key, query_list in query_groups.items():
            prune_results = self.prune_queries(query_list)
            scatter_idxs = idx_groups[key]
            
            for i, result in enumerate(prune_results):
                outputs[scatter_idxs[i]] = result
                
        return outputs
    
    def __call__(self, queries: List[Query]) -> List[PruneResponse]:
        idx_groups, query_groups = self.group_queries(queries)
        results = self.scatter_queries(idx_groups, query_groups, len(queries))
        return results

In [None]:
q1 = Query.from_minimal(embedding=[0.1])
q1.update_internal(collection_id=0)
q1.add_query_results([
    Item.from_minimal(embedding=[0.11], score=-10),
    Item.from_minimal(embedding=[0.12], score=6),
])

q2 = Query.from_minimal(embedding=[0.2])
q2.update_internal(collection_id=0)
q2.add_query_results([
    Item.from_minimal(embedding=[0.21], score=4),
    Item.from_minimal(embedding=[0.22], score=5),
])

q3 = Query.from_minimal(embedding=[0.3])
q3.update_internal(collection_id=1)
q3.add_query_results([
    Item.from_minimal(embedding=[0.31], score=7),
    Item.from_minimal(embedding=[0.32], score=8),
])

queries = [q1, q2, q3]

prune_func = TopKPrune(k=1, score_agg='mean', group_by=None)

assert [i.valid for i in prune_func(queries)] == [False, False, True]

In [None]:
q1 = Query.from_minimal(embedding=[0.1])
q1.update_internal(collection_id=0)
q1.add_query_results([
    Item.from_minimal(embedding=[0.11], score=-10),
    Item.from_minimal(embedding=[0.12], score=6),
])

q2 = Query.from_minimal(embedding=[0.2])
q2.update_internal(collection_id=0)
q2.add_query_results([
    Item.from_minimal(embedding=[0.21], score=4),
    Item.from_minimal(embedding=[0.22], score=5),
])

q3 = Query.from_minimal(embedding=[0.3])
q3.update_internal(collection_id=1)
q3.add_query_results([
    Item.from_minimal(embedding=[0.31], score=7),
    Item.from_minimal(embedding=[0.32], score=8),
])

queries = [q1, q2, q3]

prune_func = TopKPrune(k=1, score_agg='max', group_by='collection_id')

assert [i.valid for i in prune_func(queries)] == [True, False, True]

prune_func = TopKPrune(k=1, score_agg='mean', group_by='collection_id')

assert [i.valid for i in prune_func(queries)] == [False, True, True]