# Schemas

> Data Schemas

Standardized data schemas to standardize plugins

In [None]:
#| default_exp schemas

In [None]:
#| hide
from nbdev.showdoc import *
%load_ext autoreload
%autoreload 2

In [None]:
#| export
from emb_opt.imports import *

## High Level Overview

`emb_opt` is designed to run hill climbing algorithms in embedding spaces. In practice, this means we are searching through some explicit vector database or the implicit embedding space of some generative model, which we refer to as a `DataSource`. We denote the `continuous` space as referring to embeddings, and the `discrete` space as referring to discrete things represented by embeddings.

The `DataSource` is queried with a `Query`. The `Query` contains a query embedding and optionally an `item` (some discrete thing represented by the embedding). The `DataSource` uses the `Query` to return a list of `Item` objects. An `Item` represents a discrete thing returned by the `DataSource`

The `Item` results are optionally sent to a `Filter`, which removes results based on some True/False criteria.

The `Item` results are then sent to a `Score` which assigns some numeric score value to each `Item`.

The `Query` and scored `Item` results are sent to a `Update` which uses the scored items to generate a new `Query`. `Update` methods are denoted as `discrete` or `continuous`. `continuous` updates generate new queries purely in embedding space (ie by averaging `Item` embeddings). `discrete` updates create new queries specifically from `Item` results, such that each query can have a specific `item` associated with it (not possible with continuous updates). `continuous` updates generally converge faster, but certain types of `DataSource` may require a discrete item query and therefore be incompatible with `continuous` updates.

Some `Update` methods generate multiple new queries. To control the total number of queries, a `Prune` step is optionally added before the `Update` step.

The general flow is:
1. Start with a `Batch` of `Query` objects
* Query the `DataSource` 
* (optional) Send results to the `Filter`
* Send results to the `Score`
* (optional) `Prune` queries
* Use scored results to `Update` to a new set of queries

The schemas present here define the required input/output structure for each step to allow for fully flexible plugins to the process

### Data Objects

#### Internal Data

`InternalData` tracks internal information as part of the embedding search. This data is managed internally, but may be useful for certain `Prune` or `Update` configurations.

`InternalData.removed` denotes if the related `Item` or `Query` has been removed or invalidated by some step (see `DataSourceResponse`, `FilterResponse`, `ScoreResponse`, `PruneResponse`)

`InternalData.removal_reason` details the removal reason

`InternalData.parent_id` is the ID string of the parent `Query` to the related `Item` or `Query` object. `InternalData.parent_id` always points to a `Query`, never an `Item`

`InternalData.collection_id` groups `Item` and `Query` objects that come from the same initial `Query`. This is useful when an `Update` step generates multiple new queries from a single input

`InternalData.iteration` denotes which iteration of the search created the related `Item` or `Query`

In [None]:
#| export

class InteralData(BaseModel):
    'Internal Data Tracking'
    removed: bool                 # if item/query has been removed by some step
    removal_reason: Optional[str] # reason for removal
    parent_id: Optional[str]      # parent query of item/query
    collection_id: Optional[int]  # collection id of item/query
    iteration: Optional[int]      # current iteration

#### Item 

The `Item` schema is the basic "object" or "thing" we are looking for. The goal of `emb_opt` is to discover an `Item` with a high `score`

`Item.id` is the index/ID of the item (for example the database index). If no ID is provided, one will be created as a UUID. `emb_opt` assumes `Item.id` is unique to the item.

`Item.item` is the discrete thing itself

`Item.score` is the score of the item. `emb_opt` assumes a hill climbing scenario where higher scores are better than lower scores.

`Item.data` is a dictionary container for any other information associated with the item (ie other fields returned from a database query)

In [None]:
#| export

