# Schemas

> Data Schemas

In [None]:
#| default_exp schemas

In [None]:
#| hide
from nbdev.showdoc import *
%load_ext autoreload
%autoreload 2

In [None]:
#| export
from emb_opt.imports import *

  from .autonotebook import tqdm as notebook_tqdm


### Data Objects

In [None]:
#| export

Embedding = List[float]

In [None]:
#| export

class Item(BaseModel):
    item: Optional[str]
    embedding: Embedding
    score: Optional[float]
    data: Optional[dict]
        
    @model_validator(mode='before')
    @classmethod
    def _fill_data(cls, inputs: Any) -> Any:
        if type(inputs)==dict:
            if inputs.get('data', None) is None:
                inputs['data'] = {}
                
            if '_internal' not in inputs['data']:
                inputs['data']['_internal'] = {'id' : str(uuid.uuid1())}

            if inputs.get('score', None) is None:
                inputs['score'] = None

            if inputs.get('item', None) is None:
                inputs['item'] = None

        return inputs

In [None]:
#| export

class Query(BaseModel):
    item: Optional[str]
    embedding: Optional[Embedding]
    data: Optional[dict]
    query_results: Optional[list[Item]]
        
    @model_validator(mode='before')
    @classmethod
    def _fill_data(cls, inputs: Any) -> Any:
        if inputs.get('data', None) is None:
            inputs['data'] = {}
            
        if '_internal' not in inputs['data']:
            inputs['data']['_internal'] = {'id' : str(uuid.uuid1())}
        
        if inputs.get('query_results', None) is None:
            inputs['query_results'] = []
            
        if inputs.get('item', None) is None:
            inputs['item'] = None

        return inputs
    
    def __iter__(self):
        return iter(self.query_results)

    def __getitem__(self, idx: int):
        return self.query_results[idx]
    
    def __len__(self):
        return len(self.query_results)
    
    def add_query_results(self, query_results: List[Item]):
        query_id = self.data['_internal']['id']
        collection_idx = self.data['_internal'].get('collection_index', None)
        for result in query_results:
            result.data['_internal']['parent'] = query_id
            result.data['_internal']['collection_index'] = collection_idx
            self.query_results.append(result)

In [None]:
#| export

class Batch(BaseModel):
    queries: List[Query]
        
    def __iter__(self):
        return iter(self.queries)

    def __getitem__(self, idx: int):
        return self.queries[idx]
    
    def __len__(self):
        return len(self.queries)
    
    def get_item(self, query_index, result_index=None):
        if result_index is not None:
            return self.queries[query_index][result_index]
        else:
            return self.queries[query_index]
    
    def enumerate_queries(self):
        for i, query in enumerate(self.queries):
            yield ((i,None), query)
            
    def enumerate_query_results(self):
        for i, query in enumerate(self.queries):
            for j, result in enumerate(query):
                yield ((i,j), result)
                
    def flatten_queries(self):
        idxs = []
        outputs = []
        for i, q in self.enumerate_queries():
            idxs.append(i)
            outputs.append(q)
        return idxs, outputs
                
    def flatten_query_results(self):
        idxs = []
        outputs = []
        for i, r in self.enumerate_query_results():
            idxs.append(i)
            outputs.append(r)
        return idxs, outputs
    
    def clean_queries(self):
        keep = []
        remove = []
        for query in self.queries:
            if query.data['_internal'].get('remove', False):
                remove.append(query)
            else:
                keep.append(query)
        self.queries = keep
        return remove
    
    def clean_results(self):
        remove = []
        for query in self.queries:
            keep = []
            for result in query:
                if result.data['_internal'].get('remove', False):
                    remove.append(result)
                else:
                    keep.append(result)
                    
            query.query_results = keep
            
        return remove

### Data Source

In [None]:
#| export

class DataSourceResponse(BaseModel):
    valid: bool
    data: Optional[Dict]
    query_results: List[Item]

In [None]:
#| export

DataSourceFunction = Callable[List[Query], List[DataSourceResponse]]

In [None]:
q1 = Query(item='q1', embedding=[0.1])
q2 = Query(item='q2', embedding=[0.2])
q3 = Query(item='q3', embedding=[0.3])
q4 = Query(item='q4', embedding=[0.4])

In [None]:
# this should be done by init or update step
q1.data['_internal']['collection_index'] = 0
q2.data['_internal']['collection_index'] = 0
q3.data['_internal']['collection_index'] = 1
q4.data['_internal']['collection_index'] = 1

In [None]:
batch = Batch(queries=[q1, q2, q3, q4])

## Data Source

In [None]:
class DataSourceResponse(BaseModel):
    valid: bool
    data: Optional[Dict]
    query_results: List[Item]

In [None]:
DataSourceFunction = Callable[Query, DataSourceResponse]
DataSourceFunctionBatched = Callable[List[Query], List[DataSourceResponse]]

In [None]:
data_idxs, data_inputs = batch.flatten_queries()

In [None]:
r1 = Item(item='1', embedding=[0.11])
r2 = Item(item='2', embedding=[0.22])
r3 = Item(item='3', embedding=[0.33])
r4 = Item(item='4', embedding=[0.44])
r5 = Item(item='5', embedding=[0.55])

data_results = [[r1, r2], [r3,], [r4, r5], []]
data_valid = [True, False, True, True]

data_results = [DataSourceResponse(valid=data_valid[i], data={'data_test':True}, 
                                   query_results=data_results[i]) for i in range(len(data_results))]

