# Schemas

> Data Schemas

In [1]:
#| default_exp schemas

In [2]:
#| hide
from nbdev.showdoc import *
%load_ext autoreload
%autoreload 2

In [3]:
#| export
from emb_opt.imports import *

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
#| export

class Item(BaseModel):
    item: Optional[str]
    embedding: list[float]
    score: Optional[float]
    data: Optional[dict]
        
    @field_validator('data')
    def set_data(cls, data):
        return data or {}

In [5]:
#| export

class Query(BaseModel):
    item: Optional[str]
    embedding: list[float]
    query_results: Optional[list[Item]]
        
    @field_validator('query_results')
    def set_query_results(cls, query_results):
        return query_results or []

In [6]:
class DataSourceInput(BaseModel):
    item: Optional[str]
    embedding: Optional[list[float]]

In [7]:
class QueryIndex(BaseModel):
    batch_index: int
    query_index: Optional[int]
    result_index: Optional[int]

In [8]:
class FilterRequestItem(BaseModel):
    item: Optional[str]
    embedding: Optional[list[float]]

In [9]:
class FilterResult(BaseModel):
    result: bool
    filter_data: Optional[dict]
        
    @model_validator(mode='before')
    @classmethod
    def _fill_data(cls, data: Any) -> Any:
        if "filter_data" not in data:
            data["filter_data"] = None
        return data

In [10]:
class ScoreRequestItem(BaseModel):
    item: Optional[str]
    embedding: Optional[list[float]]

class ScoreResult(BaseModel):
    result: float
    score_data: Optional[dict]
        
    @model_validator(mode='before')
    @classmethod
    def _fill_data(cls, data: Any) -> Any:
        if "score_data" not in data:
            data["score_data"] = None
        return data

In [11]:
class QueryBatch(RootModel):
    root: List[Query]

    def __iter__(self):
        return iter(self.root)

    def __getitem__(self, idx: int):
        return self.root[idx]

In [12]:
class Batch(RootModel):
    root: List[QueryBatch]
        
    def __iter__(self):
        return iter(self.root)

    def __getitem__(self, idx: int):
        return self.root[idx]
    
    def index(self, batch_index, query_index=None, result_index=None):
        result = self.root[batch_index]
        if query_index is not None:
            result = result[query_index]
            if result_index is not None:
                result = result.query_results[result_index]
        
        return result
    
    def object_index(self, query_idx: QueryIndex):
        return self.index(**query_idx.model_dump())
    
    def enumerate_batches(self):
        for i, batch in enumerate(self.root):
            idx = QueryIndex(batch_index=i, query_index=None, result_index=None)
            yield (idx, batch)
            
    def enumerate_queries(self):
        for i, batch in enumerate(self.root):
            for j, query in enumerate(batch):
                idx = QueryIndex(batch_index=i, query_index=j, result_index=None)
                yield (idx, query)
                
    def enumerate_query_results(self):
        for i, batch in enumerate(self.root):
            for j, query in enumerate(batch):
                for k, result in enumerate(query.query_results):
                    idx = QueryIndex(batch_index=i, query_index=j, result_index=k)
                    yield (idx, result)

In [47]:
q1 = Query(item='q1', embedding=[0.1], query_results=[])
q2 = Query(item='q2', embedding=[0.2], query_results=[])
q3 = Query(item='q3', embedding=[0.3], query_results=[])

qb1 = QueryBatch([q1,q2])
qb2 = QueryBatch([q3])

batch = Batch([qb1, qb2])

In [14]:
def gather_data_request(batch, include_item=True, include_embedding=True):
    idxs = []
    outputs = []
    
    for idx, query in batch.enumerate_queries():
        idxs.append(idx)
        item = query.item if include_item else None
        embedding = query.embedding if include_embedding else None
        outputs.append(DataSourceInput(item=item, embedding=embedding))
        
    return idxs, outputs

In [48]:
data_idxs, data_inputs = gather_data_request(batch)

In [49]:
r1 = Item(item='1', embedding=[0.11], score=None, data=None)
r2 = Item(item='2', embedding=[0.22], score=None, data=None)
r3 = Item(item='3', embedding=[0.33], score=None, data=None)
r4 = Item(item='4', embedding=[0.44], score=None, data=None)
r5 = Item(item='5', embedding=[0.55], score=None, data=None)