class Item(BaseModel, extra='allow'):
    id: Optional[Union[str, int]] # id/index of item
    item: Optional[Any]           # the item itself
    embedding: List[float]        # embedding representing the item
    score: Optional[float]        # item score
    data: Optional[dict]          # any other associated data
    
    @model_validator(mode='after')
    def _fill_internal(self):
        if not hasattr(self, 'internal'):
            self.internal = InteralData(
                                    removed=False, 
                                    removal_reason=None,
                                    parent_id=None, 
                                    collection_id=None,
                                    iteration=None
                                    )
        
        if self.data is None:
            self.data = {}
            
        if self.id is None:
            self.id = f'item_{str(uuid.uuid1())}'
            
        return self
    
    def update_internal(self, **kwargs):
        self.internal.__dict__.update(kwargs)
        
    @classmethod
    def from_minimal(cls, 
                     id: Optional[Union[str, int]]=None, 
                     item: Optional[Any]=None, 
                     embedding: List[float]=None, 
                     score: Optional[float]=None, 
                     data: Optional[dict]=None):
        'convenience function for creating an `Item` with default `None` values'
        return cls(id=id, item=item, embedding=embedding, score=score, data=data)

In [None]:
item = Item(id=None, embedding=[0.1], item=None, score=None, data=None)
assert item.id
old_id = item.id
item = Item.model_validate(item)
assert item.id == old_id

#### Query

A `Query` is the basic object for searching a `DataSource` and holding `Item` results returned by the search.

`Query.item` is an (optional) discrete item associated with the Query. This is populated automatically when they query is created from an `Item` via `Query.from_item`

`Query.embedding` is the embedding associated with the `Query`

`Query.data` is a dictionary container for any other information associated with the query

`Query.query_results` is a list of `Item` objects returned from a query

In [None]:
#| export

class Query(BaseModel, extra='allow'):
    item: Optional[Any]
    embedding: List[float]
    data: Optional[dict]
    query_results: Optional[list[Item]]
        
    def __iter__(self):
        return iter(self.query_results)

    def __getitem__(self, idx: int):
        return self.query_results[idx]
    
    def __len__(self):
        return len(self.query_results)
    
    def valid_results(self) -> Item:
        '''
        iterates over `self.query_results`, skipping results 
        with `internal.removed=True`
        '''
        for result in self.query_results:
            if not result.internal.removed:
                yield result
    
    def enumerate_query_results(self, skip_removed: bool=True) -> (int, Item):
        '''
        enumerates over `self.query_results`. if `skip_removed=True`,
        results with `internal.removed=True` are ignored
        '''
        for i, result in enumerate(self.query_results):
            if skip_removed:
                if not result.internal.removed:
                    yield (i, result)
            else:
                yield (i, result)
    
    @model_validator(mode='after')
    def _fill_internal(self):
        if not hasattr(self, 'internal'):
            self.internal = InteralData(
                                        removed=False, 
                                        removal_reason=None,      
                                        parent_id=None, 
                                        collection_id=None,
                                        iteration=None
                                    )
            
        if not hasattr(self, 'id'):
            self.id = f'query_{str(uuid.uuid1())}'
        
        if self.query_results is None:
            self.query_results = []
            
        if self.data is None:
            self.data = {}
            
        return self
                
    @classmethod
    def from_item(cls, item: Item):
        '''
        creates a `Query` from an input `Item`. The `item`, `embedding`, and `data` 
        attributes from the `Item` are propagated to the `Query`, as well as the 
        item's parent query ID
        '''
        query = cls(item=item.item, embedding=item.embedding, data=item.data, query_results=None)
        query.data['_source_item_id'] = item.id
        query.update_internal(parent_id=item.internal.parent_id, 
                              collection_id=item.internal.collection_id)
        return query
    
    @classmethod
    def from_parent_query(cls, embedding: List[float], parent_query):
        '''
        creates a `Query` from an input `embedding` and a parent `Query`. The new
        `Query` is created from the `embedding` and assigned the `parent_query` ID as 
        the parent ID
        '''
        query = cls(item=None, embedding=embedding, data={}, query_results=None)
        query.update_internal(parent_id=parent_query.id, 
                              collection_id=parent_query.internal.collection_id)
        return query
    
    def add_query_results(self, query_results: List[Item]) -> None:
        '''
        Adds query results and propagates query parent information to them
        '''
        for result in query_results:
            result.update_internal(parent_id=self.id, 
                                   collection_id=self.internal.collection_id, 
                                   iteration=self.internal.iteration)
            self.query_results.append(result)
    
    def update_internal(self, **kwargs):
        self.internal.__dict__.update(kwargs)
        if (len(self.query_results)>0) and (len(list(self.valid_results()))==0):
            self.internal.__dict__.update({'removed':True, 'removal_reason':'all query results removed'})
            
    @classmethod
    def from_minimal(cls, 
                     item: Optional[Any]=None, 
                     embedding: List[float]=None, 
                     data: Optional[dict]=None,
                     query_results: Optional[List[Item]]=None
                    ):
        'convenience function for creating an `Query` with default `None` values'
        return cls(item=item, embedding=embedding, data=data, query_results=query_results)

