In [1]:
%%writefile ../src/retriever.py
"""
retriever.py
-  Recovery Functions: filter  by product and returns chunks top-k.
"""
#comment to test overwritting

from __future__ import annotations
import numpy as np
import pandas as pd
from typing import List, Optional
from sentence_transformers import SentenceTransformer
from .embed_index import VectorIndex, build_embeddings, load_embedding_model
from sklearn.metrics.pairwise import cosine_similarity

class Retriever:
    """
    - Receives a corpus (DF wiht columns: 'chunk' and 'product').
    - Creates embbedings and an index VectorIndex(cosine).
    - query() returns the chunks top-k; opcionally filters by product.
    """
    def __init__(self, corpus: pd.DataFrame, model_name: str="sentence-transformers/all-MiniLM-L6-v2",) -> None:
        if not {"chunk", "product"}.issubset(corpus.columns):
            raise ValueError("Corpus must contain 'chunk' and 'product' columns")
        self.corpus = corpus.reset_index(drop=True)
        self.model: SentenceTransformer = load_embedding_model(model_name)    
        
        #corpus embeddings (sting lists)
        chunks: List[str] = self.corpus["chunk"].astype(str).tolist() #lost of size n
        self.embeddings = build_embeddings(chunks, self.model) #(n,d) float32 unitary
        #index over all corpus
        self.index=VectorIndex(self.embeddings) #index over all corpus

    def query(self, question: str, product: Optional[str]=None, top_k: int=6, return_scores: bool=False):
        # query embedding(1,d)
        qv= build_embeddings([question], self.model)  #(n,d) float32 unitary
        
        #if they ask filter by product
        if product:
            mask = self.corpus["product"].eq(product).to_numpy()  #shape(n,)
            if mask.any():
                vecs=self.embeddings[mask]                        #(m,d)
                sims=cosine_similarity(qv, vecs)[0]               #(m,)
                k = min(top_k, vecs.shape[0])
                order = np.argsort(-sims)[:k]                     #local indices 0..m-1
                chunks = self.corpus.loc[mask, "chunk"].iloc[order].tolist()
                return (chunks, sims[order].tolist()) if return_scores else chunks
            # if no rows for that product, fall back to full corpus

        #with no filters: entire index using
        k=min(top_k, len(self.corpus))
        idxs, scores= self.index.search(qv, top_k=top_k)
        chunks= self.corpus["chunk"].iloc[idxs].tolist()
        return (chunks, scores.tolist()) if return_scores else chunks #return chunk with False or chunk and score with True

Overwriting ../src/retriever.py


Testing the code

In [29]:
import pandas as pd

df = pd.DataFrame({
    "product": ["A","A","B","B"],
    "chunk": [
        "El gato se sentó en la alfombra.",
        "Un felino descansando sobre una alfombra suave.",
        "Hoy lloverá en Monterrey según el pronóstico.",
        "Recetas fáciles de pollo a la parrilla."
]})

ret = Retriever(df)

# Global
print(ret.query("lloviendo el dia de hoy en Nuevo León", top_k=4, return_scores=True))
# Filter by product
print(ret.query("lloviendo el dia de hoy en Nuevo León", product="A", top_k=2, return_scores=True))

Batches: 100%|███████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 16.67it/s]
Batches: 100%|███████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 41.06it/s]


(['Hoy lloverá en Monterrey según el pronóstico.', 'El gato se sentó en la alfombra.', 'Un felino descansando sobre una alfombra suave.', 'Recetas fáciles de pollo a la parrilla.'], [0.5791796445846558, 0.5183272957801819, 0.4466252028942108, 0.38428571820259094])


Batches: 100%|███████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 37.94it/s]

(['El gato se sentó en la alfombra.', 'Un felino descansando sobre una alfombra suave.'], [0.5183272957801819, 0.4466252028942108])





In [13]:
df.head()

Unnamed: 0,product,chunk
0,A,El gato se sentó en la alfombra.
1,A,Un felino descansando sobre una alfombra suave.
2,B,Hoy lloverá en Monterrey según el pronóstico.
3,B,Recetas fáciles de pollo a la parrilla.


In [28]:
chunks, scores = ret.query("Lloviendo el día de hoy en Monterrey", top_k=3, return_scores=True)

qv= embed_index.build_embeddings(["Lloviendo el día de hoy en Monterrey"], ret.model)
idxs, scs = ret.index.search(qv, top_k=3)

for rank, (i,s) in enumerate(zip(idxs, scs), 1):
    print(f"{rank}. idx={i} | {df.loc[i, 'chunk']} | score={s:.3f}")

Batches: 100%|███████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 46.06it/s]
Batches: 100%|███████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 35.89it/s]

1. idx=2 | Hoy lloverá en Monterrey según el pronóstico. | score=0.780
2. idx=0 | El gato se sentó en la alfombra. | score=0.509
3. idx=1 | Un felino descansando sobre una alfombra suave. | score=0.467



