In [1]:
import mlflow.experiments
import nest_asyncio
nest_asyncio.apply()
import chromadb
import pandas as pd

import llama_index.core
from llama_index.core import Settings, SimpleDirectoryReader, VectorStoreIndex, StorageContext
from llama_index.core.node_parser import  SentenceSplitter
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
from llama_index.vector_stores.chroma import ChromaVectorStore
from llama_index.llms.ollama import Ollama
from llama_index.core.schema import TextNode
from llama_index.core.retrievers import QueryFusionRetriever
from llama_index.retrievers.bm25 import BM25Retriever

from llama_index.core.evaluation import (
    generate_question_context_pairs,
    EmbeddingQAFinetuneDataset,
    RetrieverEvaluator
)

import mlflow

mlflow.llama_index.autolog()

mlflow.set_experiment("rust-book-rag")

  from .autonotebook import tqdm as notebook_tqdm


<Experiment: artifact_location='file:///home/carlos/Documents/repos/rust-programming/rust-rag/notebooks/mlruns/923516027696088727', creation_time=1731166672733, experiment_id='923516027696088727', last_update_time=1731166672733, lifecycle_stage='active', name='rust-book-rag', tags={}>

In [7]:
# import phoenix as px

# # Look for a URL in the output to open the App in a browser.
# px.launch_app()

# llama_index.core.set_global_handler("arize_phoenix")

# Setup

In [2]:
mpnet_embed_model = HuggingFaceEmbedding(model_name="sentence-transformers/all-mpnet-base-v2", trust_remote_code=True)

llm = Ollama(model="llama3.2:latest", request_timeout=60, temperature=0)
# qwen2 = Ollama(model="qwen2.5:latest", request_timeout=60)

Settings.embed_model = mpnet_embed_model
Settings.llm = llm



In [3]:
QA_GENERATE_PROMPT_TMPL = """\
Context information is below.

---------------------
{context_str}
---------------------

Given the context information and not prior knowledge.
generate only questions based on the below query.

You are a Teacher/ Professor. Your task is to setup \
{num_questions_per_chunk} questions for an upcoming \
quiz/examination. The questions should be diverse in nature \
across the document. Restrict the questions to the \
context information provided. Your response should include \
the questions separated by a newline and nothing else.
"""

In [4]:
def get_nodes_from_index(index):
    """Gets the nodes from the index"""
    retriever = index.as_retriever(similarity_top_k=99999999999)
    all_nodes = retriever.retrieve("dummy")
    all_nodes = [item.node for item in all_nodes]
    return all_nodes


def build_index(documents, embed_model=Settings.embed_model, db_path="../chromadb", collection_name="rust_book", rebuild=False, distance_fn="l2"):
    """Builds the index"""

    db = chromadb.PersistentClient(db_path)

    if rebuild:
        db.delete_collection(name=collection_name)

    collection = db.get_or_create_collection(collection_name, metadata={"hnsw:space": distance_fn})
    vector_store = ChromaVectorStore(chroma_collection=collection)
    storage_context = StorageContext.from_defaults(vector_store=vector_store)

    if collection.count() > 0 and not rebuild:
        index = VectorStoreIndex.from_vector_store(vector_store, storage_context=storage_context, embed_model=embed_model)
    else:
        index = VectorStoreIndex.from_documents(documents, storage_context=storage_context, embed_model=embed_model)

    return db, collection, vector_store, index


def display_results(name, eval_results, metrics, return_agg=True):
    """Display results from evaluate."""

    metric_dicts = []
    for eval_result in eval_results:
        metric_dict = eval_result.metric_vals_dict
        metric_dicts.append(metric_dict)

    full_df = pd.DataFrame(metric_dicts)

    columns = {
        "retrievers": [name],
        **{k: [full_df[k].mean()] for k in metrics},
    }

    metric_df = pd.DataFrame(columns)

    if return_agg:
        return metric_df
    else:
        return full_df, metric_df

async def evaluate_retriever(retriever, qa_dataset, metrics = ["hit_rate", "mrr", "precision", "recall", "ap", "ndcg"]):
    retriever_evaluator = RetrieverEvaluator.from_metric_names(
        metrics, retriever=retriever
    )
    eval_results = await retriever_evaluator.aevaluate_dataset(qa_dataset)
    return  display_results("baseline top-2 eval", eval_results, metrics=metrics)