A `Query` holds `Items`, tracks parent/child relationships, and allows for convenient iteration

In [None]:
query = Query.from_minimal(embedding=[0.1])
query.update_internal(collection_id=0) # add collection ID

query_results = [
    Item.from_minimal(item='item1', embedding=[0.1]),
    Item.from_minimal(item='item2', embedding=[0.1]),
]

query.add_query_results(query_results)

# iteration over query results
assert len([i for i in query]) == 2

# propagation of query parent data

for query_result in query:
    assert query_result.internal.parent_id == query.id
    assert query_result.internal.collection_id == query.internal.collection_id

Items may be removed by various steps. Removed items are kept within the `Query` for logging purposes. `Query.valid_results` and `Query.enumerate_query_results` allow us to automatically skip removed items during iteration

In [None]:
assert len(list(query.valid_results())) == 2

query.query_results[0].update_internal(removed=True) # set first result to removed

assert len(list(query.valid_results())) == 1

assert len(list(query.enumerate_query_results())) == 1
assert len(list(query.enumerate_query_results(skip_removed=False))) == 2

query.query_results[1].update_internal(removed=True) # set second result to removed
query.update_internal() # update query internal
assert query.internal.removed # query sets itself to removed when all query results are removed

Queries can be created from another `Query` or another `Item`, with automatic data propagation between them

In [None]:
# create query from item
item = Item.from_minimal(item='test_item', embedding=[0.1])
query = Query.from_item(item)
assert query.item == item.item

# create query from query
query = Query.from_minimal(embedding=[0.1])
new_query = Query.from_parent_query(embedding=[0.2], parent_query=query)
assert new_query.internal.parent_id == query.id

#### Batch

The `Batch` object holds a list of `Query` objects and provides convenience functions for iterating over queries and query results

In [None]:
#| export

class Batch(BaseModel):
    queries: List[Query]
        
    def __iter__(self):
        return iter(self.queries)

    def __getitem__(self, idx: int):
        return self.queries[idx]
    
    def __len__(self):
        return len(self.queries)
    
    def get_item(self, 
                 query_index: int, 
                 result_index: Optional[int]=None) -> Union[Query, Item]:
        '''
        Selects an item from the batch.
        
        If `result_index=None`, returns the `Query` found at `self.queries[query_index]`
        
        Otherwise, returns the `Item` found at `self.queries[query_index][result_index]`
        '''
        if result_index is not None:
            return self.queries[query_index][result_index]
        else:
            return self.queries[query_index]
        
    def valid_queries(self) -> Query:
        'Iterates over valid queries'
        for query in self.queries:
            if not query.internal.removed:
                yield query
        
    def enumerate_queries(self, skip_removed=True) -> (Tuple[int, None], Query):
        '''
        enumerates over `self.queries`. if `skip_removed=True`,
        queries with `internal.removed=True` are ignored
        '''
        for i, query in enumerate(self.queries):
            if skip_removed:
                if not query.internal.removed:
                    yield ((i,None), query)
            else:
                yield ((i,None), query)
                
    def enumerate_query_results(self, skip_removed=True) -> (Tuple[int, int], Item):
        '''
        enumerates over results contained in `self.queries`. 
        if `skip_removed=True`, results with `internal.removed=True` 
        are ignored
        '''
        for (i,_), query in self.enumerate_queries(skip_removed):
            for j, result in query.enumerate_query_results(skip_removed):
                yield ((i,j), result)

                
    def flatten_queries(self, skip_removed=True) -> (List[Tuple[int, None]], List[Query]):
        '''
        flattens `self.queries`, returing a list of index values and 
        a list of queries.
        
        if `skip_removed=True`, queries with `internal.removed=True` 
        are ignored
        '''
        idxs = []
        outputs = []
        for i, q in self.enumerate_queries(skip_removed):
            idxs.append(i)
            outputs.append(q)
        return idxs, outputs
                
    def flatten_query_results(self, skip_removed=True) -> (List[Tuple[int, int]], List[Item]):
        '''
        flattens results contained in `self.queries`, 
        returing a list of index values and a list of items.
        
        if `skip_removed=True`, results with `internal.removed=True` 
        are ignored
        '''
        idxs = []
        outputs = []
        for i, r in self.enumerate_query_results(skip_removed):
            idxs.append(i)
            outputs.append(r)
        return idxs, outputs

