# Schemas

> Data Schemas

In [1]:
#| default_exp schemas

In [2]:
#| hide
from nbdev.showdoc import *
%load_ext autoreload
%autoreload 2

In [3]:
#| export
from emb_opt.imports import *

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
class Item(BaseModel):
    item: Optional[str]
    embedding: list[float]
    score: Optional[float]
    data: Optional[dict]
        
    @model_validator(mode='before')
    @classmethod
    def _fill_data(cls, inputs: Any) -> Any:
        if type(inputs)==dict:
            if inputs.get('data', None) is None:
                inputs['data'] = {}

            if inputs.get('score', None) is None:
                inputs['score'] = None

            if inputs.get('item', None) is None:
                inputs['item'] = None

        return inputs

In [5]:
class Query(BaseModel):
    collection_index: Optional[int]
    item: Optional[str]
    embedding: Optional[list[float]]
    data: Optional[dict]
    query_results: Optional[list[Item]]
        
    @model_validator(mode='before')
    @classmethod
    def _fill_data(cls, inputs: Any) -> Any:
        if inputs.get('data', None) is None:
            inputs['data'] = {}
        
        if inputs.get('query_results', None) is None:
            inputs['query_results'] = []
            
        if inputs.get('item', None) is None:
            inputs['item'] = None

        return inputs
    
    def __iter__(self):
        return iter(self.query_results)

    def __getitem__(self, idx: int):
        return self.query_results[idx]
    
    def __setitem__(self, idx: int, value):
        self.query_results[idx] = value

In [6]:
class Batch(BaseModel):
    queries: List[Query]
        
    def __iter__(self):
        return iter(self.queries)

    def __getitem__(self, idx: int):
        return self.queries[idx]
    
    def __setitem__(self, idx: int, value):
        self.queries[idx] = value
    
    def get_result(self, query_index, result_index):
        return self.queries[query_index][result_index]
    
    def enumerate_queries(self):
        for i, query in enumerate(self.queries):
            yield (i, query)
            
    def enumerate_query_results(self):
        for i, query in enumerate(self.queries):
            for j, result in enumerate(query):
                yield ((i,j), result)
                
    def flatten_query_results(self):
        idxs = []
        outputs = []
        for i, r in self.enumerate_query_results():
            idxs.append(i)
            outputs.append(r)
        return idxs, outputs
                
    def cleanup(self):
        self.queries = [i for i in self.queries if (i is not None)]
        for query in self.queries:
            query.query_results = [i for i in query.query_results if (i is not None)]

In [7]:
q1 = Query(collection_index=0, item='q1', embedding=[0.1])
q2 = Query(collection_index=0, item='q2', embedding=[0.2])
q3 = Query(collection_index=1, item='q3', embedding=[0.3])
q4 = Query(collection_index=1, item='q4', embedding=[0.4])

batch = Batch(queries=[q1, q2, q3, q4])

## Data Source

In [8]:
# class DataSourceRequest(BaseModel):
#     item: Optional[str]
#     embedding: Optional[list[float]]
#     data: Optional[dict]
        
class DataSourceResponse(BaseModel):
    valid: bool
    data: Optional[dict]
    query_results: list[Item]

In [9]:
data_inputs = [i[1] for i in batch.enumerate_queries()]

In [10]:
r1 = Item(item='1', embedding=[0.11], score=None, data=None)
r2 = Item(item='2', embedding=[0.22], score=None, data=None)
r3 = Item(item='3', embedding=[0.33], score=None, data=None)
r4 = Item(item='4', embedding=[0.44], score=None, data=None)
r5 = Item(item='5', embedding=[0.55], score=None, data=None)

data_results = [[r1, r2], [r3,], [r4, r5], []]
data_valid = [True, False, True, True]

data_results = [DataSourceResponse(valid=data_valid[i], data={'data_test':True}, 
                                   query_results=data_results[i]) for i in range(len(data_results))]

In [11]:
def scatter_data_response(batch, data_results):
    
    for i, data_response in enumerate(data_results):
        if data_response.valid and data_response.query_results:
            batch[i].query_results += data_response.query_results
            if data_response.data:
                batch[i].data.update(data_response.data)
        else:
            batch[i] = None
            
    batch.cleanup()

