# Update

> Update functions and classes

In [None]:
#| default_exp update

In [None]:
#| hide
from nbdev.showdoc import *
%load_ext autoreload
%autoreload 2

In [None]:
#| export
from emb_opt.imports import *
from emb_opt.utils import whiten
from emb_opt.module import Module
from emb_opt.schemas import Item, Query, Batch, UpdateFunction

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
#| export

class UpdateModule(Module):
    def __init__(self, function: UpdateFunction):
        super().__init__(Query, function)
        
    def gather_inputs(self, batch: Batch) -> (List[Tuple], List[Query]):
        idxs, inputs = batch.flatten_queries()
        return (idxs, inputs)
    
    def __call__(self, batch: Batch) -> Batch:
        
        idxs, inputs = self.gather_inputs(batch)
        new_queries = self.function(inputs)
        new_queries = self.validate_schema(new_queries)
        return Batch(queries=new_queries)

In [None]:
def passthrough_update_test(queries):
    return queries

batch = Batch(queries=[
                        Query.from_minimal(embedding=[0.1]),
                        Query.from_minimal(embedding=[0.2]),
                        Query.from_minimal(embedding=[0.3]),
                    ])

update_module = UpdateModule(passthrough_update_test)

batch = update_module(batch)

assert isinstance(batch, Batch)
assert isinstance(batch[0], Query)

In [None]:
def continuous_update_test(queries):
    outputs = []
    for query in queries:
        new_query = Query.from_parent_query(embedding=[i*2 for i in query.embedding], parent_query=query)
        outputs.append(new_query)
    return outputs

batch = Batch(queries=[
                        Query.from_minimal(embedding=[0.1]),
                        Query.from_minimal(embedding=[0.2]),
                        Query.from_minimal(embedding=[0.3]),
                    ])

[batch.queries[i].update_internal(collection_id=i) for i in range(len(batch))]

update_module = UpdateModule(continuous_update_test)

batch2 = update_module(batch)

assert all([batch2[i].internal.collection_id==batch[i].internal.collection_id for i in range(len(batch2))])

assert isinstance(batch2, Batch)
assert isinstance(batch2[0], Query)

In [None]:
def discrete_update_test(queries):
    return [Query.from_item(i[0]) for i in queries]

queries = []
for i in range(3):
    q = Query.from_minimal(embedding=[i*0.1])
    q.update_internal(collection_id=i)
    r = Item.from_minimal(embedding=[i*2*0.1])
    q.add_query_results([r])
    queries.append(q)
    
batch = Batch(queries=queries)

update_module = UpdateModule(discrete_update_test)

batch2 = update_module(batch)

assert isinstance(batch2, Batch)
assert isinstance(batch2[0], Query)

for i in range(len(batch2)):
    assert batch2[i].internal.parent_id == batch[i][0].internal.parent_id
    assert batch2[i].data['_source_item_id'] == batch[i][0].id
    assert batch2[i].internal.collection_id == batch[i].internal.collection_id

In [None]:
#| export

class UpdatePlugin():
    def __call__(self, inputs: List[Query]) -> List[Query]:
        pass

In [None]:
#| export

class TopKDiscreteUpdate(UpdatePlugin):
    def __init__(self, k: int):
        self.k = k
        
    def __call__(self, inputs: List[Query]) -> List[Query]:
        outputs = []
        
        for query in inputs:
            result_scores = np.array([i.score for i in query.valid_results()])
            topk_idxs = result_scores.argsort()[::-1][:self.k]
            top_items = [query[i] for i in topk_idxs]
            outputs += top_items
            
        outputs = [Query.from_item(i) for i in outputs]
        return outputs

In [None]:
q1 = Query.from_minimal(embedding=[0.1])
q1.add_query_results([
    Item(id=None, item='1', embedding=[0.11], score=-10, data=None),
    Item(id=None, item='2', embedding=[0.12], score=6, data=None),
    Item(id=None, item='3', embedding=[0.12], score=1, data=None),
])

q2 = Query.from_minimal(embedding=[0.2])
q2.add_query_results([
    Item(id=None, item='4', embedding=[0.21], score=4, data=None),
    Item(id=None, item='5', embedding=[0.22], score=5, data=None),
    Item(id=None, item='6', embedding=[0.12], score=2, data=None),
])