A `Batch` allows us to iterate over the queries and items in the batch in several ways

In [None]:
def build_test_batch(n_queries, n_items):
    queries = []
    
    for i in range(n_queries):
        query = Query.from_minimal(item=f'query_{i}', embedding=[0.1])
        for j in range(n_items):
            item = Item.from_minimal(item=f'item_{j}', embedding=[0.1])
            query.add_query_results([item])
        queries.append(query)
    return Batch(queries=queries)

n_queries = 3
n_items = 4
batch = build_test_batch(n_queries, n_items)

assert len(list(batch.valid_queries())) == n_queries

idxs, results = batch.flatten_query_results()
assert len(results) == n_queries*n_items
assert batch.get_item(*idxs[0]) == batch[idxs[0][0]][idxs[0][1]]

When items or queries are removed, this is accounted for 

In [None]:
batch = build_test_batch(n_queries, n_items)

batch[1].update_internal(removed=True) # invalidate query
batch[0][0].update_internal(removed=True) # invalidate item
batch[0][1].update_internal(removed=True) # invalidate item

assert len(list(batch.valid_queries())) == n_queries-1 # 1 batch removed

idxs, results = batch.flatten_query_results(skip_removed=False) # return all queries
assert len(results) == n_queries*n_items

# skips results where `removed=True`, and all results under a query with `removed=True`
idxs, results = batch.flatten_query_results(skip_removed=True)

# n_items removed from invalid query 1, 2 items invalidated
assert len(results) == n_queries*n_items - n_items - 2

### Data Source

The `DataSourceFunction` schema defines the interface for data source queries. The function takes a list of `Query` objects and returns a list of `DataSourceResponse` objects.

In [None]:
#| export

class DataSourceResponse(BaseModel):
    valid: bool                # if input `Query` was valid (if False, associated `Query` is removed)
    data: Optional[Dict]       # optional dict of data associated with the query
    query_results: List[Item]  # list of `Item` results

In [None]:
#| export

DataSourceFunction = Callable[List[Query], List[DataSourceResponse]]

### Filter

The `FilterFunction` schema defines the interface for filtering result items. The function takes a list of `Item` objects and returns a list of `FilterResponse` objects.

In [None]:
#| export

class FilterResponse(BaseModel):
    valid: bool           # if the input `Item` is valid (if False, associated `Item` is removed)
    data: Optional[Dict]  # optional dict of data associated with the filter response

In [None]:
#| export

FilterFunction = Callable[List[Item], List[FilterResponse]]

### Score

The `ScoreFunction` schema defines the interface for scoring result items. The function takes a list of `Item` objects and returns a list of `ScoreResponse` objects.

In [None]:
#| export

class ScoreResponse(BaseModel):
    valid: bool             # if the input `Item` is valid (if False, associated `Item` is removed)
    score: Optional[float]  # the score of the input `Item`. Can be `None` if `valid=False`
    data: Optional[Dict]    # optional dict of data associated with the score response

In [None]:
#| export

ScoreFunction = Callable[List[Item], List[ScoreResponse]]

### Prune

The `PruneFunction` schema defines the interface for pruning queries. The function takes a list of `Query` objects and returns a list of `PruneResponse` objects.

In [None]:
#| export

class PruneResponse(BaseModel):
    valid: bool           # if the input `Query` item is valid (if False, the associated `Query` is removed)
    data: Optional[Dict]  # optional dict of data associated with the prune response

In [None]:
#| export

PruneFunction = Callable[List[Query], List[PruneResponse]]

### Update

The `UpdateFunction` schema defines the interface for pruning queries. The function takes a list of `Query` objects and returns a list of new `Query` objects.

In [None]:
#| export

UpdateFunction = Callable[List[Query], List[Query]]