async def log_retriever_eval(retriever, retriever_name, **kwargs):
    with mlflow.start_run():
        results = await evaluate_retriever(retriever, **kwargs)
        mlflow.log_param("retriever", retriever_name)
        mlflow.log_metrics(*results.drop(columns=["retrievers"]).to_dict(orient="records"))
    return results

async def evaluate_embed_model(
    dataset,
    embed_model,
    retriever_name=None,
    top_k=2,
):
    """Evaluates the embedding model on a given dataset."""
    corpus = dataset.corpus
    queries = dataset.queries
    relevant_docs = dataset.relevant_docs

    nodes = [TextNode(id_=id_, text=text) for id_, text in corpus.items()]
    index = VectorStoreIndex(
        nodes, embed_model=embed_model, show_progress=False
    )
    retriever = index.as_retriever(similarity_top_k=top_k)

    retriever_name = retriever_name or embed_model.model_name
    results = await log_retriever_eval(retriever,  qa_dataset=dataset, retriever_name=retriever_name)
    return results

# Baseline Retriever

Note: If we want to compare different storage / embedding methods, we need to rebuild the index and qa-dataset

In [5]:
documents = SimpleDirectoryReader('../txt').load_data()

db, collection, vector_store, index = build_index(documents)
nodes = get_nodes_from_index(index)

retriever = index.as_retriever(similarity_top_k=2)
query_engine = index.as_query_engine()

Number of requested results 99999999999 is greater than number of elements in index 384, updating n_results = 384


### Generate qa dataset

In [6]:
qa_dataset = generate_question_context_pairs(nodes=nodes, num_questions_per_chunk=2, qa_generate_prompt_tmpl=QA_GENERATE_PROMPT_TMPL)
qa_dataset.save_json("../data/qa_dataset.json"),

qa_dataset = EmbeddingQAFinetuneDataset.from_json("../data/qa_dataset.json")

qa_corpus = qa_dataset.corpus
qa_nodes = [TextNode(id_=id_, text=text) for id_, text in qa_corpus.items()]

100%|██████████| 384/384 [05:36<00:00,  1.14it/s]


In [11]:
await evaluate_embed_model(qa_dataset, embed_model=mpnet_embed_model, retriever_name="mpnet top-2 eval", top_k=2)

Unnamed: 0,retrievers,hit_rate,mrr,precision,recall,ap,ndcg
0,baseline top-2 eval,0.567708,0.499349,0.283854,0.567708,0.499349,0.31715


# Query Fusion Retriever

In [None]:
bm25_retriever = BM25Retriever.from_defaults(nodes=nodes, similarity_top_k=2)
query_fusion_retriever = QueryFusionRetriever(
    [index.as_retriever(), bm25_retriever],
    similarity_top_k=2,
    num_queries=1,
    mode="reciprocal_rerank",
    verbose=False,
)

In [None]:
await log_retriever_eval(query_fusion_retriever, retriever_name="baseline query fusion", qa_dataset=qa_dataset)

Unnamed: 0,retrievers,hit_rate,mrr,precision,recall,ap,ndcg
0,baseline top-2 eval,0.716667,0.603333,0.358333,0.716667,0.603333,0.388129


Results are better! Note that the precision will be at most 0.5 because we're always retrieving 2 documents while the qa-dataset has only 1 expected results per question.

# Testing a different embedding model

### All-MiniLM-L12-v2

In [12]:
mini_lm_embed = HuggingFaceEmbedding(model_name="sentence-transformers/all-MiniLM-L12-v2")



Evaluation on the same dataset

In [None]:
mini_lm_results = await evaluate_embed_model(qa_dataset, embed_model=mini_lm_embed, retriever_name="miniLM top-2 eval", top_k=2)
mini_lm_results

Changing the embedding model makes a big difference! Let's try a couple more:

### Stella EN 400m

In [None]:
stella_small_embed = HuggingFaceEmbedding(model_name="dunzhang/stella_en_400M_v5", trust_remote_code=True)

stella_results = await evaluate_embed_model(qa_dataset, embed_model=stella_small_embed, retriever_name="stella top-2 eval", top_k=2)
stella_results

INFO:sentence_transformers.SentenceTransformer:Load pretrained SentenceTransformer: dunzhang/stella_en_400M_v5
Load pretrained SentenceTransformer: dunzhang/stella_en_400M_v5


Some weights of the model checkpoint at dunzhang/stella_en_400M_v5 were not used when initializing NewModel: ['new.pooler.dense.bias', 'new.pooler.dense.weight']
- This IS expected if you are initializing NewModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing NewModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


