In [1]:
from pymilvus import MilvusClient
from sentence_transformers import SentenceTransformer
from tqdm import trange
from langchain_huggingface import HuggingFaceEndpoint
from dotenv import load_dotenv
from langchain_core.prompts import PromptTemplate
from typing import Any
from transformers import BertForSequenceClassification, BertTokenizer
import torch
import pandas as pd

load_dotenv()

  from .autonotebook import tqdm as notebook_tqdm


True

In [2]:
dataset = pd.read_csv('../data/test.csv')
contexts = dataset["context"].dropna().unique().tolist()
qa_pairs = dataset.apply(
    lambda row: {
        "question": row["question"],
        "answer": row["answers"],
        "context": row["context"],
        "need_retrieval": row['need_retrieval']
    },
    axis=1
).tolist()

In [3]:
collection_name = "spbrag"
embedding_dim = 768

milvus_client = MilvusClient(uri="../data/milvus_demo.db")

if milvus_client.has_collection(collection_name):
    milvus_client.drop_collection(collection_name)


In [4]:
milvus_client.create_collection(
    collection_name=collection_name,
    dimension=embedding_dim,
    metric_type="L2",
    auto_id=True,
    primary_field_name="id",
    vector_field_name="embedding",
    enable_dynamic_field=True,
    index_params={
        "index_type": "IVF_FLAT",
        "metric_type": "L2",
        "params": {"nlist": 128},
    },
)

In [5]:
embedding_model = SentenceTransformer("sentence-transformers/all-mpnet-base-v2")

In [6]:
def insert_documents(contexts):
    documents = []
    for context in contexts:
        documents.append(
            {"text": context, "embedding": embedding_model.encode(context).tolist()}
        )

    # Insert in batches of 100
    for i in trange(0, len(documents), 100):
        batch = documents[i : i + 100]
        milvus_client.insert(collection_name, batch)
        print(f"Inserted {i + len(batch)} documents")


insert_documents(contexts)

100%|██████████| 4/4 [00:00<00:00, 39.09it/s]

Inserted 100 documents
Inserted 200 documents
Inserted 300 documents
Inserted 330 documents





In [7]:
llm = HuggingFaceEndpoint(
    repo_id="mistralai/Mistral-7B-Instruct-v0.2",
    task="text-generation",
    max_new_tokens=100,
    do_sample=False,
)

Note: Environment variable`HF_TOKEN` is set and is the current active token independently from the token you've just configured.


In [8]:
model_path = "./bert-text-classification-model"
num_labels = 2

classificator = BertForSequenceClassification.from_pretrained(
    model_path, num_labels=num_labels
)

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
classificator.to(device)

tokenizer_path = "./bert-text-classification-model"
tokenizer = BertTokenizer.from_pretrained(tokenizer_path)

In [9]:
import torch


def predict_class(
    model,
    text: str,
    tokenizer: Any,
    device: torch.device,
    max_length: int = 512,
) -> int:
    model.eval()
    
    inputs = tokenizer(
        text,
        return_tensors="pt",
        max_length=max_length,
        truncation=True,
        padding=True,
    )

    inputs = {k: v.to(device) for k, v in inputs.items() if k != "idx"}

    with torch.no_grad():
        outputs = model(**inputs)
        logits = outputs.logits
        predicted_class = torch.argmax(logits, dim=-1).item()

    return predicted_class

In [10]:
for qa in qa_pairs[:10]:
    query = qa['question']
    print(query)
    predicted_class = predict_class(classificator, query, tokenizer, device)
    r = 'Need Retrieval' if predicted_class else "No Need"
    print(f"Query: {query} | Retrieval: {r}")

What did Yongle want to trade with Tibet?
Query: What did Yongle want to trade with Tibet? | Retrieval: Need Retrieval
Extract the cinema industry and the percentage box office share occupied by that industry in the format {Industry} - {Percentage} and show as a bullet list. If no percentage specified just list the industry name. Indian cinema is composed of multilingual and multi-ethnic film art. In 2019, Hindi cinema represented 44% of box office revenue, followed by Telugu and Tamil film industries, each representing 13%, Malayalam and Kannada film industries, each representing 5%. Other prominent languages in the Indian film industry include Bengali, Marathi, Odia, Punjabi, Gujarati and Bhojpuri. As of 2022, the combined revenue of South Indian film industries has surpassed that of the Mumbai-based Hindi film industry (Bollywood). As of 2022, Telugu cinema leads Indian cinema's box-office revenue.[details 2]
Query: Extract the cinema industry and the percentage box office share occ

In [11]:
def rag_query(llm: Any, question: str, classificator, tokenizer, device, top_k: int = 3, context_len: int = 400) -> str:
    query_embedding = embedding_model.encode(question).tolist()

    predicted_class = predict_class(classificator, query, tokenizer, device)
    
    search_results = milvus_client.search(
        collection_name=collection_name,
        data=[query_embedding],
        limit=top_k,
        output_fields=["text"],
    )

    contexts = (
        [hit["entity"]["text"] for hit in search_results[0]] if search_results else []
    )
    context_str = " ".join(contexts)[:context_len]
    template = (
        "Answer the question based on context:\nContext: {context}\nQuestion: {question}\nAnswer:"
        if contexts and predicted_class
        else "Answer this question:\nQuestion: {question}\nAnswer:"
    )

    prompt_template = PromptTemplate.from_template(template)
    llm_chain = prompt_template | llm

    chain_input = {"question": question}
    if contexts:
        chain_input["context"] = context_str

    response = llm_chain.invoke(chain_input)
    return response, predicted_class

In [12]:
print("\nTesting RAG system...\n")
print("-" * 160)

correct = 0
num_samples = 10
for qa in qa_pairs[:num_samples]:
    query = qa["question"]
    result, predicted_class = rag_query(
        llm, qa["question"], classificator, tokenizer, device
    )
    generated_answer = result  # ["answer"]
    true_answer = qa["answer"]

    need_retrieval = qa["need_retrieval"]

    if need_retrieval == predicted_class:
        correct += 1

    print(
        f"Question: {qa['question'][:45]} | Predicted class: {predicted_class} | True class: {need_retrieval}"
    )
    print(f"Generated answer: {generated_answer[:100]}")
    print(f"True answer: {true_answer}")


Testing RAG system...

----------------------------------------------------------------------------------------------------------------------------------------------------------------




Question: What did Yongle want to trade with Tibet? | Predicted class: 1 | True class: 1
Generated answer:  Tea, horses, and salt.
True answer: tea, horses, and salt
Question: Extract the cinema industry and the percentag | Predicted class: 0 | True class: 0
Generated answer:  - Hindi - 44%
- Telugu - 13%
- Tamil - 13%
- Malayalam - 5%
- Kannada - 5%
- Bengali
- Marathi
- Od
True answer: nan
Question: Adapt the text to make it relevant for a corp | Predicted class: 0 | True class: 0
Generated answer:  "Jane is feeling less prepared for the upcoming board meeting, having exhausted all her preparation
True answer: nan
Question: What magazine did Beyoncé write a story for a | Predicted class: 1 | True class: 1
Generated answer:  Vogue
True answer: Essence
Question: What did the China Digital Times report? | Predicted class: 1 | True class: 1
Generated answer:  The China Digital Times reported that Foxconn, Apple's manufacturer, initially denied labor abuses 
True answer: a close analysis 