# utils

> util functions

In [None]:
#| default_exp utils

In [None]:
#| hide
from nbdev.showdoc import *
%load_ext autoreload
%autoreload 2

In [None]:
#| export
from emb_opt.imports import *
from emb_opt.schemas import Batch, Query, Item

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
#| export

def batch_list(inputs: list, 
               batch_size: int
              ) -> list[list]:
    if batch_size==0:
        output = [inputs]
    else:
        output = [inputs[i:i+batch_size] for i in range(0, len(inputs), batch_size)]
    return output

def unbatch_list(inputs: list[list]) -> list:
    return [item for sublist in inputs for item in sublist]

In [None]:
inputs = list(range(10))
assert unbatch_list(batch_list(inputs, 3)) == inputs

In [None]:
#| export

def whiten(scores: np.ndarray # vector shape (n,) of scores to whiten
          ) -> np.ndarray: # vector shape (n,) whitened scores
    'Whitens vector of scores'
    mean = scores.mean()
    var = scores.var()
    
    return (scores - mean) / np.sqrt(var + 1e-8)

In [None]:
#| export

def build_batch_from_embeddings(embeddings: List[List[float]]) -> Batch:
    queries = []
    for i, embedding in enumerate(embeddings):
        query = Query.from_minimal(embedding=embedding)
        query.update_internal(collection_id=i)
        queries.append(query)
        
    batch = Batch(queries=queries) 
    return batch

In [None]:
build_batch_from_embeddings([[0.1], [0.2]])

Batch(queries=[Query(item=None, embedding=[0.1], data={}, query_results=[], internal=InteralData(removed=False, removal_reason=None, parent_id=None, collection_id=0, iteration=None), id='query_b992c17c-5053-11ee-b64b-7b1d5a84b1d4'), Query(item=None, embedding=[0.2], data={}, query_results=[], internal=InteralData(removed=False, removal_reason=None, parent_id=None, collection_id=1, iteration=None), id='query_b992c17d-5053-11ee-b64b-7b1d5a84b1d4')])

In [None]:
#| export

def build_batch_from_items(items: List[Item], remap_collections=False) -> Batch:
    queries = []
    for i, item in enumerate(items):
        query = Query.from_item(item)
        if remap_collections:
            query.update_internal(collection_id=i)
        queries.append(query)
    batch = Batch(queries=queries) 
    return batch

In [None]:
build_batch_from_items([Item.from_minimal(embedding=[0.1])], remap_collections=True)

Batch(queries=[Query(item=None, embedding=[0.1], data={}, query_results=[], internal=InteralData(removed=False, removal_reason=None, parent_id=None, collection_id=0, iteration=None), id='query_bbc191d1-5053-11ee-b64b-7b1d5a84b1d4')])