<a href="https://colab.research.google.com/github/adidam/rag-impl/blob/main/Real_World_RAG_Impl.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
!pip install datasets
!pip install sentence-transformers
!pip install rank-bm25
!pip install torch transformers
!pip install huggingface_hub
!pip install langchain-community
!pip install accelerate
!pip install rank-bm25 nltk
!pip install scipy

Collecting datasets
  Downloading datasets-3.2.0-py3-none-any.whl.metadata (20 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess<0.70.17 (from datasets)
  Downloading multiprocess-0.70.16-py311-none-any.whl.metadata (7.2 kB)
Collecting fsspec<=2024.9.0,>=2023.1.0 (from fsspec[http]<=2024.9.0,>=2023.1.0->datasets)
  Downloading fsspec-2024.9.0-py3-none-any.whl.metadata (11 kB)
Downloading datasets-3.2.0-py3-none-any.whl (480 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m480.6/480.6 kB[0m [31m30.3 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading dill-0.3.8-py3-none-any.whl (116 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m10.5 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading fsspec-2024.9.0-py3-none-any.whl 

## **Adding the Imports**

In [3]:
import numpy as np
import os
import torch
import torch.nn as nn
from torchvision import datasets, transforms, models
from torch.utils.data import random_split
import torch.optim as optim
import matplotlib.pyplot as plt
import nltk

nltk.download('punkt_tab')

[nltk_data] Downloading package punkt_tab to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt_tab.zip.


True

## **Loading the RAGBench Dataset**

## **Chunking the Dataset**

In [4]:
# embedding_model = "sentence-transformers/all-MiniLM-L6-v2"
embedding_model = "BAAI/LLM-Embedder"
# embedding_model = "BAAI/bge-large-en"

In [5]:
# New code - 12/4 10 pm

from nltk.tokenize import sent_tokenize
from transformers import AutoTokenizer

# Initialize the tokenizer
tokenizer = AutoTokenizer.from_pretrained(embedding_model)
#tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")

# Sliding window configuration
TOKEN_LIMIT = 512
SLIDING_WINDOW_OVERLAP = 100  # Overlap between consecutive chunks (in tokens)

# Function for chunking with token limit and sliding window
def chunk_with_token_limit(text, token_limit=512, overlap=100):
    sentences = sent_tokenize(text)  # Split text into sentences
    chunks = []  # Store resulting chunks
    current_chunk = []  # Temporarily hold sentences for the current chunk
    current_chunk_tokens = 0  # Token count for the current chunk

    for sentence in sentences:
        # Tokenize the sentence and calculate its token count
        sentence_tokens = tokenizer.tokenize(sentence)
        num_tokens = len(sentence_tokens)

        # print(f"Tokens: {sentence_tokens[0]}")

        # If adding this sentence exceeds the token limit
        if current_chunk_tokens + num_tokens > token_limit:
            # Save the current chunk
            chunk_text = " ".join(current_chunk)
            chunks.append(chunk_text)

            # Prepare the next chunk with overlap
            overlap_tokens = tokenizer.tokenize(" ".join(current_chunk[-1:]))
            current_chunk = [sentence for sentence in current_chunk[-(overlap // len(overlap_tokens)) :]] if current_chunk else []
            current_chunk_tokens = sum(len(tokenizer.tokenize(sent)) for sent in current_chunk)

        # Add the sentence to the current chunk
        current_chunk.append(sentence)
        current_chunk_tokens += num_tokens

    # Add the last chunk if it exists
    if current_chunk:
        chunk_text = " ".join(current_chunk)
        chunks.append(chunk_text)

    return chunks

def process_document_with_identifiers(document):
    processed_data = []
    title_count = -1  # to start from 0
    # print("document>>>>>>>",document)
    for section in document:
        section_chunks = []
        passage_count = [ord('a')]  # Passage identifier as a list to handle nested increments
        title_count += 1  # Increment title count

        # Tokenize the section into sentences
        sentences = sent_tokenize(section)
        for sentence in sentences:
            if sentence.startswith("Title:"):
                # New document detected
                identifier = f"{title_count}{''.join(chr(c) for c in passage_count)}"  # Identifier for the title
                chunked_texts = chunk_with_token_limit(sentence, TOKEN_LIMIT, SLIDING_WINDOW_OVERLAP)
                for chunk in chunked_texts:
                    section_chunks.append([identifier, chunk])
                passage_count = [ord('a')]  # Reset passage count for the new document
            else:
                # Sentence under the current document
                identifier = f"{title_count}{''.join(chr(c) for c in passage_count)}"
                chunked_texts = chunk_with_token_limit(sentence, TOKEN_LIMIT, SLIDING_WINDOW_OVERLAP)
                #print("chunked_texts>>>>process_document_with_identifiers>>>>> "+ "".join(chunked_texts))
                for chunk in chunked_texts:
                    section_chunks.append([identifier, chunk])

                # Increment passage_count intelligently
                i = len(passage_count) - 1
                while i >= 0:
                    passage_count[i] += 1
                    if passage_count[i] > ord('z'):
                        passage_count[i] = ord('a')
                        if i == 0:
                            passage_count.insert(0, ord('a'))  # Add a new character to the identifier
                        i -= 1
                    else:
                        break


        print("section_chunks>>>>>>>",section_chunks)
        processed_data.append(section_chunks)

    return processed_data

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/396 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/712k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/125 [00:00<?, ?B/s]

# **Small to Big Chunking**

In [None]:
def small_to_big_chunking(text, token_limit):
    sentences = sent_tokenize(text)  # Start small: split into sentences
    chunks = []
    current_chunk = []
    current_tokens = 0

    for sentence in sentences:
        sentence_tokens = tokenizer.tokenize(sentence)
        num_tokens = len(sentence_tokens)

        # If adding this sentence exceeds the token limit, finalize the current chunk
        if current_tokens + num_tokens > token_limit:
            chunks.append(" ".join(current_chunk))
            current_chunk = []  # Start a new chunk
            current_tokens = 0

        # Add sentence to the current chunk
        current_chunk.append(sentence)
        current_tokens += num_tokens

    # Add the last chunk
    if current_chunk:
        chunks.append(" ".join(current_chunk))

    return chunks


# **Load and view first 5 rows of the Dataset**

In [6]:
# Print the top 5 rows of the dataset, for debugging purpose, we only generate embeddings for these 5 rows
from datasets import load_dataset
datasets = ['techqa']
data = load_dataset("rungalileo/ragbench", datasets[0], split="train")
top_5_rows = data.select(range(2))

for i, row in enumerate(top_5_rows):
    print(f"Row {i + 1}:")
    for field, value in row.items():
        print(f"  {field}: {value}")
    print()

README.md:   0%|          | 0.00/24.7k [00:00<?, ?B/s]

train-00000-of-00001.parquet:   0%|          | 0.00/22.5M [00:00<?, ?B/s]

validation-00000-of-00001.parquet:   0%|          | 0.00/5.40M [00:00<?, ?B/s]

test-00000-of-00001.parquet:   0%|          | 0.00/5.35M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/1192 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/304 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/314 [00:00<?, ? examples/s]

Row 1:
  id: techqa_TRAIN_Q337
  question: Why does the other instance of my multi-instance qmgr seem to hang after a failover? Queue manager will not start after failover.
  documents: ['HL083112 mqlpgrlg ZX000001 ExecCtrlrMain lpiRC_LOG_NOT_AVAILABLE mscs TECHNOTE (TROUBLESHOOTING)\n\nPROBLEM(ABSTRACT)\n You attempt to failover from the primary to secondary node under MSCS. Your WebSphere MQ queue manager fails to come up on the secondary node, and errors are generated. \n\nSYMPTOM\nThe sequence seen in the FDC files show:\n\n\nProbe Id :- HL083112 \nComponent :- mqlpgrlg \nProcess Name :- D:\\Programs\\MQSeries\\bin\\amqzxma0.exe \nMajor Errorcode :- hrcE_MQLO_UNEXPECTED_OS_ERROR \nMQM Function Stack\nkpiStartup\napiStartup\nalmPerformReDoPass\nhlgScanLogBegin\nmqlpgrlg\nxcsFFST\n\n\nProbe Id :- ZX000001\nComponent :- ExecCtrlrMain \nProcess Name :- D:\\Programs\\MQSeries\\bin\\amqzxma0.exe \nMajor Errorcode :- xecF_E_UNEXPECTED_RC \nMinor Errorcode :- lpiRC_LOG_NOT_AVAILABLE \nProb

## **Generate Embeddings**

In [7]:
from datasets import load_dataset
from sentence_transformers import SentenceTransformer
from tqdm import tqdm

#datasets = ['pubmedqa','tatqa', 'techqa'] # d4

datasets = ['techqa']

# Initialize storage for documents, IDs, and metadata
all_documents = []
all_ids = []
all_metadatas = []

# Process each dataset
doc_idx = 0  # Global document index for unique IDs
for dataset in datasets:
    data = load_dataset("rungalileo/ragbench", dataset, split="train")
    #only select first 5 records for debugging duplicate records. **PLEASE REMOVE THIS AFTER DEBUGGING**
    data = data.select(range(15))
    for idx, row in tqdm(enumerate(data), desc=f"Processing {dataset}"):
        # Extract document text
        doc_text = row.get('documents', '')

        # Skip if no documents found
        if not doc_text:
            continue

        # Process the document
        processed_output = process_document_with_identifiers(doc_text)
        added_item_idxs = set()

        # Populate the lists
        for section_idx, section in enumerate(processed_output):
            for item_idx, (prefix, content) in enumerate(section):
                # Skip if this item_idx has already been processed
                if item_idx in added_item_idxs:
                    continue

                # Add the item_idx to the set to track it
                added_item_idxs.add(item_idx)

                # Add the document
                document = f"[{prefix}] {content}"
                all_documents.append(document)

                # Construct a globally unique ID
                doc_id = f"{dataset}_{doc_idx}_{section_idx}_{item_idx}"
                all_ids.append(doc_id)

                # Construct metadata
                metadata = {
                    "dataset": dataset,
                    "global_index": doc_idx,
                    "section_index": section_idx,
                    "item_index": item_idx,
                    "prefix": prefix,
                    "type": "Title" if prefix.endswith("a") else "Passage",
                }
                all_metadatas.append(metadata)

        doc_idx += 1  # Increment global document index

# Step 4: Generate Embeddings
#embedder = SentenceTransformer("all-MiniLM-L6-v2")  # Pretrained sentence transformer
embedder = SentenceTransformer(embedding_model)  # Pretrained sentence transformer
batch_size = 2500  # Adjust based on available memory

# Generate embeddings in batches
all_embeddings = []
for i in tqdm(range(0, len(all_documents), batch_size), desc="Generating embeddings"):
    batch_docs = all_documents[i:i + batch_size]
    batch_embeddings = embedder.encode(batch_docs, show_progress_bar=True)
    all_embeddings.extend(batch_embeddings)

Processing techqa: 1it [00:00,  3.28it/s]

section_chunks>>>>>>> [['0a', 'HL083112 mqlpgrlg ZX000001 ExecCtrlrMain lpiRC_LOG_NOT_AVAILABLE mscs TECHNOTE (TROUBLESHOOTING)\n\nPROBLEM(ABSTRACT)\n You attempt to failover from the primary to secondary node under MSCS.'], ['0b', 'Your WebSphere MQ queue manager fails to come up on the secondary node, and errors are generated.'], ['0c', 'SYMPTOM\nThe sequence seen in the FDC files show:\n\n\nProbe Id :- HL083112 \nComponent :- mqlpgrlg \nProcess Name :- D:\\Programs\\MQSeries\\bin\\amqzxma0.exe \nMajor Errorcode :- hrcE_MQLO_UNEXPECTED_OS_ERROR \nMQM Function Stack\nkpiStartup\napiStartup\nalmPerformReDoPass\nhlgScanLogBegin\nmqlpgrlg\nxcsFFST\n\n\nProbe Id :- ZX000001\nComponent :- ExecCtrlrMain \nProcess Name :- D:\\Programs\\MQSeries\\bin\\amqzxma0.exe \nMajor Errorcode :- xecF_E_UNEXPECTED_RC \nMinor Errorcode :- lpiRC_LOG_NOT_AVAILABLE \nProbe Description :- AMQ6118: An internal WebSphere MQ error has occurred \n(7017) \nArith1 :- 28695 7017 \nMQM Function Stack\nxcsFFST\n\n\nCA

Processing techqa: 3it [00:00,  7.13it/s]

section_chunks>>>>>>> [['0a', ' SUBSCRIBE TO THIS APAR\nBy subscribing, you receive periodic emails alerting you to the status of the APAR, along with a link to the fix after it becomes available.'], ['0b', 'You can track this item individually or track all items by product.'], ['0c', 'Notify me when this APAR changes.'], ['0d', 'Notify me when an APAR for this component changes.'], ['0e', 'APAR STATUS\n * CLOSED AS PROGRAM ERROR.'], ['0f', 'ERROR DESCRIPTION\n *  DASH menu hiding behind AEL in IE 11 Windows Server 2012 R2,\n   Windows 2010\n   \n   Reproduced from DASH 3.1.1.x, DASH 3.1.2 including CP5 (WebGUI\n   8.1 FP4).'], ['0g', 'to DASH 3.1.3.0 but NOT on DASH 3.1.0.3 (WebGUI\n   8.1.0)\n   \n   Per WebGUI L3, there are no changes in AEL html from 8.1 to 8.1\n   FP4.'], ['0h', 'LOCAL FIX\n *  No workaround found.'], ['0i', 'PROBLEM SUMMARY\n *  Please follow below link\n   http://www-01.ibm.com/support/docview.wss?uid=swg22011801 [http://www-01.ibm.com/support/docview.wss?uid=sw

Token indices sequence length is longer than the specified maximum sequence length for this model (839 > 512). Running this sequence through the model will result in indexing errors
Processing techqa: 4it [00:00,  6.71it/s]

 [['1a', ' NEWS\n\nABSTRACT\n Information on support for the WebSphere Cast Iron Cloud integration products including the Cloud, Appliance, and Virtual Appliance.'], ['1b', 'CONTENT\nWelcome to WebSphere Cast Iron Support.'], ['1c', 'This technote provides links to more information that can help you avoid opening service requests with IBM Support.'], ['1d', 'Use the documents referenced here to answer questions about your appliance or virtual appliance and shorten the time to resolution if you need to open a service request.'], ['1e', 'Useful links: \n\n * Use this Cast Iron Support site [http://www-01.ibm.com/software/info/cloud-integration/castiron/support/] to access self-help tools and resources.'], ['1f', '* Ask questions, find answers and engage with the community via the dW Answers forum.'], ['1g', '[https://developer.ibm.com/answers/topics/castiron/] \n * Use this link [http://www.ibm.com/developerworks/views/websphere/libraryview.jsp?site_id=1&contentarea_by=WebSphere&sort_by=

Processing techqa: 6it [00:00,  6.77it/s]

section_chunks>>>>>>> [['0a', ' SECURITY BULLETIN\n\nSUMMARY\n IBM Business Process Manager is vulnerable to stored cross-site scripting, caused by improper validation of user-supplied input.'], ['0b', "A remote attacker could exploit this vulnerability using a specially-crafted URL to execute script in a victim's Web browser within the security context of the hosting Web site, once the URL is clicked."], ['0c', "An attacker could use this vulnerability to steal the victim's cookie-based authentication credentials."], ['0d', 'VULNERABILITY DETAILS\nCVEID:CVE-2015-0156 [http://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2015-0156]\n\nDESCRIPTION:IBM Business Process Manager is vulnerable to stored cross-site scripting, caused by improper validation of user-supplied input.'], ['0e', "A remote attacker could exploit this vulnerability using a specially-crafted URL to execute script in a victim's Web browser within the security context of the hosting Web site, once the URL is clicked."], ['

Processing techqa: 7it [00:01,  6.06it/s]

section_chunks>>>>>>> [['1a', ' PRODUCT DOCUMENTATION\n\nABSTRACT\n This document lists APARs and other key customer issues that are fixed in IBM® Data Studio, Version 3.1.1.'], ['1b', 'CONTENT\n \n\nFor a list of fixed APARs that are included in InfoSphere Optim Query Workload Tuner Version 3.1.1, see the following techdoc: Problems Fixed in the IBM Data Studio Clients and InfoSphere Optim Query Workload Tuner, Version 3.1.1 [http://www.ibm.com/support/docview.wss?uid=swg27024835].'], ['1c', 'For a list of fixed APARs that are included in InfoSphere Data Architect Version 8.1, see the following techdoc: Fixed APAR list for InfoSphere Data Architect, version 8.1 [http://www.ibm.com/support/docview.wss?uid=swg27023730].'], ['1d', 'The following table lists the APARs that are included IBM Data Studio Version 3.1.1:\n\nFixed APARs APAR Description IC74353 OPTIM DATABASE ADMINISTRATOR 2.2.3 DOES NOT GENERATE THE DROP INDEX STATEMENT WHEN TRYING TO ALTER IT.'], ['1e', "IC76777 WHEN HAVING -

Processing techqa: 8it [00:01,  5.96it/s]

section_chunks>>>>>>> [['1a', ' DOWNLOADABLE FILES\n\nABSTRACT\n This document describes how to download the JD Edwards data growth and application retirement portion of the IBM® InfoSphere® Optim™ Enterprise Edition for Oracle Applications 9.1.'], ['1b', 'DOWNLOAD DESCRIPTION\nYou can download the JD Edwards portion of the 9.1 Enterprise Edition for Oracle Applications using the Passport Advantage website.'], ['1c', 'The IBM InfoSphere Optim Data Growth Solution for JD Edwards EnterpriseOne supports the removal of older, infrequently accessed, data from your EnterpriseOne database and storing it where business analysts can retrieve it when necessary.'], ['1d', 'If desired, the expunged data can be accessed in realtime using the EnterpriseOne interface.'], ['1e', 'This Native Application Access allows JD Edwards users to view current and archived data seamlessly, while simultaneously allowing you to reduce the size of your production database, resulting in better performance for your d

Processing techqa: 10it [00:01,  6.82it/s]

 [['0a', 'Libery Java ; 5655S9700 R660 660 R600 600 HCI6600 R670 670 R700 700 HCI6700 5655-S97 5655S97 ; 5655Y0400 R680 680 R800 800 HCI6800 R690 690 R900 900 HCI6900 R700 700 R000 000 HCI7000 R710 710 R100 100 HCI7100 5655Y0401 5655-Y04 5655Y04 KIXINFO PRODUCT DOCUMENTATION\n\nABSTRACT\n CICS TS support is delivering the information in this document to help customers plan for updates to CICS Transaction Server for z/OS (CICS TS) and related software components such as WebSphere Liberty and CICS Explorer.'], ['0b', 'Installing APAR fixes (PTFs) using Recommended Service Upgrades (RSUs) as early as possible can help avoid problems that could result in a service call, and as long as you test appropriately, help reduce risks to your business.'], ['0c', 'CONTENT\n\n\n\n[/support/docview.wss?uid=swg27048530&amp;aid=1]Embedded components in CICS TS [http://www-01.ibm.com/support/docview.wss?uid=swg27012749#delivering]\n[/support/docview.wss?uid=swg27048530&amp;aid=2]Update recommendations\n[

Processing techqa: 11it [00:01,  6.61it/s]

section_chunks>>>>>>> [['1a', ' SECURITY BULLETIN\n\nSUMMARY\n There are vulnerabilities in IBM® Runtime Environment Java™ Technology Edition Version 7 that is used by IBM Planning Analytics Express and IBM Cognos Express.'], ['1b', 'These issues were disclosed as part of the IBM Java SDK updates in Oct 2016 and Jan 2017.'], ['1c', 'OpenSSL vulnerabilities were disclosed by the OpenSSL Project.'], ['1d', 'OpenSSL is used by IBM Planning Analytics Express and IBM Cognos Express.'], ['1e', 'The applicable CVEs have been addressed.'], ['1f', 'VULNERABILITY DETAILS\nCVEID: CVE-2016-2183 [http://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2016-2183]\nDESCRIPTION: OpenSSL could allow a remote attacker to obtain sensitive information, caused by an error in the DES/3DES cipher, used as a part of the SSL/TLS protocol.'], ['1g', 'By capturing large amounts of encrypted traffic between the SSL/TLS server and the client, a remote attacker able to conduct a man-in-the-middle attack could exploit thi

Processing techqa: 13it [00:01,  7.72it/s]

 [['1a', 'arssyscr; Report Distribution TECHNOTE (TROUBLESHOOTING)\n\nPROBLEM(ABSTRACT)\n Running arssyscr -I instance_name -r during upgrade without RDF installed or enabled can lead to errors during load.'], ['1b', 'SYMPTOM\nYou will see errors similar to (depending on database brand) \nDB Error: ORA-00942: table or view does not exist -- SQLSTATE=, SQLCODE=942, File=arsrddb.c, \nLine=951 in your systemlog\n\n\nCAUSE\nIssuing the arssyscr -I (instance name) -r command will trigger arssockd to check for the ARSDBBUNDT (RDF tables), which do not exist.'], ['1c', 'DIAGNOSING THE PROBLEM\nYou will see the "table or view does not exist" errors in your Systemlog\n\nRESOLVING THE PROBLEM\nYou will need to rename the Report Distribution tables and delete the Application, Application Group and Folders by performing the steps below:\n\nYou are going to want to change the Report Distribution table name: \n\nThis test is done on DB2, but It should be the same for (Oracle or SQL \nServer).'], ['1

Processing techqa: 15it [00:02,  7.01it/s]

section_chunks>>>>>>> [['2a', "WOMaterialStatusUpdateCronTask; TPAEWORK; TPAEINVENTORY; back order TECHNOTE (TROUBLESHOOTING)\n\nPROBLEM(ABSTRACT)\n The Work Order's status does not change if the reserved Item is issued."], ['2b', 'SYMPTOM\nA Work Order\'s status is supposed to change from "WMATL" back to "APPR" when materials have been issued to the Work Order.'], ['2c', 'RESOLVING THE PROBLEM\nSteps: \n\n1.'], ['2d', 'Go to the Organizations application.'], ['2e', 'From the More Actions menu, select the "Work Order Options" and select "Other Organization options".'], ['2f', '[/support/docview.wss?uid=swg21643029&aid=1] [/support/docview.wss?uid=swg21643029&aid=1]\nDe-select the "Ignore storeroom availability for work order status?"'], ['2g', 'checkbox.'], ['2h', '[/support/docview.wss?uid=swg21643029&amp;aid=2]\n\n\nClick on the "More Action" menu, select "Inventory Options" and click on the "Inventory Defaults" option.'], ['2i', '[/support/docview.wss?uid=swg21643029&amp;aid=3]\nSel




modules.json:   0%|          | 0.00/349 [00:00<?, ?B/s]

config_sentence_transformers.json:   0%|          | 0.00/123 [00:00<?, ?B/s]

README.md:   0%|          | 0.00/28.8k [00:00<?, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/731 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/438M [00:00<?, ?B/s]

1_Pooling/config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

Generating embeddings:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/16 [00:00<?, ?it/s]

Generating embeddings: 100%|██████████| 1/1 [02:51<00:00, 171.40s/it]


## **Store Embeddings into Milvus**

In [8]:
!pip install pymilvus pymilvus[model]

Collecting pymilvus
  Downloading pymilvus-2.5.4-py3-none-any.whl.metadata (5.7 kB)
Collecting grpcio<=1.67.1,>=1.49.1 (from pymilvus)
  Downloading grpcio-1.67.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (3.9 kB)
Collecting ujson>=2.0.0 (from pymilvus)
  Downloading ujson-5.10.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (9.3 kB)
Collecting milvus-lite>=2.4.0 (from pymilvus)
  Downloading milvus_lite-2.4.11-py3-none-manylinux2014_x86_64.whl.metadata (9.2 kB)
Collecting milvus-model>=0.1.0 (from pymilvus[model])
  Downloading milvus_model-0.2.12-py3-none-any.whl.metadata (1.6 kB)
Collecting onnxruntime (from milvus-model>=0.1.0->pymilvus[model])
  Downloading onnxruntime-1.20.1-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (4.5 kB)
Collecting coloredlogs (from onnxruntime->milvus-model>=0.1.0->pymilvus[model])
  Downloading coloredlogs-15.0.1-py2.py3-none-any.whl.metadata (12 kB)
Collecting humanfriendly>=9.1 (fr

# **Functions to check uniqueness of data being inserted to db**

In [9]:
import hashlib

# Function to generate a hash based on content and key metadata
def generate_hash(content, metadata):
    """Generate a unique hash for the document content and key metadata."""
    key_fields = f"{content}|{metadata.get('item_index')}|{metadata.get('prefix')}"
    return hashlib.md5(key_fields.encode('utf-8')).hexdigest()

# Function to retrieve existing hashes from the database
def get_existing_hashes(collection):
    """Retrieve all existing hashes (IDs) currently in the database."""
    all_records = collection.get(include=["documents", "metadatas"])  # Fetch documents and metadata
    existing_hashes = set()
    for doc, metadata in zip(all_records["documents"], all_records["metadatas"]):
        doc_hash = generate_hash(doc, metadata)
        existing_hashes.add(doc_hash)
    return existing_hashes

# Function to retrieve existing hashes from the database
def get_existing_hashes_milvus(all_records):
    """Retrieve all existing hashes (IDs) currently in the database."""
    existing_hashes = set()
    print(f"all records >>> {len(all_records)}")
    if all_records == None or len(all_records) == 0:
        return existing_hashes

    for record in all_records:
        doc = record.get("documents")
        metadata = record.get("metadata")
        doc_hash = generate_hash(doc, metadata)
        existing_hashes.add(doc_hash)
    return existing_hashes

In [10]:
import numpy as np
from pymilvus import connections
from pymilvus import FieldSchema, CollectionSchema, DataType, Collection
from pymilvus import MilvusClient
from pymilvus import utility

class VectorDataStore:
    db_url = "http://localhost:19530"

    #description = f"collection created for {self.name}"

    def __init__(self, path="/content/ragbench.db"):
        self.client = MilvusClient(path)


    def get_or_create_collection(self, name, vec_dim=128):
        try:
            self.get_collection(name)
        except:
            print(f"Collection {name} doesn't exist. Creating...")
            self.create_collection(name, vec_dim)


    def create_collection(self, name, vec_dim=128):
        if self.client.has_collection(name):
            self.default_collection_name = name

        self.description = f"collection to store {name}"

        index_params = self.client.prepare_index_params()
        index_params.add_index(
            field_name="embedding",
            index_type="AUTOINDEX",
            metric_type="COSINE"
        )
        index_params.add_index(
            field_name="sparse",
            index_type="SPARSE_INVERTED_INDEX",
            metric_type="IP"
        )
        schema = self.client.create_schema(
            auto_id=False,
            enable_dynamic_fields=True,
        )
        schema.add_field(field_name="pk", datatype=DataType.VARCHAR, max_length=64, is_primary=True)
        schema.add_field(field_name="metadata", datatype=DataType.JSON)
        schema.add_field(field_name="documents", datatype=DataType.VARCHAR, max_length=512)
        schema.add_field(field_name="sparse", datatype=DataType.SPARSE_FLOAT_VECTOR)
        schema.add_field(field_name="embedding", datatype=DataType.FLOAT_VECTOR, dim=vec_dim)

        collection = self.client.create_collection(collection_name=name,
                                       schema=schema,
                                       index_params=index_params)
        self.current_collection = collection
        return collection


    def get_collection(self, name):
        if not self.client.has_collection(name):
            raise ValueError(f"Collection '{name}' does not exist.")
        self.current_collection = Collection(name)
        return self.current_collection

    def get_all_records(self, collection):
        all_records = self.client.query(
            collection_name=collection,
            filter=None,
            output_fields=["documents", "metadata"],
            limit=10000
        )
        if all_records == None:
            all_records = []

        return all_records

    def has_entities(self, name):
        if not self.client.has_collection(name):
            raise ValueError(f"Collection '{name}' does not exists.")
        self.default_collection = name
        collection_stats = self.client.get_collection_stats(collection_name)
        count = collection_stats.get("row_count", 0)  # Retrieve the number of entities
        return count

    def insert(self, collection_name: str, metadata: list[dict[str, any]],
                documents: list[str], sparse_embs: np.ndarray, embeddings: np.ndarray, ids: list[int]):

        if not self.client.has_collection(collection_name):
            raise ValueError(f"Collection '{collection_name}' does not exist. Create it first.")

        if len(metadata) != len(embeddings) != len(documents) != len(ids):
           raise ValueError("Metadata, documnets, ids and embeddings must have the same length.")

        data = []
        for meta, doc, sp_embs, emb, id in zip(metadata, documents, sparse_embs, embeddings, ids):
          datum = {
              "pk": id,
              "metadata": meta,
              "documents": doc,
              "sparse": sp_embs,
              "embedding": emb.tolist(),
          }
          data.append(datum)

        self.client.insert(collection_name, data)
        print(f"Inserted {len(metadata)} records into collection '{collection_name}'.")

    def drop_collection(self, collection_name: str):
        if not self.client.has_collection(collection_name):
            raise ValueError(f"Collection '{collection_name}' does not exist.")
        self.client.drop_collection(collection_name)
        print(f"Dropped collection '{collection_name}'.")

    def delete_all(self, collection_name: str):
        if not self.client.has_collection(collection_name):
            raise ValueError(f"Collection '{collection_name}' does not exist.")
        self.client.delete(collection_name, expr="pk >= 0")
        self.client.flush([collection_name])

    def search(self, query_embedding: np.ndarray, top_k: int = 10) -> list[dict[str, any]]:
        """
        Search across all collections for the top-k closest embeddings.
        :param query_embedding: The embedding vector to search for.
        :param top_k: Number of top results to retrieve.
        :return: A list of dictionaries containing collection name, id, metadata, and distance.
        """
        results = []
        collections = self.client.list_collections()
        # collections = ["ragbench_collection_techqa_v09"]
        start_time = time.time()
        for collection_name in collections:
            if not self.client.has_collection(collection_name):
                continue

            # Set params to COSINE to match chromadb
            search_params = {"metric_type": "COSINE", "params": {"ef": 128}}

            search_results = self.client.search(
                collection_name=collection_name,
                data=[query_embedding],
                anns_field="embedding",
                search_params=search_params,
                limit=top_k,
                output_fields=["metadata", "documents"]
            )

            print(f"search results size : {len(search_results)}")

            for hits in search_results:
                for hit in hits:
                    print(f"Collection: {collection_name}, data: {str(hit)}")
                    results.append({
                        "collection": collection_name,
                        "id": hit["id"],
                        "metadata": hit["entity"]["metadata"],
                        "distance": hit["distance"],
                        "documents": hit["entity"]["documents"]
                      })

        results = sorted(results, key=lambda x: x["distance"])[:top_k]
        end_time = time.time()
        print(f"Search completed. Found {len(results)} results. in {end_time - start_time} secs")
        return results

    def sparse_search(self, query_embedding: np.ndarray, top_k : int=10)-> list[dict[str, any]] :
        results = []
        collections = self.client.list_collections()
        start_time = time.time()
        for collection_name in collections:
            if not self.client.has_collection(collection_name):
                continue

            # Set params to COSINE to match chromadb
            search_params = {"metric_type": "IP", "params": {"ef": 128}}

            search_results = self.client.search(
                collection_name=collection_name,
                data=[query_embedding],
                anns_field="sparse",
                search_params=search_params,
                limit=top_k,
                output_fields=["metadata", "documents"]
            )

            print(f"search results size : {len(search_results)}")

            for hits in search_results:
                for hit in hits:
                    print(f"Collection: {collection_name}, data: {str(hit)}")
                    results.append({
                        "collection": collection_name,
                        "id": hit["id"],
                        "metadata": hit["entity"]["metadata"],
                        "distance": hit["distance"],
                        "documents": hit["entity"]["documents"]
                      })

        results = sorted(results, key=lambda x: x["distance"])[:top_k]
        end_time = time.time()
        print(f"Search completed. Found {len(results)} results. in {end_time - start_time} secs")
        return results

    def hybrid_search(self, sparse_query_embedding: np.ndarray, dense_query_embedding: np.ndarray, top_k : int=10, alpha=0.3)-> list[dict[str, any]] :
        results = []
        collections = self.client.list_collections()
        start_time = time.time()
        sparse_results = self.sparse_search(sparse_query_embedding, top_k)
        n = int(len(sparse_results) * alpha)
        alpha_sparse_results = sparse_results[:n]
        dense_results = self.search(dense_query_embedding, top_k)
        #'results = sorted(results, key=lambda x: x["distance"])[:top_k]
        results = dense_results + alpha_sparse_results
        end_time = time.time()
        print(f"Hybrid Search completed. Found {len(results)} results. in {end_time - start_time} secs")
        return results

    def extract_documents(self, search_results: list[dict[str, any]]) -> list[np.ndarray]:
      """
      Extract embedding values from search results.
      :param search_results: List of dictionaries containing search results.
      :return: List of embedding vectors as NumPy arrays.
      """
      return [result["documents"] for result in search_results if "documents" in result]

## **Instantiate Milvus and add data to milvus db**

In [11]:
collection_name = "ragbench_collection_techqa_v01"

In [12]:
from sentence_transformers import SentenceTransformer
from nltk.tokenize import sent_tokenize
from transformers import AutoTokenizer

# Initialize the tokenizer
tokenizer = AutoTokenizer.from_pretrained(embedding_model)
embedder = SentenceTransformer(embedding_model)  # Pretrained sentence transformer

In [13]:
datastor = VectorDataStore()

insert_data = False
store_client = "Milvus"
num_records = 0

vector_dim = embedder.get_sentence_embedding_dimension()

datastor.get_or_create_collection(collection_name, vector_dim)
num_records = datastor.has_entities(collection_name)
if num_records == 0:
    insert_data = True

print(f"count >>> {num_records} insert_data >>> {insert_data}")

Collection ragbench_collection_techqa_v01 doesn't exist. Creating...
count >>> 0 insert_data >>> True


# **Store Embeddings into a chroma DB and Milvus**

In [None]:
!pip install chromadb

Collecting chromadb
  Downloading chromadb-0.6.3-py3-none-any.whl.metadata (6.8 kB)
Collecting build>=1.0.3 (from chromadb)
  Downloading build-1.2.2.post1-py3-none-any.whl.metadata (6.5 kB)
Collecting chroma-hnswlib==0.7.6 (from chromadb)
  Downloading chroma_hnswlib-0.7.6-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (252 bytes)
Collecting fastapi>=0.95.2 (from chromadb)
  Downloading fastapi-0.115.6-py3-none-any.whl.metadata (27 kB)
Collecting uvicorn>=0.18.3 (from uvicorn[standard]>=0.18.3->chromadb)
  Downloading uvicorn-0.34.0-py3-none-any.whl.metadata (6.5 kB)
Collecting posthog>=2.4.0 (from chromadb)
  Downloading posthog-3.8.4-py2.py3-none-any.whl.metadata (2.8 kB)
Collecting opentelemetry-exporter-otlp-proto-grpc>=1.2.0 (from chromadb)
  Downloading opentelemetry_exporter_otlp_proto_grpc-1.29.0-py3-none-any.whl.metadata (2.2 kB)
Collecting opentelemetry-instrumentation-fastapi>=0.41b0 (from chromadb)
  Downloading opentelemetry_instrumentation_fastapi-0.

In [None]:
import chromadb

client = chromadb.PersistentClient(path="./content/rag_chroma_db_d4")

# chromba_db_collection_name = "ragbench_chroma_db_collection_techqa_v21"
chroma_db_collection = client.get_or_create_collection(name=collection_name)

count = chroma_db_collection.count()
print("count >>>", count)
store_client = "Chromadb"

if count > 0:
    insert_data = False
else:
    insert_data = True


count >>> 0


## **Insert data into database**
### the store_client value stores if the code should call milvus or chromadb code
#### if insert_data is set to true it means there is no data in the collection


In [None]:
# Adding data to ChromaDB with enhanced duplicate check
existing_hashes = get_existing_hashes(chroma_db_collection)

for i in tqdm(range(0, len(all_documents), batch_size), desc="Adding data to DB"):
    batch_embeddings = all_embeddings[i:i + batch_size]
    batch_metadatas = all_metadatas[i:i + batch_size]
    batch_documents = all_documents[i:i + batch_size]
    batch_ids = []

    # Generate hashes for each document in the batch
    for doc, metadata in zip(batch_documents, batch_metadatas):
        doc_hash = generate_hash(doc, metadata)
        if doc_hash not in existing_hashes:
            batch_ids.append(doc_hash)
            existing_hashes.add(doc_hash)  # Add hash to local set to avoid duplicates in the same batch
        else:
            print(f"Skipping duplicate document: {doc[:50]}...")  # Print a preview of the duplicate doc

    # Add non-duplicate documents to the database
    if batch_ids:  # Ensure there are non-duplicate documents to add
        print(f"Adding {len(batch_ids)} documents to the database... {store_client} with {insert_data}")
        if store_client == "Chromadb" and insert_data:
            chroma_db_collection.add(
                embeddings=batch_embeddings[:len(batch_ids)],  # Trim embeddings to match batch_ids
                metadatas=batch_metadatas[:len(batch_ids)],    # Trim metadatas to match batch_ids
                documents=batch_documents[:len(batch_ids)],    # Trim documents to match batch_ids
                ids=batch_ids
            )


NameError: name 'chroma_db_collection' is not defined

## **Hybrid search implementation** ##
### Sparse vector generation ##

In [20]:
from rank_bm25 import BM25Okapi
from nltk.tokenize import word_tokenize
import scipy.sparse as sp

# Tokenize corpus and prepare BM25 encoder
def prepare_bm25_encoder(texts):
    tokenized_corpus = [word_tokenize(text.lower()) for text in texts]
    bm25_encoder = BM25Okapi(tokenized_corpus)
    return bm25_encoder

def generate_sparse_vector_bm25(query, bm25_encoder):
    tokenized_query = word_tokenize(query.lower())
    scores = bm25_encoder.get_scores(tokenized_query)
    # Convert scores to CSR format
    sparse_vector = sp.csr_matrix(scores)
    return sparse_vector

In [21]:
bm25_encoder = prepare_bm25_encoder(all_documents)

sparse_vectors = [generate_sparse_vector_bm25(text, bm25_encoder) for text in all_documents]

In [22]:
len(sparse_vectors)

507

# **Insert data to milvus**

In [23]:
store_client = "Milvus"

In [24]:
# Adding data to milvus with enhanced duplicate check
all_recs = datastor.get_all_records(collection_name)
#print(f"sample: {str(all_recs[0])}")
existing_hashes = get_existing_hashes_milvus(all_recs)

for i in tqdm(range(0, len(all_documents), batch_size), desc="Adding data to DB"):
    batch_embeddings = all_embeddings[i:i + batch_size]
    batch_sparse_embs = sparse_vectors[i:i + batch_size]
    batch_metadatas = all_metadatas[i:i + batch_size]
    batch_documents = all_documents[i:i + batch_size]
    batch_ids = []

    # Generate hashes for each document in the batch
    for doc, metadata in zip(batch_documents, batch_metadatas):
        doc_hash = generate_hash(doc, metadata)
        if doc_hash not in existing_hashes:
            batch_ids.append(doc_hash)
            existing_hashes.add(doc_hash)  # Add hash to local set to avoid duplicates in the same batch
        else:
            print(f"Skipping duplicate document: {doc[:15]}...")  # Print a preview of the duplicate doc

    # Add non-duplicate documents to the database
    if batch_ids:  # Ensure there are non-duplicate documents to add
        # Add the batch to the Milvus collection
        if store_client == "Milvus" and insert_data:
            datastor.insert(collection_name,
                metadata=batch_metadatas,
                documents=batch_documents,
                sparse_embs = np.array(batch_sparse_embs),
                embeddings=np.array(batch_embeddings),
                ids=batch_ids
            )



print(f"total data inserted into {collection_name} : {datastor.has_entities(collection_name)}")

all records >>> 507


Adding data to DB: 100%|██████████| 1/1 [00:00<00:00, 23.64it/s]

Skipping duplicate document: [0a] HL083112 m...
Skipping duplicate document: [0b] Your WebSp...
Skipping duplicate document: [0c] SYMPTOM
Th...
Skipping duplicate document: [0d] RESOLVING ...
Skipping duplicate document: [0e] PRODUCT AL...
Skipping duplicate document: [1f] ERROR DESC...
Skipping duplicate document: [1g] In this ca...
Skipping duplicate document: [1h] If the mod...
Skipping duplicate document: [1i] LOCAL FIX
...
Skipping duplicate document: [1j] PROBLEM SU...
Skipping duplicate document: [1k] *
   *****...
Skipping duplicate document: [1l] *
   *****...
Skipping duplicate document: [2m] 13366 The ...
Skipping duplicate document: [2n] If no inst...
Skipping duplicate document: [2o] 14347 Timi...
Skipping duplicate document: [2p] 14487 Aben...
Skipping duplicate document: [2q] 14591Runni...
Skipping duplicate document: [2r] 14820 Pack...
Skipping duplicate document: [2s] However, o...
Skipping duplicate document: [2t] For cases ...
Skipping duplicate document: [2u] 15149 




## **Verifying retrival logic for the Relevant documents**

In [19]:
# Retrieve from Chroma DB
import time

question ="Why does the other instance of my multi-instance qmgr seem to hang after a failover? Queue manager will not start after failover"
query_embedding = embedder.encode(question).tolist()

# Search for relevant chunks in the vector database
# Retrieve from ChromaDB
start_time = time.time()
results_chroma_db = chroma_db_collection.query(query_embeddings=[query_embedding], n_results=5)
end_time = time.time()
print(f"ChromaDB - Search completed. Found {len(results_chroma_db)} results. in {end_time - start_time} secs")

for idx, (doc, doc_id, metadata) in enumerate(zip(results_chroma_db["documents"][0], results_chroma_db["ids"][0], results_chroma_db["metadatas"][0])):
    print(f"Result {idx + 1}:")
    print(f"  Document: {doc}")
    print(f"  ID: {doc_id}")
    print(f"  Metadata: {metadata}")
    print()


NameError: name 'chroma_db_collection' is not defined

***Desnse Search from Milvus***

In [25]:
# Dense search, Retrive from milvus
import time

question ="Why does the other instance of my multi-instance qmgr seem to hang after a failover? Queue manager will not start after failover"
query_embedding = embedder.encode(question).tolist()

# Retrieve from Milvus
start_time = time.time()
results_milvus = datastor.search(query_embedding, top_k=5)
end_time = time.time()
print(f"Search completed. Found {len(results_milvus)} results. in {end_time - start_time} secs")

for doc in results_milvus:
    print("Relevant Docs from Milvus:\n", doc['documents'])

search results size : 1
Collection: ragbench_collection_techqa_v01, data: {'id': '3683bbefd68a3161c00c001f3e58b2d2', 'distance': 0.8853264451026917, 'entity': {'documents': '[0b] Your WebSphere MQ queue manager fails to come up on the secondary node, and errors are generated.', 'metadata': {'dataset': 'techqa', 'global_index': 0, 'section_index': 0, 'item_index': 1, 'prefix': '0b', 'type': 'Passage'}}}
Collection: ragbench_collection_techqa_v01, data: {'id': '5bb2a66f88c5070dd0f5d37785e3a748', 'distance': 0.8725427389144897, 'entity': {'documents': '[0d] RESOLVING THE PROBLEM\nRename the file amqalchk.fil, which is found under mq\\qmgrs\\qmgrname\\ on the shared drive (to something like amqalchk.fil_OLD); then restart the queue manager.', 'metadata': {'dataset': 'techqa', 'global_index': 0, 'section_index': 0, 'item_index': 3, 'prefix': '0d', 'type': 'Passage'}}}
Collection: ragbench_collection_techqa_v01, data: {'id': 'c2b5c816355deb079e4290bd60b23860', 'distance': 0.8629571199417114,

***Sparse Search from Milvus***

In [26]:
import time

question ="Why does the other instance of my multi-instance qmgr seem to hang after a failover? Queue manager will not start after failover"
query_embedding = generate_sparse_vector_bm25(question, bm25_encoder)

# Retrieve from Milvus
start_time = time.time()
results_milvus = datastor.sparse_search(query_embedding, top_k=5)
end_time = time.time()
print(f"Search completed. Found {len(results_milvus)} results. in {end_time - start_time} secs")

for doc in results_milvus:
    print("Relevant Docs from Milvus:\n", doc['documents'])

search results size : 1
Collection: ragbench_collection_techqa_v01, data: {'id': '0f6c1e3a33d02e3055f8298ce8d6143f', 'distance': 189942.765625, 'entity': {'documents': '[1ad] LOCAL FIX\n\nPROBLEM SUMMARY\n *  ****************************************************************\n   * USERS AFFECTED:                                              *\n   * P8 FEM users                                                 *\n   ****************************************************************\n   * PROBLEM DESCRIPTION:                                         *\n   * Adding a document in FEM with date/time for midnight and it  *\n   *                                                              *\n   * shows up in FEM with 1 hour ahead                            *\n   ****************************************************************\n   * RECOMMENDATION:                                              *\n   * Upgrade FEM and CE server to P8CE-4.0.1-011                  *\n   ********************************

**Hybrid Search from Milvus**

In [27]:
question ="Why does the other instance of my multi-instance qmgr seem to hang after a failover? Queue manager will not start after failover"
# Generate the dense embedding for the question
dense_query_embedding = embedder.encode(question).tolist()

# Generate the sparse embedding for the question
sparse_query_embedding = generate_sparse_vector_bm25(question, bm25_encoder)

# Retrieve from Milvus, passing both dense and sparse embeddings

# Retrieve from Milvus
start_time = time.time()
results_milvus = datastor.hybrid_search(sparse_query_embedding=sparse_query_embedding, dense_query_embedding=dense_query_embedding, top_k=5)
end_time = time.time()
print(f"Hybrid Search completed. Found {len(results_milvus)} results. in {end_time - start_time} secs")

for doc in results_milvus:
    print("Relevant Docs from Milvus:\n", doc['documents'])

search results size : 1
Collection: ragbench_collection_techqa_v01, data: {'id': '0f6c1e3a33d02e3055f8298ce8d6143f', 'distance': 189942.765625, 'entity': {'documents': '[1ad] LOCAL FIX\n\nPROBLEM SUMMARY\n *  ****************************************************************\n   * USERS AFFECTED:                                              *\n   * P8 FEM users                                                 *\n   ****************************************************************\n   * PROBLEM DESCRIPTION:                                         *\n   * Adding a document in FEM with date/time for midnight and it  *\n   *                                                              *\n   * shows up in FEM with 1 hour ahead                            *\n   ****************************************************************\n   * RECOMMENDATION:                                              *\n   * Upgrade FEM and CE server to P8CE-4.0.1-011                  *\n   ********************************

## **Retrival of Relevant Chunks**

In [37]:
# Function to retrieve relevant chunks
def retrieve_docs_milvus(query, top_k=10):
    # Generate embedding for the query
    # Generate the dense embedding for the question
    start_time = time.time()
    dense_query_embedding = embedder.encode(question).tolist()

    # Generate the sparse embedding for the question
    #sparse_query_embedding = generate_sparse_vector_bm25(question, bm25_encoder)

    # Perform vector search to find relevant chunks
    #results = datastor.extract_documents(datastor.search(query_embedding, top_k))
    results = datastor.search( dense_query_embedding, top_k)
    #results = datastor.hybrid_search(sparse_query_embedding, dense_query_embedding, top_k)
    print(f"results: retrieve_docs_milvus >>>  {results}")

    # HyDE search with pseudo document
    #pseudo_docs = fetch_docs_pseudo(query)
    #print(f"pseudo_docs: retrieve_docs_milvus >>>  {pseudo_docs}")

    # Extract 'documents' field
    documents_list = [item['documents'] for item in results]
    #documents_list += pseudo_docs

    end_time = time.time()
    print(f"Hybrid Search completed. Found {len(results)} results (approx. 2 * {top_k} + alpha * {top_k}). in {end_time - start_time} secs")
    # Print the extracted documents
    print("retrieve_docs_milvus >>> documents_list from Hybrid + HyDE search >>>>", documents_list)

    # Extract the retrieved chunks
    # chunks = documents_list
    # should sort and push context - but later

    return documents_list

In [29]:
# Function to retrieve relevant chunks
def retrieve_docs(query, top_k=5):
    # Generate embedding for the query
    query_embedding = embedder.encode(query).tolist()
    # Perform vector search to find relevant chunks
    results = chroma_db_collection.query(query_embeddings=[query_embedding], n_results=top_k)
    # Extract the retrieved chunks
    chunks = results["documents"]
    # should sort and push context - but later
    return chunks

In [30]:
def retrieve_docs_query(query, top_k=5):
    query_embedding = embedder.encode(query)

    if store_client == 'Milvus':
        results = datastor.search(query_embedding, top_k)
        results = datastor.extract_documents(results)
    elif store_client == "Chromadb":
        results = chroma_db_collection.query(query_embeddings=[query_embedding.tolist()], n_results=top_k)
        results = results['documents']
    return results

## **Query Classification**

In [None]:
!pip install transformers datasets torch



In [None]:
# Get the databricks dolly dataset (as mentioned in  "Searching for Best Practices in Retrieval-Augmented Generation" paper) for training the query classifer.
from datasets import load_dataset

ds = load_dataset("databricks/databricks-dolly-15k")

README.md:   0%|          | 0.00/8.20k [00:00<?, ?B/s]

databricks-dolly-15k.jsonl:   0%|          | 0.00/13.1M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/15011 [00:00<?, ? examples/s]

In [None]:
from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments
from datasets import Dataset
import torch

# Step 1: Load Pre-trained BERT and Tokenizer
model_name = "bert-base-multilingual-cased"
#model_name = "distilbert-base-multilingual-cased"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2)

# Step 2: Prepare the Dataset
# Replace with your labeled data
# Example data: {"query": ["What is COVID19?", "What do the S1 and S2 subunits contain?"], "label": [0, 1]}

data = {
    "query": [
        "What is SQLCODE=-1585",
        "The configuration task database-transfer failed with DB2 SQL Error: SQLCODE=-1585, SQLSTATE=54048 While attempting to run the database-transfer task the following error is logged to the ConfigTrace.log: action-process-constraints: Fri Oct 10 13:20:34 CDT 2014 Target started: action-process-constraints [java] Executing java with empty input string [java] [10/10/14 13:20:35.877 CDT] Attempting to create a new Instance of com.ibm.db2.jcc.DB2Driver [java] [10/10/14 13:20:36.016 CDT] Instance of com.ibm.db2.jcc.DB2Driver created successfully [java] [10/10/14 13:20:36.016 CDT] Attempting to make connection using: jdbc:db2://:60500/:returnAlias=0; :: d2svc :: PASSWORD_REMOVED [java] [10/10/14 13:20:36.954 CDT] Connection successfully made [java] [10/10/14 13:20:37.073 CDT] ERROR: Error occurred gathering data from the source database [java] com.ibm.db2.jcc.am.SqlException: DB2 SQL Error: SQLCODE=-1585, SQLSTATE=54048, SQLERRMC=null, DRIVER=4.18.60 [java] at com.ibm.db2.jcc.am.kd.a(kd.java:752)",
        "What is Websphere Application Server",
        "How do I change from shared to unshared connection? in WAS, how do I change from shared to unshared connection. I am seeing connections max out and take a long time to release.",
        "What is JMS",
        "Why is my MQ Java / JMS application getting 2035 NOT_AUTHORIZED error after upgrade of MQ? Why is my MQ Java / JMS application getting 2035 NOT_AUTHORIZED error after upgrade of MQ?",
        "What is TLS",
        "TLS protocol with ITCAM for Datapower We have a DataPower appliance with TLS security protocol enabled. Can we configure ITCAM for DataPower appliance v7.1 to specifically use the TLS protocol v1.2 (not v1.0)?"

    ],
    "label": [0, 1, 0, 1, 0, 1, 0, 1]  # 0: retrieval not needed, 1: retrieval needed
}

# Convert to Hugging Face Dataset
dataset = Dataset.from_dict(data)

#dataset = ds
# Tokenize the Dataset
def preprocess(data):
    return tokenizer(data["query"], padding="max_length", truncation=True, max_length=128)

encoded_dataset = dataset.map(preprocess, batched=True)
encoded_dataset = encoded_dataset.train_test_split(test_size=0.2)

# Step 3: Define the Training Loop
training_args = TrainingArguments(
    output_dir="./query_classifier_results",
    evaluation_strategy="epoch",
    save_strategy="epoch",
    learning_rate=1e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=10,
    weight_decay=0.01,
    logging_dir="./logs",
    logging_steps=10,
    save_total_limit=2,
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
    report_to="none"
)

# Custom Metric for Accuracy
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = torch.argmax(torch.tensor(logits), dim=-1)
    accuracy = (predictions == torch.tensor(labels)).float().mean().item()
    return {"accuracy": accuracy}

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=encoded_dataset["train"],
    eval_dataset=encoded_dataset["test"],
    tokenizer=tokenizer,
    compute_metrics=compute_metrics
)

# Step 4: Train the Model
trainer.train()

# Step 5: Save the Model
model.save_pretrained("./query_classifier")
tokenizer.save_pretrained("./query_classifier")

# Step 6: Inference Function
def classify_query(query, model_path="./query_classifier"):
    # Load saved model and tokenizer
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    model = AutoModelForSequenceClassification.from_pretrained(model_path)

    # Tokenize the input query
    inputs = tokenizer(query, return_tensors="pt", padding="max_length", truncation=True, max_length=128)

    # Perform inference
    with torch.no_grad():
        outputs = model(**inputs)
        logits = outputs.logits
        predicted_class = torch.argmax(logits, dim=1).item()

    return predicted_class


tokenizer_config.json:   0%|          | 0.00/49.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/625 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/996k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.96M [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/714M [00:00<?, ?B/s]

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-multilingual-cased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Map:   0%|          | 0/8 [00:00<?, ? examples/s]

  trainer = Trainer(


Epoch,Training Loss,Validation Loss,Accuracy
1,No log,0.660873,0.5
2,No log,0.629771,0.5
3,No log,0.597554,1.0
4,No log,0.566563,1.0
5,No log,0.537317,1.0
6,No log,0.510272,1.0
7,No log,0.48781,1.0
8,No log,0.469979,1.0
9,No log,0.458428,1.0
10,0.560600,0.452327,1.0


**Install groq**

In [None]:
! pip install groq
! pip install -q langchain langchain-groq

**Use LLM for classifying if a query needs RAG or not (instead of using the above classifer)**

In [None]:
from langchain_core.prompts import ChatPromptTemplate
from langchain_groq import ChatGroq

def query_llm_classification(query: str, data: dict):
    """
    Classify whether a query requires retrieval using a few-shot learning approach with Groq.
    """

    # Groq Chat configuration
    chat = ChatGroq(
        temperature=0.3,
        groq_api_key="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",
        model_name="llama3-8b-8192"
    )

    # Few-shot prompt construction
    prompt = ChatPromptTemplate.from_template(
        """
        Decide if the input query needs retrieval for additional context.
        Follow the examples below:

        {few_shot_examples}

        Query: {query}
        Decision:
        """
    )

    # Construct few-shot examples from the provided data
    few_shot_examples = ""
    for q, label in zip(data["query"], data["label"]):
        decision = "retrieval required" if label == 1 else "retrieval not required"
        few_shot_examples += f"Query: {q}\nDecision: {decision}\n\n"

    # Format the input to the Groq model
    input_data = {
        "few_shot_examples": few_shot_examples.strip(),
        "query": query
    }

    # Generate the response using the chain
    chain = prompt | chat
    groq_response = chain.invoke(input_data)

    # Extract and return the decision
    decision = groq_response+"".strip()
    return decision

data = {
    "query": [
        "What is SQLCODE=-1585",
        "The configuration task database-transfer failed with DB2 SQL Error: SQLCODE=-1585...",
        "What is Websphere Application Server",
        "How do I change from shared to unshared connection?",
        "What is JMS",
        "Why is my MQ Java / JMS application getting 2035 NOT_AUTHORIZED error after upgrade?",
        "What is TLS",
        "TLS protocol with ITCAM for Datapower..."
    ],
    "label": [0, 1, 0, 1, 0, 1, 0, 1]
}

#new_query = "How can I resolve SQLCODE=-1585 in DB2?"
new_query = "What is Java Messaging Service?"
decision = query_llm_classification(new_query, data)

print(f"Query: {new_query}\nDecision: {decision}")


Query: What is Java Messaging Service?
Decision: input_variables=[] input_types={} partial_variables={} messages=[AIMessage(content='Decision: retrieval not required', additional_kwargs={}, response_metadata={'token_usage': {'completion_tokens': 6, 'prompt_tokens': 187, 'total_tokens': 193, 'completion_time': 0.005, 'prompt_time': 0.033584515, 'queue_time': 0.01802098, 'total_time': 0.038584515}, 'model_name': 'llama3-8b-8192', 'system_fingerprint': 'fp_a97cfe35ae', 'finish_reason': 'stop', 'logprobs': None}, id='run-fb3873f0-e0dc-48a0-843e-9ab58dc647c2-0', usage_metadata={'input_tokens': 187, 'output_tokens': 6, 'total_tokens': 193}), HumanMessagePromptTemplate(prompt=PromptTemplate(input_variables=[], input_types={}, partial_variables={}, template=''), additional_kwargs={})]


# **Use cross encoder for re-ranking**

In [44]:
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
import pandas as pd

# Load Cross-Encoder model and tokenizer
cross_encoder_model_name = "cross-encoder/ms-marco-MiniLM-L-6-v2"
tokenizer = AutoTokenizer.from_pretrained(cross_encoder_model_name)
model = AutoModelForSequenceClassification.from_pretrained(cross_encoder_model_name)

# Ensure PyTorch is in evaluation mode
model.eval()

def rerank_with_cross_encoder(query):
    """
    Rerank documents based on relevance scores from a Cross-Encoder model.

    Args:
        query (str): The query string.
        documents (list): A list of document strings.

    Returns:
        list: A list of tuples (document, score), sorted by score in descending order.
    """
    query_embedding = embedder.encode(question)
    results = datastor.search(query_embedding, 5)
    documents = datastor.extract_documents(results)

    scores = []
    for doc in documents:
        # Tokenize query-document pair
        inputs = tokenizer(
            query,
            doc,
            return_tensors="pt",
            max_length=512,  # Limit for most transformer models
            truncation=True,
            padding="max_length",
        )
        # Compute relevance scores
        with torch.no_grad():
            outputs = model(**inputs)
            logits = outputs.logits

            # Handle binary classification or regression logits
            if logits.size(1) == 2:  # Binary classification
                score = torch.softmax(logits, dim=1)[:, 1].item()  # Probability of relevance (class 1)
            else:  # Regression or single-class output
                score = logits.squeeze().item()  # Direct score (e.g., relevance regression)

            scores.append((doc, score))
    # Sort by score in descending order
    return sorted(scores, key=lambda x: x[1], reverse=True)



## **Presort of results with monoT5 model** ##

In [32]:
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, T5Tokenizer, T5ForConditionalGeneration
import torch.nn.functional as F
import torch

# Load MonoT5 model and tokenizer
monot5_model_name = "castorini/monot5-base-msmarco"  # or a suitable MonoT5 variant
#monot5_tokenizer = AutoTokenizer.from_pretrained(monot5_model_name)
#monot5_model = AutoModelForSeq2SeqLM.from_pretrained(monot5_model_name)

monot5_tokenizer = T5Tokenizer.from_pretrained(monot5_model_name)
monot5_model = T5ForConditionalGeneration.from_pretrained(monot5_model_name)


tokenizer_config.json:   0%|          | 0.00/1.89k [00:00<?, ?B/s]

spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/1.79k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/1.84k [00:00<?, ?B/s]

You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565


pytorch_model.bin:   0%|          | 0.00/892M [00:00<?, ?B/s]

In [33]:
pip install python-terrier

Collecting python-terrier
  Downloading python_terrier-0.13.0-py3-none-any.whl.metadata (11 kB)
Collecting ir-datasets>=0.3.2 (from python-terrier)
  Downloading ir_datasets-0.5.9-py3-none-any.whl.metadata (12 kB)
Collecting wget (from python-terrier)
  Downloading wget-3.2.zip (10 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting pyjnius>=1.4.2 (from python-terrier)
  Downloading pyjnius-1.6.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (10 kB)
Collecting ir-measures>=0.3.1 (from python-terrier)
  Downloading ir_measures-0.3.6-py3-none-any.whl.metadata (7.0 kB)
Collecting pytrec-eval-terrier>=0.5.3 (from python-terrier)
  Downloading pytrec_eval_terrier-0.5.6-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (777 bytes)
Collecting chest (from python-terrier)
  Downloading chest-0.2.3.tar.gz (9.6 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting lz4 (from python-terrier)
  Downloading lz4-4.3.3-cp311-cp311-ma

In [34]:
pip install pyterrier_t5

Collecting pyterrier_t5
  Downloading pyterrier_t5-0.1.0-py3-none-any.whl.metadata (4.4 kB)
Downloading pyterrier_t5-0.1.0-py3-none-any.whl (13 kB)
Installing collected packages: pyterrier_t5
Successfully installed pyterrier_t5-0.1.0


In [38]:
import pyterrier as pt
from pyterrier_t5 import MonoT5ReRanker
# Initialize PyTerrier (if not already initialized)
if not pt.started():
    pt.init()

def transform_to_tuples(doc_list):
    transformed = []
    for doc in doc_list:
        transformed.append((doc,))  # Wrap each document in a tuple
    return transformed


def rerank_with_monot5(question):
  # Initialize MonoT5 reranker
  monot5 = MonoT5ReRanker()

  query_embedding = embedder.encode(question)

  db_docs = datastor.search(query_embedding, 5)

  df_docs = pd.DataFrame({"text": db_docs})
  df_docs["qid"] = 1  # Query ID
  df_docs["query"] = question  # Add query column

  # Apply MonoT5 reranking
  reranked_docs = monot5.transform(df_docs)

  # Sort documents by rank (ascending)
  reranked_docs_sorted = reranked_docs.sort_values(by="score", ascending=False)
  print("reranked_docs_sorted>>>",reranked_docs_sorted)
  monot5_doc_list=[]
  # Print reranked and sorted documents
  for _, doc in reranked_docs_sorted.iterrows():
      print(f"Rank: {doc['score']}, Doc: {doc['text']}")
      monot5_doc_list.append(doc['text']['documents'])
  monot5_doc_list = transform_to_tuples(monot5_doc_list)
  return monot5_doc_list

  if not pt.started():


 ### function to sort data ###   

In [None]:
def sort_documents(query, documents, top_n=5):

  # Prepare inputs
  inputs = [f"Query: {query} Document: {doc}" for doc in documents]

  # Tokenize inputs
  tokenized_inputs = monot5_tokenizer(inputs, padding=True, truncation=True, return_tensors="pt")

  # Get model predictions
  outputs = monot5_model.generate(**tokenized_inputs)

  # Decode outputs to get relevance scores (e.g., 'true' or 'false')
  predictions = [monot5_tokenizer.decode(output, skip_special_tokens=True) for output in outputs]

  # Convert predictions to scores
  scores = [1.0 if pred.lower() == "true" else 0.0 for pred in predictions]

  # Rank documents
  ranked_docs = sorted(zip(documents, scores), key=lambda x: x[1], reverse=True)

  # Output ranked documents
  results = []
  for i, (doc, score) in enumerate(ranked_docs, start=1):
    print(f"debug: Rank {i}: {doc} (Score: {score})")
    if score == 1.0:
      if isinstance(doc, str):
        results.append(doc)
      else:
        results.append(doc['documents'])

  print(f"debug: results: {results}")

  return results[:top_n]

### Test code ###

In [39]:
'''
# call sort documents

print(f"quesiton: {question}")
db_docs = datastor.extract_documents(results_milvus)
print(f"documents: {db_docs}")

result_docs = sort_documents(question, db_docs)
print("result_docs >>>",result_docs)
'''
# Test with monot5 re-ranker
monot5_reranked_doc_list = rerank_with_monot5(question)
print("monot5_reranked_doc_list >>> ",monot5_reranked_doc_list)

search results size : 1
Collection: ragbench_collection_techqa_v01, data: {'id': '3683bbefd68a3161c00c001f3e58b2d2', 'distance': 0.8853264451026917, 'entity': {'documents': '[0b] Your WebSphere MQ queue manager fails to come up on the secondary node, and errors are generated.', 'metadata': {'dataset': 'techqa', 'global_index': 0, 'section_index': 0, 'item_index': 1, 'prefix': '0b', 'type': 'Passage'}}}
Collection: ragbench_collection_techqa_v01, data: {'id': '5bb2a66f88c5070dd0f5d37785e3a748', 'distance': 0.8725427389144897, 'entity': {'documents': '[0d] RESOLVING THE PROBLEM\nRename the file amqalchk.fil, which is found under mq\\qmgrs\\qmgrname\\ on the shared drive (to something like amqalchk.fil_OLD); then restart the queue manager.', 'metadata': {'dataset': 'techqa', 'global_index': 0, 'section_index': 0, 'item_index': 3, 'prefix': '0d', 'type': 'Passage'}}}
Collection: ragbench_collection_techqa_v01, data: {'id': 'c2b5c816355deb079e4290bd60b23860', 'distance': 0.8629571199417114,

monoT5:   0%|          | 0/2 [00:00<?, ?batches/s]

reranked_docs_sorted>>>                                                 text  qid  \
4  {'collection': 'ragbench_collection_techqa_v01...    1   
3  {'collection': 'ragbench_collection_techqa_v01...    1   
1  {'collection': 'ragbench_collection_techqa_v01...    1   
0  {'collection': 'ragbench_collection_techqa_v01...    1   
2  {'collection': 'ragbench_collection_techqa_v01...    1   

                                               query     score  rank  
4  Why does the other instance of my multi-instan... -0.060666     0  
3  Why does the other instance of my multi-instan... -0.282272     1  
1  Why does the other instance of my multi-instan... -0.536859     2  
0  Why does the other instance of my multi-instan... -2.973851     3  
2  Why does the other instance of my multi-instan... -4.154908     4  
Rank: -0.06066613271832466, Doc: {'collection': 'ragbench_collection_techqa_v01', 'id': '3683bbefd68a3161c00c001f3e58b2d2', 'metadata': {'dataset': 'techqa', 'global_index': 0, 'secti

In [45]:
# Re-rank the documents with cross encoder

reranked_docs = rerank_with_cross_encoder(question)

# Display results
print("Cross Encoder Re-ranked Documents:")
for rank, (doc, score) in enumerate(reranked_docs, 1):
    print(f"Rank {rank}: Score = {score:.4f}, Doc = {doc}")


search results size : 1
Collection: ragbench_collection_techqa_v01, data: {'id': '3683bbefd68a3161c00c001f3e58b2d2', 'distance': 0.8853264451026917, 'entity': {'documents': '[0b] Your WebSphere MQ queue manager fails to come up on the secondary node, and errors are generated.', 'metadata': {'dataset': 'techqa', 'global_index': 0, 'section_index': 0, 'item_index': 1, 'prefix': '0b', 'type': 'Passage'}}}
Collection: ragbench_collection_techqa_v01, data: {'id': '5bb2a66f88c5070dd0f5d37785e3a748', 'distance': 0.8725427389144897, 'entity': {'documents': '[0d] RESOLVING THE PROBLEM\nRename the file amqalchk.fil, which is found under mq\\qmgrs\\qmgrname\\ on the shared drive (to something like amqalchk.fil_OLD); then restart the queue manager.', 'metadata': {'dataset': 'techqa', 'global_index': 0, 'section_index': 0, 'item_index': 3, 'prefix': '0d', 'type': 'Passage'}}}
Collection: ragbench_collection_techqa_v01, data: {'id': 'c2b5c816355deb079e4290bd60b23860', 'distance': 0.8629571199417114,

# **HyDE Implementation with an LLM to generate a pseudo document** ##

In [None]:
from langchain_core.prompts import ChatPromptTemplate
from langchain_groq import ChatGroq

chat = ChatGroq(temperature=0.3, groq_api_key="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", model_name="llama3-8b-8192")

def generate_pseudo_document_llm(query: str):

    prompt = ChatPromptTemplate.from_template(
      """
      Please provide a response to the query below, in a passage.
      Respond just with the passage only.

      Question: {query}

      Passage:
      """
    )

    chain = prompt | chat

    groq_response = chain.invoke({"query": query})

    groq_response.pretty_print()

    answer = groq_response.content
    return answer


def get_pseudo_doc_embeddings(answer: str):
    chunks = chunk_with_token_limit(answer, TOKEN_LIMIT, SLIDING_WINDOW_OVERLAP)
    return embedder.encode(chunks).tolist()

def fetch_docs_pseudo(query: str):
    answer = generate_pseudo_document_llm(query)
    pseudo_doc_embeddings = get_pseudo_doc_embeddings(answer)
    return datastor.extract_documents(datastor.search(pseudo_doc_embeddings))


In [None]:
question = "When was Rolex registered?"

answer = generate_pseudo_document_llm(question)
print(f'answer {answer}')
doc_results = fetch_docs_pseudo(question)

print(f'doc_results: {doc_results}')

# **Doc Summarization (Recomp)**

In [None]:
from transformers import pipeline

# Initialize the summarization pipeline
summarizer = pipeline('summarization')

# Abstractive Compression: Generate a concise summary by synthesizing information from multiple documents.
def summarize_docs(retrieved_docs):

    # Concatenate documents and perform abstractive summarization
    concatenated_docs = ' '.join(retrieved_docs)
    summary = summarizer(concatenated_docs, max_length=150, min_length=30, do_sample=False)[0]['summary_text']

    return summary


No model was supplied, defaulted to sshleifer/distilbart-cnn-12-6 and revision a4f8f3e (https://huggingface.co/sshleifer/distilbart-cnn-12-6).
Using a pipeline without specifying a model name and revision in production is not recommended.


config.json:   0%|          | 0.00/1.80k [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/1.22G [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/899k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

Device set to use cuda:0


# **LLM Inference with groq**

In [46]:
! pip install langchain
! pip install groq
! pip install -q langchain langchain-groq

Collecting groq
  Downloading groq-0.15.0-py3-none-any.whl.metadata (14 kB)
Downloading groq-0.15.0-py3-none-any.whl (109 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m109.6/109.6 kB[0m [31m8.5 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: groq
Successfully installed groq-0.15.0


In [47]:
import langchain

from langchain_core.prompts import ChatPromptTemplate
from langchain_groq import ChatGroq

# Initializing the context variable which later gets populated with retrieved chunks
context = ""

In [69]:
def query_response_from_llm(query: str):

    # retrieve chunks from milvus db
    chunks = retrieve_docs_milvus(query)

    # retrieve chunks from chroma db
    #chunks = retrieve_docs(query)

    # Flatten the list if necessary
    if any(isinstance(chunk, list) for chunk in chunks):
      chunks = [item for sublist in chunks for item in (sublist if isinstance(sublist, list) else [sublist])]

    chat = ChatGroq(temperature=0.3, groq_api_key="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", model_name="llama3-8b-8192")

    prompt=ChatPromptTemplate.from_template(
      """
      Please provide a response to the query below, strictly adhering to the
      information presented in the following documents.
      Do not generate any text beyond what is explicitly stated in the documents.

      Context: {context}

      Question: {query}

      Answer:
      """
    )


    # Summarize the retrieved doc chunks
    # Compress documents using abstractive summarization before appending to context
    chunk_summary = ""

    '''
    print("chunks>>>>",chunks)
    sorted_chunks = sort_documents(query, chunks)
    print(f"sorted_chunks: {sorted_chunks}")
    chunk_summary = summarize_docs(sorted_chunks)
    print("chunk_summary>>>>",chunk_summary)
    chunks = chunk_summary.join(chunks)
    '''

    #Cross encoder reranked docs
    reranked_chunks = rerank_with_cross_encoder(question)

    #monot5 reranked docs
    #reranked_chunks = rerank_with_monot5(question)

    doc_chunks = [chunk[0] for chunk in reranked_chunks]
    print("doc_chunks>>after re-ranking>>", doc_chunks)
    chain = prompt | chat

    context = "".join(doc_chunks)

    print("context>>>from 1st RAG>>>>>> ",context)

    groq_response = chain.invoke({"context": context, "query": query})

    print("groq_response>>>from 1st RAG>>>>>> ",groq_response)

    answer = groq_response
    return answer, context

In [None]:
def query_response_from_llm_no_rag(query: str):

    chat = ChatGroq(temperature=0.3, groq_api_key="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", model_name="llama3-8b-8192")

    prompt=ChatPromptTemplate.from_template(
      """
      Please provide a response to the query below

      Question: {query}

      Answer:
      """
    )

    chain = prompt | chat

    groq_response_no_rag = chain.invoke({"query": query})

    print("groq_response_no_rag>>>no RAG>>>>>> ",groq_response_no_rag)

    answer = groq_response_no_rag
    return answer

## **USER QUERY**

In [70]:
#query = "What are the most frequent clinical manifestations of human adenovirus type 55 (HAdV-55) induced ARDS?"
#query ="The configuration task database-transfer failed with DB2 SQL Error: SQLCODE=-1585, SQLSTATE=54048 While attempting to run the database-transfer task the following error is logged to the ConfigTrace.log: action-process-constraints: Fri Oct 10 13:20:34 CDT 2014 Target started: action-process-constraints [java] Executing java with empty input string [java] [10/10/14 13:20:35.877 CDT] Attempting to create a new Instance of com.ibm.db2.jcc.DB2Driver [java] [10/10/14 13:20:36.016 CDT] Instance of com.ibm.db2.jcc.DB2Driver created successfully [java] [10/10/14 13:20:36.016 CDT] Attempting to make connection using: jdbc:db2://:60500/:returnAlias=0; :: d2svc :: PASSWORD_REMOVED [java] [10/10/14 13:20:36.954 CDT] Connection successfully made [java] [10/10/14 13:20:37.073 CDT] ERROR: Error occurred gathering data from the source database [java] com.ibm.db2.jcc.am.SqlException: DB2 SQL Error: SQLCODE=-1585, SQLSTATE=54048, SQLERRMC=null, DRIVER=4.18.60 [java] at com.ibm.db2.jcc.am.kd.a(kd.java:752)"
query = "Why does the other instance of my multi-instance qmgr seem to hang after a failover? Queue manager will not start after failover."

#query = "What is SQLCODE=-1585"

# Query Classification
#classification = classify_query(query)

classification = 1

#print('classification >>>> ',classification)

if classification == 1:
    print("Perform retrieval using RAG.")
    # Call RAG pipeline for retrieval and response generation
    #answer,context = query_response_from_llm(query)
    print('Question for LLM>>>>', question)
    answer,context = query_response_from_llm(question)
else:
    print("No RAG - only use LLM for response generation.")
    # Call LLM directly
    answer = query_response_from_llm_no_rag(query)

print("context  >>from llm>>> ", context)
print("answer  >>from llm>>> ", answer)

Perform retrieval using RAG.
Question for LLM>>>> Why does the other instance of my multi-instance qmgr seem to hang after a failover? Queue manager will not start after failover
search results size : 1
Collection: ragbench_collection_techqa_v01, data: {'id': '3683bbefd68a3161c00c001f3e58b2d2', 'distance': 0.8853264451026917, 'entity': {'documents': '[0b] Your WebSphere MQ queue manager fails to come up on the secondary node, and errors are generated.', 'metadata': {'dataset': 'techqa', 'global_index': 0, 'section_index': 0, 'item_index': 1, 'prefix': '0b', 'type': 'Passage'}}}
Collection: ragbench_collection_techqa_v01, data: {'id': '5bb2a66f88c5070dd0f5d37785e3a748', 'distance': 0.8725427389144897, 'entity': {'documents': '[0d] RESOLVING THE PROBLEM\nRename the file amqalchk.fil, which is found under mq\\qmgrs\\qmgrname\\ on the shared drive (to something like amqalchk.fil_OLD); then restart the queue manager.', 'metadata': {'dataset': 'techqa', 'global_index': 0, 'section_index': 0,

## **PROMPT for generating metrics as JSON response**

In [71]:
def generate_prompt():
    """
    Generate a prompt template for assessing the support and relevance of an LLM-generated response.
    """
    return """
    I asked someone to answer a question based on one or more documents.
    Your task is to review their response and assess whether or not each sentence
    in that response is supported by text in the documents. And if so, which
    sentences in the documents provide that support. You will also tell me which
    of the documents contain useful information for answering the question, and
    which of the documents the answer was sourced from.
    Here are the documents, each of which is split into sentences.Alongside each
    sentence is associated key, such as ’[0a].’ or ’[0b].’ that you can use to refer
    to it:

    ‘‘‘
    {documents}
    ‘‘‘
    The question was:
    ‘‘‘
    {question}
    ‘‘‘

    Here is their response, split into sentences. Alongside each sentence is
    associated key, such as ’a.’ or ’b.’ that you can use to refer to it. Note
    that these keys are unique to the response, and are not related to the keys
    in the documents:
    ‘‘‘
    {answer}
    ‘‘‘
    You must respond with a JSON object matching this schema:
    ‘‘‘
    {{
    "relevance_explanation": string,
    "all_relevant_sentence_keys": [string],
    "overall_supported_explanation": string,
    "overall_supported": boolean,
    "sentence_support_information": [
    {{
    "response_sentence_key": string,
    "explanation": string,
    "supporting_sentence_keys": [string],
    "fully_supported": boolean
    }},
    ],
    "all_utilized_sentence_keys": [string]
    }}
    ‘‘‘
    The relevance_explanation field is a string explaining which documents
    contain useful information for answering the question. Provide a step-by-step
    breakdown of information provided in the documents and how it is useful for
    answering the question.
    The all_relevant_sentence_keys field is a list of all document sentences keys
    (e.g. ’0a’) that are relevant to the question. Include every sentence that is
    useful and relevant to the question, even if it was not used in the response,
    or if only parts of the sentence are useful. Ignore the provided response when
    making this judgement and base your judgement solely on the provided documents
    and question. Omit sentences that, if removed from the document, would not
    impact someone’s ability to answer the question.
    The overall_supported_explanation field is a string explaining why the response
    *as a whole* is or is not supported by the documents. In this field, provide a
    step-by-step breakdown of the claims made in the response and the support (or
    lack thereof) for those claims in the documents. Begin by assessing each claim
    separately, one by one; don’t make any remarks about the response as a whole
    until you have assessed all the claims in isolation.
    The overall_supported field is a boolean indicating whether the response as a
    whole is supported by the documents. This value should reflect the conclusion
    you drew at the end of your step-by-step breakdown in overall_supported_explanation.
    In the sentence_support_information field, provide information about the support
    *for each sentence* in the response.
    The sentence_support_information field is a list of objects, one for each sentence
    in the response. Each object MUST have the following fields:
    - response_sentence_key: a string identifying the sentence in the response.
    This key is the same as the one used in the response above.

    - explanation: a string explaining why the sentence is or is not supported by the
    documents.
    - supporting_sentence_keys: keys (e.g. ’[0a]’) of sentences from the documents that
    support the response sentence. If the sentence is not supported, this list MUST
    be empty. If the sentence is supported, this list MUST contain one or more keys.
    In special cases where the sentence is supported, but not by any specific sentence,
    you can use the string "supported_without_sentence" to indicate that the sentence
    is generally supported by the documents. Consider cases where the sentence is
    expressing inability to answer the question due to lack of relevant information in
    the provided context as "supported_without_sentence". In cases where the sentence
    is making a general statement (e.g. outlining the steps to produce an answer, or
    summarizing previously stated sentences, or a transition sentence), use the
    string "general". In cases where the sentence is correctly stating a well-known fact,
    like a mathematical formula, use the string "well_known_fact". In cases where the
    sentence is performing numerical reasoning (e.g. addition, multiplication), use
    the string "numerical_reasoning".
    - fully_supported: a boolean indicating whether the sentence is fully supported by
    the documents.
    - This value should reflect the conclusion you drew at the end of your step-by-step
    breakdown in explanation.
    - If supporting_sentence_keys is an empty list, then fully_supported must be false.
    - Otherwise, use fully_supported to clarify whether everything in the response
    sentence is fully supported by the document text indicated in supporting_sentence_keys
    (fully_supported = true), or whether the sentence is only partially or incompletely
    supported by that document text (fully_supported = false).
    The all_utilized_sentence_keys field is a list of all sentences keys (e.g. ’0a’) that
    were used to construct the answer. Include every sentence that either directly supported
    the answer, or was implicitly used to construct the answer, even if it was not used
    in its entirety. Omit sentences that were not used, and could have been removed from
    the documents without affecting the answer.
    You must respond with a valid JSON string. Use escapes for quotes, e.g. \\"\\", and
    newlines, e.g. \\n. Do not write anything before or after the JSON string. Do not
    wrap the JSON string in backticks like ‘‘‘ or ‘‘‘json.
    As a reminder: your task is to review the response and assess which documents contain
    useful information pertaining to the question, and how each sentence in the response
    is supported by the text in the documents.
    """.strip()


## **Response generation using groq using llama3-70b-8192**

In [72]:
from langchain_core.prompts import PromptTemplate
from langchain_core.prompts import ChatPromptTemplate
from langchain_groq import ChatGroq

chat = ChatGroq(temperature=0.3, groq_api_key="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", model_name="llama3-70b-8192")

prompt_template_with_docs = PromptTemplate(
    input_variables=["documents", "question", "answer"],
    template=generate_prompt(),
)

# For testing
query = question

print('context for groq >>>> ', context)
print('query for groq >>>> ', query)
print('answer for groq >>>> ', answer)

chain = prompt_template_with_docs | chat

# if classification == 1:
#  groq_response_with_context_qanda = chain.invoke({"documents": context, "question": query, "answer":groq_response})
# else
#  groq_response_with_out-rag = chain.invoke({"documents": "", "question": query, "answer":groq_response})

groq_response_with_context_qanda = chain.invoke({"documents": context, "question": query, "answer":answer})

print("groq_response>>>>with context, query and answer>>>>> ",groq_response_with_context_qanda)

context for groq >>>>  [1j] PROBLEM SUMMARY
 *  ****************************************************************
   * USERS AFFECTED:                                              *
   * Users with an HDR pair that is an ER participant             *
   ****************************************************************
   * PROBLEM DESCRIPTION:                                         *
   * After failover, the primary shows online mode but the        *
   * clients are not able to connect to the server.[0d] RESOLVING THE PROBLEM
Rename the file amqalchk.fil, which is found under mq\qmgrs\qmgrname\ on the shared drive (to something like amqalchk.fil_OLD); then restart the queue manager.[0b] Your WebSphere MQ queue manager fails to come up on the secondary node, and errors are generated.[0a] HL083112 mqlpgrlg ZX000001 ExecCtrlrMain lpiRC_LOG_NOT_AVAILABLE mscs TECHNOTE (TROUBLESHOOTING)

PROBLEM(ABSTRACT)
 You attempt to failover from the primary to secondary node under MSCS.[1h] If the mode

## **JSON Data parsing to retrieve metrics**

In [73]:
import re
import json

In [74]:
# Extract the content field using regular expressions
content_match = re.search(r"content='(.*?)' additional_kwargs=", str(groq_response_with_context_qanda), re.DOTALL)
if content_match:
    content = content_match.group(1)
    print("Extracted Content:")
    print(content)

    json_match = re.search(r"\{.*\}", content, re.DOTALL)
    if json_match:
      json_str = json_match.group(0)

      json_str = json_str.replace("'", '"').replace("\\n","").replace("\\","")
      print(json_str)
      try:
        # Parse the JSON
        parsed_json = json.loads(json_str)
        print("Extracted JSON:")
        print(json.dumps(parsed_json, indent=4))

      except json.JSONDecodeError as e:
        print(f"Error decoding JSON: {e}")
    else:
      print("No JSON found in the provided string.")

else:
    print("Content field not found in the provided string.")


Extracted Content:
{\n"relevance_explanation": "The document contains useful information for answering the question, specifically the problem description and resolution sections.",\n"all_relevant_sentence_keys": ["[1h]", "[1j]", "[0d]"],\n"overall_supported_explanation": "The response claims that the reason for the queue manager hanging after failover is because the mode is not set until after ER has finished syncing and connections are accepted or a new return value is used to indicate that the instance is online and accepting connections. This claim is supported by the document, specifically sentence [1h], which describes the problem as attempting to failover from the primary to secondary node under MSCS, and sentence [1j], which provides a possible solution to the problem.",\n"overall_supported": true,\n"sentence_support_information": [\n{\n"response_sentence_key": "a.",\n"explanation": "The sentence is supported by the document, specifically sentences [1h] and [1j].",\n"supporting_

In [75]:
data = parsed_json

## **Computation Metrics from JSON response in comparison with ground truth**

In [76]:
import json

In [77]:
# Helper function for length computation (mocked as sentence count here)
def compute_length(keys):
    return len(keys)

# Metrics Computation
def compute_metrics(data):
    all_relevant = data["all_relevant_sentence_keys"]
    all_utilized = data["all_utilized_sentence_keys"]
    sentences_info = data["sentence_support_information"]

    # Context Relevance
    total_relevant_length = compute_length(all_relevant)
    total_context_length = total_relevant_length  # Assuming all relevant are part of the context
    context_relevance = total_relevant_length / total_context_length if total_context_length > 0 else 0

    # Context Utilization
    total_utilized_length = compute_length(all_utilized)
    context_utilization = total_utilized_length / total_context_length if total_context_length > 0 else 0

    # Completeness
    total_relevant_utilized = sum(
        1 for s in sentences_info if set(s["supporting_sentence_keys"]).intersection(all_utilized)
    )
    completeness = total_relevant_utilized / total_relevant_length if total_relevant_length > 0 else 0

    # Adherence
    adherence = all(s["fully_supported"] for s in sentences_info)

    return {
        "Context Relevance": context_relevance,
        "Context Utilization": context_utilization,
        "Completeness": completeness,
        "Adherence": adherence
    }

In [78]:
# Compute and print metrics
predicted_metrics = compute_metrics(data)
print(json.dumps(predicted_metrics, indent=4))

{
    "Context Relevance": 1.0,
    "Context Utilization": 0.6666666666666666,
    "Completeness": 0.3333333333333333,
    "Adherence": true
}


## **Fetching the groud truth values**

In [None]:
sub_dataset = load_dataset("rungalileo/ragbench", "techqa")

In [None]:
def fetch_ground_truth(dataset, question):
    for sample in dataset["train"]:  # Change "train" to the correct split if needed
        if sample["question"] == question:
            return {
                "Context Relevance": sample["relevance_score"],  # Adjust column name if needed
                "Context Utilization": sample["utilization_score"],  # Adjust column name if needed
                "Adherence": sample["adherence_score"]  # Adjust column name if needed
            }
    return None


In [None]:
print(f"query >>>> {query}")
ground_truth = fetch_ground_truth(sub_dataset, query)

if ground_truth:
    print("Ground Truth Values:")
    print(json.dumps(ground_truth, indent=4))
else:
    print(f"Question not found in the dataset: {query}")

query >>>> Why does the other instance of my multi-instance qmgr seem to hang after a failover? Queue manager will not start after failover.
Ground Truth Values:
{
    "Context Relevance": 0.01411764705882353,
    "Context Utilization": 0.009411764705882352,
    "Adherence": true
}


In [None]:
## DONOT RUN THIS
ground_truth = """{
    "Context Relevance": 0.6470588235294118,
    "Context Utilization": 0.35294117647058826,
    "Adherence": false
}"""

In [None]:
## DONOT RUN THIS
predicted_metrics = """{
    "Context Relevance": 1.0,
    "Context Utilization": 0.6,
    "Completeness": 1.0,
    "Adherence": true
}"""

## **Evaluation Metrics**

In [None]:
from sklearn.metrics import mean_squared_error, roc_auc_score
import numpy as np
import json

# **RMSE**

In [None]:
def compute_rmse(predicted, ground_truth):

    # Extract true and predicted values
    y_true_relevance = ground_truth["Context Relevance"]
    y_true_utilization = ground_truth["Context Utilization"]

    y_pred_relevance = predicted["Context Relevance"]
    y_pred_utilization = predicted["Context Utilization"]

    # Compute RMSE for Context Relevance and Context Utilization
    rmse_relevance = np.sqrt((y_pred_relevance - y_true_relevance) ** 2)
    rmse_utilization = np.sqrt((y_pred_utilization - y_true_utilization) ** 2)

    return {
        "RMSE-Relevance": rmse_relevance,
        "RMSE-Utililization": rmse_utilization,
    }


# **AUCROC**

In [None]:
#The AUC-ROC is a metric for binary classification, but it requires:
# 1) Both positive and negative classes in the ground truth.
# 2) A set of predictions (not just a single value).
# The AUR-ROC code that we have written earlier will always fail because at any given point of time only one class (either true or false) of ground truth and predicted value is
# passed to the roc_auc_score library function. However, by definition, AUC-ROC requires a 'set of 2 class' values. Following is a toy example
# TODO : For our project metric, we need to pass 2 sets of values to roc_auc_score function
  # 1) multiple adherence values of ground truth queries
  # 2) multiple adherence values from our predictions

y_true = ["true","false","false","true"]  # Ground truth with both classes
y_pred = ["true","true","true","false"]  # Model probabilities

mapping = {"true": 1, "false": 0}
y_true_numeric_alt = [mapping[val] for val in y_true]
y_pred_numeric_alt = [mapping[val] for val in y_pred]

auc_roc = roc_auc_score(y_true_numeric_alt, y_pred_numeric_alt)
print("AUC-ROC:", auc_roc)

AUC-ROC: 0.25


In [None]:
evaluation_metrics = compute_rmse(predicted_metrics, ground_truth)

TypeError: string indices must be integers, not 'str'

In [None]:
# Print Results
print("Ground Truth Values (JSON):")
print(json.dumps(ground_truth, indent=4))
print("\nPredicted Metrics:")
print(json.dumps(predicted_metrics, indent=4))
print("\nEvaluation Metrics (RMSE and AUC-ROC):")
print(json.dumps(evaluation_metrics, indent=4))