# Schemas

> Data Schemas

In [None]:
#| default_exp schemas

In [None]:
#| hide
from nbdev.showdoc import *
%load_ext autoreload
%autoreload 2

In [None]:
#| export
from emb_opt.imports import *

  from .autonotebook import tqdm as notebook_tqdm


### Data Objects

In [None]:
#| export

class InteralData(BaseModel):
    removed: bool
    removal_reason: Optional[str]
    parent_id: Optional[str]
    collection_id: Optional[int]
    iteration: Optional[int]

In [None]:
#| export

class Item(BaseModel, extra='allow'):
    id: Optional[Union[str, int]]
    item: Optional[Any]
    embedding: List[float]
    score: Optional[float]
    data: Optional[dict]
    
    @model_validator(mode='after')
    def _fill_internal(self):
        if not hasattr(self, 'internal'):
            self.internal = InteralData(
                                    removed=False, 
                                    removal_reason=None,
                                    parent_id=None, 
                                    collection_id=None,
                                    iteration=None
                                    )
        
        if self.data is None:
            self.data = {}
            
        if self.id is None:
            self.id = f'item_{str(uuid.uuid1())}'
            
        return self
    
    def update_internal(self, **kwargs):
        self.internal.__dict__.update(kwargs)
        
    @classmethod
    def from_minimal(cls, id=None, item=None, embedding=None, score=None, data=None):
        return cls(id=id, item=item, embedding=embedding, score=score, data=data)

In [None]:
item = Item(id=None, embedding=[0.1], item=None, score=None, data=None)
assert item.id
old_id = item.id
item = Item.model_validate(item)
assert item.id == old_id

In [None]:
#| export

class Query(BaseModel, extra='allow'):
    item: Optional[Any]
    embedding: List[float]
    data: Optional[dict]
    query_results: Optional[list[Item]]
        
    def __iter__(self):
        return iter(self.query_results)

    def __getitem__(self, idx: int):
        return self.query_results[idx]
    
    def __len__(self):
        return len(self.query_results)
    
    def valid_results(self):
        for result in self.query_results:
            if not result.internal.removed:
                yield result
    
    def enumerate_query_results(self, skip_removed=True):
        for i, result in enumerate(self.query_results):
            if skip_removed:
                if not result.internal.removed:
                    yield (i, result)
            else:
                yield (i, result)
    
    @model_validator(mode='after')
    def _fill_internal(self):
        if not hasattr(self, 'internal'):
            self.internal = InteralData(
                                        removed=False, 
                                        removal_reason=None,      
                                        parent_id=None, 
                                        collection_id=None,
                                        iteration=None
                                    )
            
        if not hasattr(self, 'id'):
            self.id = f'query_{str(uuid.uuid1())}'
        
        if self.query_results is None:
            self.query_results = []
            
        if self.data is None:
            self.data = {}
            
        return self
                
    @classmethod
    def from_item(cls, item: Item):
        query = cls(item=item.item, embedding=item.embedding, data=item.data, query_results=None)
        query.data['_source_item_id'] = item.id
        query.update_internal(parent_id=item.internal.parent_id, collection_id=item.internal.collection_id)
        return query
    
    @classmethod
    def from_parent_query(cls, embedding: List[float], parent_query):
        query = cls(item=None, embedding=embedding, data=None, query_results=None)
        query.update_internal(parent_id=parent_query.id, collection_id=parent_query.internal.collection_id)
        return query
    
    def add_query_results(self, query_results: List[Item]):
        parent_id = self.id
        collection_id = self.internal.collection_id
        iteration = self.internal.iteration
        for result in query_results:
            result.update_internal(parent_id=parent_id, collection_id=collection_id, iteration=iteration)
            self.query_results.append(result)
    
    def update_internal(self, **kwargs):
        self.internal.__dict__.update(kwargs)
        if (len(self.query_results)>0) and (len(list(self.valid_results()))==0):
            self.internal.__dict__.update({'removed':True, 'removal_reason':'all query results removed'})
            
    @classmethod
    def from_minimal(cls, item=None, embedding=None, data=None, query_results=None):
        return cls(item=item, embedding=embedding, data=data, query_results=query_results)

In [None]:
# test query to item
q1 = Query.from_minimal(embedding=[0.1])
q1.update_internal(collection_id=0)
i1 = Item.from_minimal(embedding=[0.1])
q1.add_query_results([i1])

assert i1.internal.collection_id == q1.internal.collection_id

