# Update

> Update functions and classes

In [None]:
#| default_exp update

In [None]:
#| hide
from nbdev.showdoc import *
%load_ext autoreload
%autoreload 2

In [None]:
#| export
from emb_opt.imports import *
from emb_opt.utils import compute_rl_grad, query_to_rl_inputs
from emb_opt.module import Module
from emb_opt.schemas import Item, Query, Batch, UpdateFunction

The Update step uses a set of queries and scored results to generate a new set of queries for the next iteration of the search.

Updates are denoted as `discrete` or `continuous`. `continuous` updates generate new query embeddings purely in embedding space (ie by averaging several embeddings). As a result, `continuous` update outputs do not have a specific `item` associated with them. `discrete` updates use a specific query result `Item` as the update, maintaining the `item` associated with it.

The update step is formalized by the `UpdateFunction` schema, which maps inputs `List[Query]` to outputs `List[Query]`. Note that the number of outputs can be different from the number of inputs.

The `UpdateModule` manages execution of a `UpdateFunction`. The `UpdateModule` gathers valid items, sends them to the `UpdateFunction`, and processes the results.

In [None]:
#| export

class UpdateModule(Module):
    def __init__(self, function: UpdateFunction):
        super().__init__(Query, function)
        
    def gather_inputs(self, batch: Batch) -> (List[Tuple], List[Query]):
        idxs, inputs = batch.flatten_queries()
        return (idxs, inputs)
    
    def __call__(self, batch: Batch) -> Batch:
        
        idxs, inputs = self.gather_inputs(batch)
        new_queries = self.function(inputs)
        new_queries = self.validate_schema(new_queries)
        return Batch(queries=new_queries)

In [None]:
def passthrough_update_test(queries):
    return queries

batch = Batch(queries=[
                        Query.from_minimal(embedding=[0.1]),
                        Query.from_minimal(embedding=[0.2]),
                        Query.from_minimal(embedding=[0.3]),
                    ])

update_module = UpdateModule(passthrough_update_test)

batch = update_module(batch)

assert isinstance(batch, Batch)
assert isinstance(batch[0], Query)

In [None]:
def continuous_update_test(queries):
    outputs = []
    for query in queries:
        new_query = Query.from_parent_query(embedding=[i*2 for i in query.embedding], parent_query=query)
        outputs.append(new_query)
    return outputs

batch = Batch(queries=[
                        Query.from_minimal(embedding=[0.1]),
                        Query.from_minimal(embedding=[0.2]),
                        Query.from_minimal(embedding=[0.3]),
                    ])

[batch.queries[i].update_internal(collection_id=i) for i in range(len(batch))]

update_module = UpdateModule(continuous_update_test)

batch2 = update_module(batch)

assert all([batch2[i].internal.collection_id==batch[i].internal.collection_id for i in range(len(batch2))])

assert isinstance(batch2, Batch)
assert isinstance(batch2[0], Query)

In [None]:
def discrete_update_test(queries):
    return [Query.from_item(i[0]) for i in queries]

queries = []
for i in range(3):
    q = Query.from_minimal(embedding=[i*0.1])
    q.update_internal(collection_id=i)
    r = Item.from_minimal(embedding=[i*2*0.1])
    q.add_query_results([r])
    queries.append(q)
    
batch = Batch(queries=queries)

update_module = UpdateModule(discrete_update_test)

batch2 = update_module(batch)

assert isinstance(batch2, Batch)
assert isinstance(batch2[0], Query)

for i in range(len(batch2)):
    assert batch2[i].internal.parent_id == batch[i][0].internal.parent_id
    assert batch2[i].data['_source_item_id'] == batch[i][0].id
    assert batch2[i].internal.collection_id == batch[i].internal.collection_id

In [None]:
#| export

class UpdatePlugin():
    '''
    UpdatePlugin - documentation for plugin functions to `UpdateFunction`
    
    A valid `UpdateFunction` is any function that maps `List[Query]` to 
    `List[Query]`. The inputs will be given as `Query` objects. 
    The outputs can be either a list of `Query` objects or a list of 
    valid json dictionaries that match the `Query` schema. The number of 
    outputs can be different from the number of inputs
    
    Item schema:
    
    `{
        'id' : Optional[Union[str, int]]
        'item' : Optional[Any],
        'embedding' : List[float],
        'score' : float,
        'data' : Optional[Dict],
    }`
    
    
    Query schema:
    
    `{
        'item' : Optional[Any],
        'embedding' : List[float],
        'data' : Optional[Dict],
        'query_results': List[Item]
    }`
    
    Input schema:
    
    `List[Query]`

    Output schema:
    
    `List[Query]`
    
    '''
    def __call__(self, inputs: List[Query]) -> List[Query]:
        pass

