# Update

> Update functions and classes

In [None]:
#| default_exp update

In [None]:
#| hide
from nbdev.showdoc import *
%load_ext autoreload
%autoreload 2

In [None]:
#| export
from emb_opt.imports import *
from emb_opt.core import Module
from emb_opt.schemas import (
                            Item, 
                            Query, 
                            Batch, 
                            UpdateFunction,
                            ContinuousUpdateResponse,
                            UpdateResponse,
                            )

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
#| export

class UpdateModule(Module):
    def __init__(self, function: UpdateFunction):
        super().__init__(UpdateResponse, function)
        
    def validate_schema(self, results: UpdateResponse) -> UpdateResponse:
        results = self.output_schema.model_validate(results)
        return results
        
    def build_batch(self, results: UpdateResponse) -> Batch:
        results = results.results
        if isinstance(results, Batch):
            batch = results
        elif isinstance(results[0], Item):
            batch = Batch(queries=[Query.from_item(i) for i in results])
        elif isinstance(results[0], ContinuousUpdateResponse):
            batch = Batch(queries=[Query.from_parent(i.embedding, i.parent_query) for i in results])
            
        return batch
        
    def __call__(self, batch: Batch) -> Batch:
        results = self.function(batch)
        results = self.validate_schema(results)
        batch = self.build_batch(results)
        return batch

In [None]:
# batch to batch

def passthrough_update_test(batch):
    return UpdateResponse(results=batch)

batch = Batch(queries=[
                        Query(embedding=[0.1]),
                        Query(embedding=[0.2]),
                        Query(embedding=[0.3]),
                    ])

update_module = UpdateModule(passthrough_update_test)

batch = update_module(batch)

In [None]:
# continuous update

batch = Batch(queries=[
                        Query(embedding=[0.1]),
                        Query(embedding=[0.2]),
                        Query(embedding=[0.3]),
                    ])

[batch.queries[i].add_collection_index(i) for i in range(len(batch))]

def continuous_update_test(queries):
    results = {'results':[ContinuousUpdateResponse(embedding=[j*2 for j in i.embedding], parent_query=i) 
                          for i in queries]}
    return results

update_module = UpdateModule(continuous_update_test)

batch2 = update_module(batch)

assert all([batch2[i].data['_internal']['collection_index'] == batch[i].data['_internal']['collection_index']
 for i in range(len(batch2))])

In [None]:
# discrete update

queries = []
for i in range(3):
    q = Query(embedding=[i*0.1])
    q.add_collection_index(i)
    r = Item(embedding=[i*2*0.1])
    q.add_query_results([r])
    queries.append(q)
    
batch = Batch(queries=queries)

def discrete_update_test(queries):
    return {'results' : [i.query_results[0] for i in queries]}

update_module = UpdateModule(discrete_update_test)

batch2 = update_module(batch)

for i in range(len(batch2)):
    query_id = batch2[i].data['_internal']['id']
    parent_id = batch2[i].data['_internal']['parent']
    collection_id = batch2[i].data['_internal']['collection_index']
    
    assert query_id == batch[i][0].data['_internal']['id']
    assert parent_id == batch[i].data['_internal']['id']
    assert collection_id == batch[i].data['_internal']['collection_index']