# Update

> Update functions and classes

In [None]:
#| default_exp update

In [None]:
#| hide
from nbdev.showdoc import *
%load_ext autoreload
%autoreload 2

In [None]:
#| export
from emb_opt.imports import *
from emb_opt.core import Module
from emb_opt.schemas import (
                            Item, 
                            Query, 
                            Batch, 
                            UpdateFunction,
                            DiscreteUpdateFunction,
                            ContinuousUpdateFunction,
                            ContinuousUpdateResponse
                            )

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
#| export

class UpdateModule(Module):
    def __init__(self,
                 function: UpdateFunction,
                ):
        super().__init__(Query, function)
        
    def gather_inputs(self, batch: Batch) -> (List[Tuple], List[Query]):
        idxs, inputs = batch.flatten_queries()
        return (idxs, inputs)
    
    def build_batch(self, results: List[Query]) -> Batch:
        return Batch(queries=results)
        
    def __call__(self, batch: Batch) -> Batch:
        
        idxs, inputs = self.gather_inputs(batch)
        results = self.function(inputs)
        results = self.validate_schema(results)
        batch = self.build_batch(results)
        return batch

In [None]:
batch = Batch(queries=[
                        Query(embedding=[0.1]),
                        Query(embedding=[0.2]),
                        Query(embedding=[0.3]),
                    ])

def passthrough_update(queries):
    return queries

update_module = UpdateModule(passthrough_update)

batch = update_module(batch)

In [None]:
#| export

class ContinuousUpdateModule(UpdateModule):
    def __init__(self, function: ContinuousUpdateFunction):
        super().__init__(function)
        self.output_schema = ContinuousUpdateResponse
    
    def build_batch(self, results: List[ContinuousUpdateResponse]) -> Batch:
        queries = [Query.from_parent(i.embedding, i.parent_query) for i in results]
        return Batch(queries=queries)

In [None]:
batch = Batch(queries=[
                        Query(embedding=[0.1]),
                        Query(embedding=[0.2]),
                        Query(embedding=[0.3]),
                    ])

[batch.queries[i].add_collection_index(i) for i in range(len(batch))]

def continuous_update_test(queries):
    results = [ContinuousUpdateResponse(embedding=[j*2 for j in i.embedding], parent_query=i) for i in queries]
    return results

update_module = ContinuousUpdateModule(continuous_update_test)

batch2 = update_module(batch)

assert all([batch2[i].data['_internal']['collection_index'] == batch[i].data['_internal']['collection_index']
 for i in range(len(batch2))])

In [None]:
#| export

class DiscreteUpdateModule(UpdateModule):
    def __init__(self, function: DiscreteUpdateFunction):
        super().__init__(function)
        self.output_schema = Item
    
    def build_batch(self, results: List[Item]) -> Batch:
        queries = [Query.from_item(i) for i in results]
        return Batch(queries=queries)

In [None]:
queries = []
for i in range(3):
    q = Query(embedding=[i*0.1])
    q.add_collection_index(i)
    r = Item(embedding=[i*2*0.1])
    q.add_query_results([r])
    queries.append(q)
    
batch = Batch(queries=queries)

def discrete_update_test(queries):
    return [i.query_results[0] for i in queries]

update_module = DiscreteUpdateModule(discrete_update_test)

batch2 = update_module(batch)

for i in range(len(batch2)):
    query_id = batch2[i].data['_internal']['id']
    parent_id = batch2[i].data['_internal']['parent']
    collection_id = batch2[i].data['_internal']['collection_index']
    
    assert query_id == batch[i][0].data['_internal']['id']
    assert parent_id == batch[i].data['_internal']['id']
    assert collection_id == batch[i].data['_internal']['collection_index']