In [None]:
#| export

class UpdatePluginGradientWrapper():
    '''
    UpdatePluginGradientWrapper - this class wraps a valid 
    `UpdateFunction` to estimate the gradient of new queries 
    using the results and scores computed for the parent query.
    
    This wrapper integrates with `DataPluginGradWrapper`, which 
    allows us to create new query vectors based on the gradient
    '''
    def __init__(self, 
                 function: UpdateFunction,                          # `UpdateFunction` to wrap
                 distance_penalty: float=0,                         # RL grad distance penalty
                 max_norm: Optional[float] = None,                  # max grad norm
                 norm_type: Optional[Union[float, int, str]] = 2.0  # grad norm type
                ):
        
        self.function = function
        self.distance_penalty = distance_penalty
        self.max_norm = max_norm
        self.norm_type = norm_type
        
    def __call__(self, inputs: List[Query]) -> List[Query]:
        outputs = self.function(inputs)
        
        id_dict = {i.id : i for i in inputs}
        
        for query in outputs:
            parent = id_dict.get(query.internal.parent_id, None)
            if parent:
                _, result_embeddings, scores = query_to_rl_inputs(parent)
                query_embedding = np.array(query.embedding)
                grad = compute_rl_grad(query_embedding, 
                                       result_embeddings, 
                                       scores,
                                       distance_penalty=self.distance_penalty,
                                       max_norm=self.max_norm, 
                                       norm_type=self.norm_type,
                                       score_grad=True
                                      )
            else:
                grad = np.zeros(np.array(query.embedding).shape)
                
            query.data['_score_grad'] = grad
            
        return outputs

In [None]:
#| export

class TopKDiscreteUpdate(UpdatePlugin):
    '''
    TopKDiscreteUpdate - discrete update that 
    generates `k` new queries from the top `k` 
    scoring items in each input query
    '''
    def __init__(self, 
                 k: int # top k items to return as new queries
                ):
        self.k = k
        
    def __call__(self, inputs: List[Query]) -> List[Query]:
        outputs = []
        
        for query in inputs:
            result_scores = np.array([i.score for i in query.valid_results()])
            topk_idxs = result_scores.argsort()[::-1][:self.k]
            top_items = [query[i] for i in topk_idxs]
            outputs += top_items
            
        outputs = [Query.from_item(i) for i in outputs]
        return outputs

In [None]:
q1 = Query.from_minimal(embedding=[0.1])
q1.add_query_results([
    Item(id=None, item='1', embedding=[0.11], score=-10, data=None),
    Item(id=None, item='2', embedding=[0.12], score=6, data=None),
    Item(id=None, item='3', embedding=[0.12], score=1, data=None),
])

q2 = Query.from_minimal(embedding=[0.2])
q2.add_query_results([
    Item(id=None, item='4', embedding=[0.21], score=4, data=None),
    Item(id=None, item='5', embedding=[0.22], score=5, data=None),
    Item(id=None, item='6', embedding=[0.12], score=2, data=None),
])

batch = Batch(queries=[q1, q2])

update_func = TopKDiscreteUpdate(k=2)
update_module = UpdateModule(update_func)
batch2 = update_module(batch)

assert [i.item for i in batch2] == ['2', '3', '5', '4']

update_func2 = UpdatePluginGradientWrapper(update_func)
update_module2 = UpdateModule(update_func2)
batch3 = update_module2(batch)

In [None]:
#| export

class TopKContinuousUpdate():
    '''
    TopKContinuousUpdate - continuous update that 
    generates 1 new query by averaging the top `k` 
    scoring item embeddings for each input query
    '''
    def __init__(self, 
                 k: int # top k items to average
                ):
        self.k = k
        
    def __call__(self, inputs: List[Query]) -> List[Query]:
        outputs = []
        
        for query in inputs:
            result_scores = np.array([i.score for i in query.valid_results()])
            topk_idxs = result_scores.argsort()[::-1][:self.k]
            topk_embs = np.array([query[i].embedding for i in topk_idxs])
            
            new_embedding = np.average(topk_embs, 0)
            output = Query.from_parent_query(embedding=new_embedding, parent_query=query)
            outputs.append(output)
        return outputs