batch = Batch(queries=[q1, q2])

update_func = TopKDiscreteUpdate(k=2)
update_module = UpdateModule(update_func)
batch2 = update_module(batch)

assert [i.item for i in batch2] == ['2', '3', '5', '4']

In [None]:
#| export

class TopKContinuousUpdate():
    def __init__(self, k: int):
        self.k = k
        
    def __call__(self, inputs: List[Query]) -> List[Query]:
        outputs = []
        
        for query in inputs:
            result_scores = np.array([i.score for i in query.valid_results()])
            topk_idxs = result_scores.argsort()[::-1][:self.k]
            topk_embs = np.array([query[i].embedding for i in topk_idxs])
            
            new_embedding = np.average(topk_embs, 0)
            output = Query.from_parent_query(embedding=new_embedding, parent_query=query)
            outputs.append(output)
        return outputs

In [None]:
q1 = Query.from_minimal(embedding=[0.1])
q1.add_query_results([
    Item(id=None, item='1', embedding=[0.1], score=-10, data=None),
    Item(id=None, item='2', embedding=[0.2], score=6, data=None),
])

q2 = Query.from_minimal(embedding=[0.2])
q2.add_query_results([
    Item(id=None, item='4', embedding=[0.2], score=4, data=None),
    Item(id=None, item='5', embedding=[0.3], score=5, data=None),
])

batch = Batch(queries=[q1, q2])

update_func = TopKContinuousUpdate(k=2)
update_module = UpdateModule(update_func)
batch2 = update_module(batch)

assert np.allclose([i.embedding for i in batch2], [[0.15], [0.25]])

update_func = TopKContinuousUpdate(k=1)
update_module = UpdateModule(update_func)
batch2 = update_module(batch)

assert np.allclose([i.embedding for i in batch2], [[0.2], [0.3]])

In [None]:
#| export

class RLUpdate():
    def __init__(self,
                 lrs: Union[List[float], np.ndarray],
                 distance_penalty: float
                ):
        self.lrs = np.array(lrs)
        self.distance_penalty = distance_penalty
        
    def __call__(self, queries: List[Query]) -> List[Query]:
        query_embeddings = np.array([i.embedding for i in queries])
        
        result_embeddings = [np.array([i.embedding for i in query.valid_results()]) 
                             for query in queries]
        
        advantages = [whiten(np.array([i.score for i in query.valid_results()])) 
                      for query in queries]

        advantage_grad = np.array(
                        [(advantages[i][:,None] * (2*(query_embeddings[i,None] - result_embeddings[i]))).mean(0)
                        for i in range(len(queries))])

        distance_grad = 2*(query_embeddings - np.array([i.mean(0) for i in result_embeddings]))

        grads = advantage_grad + (self.distance_penalty * distance_grad)

        new_embeddings = query_embeddings[:,None] - (grads[:,None,:] * self.lrs[None,:,None])
        # (n,m,-1) = (n,1,-1) - ((n,1,-1) * (1,m,1))
        
        results = []
        
        for i in range(new_embeddings.shape[0]): # number of embeddings
            for j in range(new_embeddings.shape[1]): # learning rates
                
                new_query = Query.from_parent_query(embedding=new_embeddings[i][j].tolist(), 
                                                    parent_query=queries[i])
                new_query.data['rl_update_details'] = {
                                                        'parent_embedding' : query_embeddings[i].tolist(),
                                                        'lr' : self.lrs[j],
                                                        'grad' : grads[i].tolist(),
                                                    }
                
                results.append(new_query)

        return results

In [None]:
lrs = np.array([1e-2, 1e-1, 1e0, 1e1])
dp = 0.1

update_function = RLUpdate(lrs, dp)

update_module = UpdateModule(update_function)

queries = []
for i in range(1,4):
    q = Query.from_minimal(embedding=[i*0.1])
    q.update_internal(collection_id=i)
    r = Item.from_minimal(embedding=[i*2*0.1], score=i*1.5)
    q.add_query_results([r])
    queries.append(q)
    
batch = Batch(queries=queries)

batch2 = update_module(batch)

assert len(batch2)/len(batch) == len(lrs)