In [12]:
scatter_data_response(batch, data_results)

## Filter 

In [13]:
filter_idxs, filter_inputs = batch.flatten_query_results()

In [14]:
class FilterResponse(BaseModel):
    valid: bool
    data: Optional[dict]
        
    @model_validator(mode='before')
    @classmethod
    def _fill_data(cls, data: Any) -> Any:
        if "data" not in data:
            data["data"] = None
        return data

In [15]:
filter_results = [{'valid':True}, 
                  {'valid':False}, 
                  {'valid':True, 'data':{'blah':5}},
                  {'valid':True},
                 ]
filter_results = [FilterResponse(**i) for i in filter_results]

In [16]:
[i[1] for i in batch.enumerate_query_results()]

[Item(item='1', embedding=[0.11], score=None, data={}),
 Item(item='2', embedding=[0.22], score=None, data={}),
 Item(item='4', embedding=[0.44], score=None, data={}),
 Item(item='5', embedding=[0.55], score=None, data={})]

In [17]:
def scatter_filter_response(batch, idxs, filter_results):
    
    for (result_idx, filter_result) in zip(idxs, filter_results):
        q_idx, r_idx = result_idx
        if filter_result.valid:
            if filter_result.data:
                batch[q_idx][r_idx].data.update(filter_result.data)
        else:
            batch[q_idx][r_idx] = None
            
    batch.cleanup()

In [18]:
scatter_filter_response(batch, filter_idxs, filter_results)

In [19]:
[i[1] for i in batch.enumerate_query_results()]

[Item(item='1', embedding=[0.11], score=None, data={}),
 Item(item='4', embedding=[0.44], score=None, data={'blah': 5}),
 Item(item='5', embedding=[0.55], score=None, data={})]

## Score

In [20]:
score_idxs, score_inputs = batch.flatten_query_results()

In [21]:
class ScoreResponse(BaseModel):
    valid: bool
    score: float
    data: Optional[dict]
        
    @model_validator(mode='before')
    @classmethod
    def _fill_data(cls, data: Any) -> Any:
        if "data" not in data:
            data["data"] = None
        return data

In [22]:
score_results = [{'score':1., 'valid':True}, 
                  {'score':2., 'valid':True, 'data':{'check':'suspect'}}, 
                  {'score':3., 'valid':True},
                 ]
score_results = [ScoreResponse(**i) for i in score_results]

In [23]:
def scatter_score_response(batch, idxs, score_results):
    
    for (result_idx, score_result) in zip(idxs, score_results):
        q_idx, r_idx = result_idx
        if score_result.valid:
            batch[q_idx][r_idx].score = score_result.score
            if score_result.data:
                batch[q_idx][r_idx].data.update(score_result.data)
        else:
            batch[q_idx][r_idx] = None
            
    batch.cleanup()

In [24]:
scatter_score_response(batch, score_idxs, score_results)

In [25]:
[i[1] for i in batch.enumerate_query_results()]

[Item(item='1', embedding=[0.11], score=1.0, data={}),
 Item(item='4', embedding=[0.44], score=2.0, data={'blah': 5, 'check': 'suspect'}),
 Item(item='5', embedding=[0.55], score=3.0, data={})]

In [None]:
# prune response = FilterResponse

In [None]:
class PruneResult(BaseModel):
    result: bool
    prune_data: Optional[dict]
        
    @model_validator(mode='before')
    @classmethod
    def _fill_data(cls, data: Any) -> Any:
        if "prune_data" not in data:
            data["prune_data"] = None
        return data

In [None]:
prune_results = [{'result':True}, {'result':False}, {'result':True}]
prune_results = [PruneResult(**i) for i in prune_results]

In [None]:
def scatter_prune_response(batch, prune_results):
    
    for i, prune_result in enumerate(prune_results):
        if prune_result.result:
            if prune_result.prune_data:
                batch[i].data['filter_data'] = prune_result.prune_data
        else:
            batch[i] = None
            
    batch.cleanup()

In [None]:
scatter_prune_response(batch, prune_results)