INFO:sentence_transformers.SentenceTransformer:2 prompts are loaded, with the keys: ['query', 'text']
2 prompts are loaded, with the keys: ['query', 'text']


Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Unnamed: 0,retrievers,hit_rate,mrr,precision,recall,ap,ndcg
0,baseline top-2 eval,0.75,0.686667,0.375,0.75,0.686667,0.431196


Out of curiosity, let's see how stella performs with a fusion query retriever:

In [None]:

index = VectorStoreIndex(
    qa_nodes, embed_model=stella_small_embed, show_progress=False
)

bm25_retriever = BM25Retriever.from_defaults(nodes=qa_nodes, similarity_top_k=2)
stella_query_fusion_retriever = QueryFusionRetriever(
    [index.as_retriever(), bm25_retriever],
    similarity_top_k=2,
    num_queries=1,
    mode="reciprocal_rerank",
    verbose=False,
)

results = await log_retriever_eval(stella_query_fusion_retriever,  qa_dataset=qa_dataset, retriever_name="stella query-fusion top-2")
results

Unnamed: 0,retrievers,hit_rate,mrr,precision,recall,ap,ndcg
0,baseline top-2 eval,0.85,0.74,0.425,0.85,0.74,0.47139


And let's test one more embedding model

### bge-large

In [None]:
bge_large_embed = HuggingFaceEmbedding(model_name="BAAI/bge-large-en-v1.5", trust_remote_code=False)

bge_large_results = await evaluate_embed_model(qa_dataset, embed_model=bge_large_embed, retriever_name="bge-large top-2 eval", top_k=2)
bge_large_results

Unnamed: 0,retrievers,hit_rate,mrr,precision,recall,ap,ndcg
0,baseline top-2 eval,0.756667,0.691667,0.378333,0.756667,0.691667,0.43453


# Embedding fine tuning

Let's fine tune these models and see which one gives better performance. Let's try first with MiniLM as it is a small model and should be 'quick enough'

In [14]:
from sklearn.model_selection import train_test_split

train_nodes, all_test_nodes = train_test_split(qa_nodes, test_size=0.3)
eval_nodes, test_nodes = train_test_split(all_test_nodes, test_size=0.5)

print(f"Train size: ", len(train_nodes))
print(f"Valid size: ", len(eval_nodes))
print(f"Test size: ", len(test_nodes))

Train size:  268
Valid size:  58
Test size:  58


In [5]:
from llama_index.finetuning import generate_qa_embedding_pairs

# train_dataset = generate_qa_embedding_pairs(
#     llm=Settings.llm,
#     nodes=train_nodes,
#     output_path="../data/ft_train_dataset.json",
#     qa_generate_prompt_tmpl=QA_GENERATE_PROMPT_TMPL,
#     num_questions_per_chunk=2,
#     verbose=False,
#     save_every=50
# )

# eval_dataset = generate_qa_embedding_pairs(
#     llm=Settings.llm,
#     nodes=eval_nodes,
#     output_path="../data/ft_eval_dataset.json",
#     qa_generate_prompt_tmpl=QA_GENERATE_PROMPT_TMPL,
#     num_questions_per_chunk=2,
#     verbose=False
# )

# test_dataset = generate_qa_embedding_pairs(
#     llm=Settings.llm,
#     nodes=test_nodes,
#     output_path="../data/ft_test_dataset.json",
#     qa_generate_prompt_tmpl=QA_GENERATE_PROMPT_TMPL,
#     num_questions_per_chunk=2,
#     verbose=False
# )

train_dataset = EmbeddingQAFinetuneDataset.from_json("../data/ft_train_dataset.json")
eval_dataset = EmbeddingQAFinetuneDataset.from_json("../data/ft_eval_dataset.json")
test_dataset = EmbeddingQAFinetuneDataset.from_json("../data/ft_test_dataset.json")

### Using SentenceTransformers directly

In [6]:
from sentence_transformers import InputExample, SentenceTransformer, losses
from torch.utils.data import DataLoader
from sentence_transformers.evaluation import InformationRetrievalEvaluator
from datasets import Dataset, DatasetBuilder, load_dataset

In [7]:
def build_st_dataset(dataset):
    anchors = [q for _, q in dataset.queries.items()]
    positives = [q[0] for _, q in dataset.relevant_docs.items()]
    return Dataset.from_dict({"anchor": anchors, "positive": positives})

