# Update

> Update functions and classes

In [None]:
#| default_exp update

In [None]:
#| hide
from nbdev.showdoc import *
%load_ext autoreload
%autoreload 2

In [None]:
#| export
from emb_opt.imports import *
from emb_opt.utils import whiten
from emb_opt.core import Module
from emb_opt.schemas import (
                            Item, 
                            Query, 
                            Batch, 
                            ContinuousUpdateResponse,
                            UpdateFunction,
                            UpdateResponse,
                            UpdateResponseValidator,
                            )

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
#| export

class UpdateModule(Module):
    def __init__(self, function: UpdateFunction):
        super().__init__(UpdateResponseValidator, function)
        
    def validate_schema(self, results: UpdateResponse) -> UpdateResponse:
        results =  self.output_schema.model_validate({'results':results}, strict=True)               
        return results.results
        
    def build_batch(self, results: UpdateResponse) -> Batch:
        if isinstance(results, Batch):
            batch = results
        elif isinstance(results[0], Query):
            batch = Batch(queries=results)
        elif isinstance(results[0], Item):
            batch = Batch(queries=[Query.from_item(i) for i in results])
        elif isinstance(results[0], ContinuousUpdateResponse):
            batch = Batch(queries=[Query.from_parent(i.embedding, i.parent_query) for i in results])
            
        return batch
        
    def __call__(self, batch: Batch) -> Batch:
        results = self.function(batch)
        results = self.validate_schema(results)
        batch = self.build_batch(results)
        return batch

In [None]:
# batch to batch

def passthrough_update_test(batch):
    return batch

batch = Batch(queries=[
                        Query(embedding=[0.1]),
                        Query(embedding=[0.2]),
                        Query(embedding=[0.3]),
                    ])

update_module = UpdateModule(passthrough_update_test)

batch = update_module(batch)

assert isinstance(batch, Batch)
assert isinstance(batch[0], Query)

In [None]:
# continuous update

batch = Batch(queries=[
                        Query(embedding=[0.1]),
                        Query(embedding=[0.2]),
                        Query(embedding=[0.3]),
                    ])

[batch.queries[i].add_collection_index(i) for i in range(len(batch))]

def continuous_update_test(queries):
    results = [ContinuousUpdateResponse(embedding=[j*2 for j in i.embedding], parent_query=i) 
                          for i in queries]
    return results

update_module = UpdateModule(continuous_update_test)

batch2 = update_module(batch)

assert all([batch2[i].data['_internal']['collection_index'] == batch[i].data['_internal']['collection_index']
             for i in range(len(batch2))])

assert isinstance(batch2, Batch)
assert isinstance(batch2[0], Query)

In [None]:
# discrete update

queries = []
for i in range(3):
    q = Query(embedding=[i*0.1])
    q.add_collection_index(i)
    r = Item(embedding=[i*2*0.1])
    q.add_query_results([r])
    queries.append(q)
    
batch = Batch(queries=queries)

def discrete_item_update_test(batch):
    return [i.query_results[0] for i in batch]

update_module = UpdateModule(discrete_item_update_test)

batch2 = update_module(batch)

assert isinstance(batch2, Batch)
assert isinstance(batch2[0], Query)

for i in range(len(batch2)):
    query_id = batch2[i].data['_internal']['id']
    parent_id = batch2[i].data['_internal']['parent']
    collection_id = batch2[i].data['_internal']['collection_index']
    
    assert query_id == batch[i][0].data['_internal']['id']
    assert parent_id == batch[i].data['_internal']['id']
    assert collection_id == batch[i].data['_internal']['collection_index']
    

    
def discrete_query_update_test(batch):
    return batch.queries

update_module = UpdateModule(discrete_query_update_test)

batch2 = update_module(batch)

assert isinstance(batch2, Batch)
assert isinstance(batch2[0], Query)

In [None]:
#| export

class TopKDiscreteUpdate():
    def __init__(self, k):
        self.k = k
        
    def __call__(self, batch: Batch):
        outputs = []
        
        for query in batch:
            result_scores = np.array([i.score for i in query])
            topk_idxs = result_scores.argsort()[::-1][:self.k]
            top_items = [query[i] for i in topk_idxs]
            outputs += top_items
        return outputs

In [None]:
q1 = Query(embedding=[0.1])
q1.add_query_results([
    Item(item='1', embedding=[0.11], score=-10),
    Item(item='2', embedding=[0.12], score=6),
    Item(item='3', embedding=[0.12], score=1),
])

q2 = Query(embedding=[0.2])
q2.add_query_results([
    Item(item='4', embedding=[0.21], score=4),
    Item(item='5', embedding=[0.22], score=5),
    Item(item='6', embedding=[0.12], score=2),
])

batch = Batch(queries=[q1, q2])

update_func = TopKDiscreteUpdate(k=2)
update_module = UpdateModule(update_func)
batch2 = update_module(batch)

assert [i.item for i in batch2] == ['2', '3', '5', '4']

In [None]:
#| export

class TopKContinuousUpdate():
    def __init__(self, 
                 k: int,
                ):
        self.k = k
        
    def __call__(self, batch: Batch):
        outputs = []
        
        for query in batch:
            result_scores = np.array([i.score for i in query])
            topk_idxs = result_scores.argsort()[::-1][:self.k]
            topk_embs = np.array([query[i].embedding for i in topk_idxs])
            
            new_embedding = np.average(topk_embs, 0)
            
            output = ContinuousUpdateResponse(embedding=new_embedding, parent_query=query)
            outputs.append(output)
        return outputs

In [None]:
q1 = Query(embedding=[0.1])
q1.add_query_results([
    Item(item='1', embedding=[0.1], score=-10),
    Item(item='2', embedding=[0.2], score=6),
])

q2 = Query(embedding=[0.2])
q2.add_query_results([
    Item(item='4', embedding=[0.2], score=4),
    Item(item='5', embedding=[0.3], score=5),
])

batch = Batch(queries=[q1, q2])

update_func = TopKContinuousUpdate(k=2)
update_module = UpdateModule(update_func)
batch2 = update_module(batch)

assert np.allclose([i.embedding for i in batch2], [[0.15], [0.25]])

update_func = TopKContinuousUpdate(k=1)
update_module = UpdateModule(update_func)
batch2 = update_module(batch)

assert np.allclose([i.embedding for i in batch2], [[0.2], [0.3]])

In [None]:
#| export

class RLUpdate():
    def __init__(self,
                 lrs: List[float],
                 distance_penalty: float
                ):
        self.lrs = lrs
        self.distance_penalty = distance_penalty
        
    def __call__(self, batch: Batch) -> List[ContinuousUpdateResponse]:

        query_embeddings = np.array([i.embedding for i in batch])
        result_embeddings = [np.array([i.embedding for i in query]) for query in batch]
        advantages = [whiten(np.array([i.score for i in query])) for query in batch]

        advantage_grad = np.array(
                        [(advantages[i][:,None] * (2*(query_embeddings[i,None] - result_embeddings[i]))).mean(0)
                        for i in range(len(batch))])

        distance_grad = 2*(query_embeddings - np.array([i.mean(0) for i in result_embeddings]))

        grads = advantage_grad + (self.distance_penalty * distance_grad)

        new_embeddings = query_embeddings[:,None] - (grads[:,None,:] * self.lrs[None,:,None])
        
        results = []
        
        for i in range(new_embeddings.shape[0]):
            for j in range(new_embeddings.shape[1]):
                results.append(
                    ContinuousUpdateResponse(embedding=new_embeddings[i][j].tolist(), parent_query=batch[i]))

        return results