data_results = [[r1, r2], [r3,], [r4, r5]]

In [17]:
def scatter_data_response(batch, idxs, data_results):
    for (query_idx, item_list) in zip(idxs, data_results):
        batch.object_index(query_idx).query_results = item_list

In [50]:
scatter_data_response(batch, data_idxs, data_results)

In [19]:
def gather_filter_request(batch, include_item=True, include_embedding=True):
    idxs = []
    outputs = []
    
    for idx, query_result in batch.enumerate_query_results():
        idxs.append(idx)
        
        item = query_result.item if include_item else None
        embedding = query_result.embedding if include_embedding else None
        outputs.append(FilterRequestItem(item=item, embedding=embedding))
        
    return idxs, outputs

In [51]:
filter_idxs, filter_inputs = gather_filter_request(batch)

In [52]:
filter_results = [{'result':True}, 
                  {'result':False}, 
                  {'result':True, 'filter_data':{'blah':5}},
                  {'result':True},
                  {'result':False}
                 ]
filter_results = [FilterResult(**i) for i in filter_results]

In [53]:
[i[1] for i in batch.enumerate_query_results()]

[Item(item='1', embedding=[0.11], score=None, data={}),
 Item(item='2', embedding=[0.22], score=None, data={}),
 Item(item='3', embedding=[0.33], score=None, data={}),
 Item(item='4', embedding=[0.44], score=None, data={}),
 Item(item='5', embedding=[0.55], score=None, data={})]

In [23]:
def scatter_filter_response(batch, idxs, filter_results):
    for (result_idx, filter_result) in zip(idxs, filter_results):
        query = batch.index(result_idx.batch_index, result_idx.query_index)
        result = filter_result.result
        data = filter_result.filter_data
        
        if result==False:
            query.query_results[result_idx.result_index] = None
        else:
            if data:
                query.query_results[result_idx.result_index].data['filter_data'] = data
        
    for _, query in batch.enumerate_queries():
        query.query_results = [i for i in query.query_results if i]

In [54]:
scatter_filter_response(batch, filter_idxs, filter_results)

In [55]:
[i[1] for i in batch.enumerate_query_results()]

[Item(item='1', embedding=[0.11], score=None, data={}),
 Item(item='3', embedding=[0.33], score=None, data={'filter_data': {'blah': 5}}),
 Item(item='4', embedding=[0.44], score=None, data={})]

In [26]:
def gather_score_request(batch, include_item=True, include_embedding=True):
    idxs = []
    outputs = []
    
    for idx, query_result in batch.enumerate_query_results():
        idxs.append(idx)
        
        item = query_result.item if include_item else None
        embedding = query_result.embedding if include_embedding else None
        outputs.append(ScoreRequestItem(item=item, embedding=embedding))
        
    return idxs, outputs

In [56]:
score_idxs, score_inputs = gather_score_request(batch)

In [57]:
score_results = [{'result':1.}, 
                  {'result':2., 'score_data':{'check':'suspect'}}, 
                  {'result':3.},
                 ]
score_results = [ScoreResult(**i) for i in score_results]

In [29]:
def scatter_score_response(batch, idxs, score_results):
    for (result_idx, score_result) in zip(idxs, score_results):
        query_result = batch.object_index(result_idx)
        query_result.score = score_result.result
        
        if score_result.score_data:
            query_result.data['score_data'] = score_result.score_data

In [58]:
scatter_score_response(batch, score_idxs, score_results)

In [59]:
prune_inputs = [[i for i in j] for j in batch.root]

In [45]:
class PruneResultItem(BaseModel):
    result: bool
        
class PruneResult(RootModel):
    root: List[List[PruneResultItem]]
        
    def __iter__(self):
        return iter(self.root)

In [60]:
prune_results = [
    [
        {'result' : False},
        {'result' : True}
    ],
    [
        {'result' : False}
    ]
]

prune_results = PruneResult(prune_results)

In [34]:
def scatter_prune_response(batch, prune_results):
    for batch_idx, i in enumerate(prune_results):
        for query_idx, prune_result in enumerate(i):
            if prune_result.result==False:
                batch.root[batch_idx].root[query_idx] = None
                
    for _, query_batch in batch.enumerate_batches():
        query_batch.root = [i for i in query_batch.root if i]
        
    batch.root = [i for i in batch.root if i.root]

In [61]:
scatter_prune_response(batch, prune_results)