In [9]:
from sentence_transformers import (
    SentenceTransformer,
    SentenceTransformerTrainer,
    SentenceTransformerTrainingArguments,
    SentenceTransformerModelCardData,
)

from sentence_transformers.training_args import BatchSamplers


train_ds = build_st_dataset(train_dataset)
eval_ds = build_st_dataset(eval_dataset)
test_ds = build_st_dataset(test_dataset)

dev_evaluator = InformationRetrievalEvaluator(queries=eval_dataset.queries, corpus=eval_dataset.corpus, relevant_docs=eval_dataset.relevant_docs)

ft_model = SentenceTransformer("all-MiniLM-L12-v2")
loss = losses.MultipleNegativesRankingLoss(model=ft_model)


args = SentenceTransformerTrainingArguments(
    # Required parameter:
    output_dir="../models/ft_mini_lm",
    # Optional training parameters:
    num_train_epochs=10,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    learning_rate=2e-5,
    warmup_ratio=0.1,
    fp16=True,  # Set to False if you get an error that your GPU can't run on FP16
    bf16=False,  # Set to True if you have a GPU that supports BF16
    batch_sampler=BatchSamplers.NO_DUPLICATES,  # MultipleNegativesRankingLoss benefits from no duplicate samples in a batch
    # Optional tracking/debugging parameters:
    eval_strategy="steps",
    eval_steps=10,
    save_strategy="steps",
    save_steps=100,
    save_total_limit=2,
    logging_steps=50,
    run_name="mini-lm-full-run",  # Will be used in W&B if `wandb` is installed
)


INFO:sentence_transformers.SentenceTransformer:Use pytorch device_name: cuda
Use pytorch device_name: cuda
INFO:sentence_transformers.SentenceTransformer:Load pretrained SentenceTransformer: all-MiniLM-L12-v2
Load pretrained SentenceTransformer: all-MiniLM-L12-v2




In [10]:
trainer = SentenceTransformerTrainer(
    model=ft_model,
    args=args,
    train_dataset=train_ds,
    eval_dataset=test_ds,
    loss=loss,
    evaluator=dev_evaluator,
)
trainer.train()



INFO:sentence_transformers.evaluation.InformationRetrievalEvaluator:Information Retrieval Evaluation of the model on the  dataset:
Information Retrieval Evaluation of the model on the  dataset:
INFO:sentence_transformers.evaluation.InformationRetrievalEvaluator:Queries: 116
Queries: 116
INFO:sentence_transformers.evaluation.InformationRetrievalEvaluator:Corpus: 58

Corpus: 58

INFO:sentence_transformers.evaluation.InformationRetrievalEvaluator:Score-Function: cosine
Score-Function: cosine
INFO:sentence_transformers.evaluation.InformationRetrievalEvaluator:Accuracy@1: 65.52%
Accuracy@1: 65.52%
INFO:sentence_transformers.evaluation.InformationRetrievalEvaluator:Accuracy@3: 82.76%
Accuracy@3: 82.76%
INFO:sentence_transformers.evaluation.InformationRetrievalEvaluator:Accuracy@5: 87.93%
Accuracy@5: 87.93%
INFO:sentence_transformers.evaluation.InformationRetrievalEvaluator:Accuracy@10: 94.83%
Accuracy@10: 94.83%
INFO:sentence_transformers.evaluation.InformationRetrievalEvaluator:Precision@1:

2024/11/10 22:31:33 ERROR mlflow.utils.async_logging.async_logging_queue: Run Id a26a0472671a48a191bda9a1994d6240: Failed to log run data: Exception: Invalid value "eval_cosine_accuracy@1" for parameter 'metrics[1].name' supplied: Names may only contain alphanumerics, underscores (_), dashes (-), periods (.), spaces ( ), colon(:) and slashes (/).
                                                 
 95%|█████████▍| 370/390 [00:39<00:01, 12.48it/s]

