In [None]:
import pandas as pd
import numpy as np
import faiss
import re
import spacy
import pickle
from sentence_transformers import SentenceTransformer

EMBEDDING_MODEL = 'all-MiniLM-L6-v2'
# nlp = spacy.load("en_core_web_sm")

In [None]:
queries = ["What is the Sorcerer’s Stone and what does it do?",
           "How does Harry get into Gryffindor?",
           "What does the Mirror of Erised show Harry?",
           "Who helps Harry get past Fluffy?",
           "What happens during Harry’s first Quidditch match?", ]
complex_queries = [
    "How do Harry, Ron, and Hermione manage to bypass each of the obstacles guarding the Sorcerer’s Stone?",
    "what role does Fluffy play in the protection system?",
    "Snape’s and Quirrell’s interference Harrys first quidditch match"]

In [None]:
def embedding_model():
    model = SentenceTransformer(EMBEDDING_MODEL)
    return model

In [None]:
def load_df(dir_path: str, file_name: str, file_fmt: str):
    file_path = f"{dir_path}/{file_name}.{file_fmt}"
    return pd.read_csv(file_path)


def load_embeddings(file_path="data/embeddings.pkl"):
    with open(file_path, "rb") as f:
        loaded_embeddings = pickle.load(f)
    return loaded_embeddings


def clean_text(text, nlp=None):
    """
    Cleans text by:
    - Normalizing Unicode characters
    - Removing special characters and extra whitespace
    - Lowercasing text
    - Keeping only alphanumeric characters and essential punctuation
    """
    text = re.sub(r"[^a-zA-Z0-9\s.,?!]", "", text)
    text = re.sub(r"\s+", " ", text).strip()
    text = text.lower()

    if nlp is not None:
        doc = nlp(text)
        text = " ".join(token.lemma_ for token in doc)
    return text


def build_faiss_flatl2_index(embeddings: np.ndarray, dim: tuple):
    """
    This function builds a Faiss flat L2 index.
    Args:
        embeddings: An array of shape (n_index, dim) containing the index vectors.
        dim: The dimensionality of the vectors.
    Returns:
        A Faiss flat L2 index.
    """
    index = faiss.IndexFlatL2(dim[1])
    index.add(embeddings)
    return index


def build_IVFIndex(embeddings, dim: tuple, _nprob=None):
    nlist = 18
    quantizer = faiss.IndexFlatL2(dim[1])
    index = faiss.IndexIVFFlat(quantizer, dim[1], nlist)
    index.train(embeddings)
    index.add(embeddings)
    if _nprob is not None:
        index.nprobe = _nprob
    return index


def retrieve_top_passages(query, model, index, chunks, top_n=5):
    query_embedding = model.encode([query], convert_to_numpy=True)
    distances, indices = index.search(query_embedding, top_n)
    retrieved_passages = chunks['chunk'].iloc[indices]
    results = retrieved_passages
    return results


def retrieve(model, index, chunks):
    retrieval_results = {query: [] for query in queries}
    for query in queries:
        retrieval_results[query] = retrieve_top_passages(query, model, index, chunks)
    return retrieval_results

In [None]:
chunks_df = pd.read_csv('data/clean_chunks.csv')
embeddings = load_embeddings('data/clean_embeddings.pkl')
dim = embeddings.shape
index = build_faiss_flatl2_index(embeddings, dim)
passages_retrieved = retrieve(model, index, chunks_df)