In [None]:
def scatter_data_response(batch, idxs, results):
    for (q_idx, r_idx), result in zip(idxs, results):
        batch_item = batch.get_item(q_idx, r_idx)
        if result.data:
            batch_item.data.update(result.data)
        
        if result.valid:
            if result.query_results:
                batch_item.add_query_results(result.query_results)
                
            else:
                batch_item.data['_internal']['remove'] = True
                batch_item.data['_internal']['remove_details'] = 'query returned no results'
        
        else:
            batch_item.data['_internal']['remove'] = True
            batch_item.data['_internal']['remove_details'] = 'query response invalid'

In [None]:
scatter_data_response(batch, data_idxs, data_results)

In [None]:
removed_queries = batch.clean_queries()

In [None]:
removed_queries

[Query(item='q2', embedding=[0.2], data={'_internal': {'id': '069ef497-4c20-11ee-b64b-7b1d5a84b1d4', 'collection_index': 0, 'remove': True, 'remove_details': 'query response invalid'}, 'data_test': True}, query_results=[]),
 Query(item='q4', embedding=[0.4], data={'_internal': {'id': '069ef499-4c20-11ee-b64b-7b1d5a84b1d4', 'collection_index': 1, 'remove': True, 'remove_details': 'query returned no results'}, 'data_test': True}, query_results=[])]

In [None]:
[i.item for i in batch]

['q1', 'q3']

## Filter

In [None]:
class FilterResponse(BaseModel):
    valid: bool
    data: Optional[dict]
        
    @model_validator(mode='before')
    @classmethod
    def _fill_data(cls, data: Any) -> Any:
        if "data" not in data:
            data["data"] = None
        return data

In [None]:
FilterFunction = Callable[Item, FilterResponse]
FilterFunctionBatched = Callable[List[Item], List[FilterResponse]]

In [None]:
filter_idxs, filter_inputs = batch.flatten_query_results()

In [None]:
filter_results = [{'valid':True}, 
                  {'valid':False}, 
                  {'valid':True, 'data':{'blah':5}},
                  {'valid':True},
                 ]
filter_results = [FilterResponse(**i) for i in filter_results]

In [None]:
def scatter_filter_response(batch, idxs, results):
    for (q_idx, r_idx), result in zip(idxs, results):
        batch_item = batch.get_item(q_idx, r_idx)
        if result.data:
            batch_item.data.update(result.data)
            
        if not result.valid:
            batch_item.data['_internal']['remove'] = True
            batch_item.data['_internal']['remove_details'] = 'filter response invalid'

In [None]:
scatter_filter_response(batch, filter_idxs, filter_results)

In [None]:
removed_results = batch.clean_results()

In [None]:
removed_results

[Item(item='2', embedding=[0.22], score=None, data={'_internal': {'id': '092a869f-4c20-11ee-b64b-7b1d5a84b1d4', 'parent': '069ef496-4c20-11ee-b64b-7b1d5a84b1d4', 'collection_index': 0, 'remove': True, 'remove_details': 'filter response invalid'}})]

## Score

In [None]:
class ScoreResponse(BaseModel):
    valid: bool
    score: float
    data: Optional[dict]
        
    @model_validator(mode='before')
    @classmethod
    def _fill_data(cls, data: Any) -> Any:
        if "data" not in data:
            data["data"] = None
        return data

In [None]:
ScoreFunction = Callable[Item, ScoreResponse]
ScoreFunctionBatched = Callable[List[Item], List[ScoreResponse]]

In [None]:
score_idxs, score_inputs = batch.flatten_query_results()

In [None]:
score_results = [{'score':1., 'valid':True}, 
                  {'score':2., 'valid':True, 'data':{'check':'suspect'}}, 
                  {'score':3., 'valid':True},
                 ]
score_results = [ScoreResponse(**i) for i in score_results]

In [None]:
def scatter_score_response(batch, idxs, results):
    for (q_idx, r_idx), result in zip(idxs, results):
        batch_item = batch.get_item(q_idx, r_idx)
        
        batch_item.score = result.score
        
        if result.data:
            batch_item.data.update(result.data)
            
        if not result.valid:
            batch_item.data['_internal']['remove'] = True
            batch_item.data['_internal']['remove_details'] = 'score response invalid'

In [None]:
scatter_score_response(batch, score_idxs, score_results)

## Prune

In [None]:
class PruneResponse(BaseModel):
    valid: bool
    data: Optional[dict]
        
    @model_validator(mode='before')
    @classmethod
    def _fill_data(cls, data: Any) -> Any:
        if "data" not in data:
            data["data"] = None
        return data

In [None]:
PruneFunction = Callable[Item, PruneResponse]
PruneFunctionBatched = Callable[List[Item], List[PruneResponse]]

In [None]:
prune_results = [{'valid':True}, {'valid':False}, {'valid':True}]
prune_results = [PruneResponse(**i) for i in prune_results]

In [None]:
prune_idxs, prune_inputs = batch.flatten_queries()

In [None]:
def scatter_prune_response(batch, idxs, results):
    for (q_idx, r_idx), result in zip(idxs, results):
        batch_item = batch.get_item(q_idx, r_idx)
        if result.data:
            batch_item.data.update(result.data)
            
        if not result.valid:
            batch_item.data['_internal']['remove'] = True
            batch_item.data['_internal']['remove_details'] = 'pruned'

In [None]:
scatter_prune_response(batch, prune_idxs, prune_results)

In [None]:
removed_queries = batch.clean_queries()

In [None]:
removed_queries[0].data

{'_internal': {'id': '069ef498-4c20-11ee-b64b-7b1d5a84b1d4',
  'collection_index': 1,
  'remove': True,
  'remove_details': 'pruned'},
 'data_test': True}