# Schemas

> Data Schemas

In [None]:
#| default_exp schemas

In [None]:
#| hide
from nbdev.showdoc import *
%load_ext autoreload
%autoreload 2

In [None]:
#| export
from emb_opt.imports import *

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
class Item(BaseModel):
    item: Optional[str]
    embedding: list[float]
    score: Optional[float]
    data: Optional[dict]
        
    @field_validator('data')
    def set_data(cls, data):
        return data or {}

In [None]:
class Query(BaseModel):
    collection_index: int
    item: Optional[str]
    embedding: list[float]
    data: Optional[dict]
    query_results: Optional[list[Item]]
        
    @model_validator(mode='before')
    @classmethod
    def _fill_data(cls, inputs: Any) -> Any:
        if inputs.get('data', None) is None:
            inputs['data'] = {}
        
        if inputs.get('query_results', None) is None:
            inputs['query_results'] = []

        return inputs
    
    def __iter__(self):
        return iter(self.query_results)

    def __getitem__(self, idx: int):
        return self.query_results[idx]
    
    def __setitem__(self, idx: int, value):
        self.query_results[idx] = value

In [None]:
class Batch(BaseModel):
    queries: List[Query]
        
    def __iter__(self):
        return iter(self.queries)

    def __getitem__(self, idx: int):
        return self.queries[idx]
    
    def __setitem__(self, idx: int, value):
        self.queries[idx] = value
    
    def get_result(self, query_index, result_index):
        return self.queries[query_index][result_index]
    
    def enumerate_queries(self):
        for i, query in enumerate(self.queries):
            yield (i, query)
            
    def enumerate_query_results(self):
        for i, query in enumerate(self.queries):
            for j, result in enumerate(query):
                yield ((i,j), result)
                
    def cleanup(self):
        self.queries = [i for i in self.queries if (i is not None)]
        for query in self.queries:
            query.query_results = [i for i in query.query_results if (i is not None)]

In [None]:
q1 = Query(collection_index=0, item='q1', embedding=[0.1])
q2 = Query(collection_index=0, item='q2', embedding=[0.2])
q3 = Query(collection_index=1, item='q3', embedding=[0.3])

batch = Batch(queries=[q1, q2, q3])

## Data Source

In [None]:
class DataSourceInput(BaseModel):
    item: Optional[str]
    embedding: Optional[list[float]]
    data: Optional[dict]

In [None]:
def gather_data_request(batch, include_item=True, include_embedding=True, include_data=True):
    outputs = []
    
    for idx, query in batch.enumerate_queries():
        item = query.item if include_item else None
        embedding = query.embedding if include_embedding else None
        data = query.data if include_data else None
        outputs.append(DataSourceInput(item=item, embedding=embedding, data=data))
        
    return outputs

In [None]:
data_inputs = gather_data_request(batch)

In [None]:
r1 = Item(item='1', embedding=[0.11], score=None, data=None)
r2 = Item(item='2', embedding=[0.22], score=None, data=None)
r3 = Item(item='3', embedding=[0.33], score=None, data=None)
r4 = Item(item='4', embedding=[0.44], score=None, data=None)
r5 = Item(item='5', embedding=[0.55], score=None, data=None)

data_results = [[r1, r2], [r3,], [r4, r5]]

In [None]:
def scatter_data_response(batch, data_results):
    
    for i, item_list in enumerate(data_results):
        batch[i].query_results += item_list

In [None]:
scatter_data_response(batch, data_results)

## Filter 

In [None]:
class FilterRequestItem(BaseModel):
    item: Optional[str]
    embedding: Optional[list[float]]
    data: Optional[dict]

In [None]:
def gather_filter_request(batch, include_item=True, include_embedding=True, include_data=True):
    idxs = []
    outputs = []
    
    for idx, query_result in batch.enumerate_query_results():
        idxs.append(idx)
        
        item = query_result.item if include_item else None
        embedding = query_result.embedding if include_embedding else None
        data = query_result.data if include_data else None
        outputs.append(FilterRequestItem(item=item, embedding=embedding, data=data))
        
    return idxs, outputs

In [None]:
filter_idxs, filter_inputs = gather_filter_request(batch)

In [None]:
class FilterResult(BaseModel):
    result: bool
    filter_data: Optional[dict]
        
    @model_validator(mode='before')
    @classmethod
    def _fill_data(cls, data: Any) -> Any:
        if "filter_data" not in data:
            data["filter_data"] = None
        return data