In [None]:
q1 = Query.from_minimal(embedding=[0.1])
q1.add_query_results([
    Item(id=None, item='1', embedding=[0.1], score=-10, data=None),
    Item(id=None, item='2', embedding=[0.2], score=6, data=None),
])

q2 = Query.from_minimal(embedding=[0.2])
q2.add_query_results([
    Item(id=None, item='4', embedding=[0.2], score=4, data=None),
    Item(id=None, item='5', embedding=[0.3], score=5, data=None),
])

batch = Batch(queries=[q1, q2])

update_func = TopKContinuousUpdate(k=2)
update_module = UpdateModule(update_func)
batch2 = update_module(batch)

assert np.allclose([i.embedding for i in batch2], [[0.15], [0.25]])

update_func = TopKContinuousUpdate(k=1)
update_module = UpdateModule(update_func)
batch2 = update_module(batch)

assert np.allclose([i.embedding for i in batch2], [[0.2], [0.3]])

update_func2 = UpdatePluginGradientWrapper(update_func)
update_module2 = UpdateModule(update_func2)
batch3 = update_module2(batch)

In [None]:
#| export

class RLUpdate():
    '''
    RLUpdate - uses reinforcement learning to update queries
    
    To compute the gradient with RL:
    1. compute advantages by whitening scores
        1. `advantage[i] = (scores[i] - scores.mean()) / scores.std()`
    2. compute advantage loss
        1. `advantage_loss[i] = advantage[i] * (query_embedding - result_embedding[i])**2`
    3. compute distance loss
        1. `distance_loss[i] = distance_penalty * (query_embedding - result_embedding[i])**2`
    4. sum loss terms
        1. `loss[i] = advantage_loss[i] + distance_loss[i]`
    5. compute the gradient
    
    This gives a closed for calculation of the gradient as:
    
    `grad[i] = 2 * (advantage[i] + distance_penalty) * (query_embedding - result_embedding[i])`    
    '''
    def __init__(self,
                 lrs: Union[List[float], np.ndarray],            # list of learning rates
                 distance_penalty: float,                        # distance penalty coefficient
                 max_norm: Optional[float]=None,                 # optional max grad norm for clipping
                 norm_type: Optional[Union[float, int, str]]=2.0 # norm type
                ):
        self.lrs = np.array(lrs)
        self.distance_penalty = distance_penalty
        self.max_norm = max_norm
        self.norm_type = norm_type
        
    def compute_grad(self, query: Query):
        query_embedding, result_embeddings, scores = query_to_rl_inputs(query)
        
        grad = compute_rl_grad(query_embedding, result_embeddings, scores, 
                               self.distance_penalty, self.max_norm, self.norm_type)
        
        return grad
        
    def __call__(self, queries: List[Query]) -> List[Query]:
        
        results = []
        
        for query in queries:
            grad = self.compute_grad(query)
            query_embedding = np.array(query.embedding)
            new_embeddings = query_embedding[None] - (grad[None] * self.lrs[:,None])  # (1,n) - (1,n) * (k,1)
            
            for i in range(new_embeddings.shape[0]):
                assert new_embeddings[i].shape == query_embedding.shape
                
                new_query = Query.from_parent_query(embedding=new_embeddings[i].tolist(), 
                                                    parent_query=query)
                
                new_query.data['rl_update_details'] = {
                                                        'parent_embedding' : query_embedding.tolist(),
                                                        'lr' : self.lrs[i],
                                                        'grad' : grad.tolist(),
                                                    }
                
                results.append(new_query)
                
        return results

In [None]:
lrs = np.array([1e-2, 1e-1, 1e0, 1e1])
dp = 0.1

update_function = RLUpdate(lrs, dp, max_norm=1.)

update_module = UpdateModule(update_function)

queries = []
for i in range(1,4):
    q = Query.from_minimal(embedding=[i*0.1, i*0.2, i*0.3, 0.1, 0.1])
    q.update_internal(collection_id=i)
    r1 = Item.from_minimal(embedding=[2*i*0.1, 2*i*0.2, 2*i*0.3, -0.1, -0.1], score=i*1.5)
    r2 = Item.from_minimal(embedding=[-2*i*0.1, 2*i*0.2, -2*i*0.3, -0.1, -0.1], score=i*.5)
    q.add_query_results([r1, r2])
    queries.append(q)
    
batch = Batch(queries=queries)

batch2 = update_module(batch)

assert len(batch2)/len(batch) == len(lrs)