# utils

> util functions

In [None]:
#| default_exp utils

In [None]:
#| hide
from nbdev.showdoc import *
%load_ext autoreload
%autoreload 2

In [None]:
#| export
from emb_opt.imports import *
from emb_opt.schemas import Batch, Query, Item

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
#| export

def batch_list(inputs: List[Any],   # input list to be batched
               batch_size: int      # batch size
              ) -> List[List[Any]]: # batched output list
    '''
    batches the input list into chunks of size `batch_size`, with the last batch ragged
    
    if `batch_size=0`, returns list of all inputs
    '''
    if batch_size==0:
        output = [inputs]
    else:
        output = [inputs[i:i+batch_size] for i in range(0, len(inputs), batch_size)]
    return output

def unbatch_list(inputs: List[List[Any]] # input batched list
                ) -> List[Any]:          # flattened output list
    'flattens a batched list'
    return [item for sublist in inputs for item in sublist]

In [None]:
inputs = list(range(10))
assert unbatch_list(batch_list(inputs, 3)) == inputs

In [None]:
#| export

def whiten(scores: np.ndarray # vector shape (n,) of scores to whiten
          ) -> np.ndarray:    # vector shape (n,) whitened scores
    'Whitens vector of scores'
    mean = scores.mean()
    var = scores.var()
    
    return (scores - mean) / np.sqrt(var + 1e-8)

In [None]:
#| export

def build_batch_from_embeddings(embeddings: List[List[float]] # input embeddings
                               ) -> Batch:                    # output batch
    '''
    creates a `Batch` from a list of `embeddings`. Each embedding 
    is converted to a `Query` with a unique `collection_id`
    '''
    queries = []
    for i, embedding in enumerate(embeddings):
        query = Query.from_minimal(embedding=embedding)
        query.update_internal(collection_id=i)
        queries.append(query)
        
    batch = Batch(queries=queries) 
    return batch

In [None]:
build_batch_from_embeddings([[0.1], [0.2]])

Batch(queries=[Query(item=None, embedding=[0.1], data={}, query_results=[], internal=InteralData(removed=False, removal_reason=None, parent_id=None, collection_id=0, iteration=None), id='query_54271b66-50d1-11ee-b64b-7b1d5a84b1d4'), Query(item=None, embedding=[0.2], data={}, query_results=[], internal=InteralData(removed=False, removal_reason=None, parent_id=None, collection_id=1, iteration=None), id='query_54271b67-50d1-11ee-b64b-7b1d5a84b1d4')])

In [None]:
#| export

def build_batch_from_items(items: List[Item],      # input items
                           remap_collections=False # if collection ID should be remapped
                          ) -> Batch:              # output batch
    '''
    creates a `Batch` from a list of `Item` objects. Each `Item` 
    is converted to a `Query`. If `remap_collections=True`, each 
    `Query` is given a unique `collection_id`. Otherwise, each 
    `Query` retains the `collection_id` of the `Item` used to 
    create it
    '''
    queries = []
    for i, item in enumerate(items):
        query = Query.from_item(item)
        if remap_collections:
            query.update_internal(collection_id=i)
        queries.append(query)
    batch = Batch(queries=queries) 
    return batch

In [None]:
build_batch_from_items([Item.from_minimal(embedding=[0.1])], remap_collections=True)

Batch(queries=[Query(item=None, embedding=[0.1], data={'_source_item_id': 'item_95b496da-50d1-11ee-b64b-7b1d5a84b1d4'}, query_results=[], internal=InteralData(removed=False, removal_reason=None, parent_id=None, collection_id=0, iteration=None), id='query_95b496db-50d1-11ee-b64b-7b1d5a84b1d4')])