# query update

> query update functions

In [None]:
#| default_exp query_update

In [None]:
#| hide
from nbdev.showdoc import *
%load_ext autoreload
%autoreload 2

In [None]:
#| export
from emb_opt.imports import *
from emb_opt.utils import pack_dataset, whiten
from emb_opt.backends.hf import HFDatabase
from emb_opt.core import Score

In [None]:
#| export

class QueryUpdate():
    def __call__(self, query_vectors: np.ndarray, query_dataset: Dataset) -> np.ndarray:
        return query_vectors

In [None]:
#| export

class RLUpdate(QueryUpdate):
    def __init__(self, lr: float):
        self.lr = lr
        
    def __call__(self, query_vectors: np.ndarray, query_dataset: Dataset) -> np.ndarray:
        
        packed_dict = pack_dataset(query_dataset, 'query_idx', ['embedding', 'score'])
        grads = []
        
        for query_idx in range(query_vectors.shape[0]):
            embs = np.array(packed_dict[query_idx]['embedding'])
            scores = np.array(packed_dict[query_idx]['score'])

            advantages = whiten(scores)
            grad = (advantages[:,None] * (2*(query_vectors[query_idx][None] - embs))).mean(0)
            grads.append(grad)

        grads = np.array(grads)
        updated_query_vectors = query_vectors - self.lr*grads
        return updated_query_vectors

In [None]:
def dummy_score(row):
    return np.linalg.norm(row['embedding'])

vectors = np.random.randn(128, 256)
vector_dataset = Dataset.from_list([{'embedding' : i} for i in vectors])
vector_dataset.add_faiss_index('embedding')

db = HFDatabase(vector_dataset, 'embedding', 10)
score = Score(dummy_score)
update_strategy = RLUpdate(0.5)


query_vectors = np.random.randn(3, 256)/10
query_dataset = db.query(query_vectors)
query_dataset = score(query_dataset)
updated_queries = update_strategy(query_vectors, query_dataset)

assert np.all(np.linalg.norm(updated_queries, axis=-1) > np.linalg.norm(query_vectors, axis=-1))

100%|███████████████████████████████████████████| 1/1 [00:00<00:00, 3728.27it/s]
                                                                                

In [None]:
#| export

class KNNUpdate(QueryUpdate):
    def __init__(self, k: int, score_weighting: bool=True):
        self.k = k
        self.score_weighting = score_weighting
        
    def __call__(self, query_vectors: np.ndarray, query_dataset: Dataset) -> np.ndarray:
        
        packed_dict = pack_dataset(query_dataset, 'query_idx', ['embedding', 'score'])
        new_queries = []
        
        for query_idx in range(query_vectors.shape[0]):
            embs = np.array(packed_dict[query_idx]['embedding'])
            scores = np.array(packed_dict[query_idx]['score'])

            topk_idxs = scores.argsort()[::-1][:self.k]
            topk_embs = embs[topk_idxs]
            topk_scores = scores[topk_idxs]

            if self.score_weighting:
                new_queries.append(np.average(topk_embs, 0, weights=topk_scores))
            else:
                new_queries.append(np.average(topk_embs, 0))

        query_vectors = np.array(new_queries)
        
        return query_vectors

In [None]:
def dummy_score(row):
    return np.linalg.norm(row['embedding'])

vectors = np.random.randn(128, 256)
vector_dataset = Dataset.from_list([{'embedding' : i} for i in vectors])
vector_dataset.add_faiss_index('embedding')

db = HFDatabase(vector_dataset, 'embedding', 10)
score = Score(dummy_score)
update_strategy = KNNUpdate(3)


query_vectors = np.random.randn(3, 256)/10
query_dataset = db.query(query_vectors)
query_dataset = score(query_dataset)
updated_queries = update_strategy(query_vectors, query_dataset)

assert np.all(np.linalg.norm(updated_queries, axis=-1) > np.linalg.norm(query_vectors, axis=-1))

100%|███████████████████████████████████████████| 1/1 [00:00<00:00, 2328.88it/s]
                                                                                

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()