# Schemas

> Data Schemas

In [None]:
#| default_exp schemas

In [None]:
#| hide
from nbdev.showdoc import *
%load_ext autoreload
%autoreload 2

In [None]:
#| export
from emb_opt.imports import *

  from .autonotebook import tqdm as notebook_tqdm


### Data Objects

In [None]:
#| export

class InteralData(BaseModel):
    id: str
    removed: bool
    removal_reason: Optional[str]
    parent_id: Optional[str]
    collection_id: Optional[int]
        
    @model_validator(mode='before')
    @classmethod
    def _fill_data(cls, inputs: Any) -> Any:
        if isinstance(inputs, BaseModel):
            inputs = inputs.model_dump()
            
        if inputs.get('id', None) is None:
            inputs['id'] = str(uuid.uuid1())
            
        if inputs.get('removal_reason', None) is None:
            inputs['removal_reason'] = ''
            
        return inputs

In [None]:
#| export

class Item(BaseModel):
    item: Optional[str]
    embedding: List[float]
    score: Optional[float]
    data: Optional[dict]
    internal: Optional[InteralData]
        
    @model_validator(mode='before')
    @classmethod
    def _fill_data(cls, inputs: Any) -> Any:
        if isinstance(inputs, BaseModel):
            inputs = inputs.model_dump()

        if inputs.get('data', None) is None:
            inputs['data'] = {}
            
        if inputs.get('internal', None) is None:
            inputs['internal'] = InteralData(id=None, removed=False, 
                                           parent_id=None, collection_id=None)

        if inputs.get('score', None) is None:
            inputs['score'] = None

        if inputs.get('item', None) is None:
            inputs['item'] = None

        return inputs
    
    def update_internal(self, **kwargs):
        self.internal.__dict__.update(kwargs)

In [None]:
#| export

class Query(BaseModel):
    item: Optional[str]
    embedding: Optional[List[float]]
    data: Optional[dict]
    query_results: Optional[list[Item]]
    internal: Optional[InteralData]
        
    def __iter__(self):
        return iter(self.query_results)

    def __getitem__(self, idx: int):
        return self.query_results[idx]
    
    def __len__(self):
        return len(self.query_results)
    
    def valid_results(self):
        for result in self.query_results:
            if not result.internal.removed:
                yield result
    
    def enumerate_query_results(self, skip_removed=True):
        for i, result in enumerate(self.query_results):
            if skip_removed:
                if not result.internal.removed:
                    yield (i, result)
            else:
                yield (i, result)
        
    @model_validator(mode='before')
    @classmethod
    def _fill_data(cls, inputs: Any) -> Any:
        if isinstance(inputs, BaseModel):
            inputs = inputs.model_dump()
                        
        if inputs.get('data', None) is None:
            inputs['data'] = {}
            
        if inputs.get('internal', None) is None:
            inputs['internal'] = InteralData(id=str(uuid.uuid1()), removed=False, 
                                           parent_id=None, collection_id=None)
        
        if inputs.get('query_results', None) is None:
            inputs['query_results'] = []
            
        if inputs.get('item', None) is None:
            inputs['item'] = None

        return inputs
                
    @classmethod
    def from_item(cls, item: Item):
        inputs = {
            'item' : item.item,
            'embedding' : item.embedding,
            'data' : item.data,
            'query_results' : [],
            'internal' : {
                'id' : None,
                'removed' : False,
                'parent_id' : item.internal.id,
                'collection_id' : item.internal.collection_id
            }
        }
        return cls(**inputs)
    
    @classmethod
    def from_parent_query(cls, embedding: List[float], parent_query):
        inputs = {
            'embedding' : embedding,
            'internal' : {
                'id' : None,
                'removed' : False,
                'parent_id' : parent_query.internal.id,
                'collection_id' : parent_query.internal.collection_id
            }
        }
        
        return cls(**inputs)
    
    def add_query_results(self, query_results: List[Item]):
        
        parent_id = self.internal.id
        collection_id = self.internal.collection_id
        for result in query_results:
            result.update_internal(parent_id=parent_id, collection_id=collection_id)
            self.query_results.append(result)
    
    def update_internal(self, **kwargs):
        self.internal.__dict__.update(kwargs)
        if (len(self.query_results)>0) and (len(list(self.valid_results()))==0):
            self.internal.__dict__.update({'removed':True, 'removal_reason':'all query results removed'})