# test item to query
q2 = Query.from_item(i1)

assert q1.id == i1.internal.parent_id
assert q2.internal.parent_id == q1.id
assert q1.internal.collection_id == q2.internal.collection_id

# test query to query
q1 = Query.from_minimal(embedding=[0.1])
q1.update_internal(collection_id=0)
q2 = Query.from_parent_query(embedding=[0.1], parent_query=q1)

assert q2.internal.parent_id == q1.id
assert q1.internal.collection_id == q2.internal.collection_id

# test removals

q1 = Query.from_minimal(embedding=[0.1])
results = [Item.from_minimal(embedding=[0.1]), 
           Item.from_minimal(embedding=[0.2])]
q1.add_query_results(results)

assert len(list(q1.valid_results())) == 2

q1[0].update_internal(removed=True)
q1.update_internal()

assert len(list(q1.valid_results())) == 1
assert not q1.internal.removed

q1[1].update_internal(removed=True)
q1.update_internal()

assert len(list(q1.valid_results())) == 0
assert q1.internal.removed

In [None]:
#| export

class Batch(BaseModel):
    queries: List[Query]
        
    def __iter__(self):
        return iter(self.queries)

    def __getitem__(self, idx: int):
        return self.queries[idx]
    
    def __len__(self):
        return len(self.queries)
    
    def get_item(self, query_index, result_index=None):
        if result_index is not None:
            return self.queries[query_index][result_index]
        else:
            return self.queries[query_index]
        
    def valid_queries(self):
        for query in self.queries:
            if not query.internal.removed:
                yield query
        
    def enumerate_queries(self, skip_removed=True):
        for i, query in enumerate(self.queries):
            if skip_removed:
                if not query.internal.removed:
                    yield ((i,None), query)
            else:
                yield ((i,None), query)
                
    def enumerate_query_results(self, skip_removed=True):
        for (i,_), query in self.enumerate_queries(skip_removed):
            for j, result in query.enumerate_query_results(skip_removed):
                yield ((i,j), result)

                
    def flatten_queries(self, skip_removed=True):
        idxs = []
        outputs = []
        for i, q in self.enumerate_queries(skip_removed):
            idxs.append(i)
            outputs.append(q)
        return idxs, outputs
                
    def flatten_query_results(self, skip_removed=True):
        idxs = []
        outputs = []
        for i, r in self.enumerate_query_results(skip_removed):
            idxs.append(i)
            outputs.append(r)
        return idxs, outputs

In [None]:
queries = []
for i in range(1,3):
    items = [Item.from_minimal(item=str(j), embedding=[0.1*j]) for j in range(2)]
    items[0].update_internal(removed=True)
    query = Query.from_minimal(embedding=[0.1])
    query.update_internal(collection_id=i)
    query.add_query_results(items)
    queries.append(query)
    
batch = Batch(queries=queries)
assert len(batch.flatten_query_results()[1])==2

### Data Source

In [None]:
#| export

class DataSourceResponse(BaseModel):
    valid: bool
    data: Optional[Dict]
    query_results: List[Item]

In [None]:
#| export

DataSourceFunction = Callable[List[Query], List[DataSourceResponse]]

### Filter

In [None]:
#| export

class FilterResponse(BaseModel):
    valid: bool
    data: Optional[dict]
        
    @model_validator(mode='before')
    @classmethod
    def _fill_data(cls, data: Any) -> Any:
        if "data" not in data:
            data["data"] = None
        return data

In [None]:
#| export

FilterFunction = Callable[List[Item], List[FilterResponse]]

### Score

In [None]:
#| export

class ScoreResponse(BaseModel):
    valid: bool
    score: Optional[float]
    data: Optional[dict]
        
    @model_validator(mode='before')
    @classmethod
    def _fill_data(cls, data: Any) -> Any:
        if "data" not in data:
            data["data"] = None
        return data

In [None]:
#| export

ScoreFunction = Callable[List[Item], List[ScoreResponse]]

### Prune

In [None]:
#| export

class PruneResponse(BaseModel):
    valid: bool
    data: Optional[dict]
        
    @model_validator(mode='before')
    @classmethod
    def _fill_data(cls, data: Any) -> Any:
        if "data" not in data:
            data["data"] = None
        return data

In [None]:
#| export

PruneFunction = Callable[List[Query], List[PruneResponse]]

### Update

In [None]:
#| export

UpdateFunction = Callable[List[Query], List[Query]]