In [None]:
filter_results = [{'result':True}, 
                  {'result':False}, 
                  {'result':True, 'filter_data':{'blah':5}},
                  {'result':True},
                  {'result':False}
                 ]
filter_results = [FilterResult(**i) for i in filter_results]

In [None]:
[i[1] for i in batch.enumerate_query_results()]

[Item(item='1', embedding=[0.11], score=None, data={}),
 Item(item='2', embedding=[0.22], score=None, data={}),
 Item(item='3', embedding=[0.33], score=None, data={}),
 Item(item='4', embedding=[0.44], score=None, data={}),
 Item(item='5', embedding=[0.55], score=None, data={})]

In [None]:
def scatter_filter_response(batch, idxs, filter_results):
    
    for (result_idx, filter_result) in zip(idxs, filter_results):
        q_idx, r_idx = result_idx
        if filter_result.result:
            if filter_result.filter_data:
                batch[q_idx][r_idx].data['filter_data'] = filter_result.filter_data
        else:
            batch[q_idx][r_idx] = None
            
    batch.cleanup()

In [None]:
scatter_filter_response(batch, filter_idxs, filter_results)

In [None]:
[i[1] for i in batch.enumerate_query_results()]

[Item(item='1', embedding=[0.11], score=None, data={}),
 Item(item='3', embedding=[0.33], score=None, data={'filter_data': {'blah': 5}}),
 Item(item='4', embedding=[0.44], score=None, data={})]

In [None]:
class ScoreRequestItem(BaseModel):
    item: Optional[str]
    embedding: Optional[list[float]]
    data: Optional[dict]

In [None]:
def gather_score_request(batch, include_item=True, include_embedding=True, include_data=True):
    idxs = []
    outputs = []
    
    for idx, query_result in batch.enumerate_query_results():
        idxs.append(idx)
        
        item = query_result.item if include_item else None
        embedding = query_result.embedding if include_embedding else None
        data = query_result.data if include_data else None
        outputs.append(ScoreRequestItem(item=item, embedding=embedding, data=data))
        
    return idxs, outputs

In [None]:
score_idxs, score_inputs = gather_score_request(batch)

In [None]:
class ScoreResult(BaseModel):
    result: float
    score_data: Optional[dict]
        
    @model_validator(mode='before')
    @classmethod
    def _fill_data(cls, data: Any) -> Any:
        if "score_data" not in data:
            data["score_data"] = None
        return data

In [None]:
score_results = [{'result':1.}, 
                  {'result':2., 'score_data':{'check':'suspect'}}, 
                  {'result':3.},
                 ]
score_results = [ScoreResult(**i) for i in score_results]

In [None]:
def scatter_score_response(batch, idxs, score_results):
    
    for (result_idx, score_result) in zip(idxs, score_results):
        q_idx, r_idx = result_idx
        batch[q_idx][r_idx].score = score_result.result
        
        if score_result.score_data:
            batch[q_idx][r_idx].data['score_data'] = score_result.score_data

In [None]:
scatter_score_response(batch, score_idxs, score_results)

In [None]:
[i[1] for i in batch.enumerate_query_results()]

[Item(item='1', embedding=[0.11], score=1.0, data={}),
 Item(item='3', embedding=[0.33], score=2.0, data={'filter_data': {'blah': 5}, 'score_data': {'check': 'suspect'}}),
 Item(item='4', embedding=[0.44], score=3.0, data={})]

In [None]:
# prune request = batch

In [None]:
class PruneResult(BaseModel):
    result: bool
    prune_data: Optional[dict]
        
    @model_validator(mode='before')
    @classmethod
    def _fill_data(cls, data: Any) -> Any:
        if "prune_data" not in data:
            data["prune_data"] = None
        return data

In [None]:
prune_results = [{'result':True}, {'result':False}, {'result':True}]
prune_results = [PruneResult(**i) for i in prune_results]

In [None]:
def scatter_prune_response(batch, prune_results):
    
    for i, prune_result in enumerate(prune_results):
        if prune_result.result:
            if prune_result.prune_data:
                batch[i].data['filter_data'] = prune_result.prune_data
        else:
            batch[i] = None
            
    batch.cleanup()

In [None]:
scatter_prune_response(batch, prune_results)