In [None]:
item = Item(embedding=[0.1])
query = Query.from_item(item)
query.update_internal(collection_id=0)

assert query.internal.parent_id == item.internal.id

query2 = Query.from_parent_query([.2], parent_query=query)
assert query2.internal.parent_id == query.internal.id
assert query2.internal.collection_id == query.internal.collection_id

query = Query(embedding=[0.1])
query.update_internal(collection_id=0)
result = Item(embedding=[0.1])
query.add_query_results([result])
assert query[0].internal.parent_id == query.internal.id
assert query[0].internal.collection_id == query.internal.collection_id


query = Query(embedding=[0.1])
results = [Item(embedding=[0.1]), Item(embedding=[0.2])]
results[0].update_internal(removed=True)
query.add_query_results(results)
assert len(list(query.valid_results())) == 1

query = Query(embedding=[0.1])
results = [Item(embedding=[0.1])]
results[0].update_internal(removed=True)
query.add_query_results(results)
query.update_internal()
assert query.internal.removed

In [None]:
#| export

class Batch(BaseModel):
    queries: List[Query]
        
    def __iter__(self):
        return iter(self.queries)

    def __getitem__(self, idx: int):
        return self.queries[idx]
    
    def __len__(self):
        return len(self.queries)
    
    def get_item(self, query_index, result_index=None):
        if result_index is not None:
            return self.queries[query_index][result_index]
        else:
            return self.queries[query_index]
        
    def enumerate_queries(self, skip_removed=True):
        for i, query in enumerate(self.queries):
            if skip_removed:
                if not query.internal.removed:
                    yield ((i,None), query)
            else:
                yield ((i,None), query)
                
    def enumerate_query_results(self, skip_removed=True):
        for (i,_), query in self.enumerate_queries(skip_removed):
            for j, result in query.enumerate_query_results(skip_removed):
                yield ((i,j), result)

                
    def flatten_queries(self, skip_removed=True):
        idxs = []
        outputs = []
        for i, q in self.enumerate_queries(skip_removed):
            idxs.append(i)
            outputs.append(q)
        return idxs, outputs
                
    def flatten_query_results(self, skip_removed=True):
        idxs = []
        outputs = []
        for i, r in self.enumerate_query_results(skip_removed):
            idxs.append(i)
            outputs.append(r)
        return idxs, outputs

In [None]:
queries = []
for i in range(1,3):
    items = [Item(item=str(j), embedding=[0.1*j]) for j in range(2)]
    items[0].update_internal(removed=True)
    query = Query(embedding=[0.1])
    query.update_internal(collection_id=i)
    query.add_query_results(items)
    queries.append(query)
    
batch = Batch(queries=queries)

assert len(batch.flatten_query_results()[1])==2

### Data Source

In [None]:
#| export

class DataSourceResponse(BaseModel):
    valid: bool
    data: Optional[Dict]
    query_results: List[Item]

In [None]:
#| export

DataSourceFunction = Callable[List[Query], List[DataSourceResponse]]

### Filter

In [None]:
#| export

class FilterResponse(BaseModel):
    valid: bool
    data: Optional[dict]
        
    @model_validator(mode='before')
    @classmethod
    def _fill_data(cls, data: Any) -> Any:
        if "data" not in data:
            data["data"] = None
        return data

In [None]:
#| export

FilterFunction = Callable[List[Item], List[FilterResponse]]

### Score

In [None]:
#| export

class ScoreResponse(BaseModel):
    valid: bool
    score: Optional[float]
    data: Optional[dict]
        
    @model_validator(mode='before')
    @classmethod
    def _fill_data(cls, data: Any) -> Any:
        if "data" not in data:
            data["data"] = None
        return data

In [None]:
#| export

ScoreFunction = Callable[List[Item], List[ScoreResponse]]

### Prune

In [None]:
#| export

class PruneResponse(BaseModel):
    valid: bool
    data: Optional[dict]
        
    @model_validator(mode='before')
    @classmethod
    def _fill_data(cls, data: Any) -> Any:
        if "data" not in data:
            data["data"] = None
        return data

In [None]:
#| export

PruneFunction = Callable[List[Query], List[PruneResponse]]

### Update

In [None]:
#| export

UpdateFunction = Callable[List[Query], List[Query]]