{'eval_loss': 3.3589484691619873, 'eval_cosine_accuracy@1': 0.6551724137931034, 'eval_cosine_accuracy@3': 0.8275862068965517, 'eval_cosine_accuracy@5': 0.8793103448275862, 'eval_cosine_accuracy@10': 0.9482758620689655, 'eval_cosine_precision@1': 0.6551724137931034, 'eval_cosine_precision@3': 0.2758620689655172, 'eval_cosine_precision@5': 0.17586206896551723, 'eval_cosine_precision@10': 0.09482758620689653, 'eval_cosine_recall@1': 0.6551724137931034, 'eval_cosine_recall@3': 0.8275862068965517, 'eval_cosine_recall@5': 0.8793103448275862, 'eval_cosine_recall@10': 0.9482758620689655, 'eval_cosine_ndcg@10': 0.7995888861811478, 'eval_cosine_mrr@10': 0.7522475369458129, 'eval_cosine_map@100': 0.7540109869470925, 'eval_dot_accuracy@1': 0.6551724137931034, 'eval_dot_accuracy@3': 0.8275862068965517, 'eval_dot_accuracy@5': 0.8793103448275862, 'eval_dot_accuracy@10': 0.9482758620689655, 'eval_dot_precision@1': 0.6551724137931034, 'eval_dot_precision@3': 0.2758620689655172, 'eval_dot_precision@5': 

 97%|█████████▋| 379/390 [00:40<00:00, 13.33it/s]

INFO:sentence_transformers.evaluation.InformationRetrievalEvaluator:Information Retrieval Evaluation of the model on the  dataset:
Information Retrieval Evaluation of the model on the  dataset:
INFO:sentence_transformers.evaluation.InformationRetrievalEvaluator:Queries: 116
Queries: 116
INFO:sentence_transformers.evaluation.InformationRetrievalEvaluator:Corpus: 58

Corpus: 58

INFO:sentence_transformers.evaluation.InformationRetrievalEvaluator:Score-Function: cosine
Score-Function: cosine
INFO:sentence_transformers.evaluation.InformationRetrievalEvaluator:Accuracy@1: 65.52%
Accuracy@1: 65.52%
INFO:sentence_transformers.evaluation.InformationRetrievalEvaluator:Accuracy@3: 82.76%
Accuracy@3: 82.76%
INFO:sentence_transformers.evaluation.InformationRetrievalEvaluator:Accuracy@5: 87.93%
Accuracy@5: 87.93%
INFO:sentence_transformers.evaluation.InformationRetrievalEvaluator:Accuracy@10: 94.83%
Accuracy@10: 94.83%
INFO:sentence_transformers.evaluation.InformationRetrievalEvaluator:Precision@1:

2024/11/10 22:31:34 ERROR mlflow.utils.async_logging.async_logging_queue: Run Id a26a0472671a48a191bda9a1994d6240: Failed to log run data: Exception: Invalid value "eval_cosine_accuracy@1" for parameter 'metrics[1].name' supplied: Names may only contain alphanumerics, underscores (_), dashes (-), periods (.), spaces ( ), colon(:) and slashes (/).
                                                 
 97%|█████████▋| 380/390 [00:40<00:00, 13.33it/s]

{'eval_loss': 3.3618862628936768, 'eval_cosine_accuracy@1': 0.6551724137931034, 'eval_cosine_accuracy@3': 0.8275862068965517, 'eval_cosine_accuracy@5': 0.8793103448275862, 'eval_cosine_accuracy@10': 0.9482758620689655, 'eval_cosine_precision@1': 0.6551724137931034, 'eval_cosine_precision@3': 0.2758620689655172, 'eval_cosine_precision@5': 0.17586206896551723, 'eval_cosine_precision@10': 0.09482758620689653, 'eval_cosine_recall@1': 0.6551724137931034, 'eval_cosine_recall@3': 0.8275862068965517, 'eval_cosine_recall@5': 0.8793103448275862, 'eval_cosine_recall@10': 0.9482758620689655, 'eval_cosine_ndcg@10': 0.7997133248115077, 'eval_cosine_mrr@10': 0.7523672687465792, 'eval_cosine_map@100': 0.7541307187478586, 'eval_dot_accuracy@1': 0.6551724137931034, 'eval_dot_accuracy@3': 0.8275862068965517, 'eval_dot_accuracy@5': 0.8793103448275862, 'eval_dot_accuracy@10': 0.9482758620689655, 'eval_dot_precision@1': 0.6551724137931034, 'eval_dot_precision@3': 0.2758620689655172, 'eval_dot_precision@5': 

100%|█████████▉| 389/390 [00:41<00:00, 12.84it/s]

INFO:sentence_transformers.evaluation.InformationRetrievalEvaluator:Information Retrieval Evaluation of the model on the  dataset:
Information Retrieval Evaluation of the model on the  dataset:
INFO:sentence_transformers.evaluation.InformationRetrievalEvaluator:Queries: 116
Queries: 116
INFO:sentence_transformers.evaluation.InformationRetrievalEvaluator:Corpus: 58

Corpus: 58

INFO:sentence_transformers.evaluation.InformationRetrievalEvaluator:Score-Function: cosine
Score-Function: cosine
INFO:sentence_transformers.evaluation.InformationRetrievalEvaluator:Accuracy@1: 65.52%
Accuracy@1: 65.52%
INFO:sentence_transformers.evaluation.InformationRetrievalEvaluator:Accuracy@3: 82.76%
Accuracy@3: 82.76%
INFO:sentence_transformers.evaluation.InformationRetrievalEvaluator:Accuracy@5: 87.93%
Accuracy@5: 87.93%
INFO:sentence_transformers.evaluation.InformationRetrievalEvaluator:Accuracy@10: 94.83%
Accuracy@10: 94.83%
INFO:sentence_transformers.evaluation.InformationRetrievalEvaluator:Precision@1:

2024/11/10 22:31:35 ERROR mlflow.utils.async_logging.async_logging_queue: Run Id a26a0472671a48a191bda9a1994d6240: Failed to log run data: Exception: Invalid value "eval_cosine_accuracy@1" for parameter 'metrics[1].name' supplied: Names may only contain alphanumerics, underscores (_), dashes (-), periods (.), spaces ( ), colon(:) and slashes (/).
                                                 
100%|██████████| 390/390 [00:41<00:00, 12.84it/s]

{'eval_loss': 3.3631532192230225, 'eval_cosine_accuracy@1': 0.6551724137931034, 'eval_cosine_accuracy@3': 0.8275862068965517, 'eval_cosine_accuracy@5': 0.8793103448275862, 'eval_cosine_accuracy@10': 0.9482758620689655, 'eval_cosine_precision@1': 0.6551724137931034, 'eval_cosine_precision@3': 0.2758620689655172, 'eval_cosine_precision@5': 0.17586206896551723, 'eval_cosine_precision@10': 0.09482758620689653, 'eval_cosine_recall@1': 0.6551724137931034, 'eval_cosine_recall@3': 0.8275862068965517, 'eval_cosine_recall@5': 0.8793103448275862, 'eval_cosine_recall@10': 0.9482758620689655, 'eval_cosine_ndcg@10': 0.7997133248115077, 'eval_cosine_mrr@10': 0.7523672687465792, 'eval_cosine_map@100': 0.7541307187478586, 'eval_dot_accuracy@1': 0.6551724137931034, 'eval_dot_accuracy@3': 0.8275862068965517, 'eval_dot_accuracy@5': 0.8793103448275862, 'eval_dot_accuracy@10': 0.9482758620689655, 'eval_dot_precision@1': 0.6551724137931034, 'eval_dot_precision@3': 0.2758620689655172, 'eval_dot_precision@5': 

100%|██████████| 390/390 [00:42<00:00, 12.84it/s]

{'train_runtime': 42.6311, 'train_samples_per_second': 144.026, 'train_steps_per_second': 9.148, 'train_loss': 1.5817057634011293, 'epoch': 9.77}


100%|██████████| 390/390 [00:43<00:00,  8.95it/s]


TrainOutput(global_step=390, training_loss=1.5817057634011293, metrics={'train_runtime': 42.6311, 'train_samples_per_second': 144.026, 'train_steps_per_second': 9.148, 'total_flos': 0.0, 'train_loss': 1.5817057634011293, 'epoch': 9.76923076923077})

In [11]:
ft_model.save_pretrained("../models/ft_mini_lm/final")

INFO:sentence_transformers.SentenceTransformer:Save model to ../models/ft_mini_lm/final
Save model to ../models/ft_mini_lm/final


In [12]:
loaded_ft_model = HuggingFaceEmbedding("../models/ft_mini_lm/final")

INFO:sentence_transformers.SentenceTransformer:Load pretrained SentenceTransformer: ../models/ft_mini_lm/final
Load pretrained SentenceTransformer: ../models/ft_mini_lm/final


INFO:sentence_transformers.SentenceTransformer:2 prompts are loaded, with the keys: ['query', 'text']
2 prompts are loaded, with the keys: ['query', 'text']


In [None]:
ft_mini_lm_results = await evaluate_embed_model(test_dataset, embed_model=loaded_ft_model, retriever_name="ft miniLM top-2 eval", top_k=2)
ft_mini_lm_results

In [15]:
ft_mini_lm_results

Unnamed: 0,retrievers,hit_rate,mrr,precision,recall,ap,ndcg
0,baseline top-2 eval,0.655844,0.574675,0.327922,0.655844,0.574675,0.365393
