# Retriever Customization - Synthetic Data Generation (Part 1/2)

Authors - Aditya Malte, Vinay Raman, Dora Li

## Introduction
Text retrievers and embedding models play a crucial role in modern information retrieval systems by converting both queries and documents into dense numerical vectors (embeddings) that capture their semantic meaning. This allows the system to find relevant documents by measuring the similarity between a query's embedding and document embeddings in the database.

The accuracy of these models directly impacts their usefulness. When a retriever has been trained primarily on one type of content (like general web text or news articles) but is asked to retrieve documents from a specialized domain (such as medical literature), its performance can degrade significantly.

This is why many organizations fine-tune domain-specific retrievers for their particular use cases, ensuring more accurate and relevant document retrieval. As with all fine-tuning, high-quality domain-specific data is required and can be generated with LLMs such as [NVIDIA's Nemotron-4-304B-Instruct](https://blogs.nvidia.com/blog/nemotron-4-synthetic-data-generation-llm-training/) that are specially trained and licensed for synthetic data generation. 

## Overview 

This two-part tutorial demonstrates how to improve retrieval performance by fine-tuning embedding models using synthetic training data. The process is split across two notebooks:
 
1. `synthetic_data_generation_nemo.ipynb` **(this notebook)**:
    - Use an LLM from build.nvidia.com (or deploy your own using NIM!) to create training examples containing generated queries and positive chunks. By default the notebook will use nfcorpus, but you can easily swap in your own data.
    - Implement hard negative mining to find challenging negative examples
    - Save results to a `.jsonl` file 


2. `retriever_customization.ipynb`
    - Use the generated training data in the `.jsonl` file to fine-tune a retriever model using Nemo Framework
    - Evaluate the results of your fine-tuned embedding model against the original using BeIR Benchmark

NOTE: This tutorial is only meant as a demo, and hence only a small subset of the corpus is used for training data generation - in order for the notebook run to complete in a reasonable time.

## Setup Instructions

#### NeMo Framework Docker Container ####
This notebook runs in a Docker environment built from the NeMo FW repo. Refer https://github.com/NVIDIA/NeMo/tree/main for instructions on how to build and run the docker containers. Ensure that the docker container you run this notebook in is built from the main branch of the NeMo repository. The current notebooks were tested on Nemo Framework 24.07 on a single-GPU machine (L40s).

Run docker when inside the `synthetic-data-retriever-customization` directory using this command:

`docker run -it --rm --gpus all --ipc=host --network host -v $(pwd):/workspace nvcr.io/nvidia/nemo:24.07`

<br> 

#### NVIDIA AI Endpoints
You'll need access to an **LLM** for generating queries and a **Text Embedding Model** for mining hard negatives. By default, this notebook uses the [Nemotron-4-340b-Instruct](https://build.nvidia.com/nvidia/nemotron-4-340b-instruct) and [NV-EmbedQA-E5-V5](https://build.nvidia.com/nvidia/nv-embedqa-e5-v5) API endpoints from [www.build.nvidia.com](https://www.build.nvidia.com), for the LLM and text embedding models respectively. 

**An API Key is required.** Get your API Key by following the link above to the model and clicking on "Build with this NIM". All new users will get a number of tokens upon registering. Set the environment variable NVIDIA_API_KEY with your API key value.

Optionally, you can self-host either model using **[NIM (NVIDIA Inference Microservice](https://docs.nvidia.com/nim/large-language-models/latest/getting-started.html)** and pass in the local url when creating your LLM client later on. Follow the instructions in the link. Note that system GPU requirements will depend on the model you choose to deploy.  


## Import Libraries

In [1]:
import os
import json
import pandas as pd
from collections import OrderedDict
import torch
import math
import numpy as np

import re
from nltk.tokenize import sent_tokenize
import nltk
nltk.download('punkt')
nltk.download('punkt_tab')

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package punkt_tab to /root/nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!


True

In [2]:
!pip install ipywidgets
!pip install beir

Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com
[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.0[0m[39;49m -> [0m[32;49m24.3.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython -m pip install --upgrade pip[0m
Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com


[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.0[0m[39;49m -> [0m[32;49m24.3.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython -m pip install --upgrade pip[0m


### Download Nfcorpus Dataset

In [3]:
from datasets import load_dataset

As an example, we have chosen the [`nfcorpus`](https://www.cl.uni-heidelberg.de/statnlpgroup/nfcorpus/) public text dataset to generate the synthetic data from. But you can choose any other existing dataset or ideally provide your own proprietary documents to generate data from. 

In [4]:
DOMAIN = "BeIR/nfcorpus"
corpus = load_dataset(DOMAIN, "corpus")["corpus"]

## Synthetic Data Generation from Knowledge Base

In this section we will: 
1. Break each text sample in the nfcorpus dataset that we downloaded into smaller chunks.
2. Compose an LLM prompt that provides detailed instructions on how to generate queries based on each chunk.
3. Send the queries to our LLM as an asynchronous batch job.
4. Parse the queries and populate our synthetic dataset with query + positive chunks.

### 1. Chunk Knowledge Base

Chunking is required to break large documents into smaller chunks that an LLM can take as input. In this case we chunk the texts into samples of around word count 300, ensuring that sentences are not broken. 

In [5]:
def chunk_text(samples, chunk_size=300):
    
    final_chunks = []
    for idx, (text, title, paragraph_id) in enumerate(samples):
        sentence_list = sent_tokenize(text)
        chunk = []
        chunk_id = 0
        word_count = 0
        for sentence in sentence_list:
            word_tokens = sentence.split()
            word_count += len(word_tokens)
            chunk.append(sentence)
            
            if(word_count >= chunk_size):
                chunk_text = " ".join(chunk)
                final_chunks.append((chunk_text, title, paragraph_id, chunk_id))
                chunk_id += 1
                
                chunk = []
                word_count = 0
                
        if len(chunk) > 0:
            chunk_text = " ".join(chunk)
            if(chunk_id==0 or len(chunk_text.split())>int((0.4*chunk_size))): # Only include the last chunk if it has significant number of words
                final_chunks.append((chunk_text, title, paragraph_id, chunk_id)) # , or if the sample itself was a single chunk (chunk_id=0)

    return final_chunks

In [6]:
kb = pd.DataFrame(corpus)
kb["paragraph_id"] = range(len(kb)) # assign a paragraph id to keep track of the original source document of each chunk
kb

Unnamed: 0,_id,title,text,paragraph_id
0,MED-10,Statin Use and Breast Cancer Survival: A Natio...,"Recent studies have suggested that statins, an...",0
1,MED-14,Statin use after diagnosis of breast cancer an...,BACKGROUND: Preclinical studies have shown tha...,1
2,MED-118,Alkylphenols in human milk and their relations...,The aims of this study were to determine the c...,2
3,MED-301,Methylmercury: A Potential Environmental Risk ...,Epilepsy or seizure disorder is one of the mos...,3
4,MED-306,Sensitivity of Continuous Performance Test (CP...,Hit Reaction Time latencies (HRT) in the Conti...,4
...,...,...,...,...
3628,MED-917,Effect of freezing and storage on the phenolic...,Scottish-grown red raspberries are a rich sour...,3628
3629,MED-941,Topical vitamin A treatment of recalcitrant co...,BACKGROUND: Common warts (verruca vulgaris) ar...,3629
3630,MED-942,Esophageal injury by apple cider vinegar table...,Apple cider vinegar products are advertised in...,3630
3631,MED-952,Cannabis and the lung.,The use of cannabis is embedded within many so...,3631


Notes: 
- We are only sampling 100 out of around 5000 documents in the corpus, in order to allow the notebook to complete in a reasonable time for this tutorial. Feel free to increase it, especially if you are running this with your own data.

- Most of the nfcorpus documents are already very short passages so they will only contain a single chunk. 

In [7]:
kb = kb.sample(n=100)

kb_chunked = pd.DataFrame(chunk_text(kb[["text", "title", "paragraph_id"]].itertuples(index=False)), columns=["text", "title", "paragraph_id", "chunk_id"])
kb_chunked.columns = ['chunk_text', 'title', 'paragraph_id', 'chunk_id']
kb_chunked.drop(["title"], axis=1, inplace=True)
kb_chunked

Unnamed: 0,chunk_text,paragraph_id,chunk_id
0,PURPOSE: To report the rate of recanalization ...,3458,0
1,Recent evidence underlines the role of Western...,1116,0
2,PURPOSE: The data on the association between c...,3488,0
3,Starch in white wheat bread (WB) induces high ...,759,0
4,1. Bowel transit time has been investigated in...,2927,0
...,...,...,...
96,OBJECTIVES/HYPOTHESIS: This study aimed to ass...,2873,0
97,Research finding on the composition of macronu...,929,0
98,OBJECTIVE: To demonstrate the effects of a ver...,3122,0
99,The widely used food additive carrageenan (CGN...,2025,0


In [8]:
kb = kb.merge(kb_chunked, how="left", on="paragraph_id")
kb

Unnamed: 0,_id,title,text,paragraph_id,chunk_text,chunk_id
0,MED-5212,Surgical punctal occlusion with a high heat-en...,PURPOSE: To report the rate of recanalization ...,3458,PURPOSE: To report the rate of recanalization ...,0
1,MED-2117,Diet in acne: further evidence for the role of...,Recent evidence underlines the role of Western...,1116,Recent evidence underlines the role of Western...,0
2,MED-5243,Coffee consumption and risk of fractures: a sy...,PURPOSE: The data on the association between c...,3488,PURPOSE: The data on the association between c...,0
3,MED-1676,Berries reduce postprandial insulin responses ...,Starch in white wheat bread (WB) induces high ...,759,Starch in white wheat bread (WB) induces high ...,0
4,MED-4637,Fibre and bowel transit times.,1. Bowel transit time has been investigated in...,2927,1. Bowel transit time has been investigated in...,0
...,...,...,...,...,...,...
96,MED-4565,The clinical significance of nasal irrigation ...,OBJECTIVES/HYPOTHESIS: This study aimed to ass...,2873,OBJECTIVES/HYPOTHESIS: This study aimed to ass...,0
97,MED-1873,Dietary saturated fat intake is negatively ass...,Research finding on the composition of macronu...,929,Research finding on the composition of macronu...,0
98,MED-4853,"Effects of a very low-fat, vegan diet in subje...",OBJECTIVE: To demonstrate the effects of a ver...,3122,OBJECTIVE: To demonstrate the effects of a ver...,0
99,MED-3496,Pro-inflammatory NF-κB and early growth respon...,The widely used food additive carrageenan (CGN...,2025,The widely used food additive carrageenan (CGN...,0


In [9]:
kb["chunk_title_text_concat"] = kb["title"] + "\n" + kb["chunk_text"]
kb

Unnamed: 0,_id,title,text,paragraph_id,chunk_text,chunk_id,chunk_title_text_concat
0,MED-5212,Surgical punctal occlusion with a high heat-en...,PURPOSE: To report the rate of recanalization ...,3458,PURPOSE: To report the rate of recanalization ...,0,Surgical punctal occlusion with a high heat-en...
1,MED-2117,Diet in acne: further evidence for the role of...,Recent evidence underlines the role of Western...,1116,Recent evidence underlines the role of Western...,0,Diet in acne: further evidence for the role of...
2,MED-5243,Coffee consumption and risk of fractures: a sy...,PURPOSE: The data on the association between c...,3488,PURPOSE: The data on the association between c...,0,Coffee consumption and risk of fractures: a sy...
3,MED-1676,Berries reduce postprandial insulin responses ...,Starch in white wheat bread (WB) induces high ...,759,Starch in white wheat bread (WB) induces high ...,0,Berries reduce postprandial insulin responses ...
4,MED-4637,Fibre and bowel transit times.,1. Bowel transit time has been investigated in...,2927,1. Bowel transit time has been investigated in...,0,Fibre and bowel transit times.\n1. Bowel trans...
...,...,...,...,...,...,...,...
96,MED-4565,The clinical significance of nasal irrigation ...,OBJECTIVES/HYPOTHESIS: This study aimed to ass...,2873,OBJECTIVES/HYPOTHESIS: This study aimed to ass...,0,The clinical significance of nasal irrigation ...
97,MED-1873,Dietary saturated fat intake is negatively ass...,Research finding on the composition of macronu...,929,Research finding on the composition of macronu...,0,Dietary saturated fat intake is negatively ass...
98,MED-4853,"Effects of a very low-fat, vegan diet in subje...",OBJECTIVE: To demonstrate the effects of a ver...,3122,OBJECTIVE: To demonstrate the effects of a ver...,0,"Effects of a very low-fat, vegan diet in subje..."
99,MED-3496,Pro-inflammatory NF-κB and early growth respon...,The widely used food additive carrageenan (CGN...,2025,The widely used food additive carrageenan (CGN...,0,Pro-inflammatory NF-κB and early growth respon...


### 2. Prompt Generation

A prompt serves the purpose of providing context to the LLM for generation. You should modify this prompt as appropriate for your specific domain. 

The default prompt in this example is from the NVIDIA documentation/help page. It provides detailed instructions and provides examples of the types of queries the model should generate. In this prompt we ask the LLM to generate three unique questions for each chunk. 

In [10]:
system_prompt = """You are a data annotator trying to generate three search queries for the Document 2. The generated queries must be answerable from Document 2. Each generated query must be enclosed within the <q> and </q> tags as shown in Example. Only generate the query, do not generate the answer. An example is:\n"""

example = """Example:AV Sync 
Use of an AV Receiver with HDMI for video may result in audio lagging behind video.  First try 
using the receiver AV sync settings to calibrate.  If this does not work, use the AV sync slider 
utility in Settings  > Display & sound > Advanced settings > Audio video sync to calibrate for 
any audio delay.  The AV sync slider allows you to advance audio by 1 second (in small 
increments of 10ms) to synchronize the audio and video. 
Note that this tool is effective only when SHIELD is connected to your AV Receiver over HDMI 
(i.e. audio/video over HDMI); it is not meant to be used when a headset is plugged into SHIELD 
Controller/SHIELD Remote or USB audio device or Bluetooth audio device. 
If video lags behind audio (i.e. audio is ahead of video) then use your AV receiver’s settings to 
delay audio.
ADJUST FOR OVERSCAN

For TVs that don't provide their own overscan settings, use this setting to adjust the picture size to fit the screen.

Go to Settings > Device Preferences > Display & Sound > Advanced Settings > Adjust for overscan to resize the picture on your TV or display.  Use the UP and DOWN d-pad buttons on your remote to maximize the picture on your TV.  Make sure the green triangles are completely visible to avoid overscan.
Generated Queries:
1. <q>How do I adjust the display so that my picture does not go out of the screen?</q>
2. <q>Why is AV Sync not working when I'm plugging my SHIELD into my bluetooth earphone?</q>
3. <q>How many seconds can I delay audio by in AV Sync?</q>
"""
system_prompt += example
system_prompt += "Do not use text from Example to generate queries for Document 2."
print(system_prompt)

You are a data annotator trying to generate three search queries for the Document 2. The generated queries must be answerable from Document 2. Each generated query must be enclosed within the <q> and </q> tags as shown in Example. Only generate the query, do not generate the answer. An example is:
Example:AV Sync 
Use of an AV Receiver with HDMI for video may result in audio lagging behind video.  First try 
using the receiver AV sync settings to calibrate.  If this does not work, use the AV sync slider 
utility in Settings  > Display & sound > Advanced settings > Audio video sync to calibrate for 
any audio delay.  The AV sync slider allows you to advance audio by 1 second (in small 
increments of 10ms) to synchronize the audio and video. 
Note that this tool is effective only when SHIELD is connected to your AV Receiver over HDMI 
(i.e. audio/video over HDMI); it is not meant to be used when a headset is plugged into SHIELD 
Controller/SHIELD Remote or USB audio device or Bluetooth a

In [11]:
# compose the full query that will be sent to the LLM for each chunk of data
def get_prompt(message: str, chat_history: list[tuple[str, str]],
               system_prompt: str) -> str:
    texts = [f'<s>[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n']
    # The first user input is _not_ stripped
    do_strip = False
    for user_input, response in chat_history:
        user_input = user_input.strip() if do_strip else user_input
        do_strip = True
        texts.append(f'{user_input} [/INST] {response.strip()} </s><s>[INST] ') 
    message = message.strip() if do_strip else message
    message = "\nDocument 2:" + message
    texts.append(f'{message} [/INST]')
    return ''.join(texts)

In [12]:
kb["prompt"] = kb["chunk_title_text_concat"].apply(get_prompt, system_prompt=system_prompt, chat_history="")
kb

Unnamed: 0,_id,title,text,paragraph_id,chunk_text,chunk_id,chunk_title_text_concat,prompt
0,MED-5212,Surgical punctal occlusion with a high heat-en...,PURPOSE: To report the rate of recanalization ...,3458,PURPOSE: To report the rate of recanalization ...,0,Surgical punctal occlusion with a high heat-en...,<s>[INST] <<SYS>>\nYou are a data annotator tr...
1,MED-2117,Diet in acne: further evidence for the role of...,Recent evidence underlines the role of Western...,1116,Recent evidence underlines the role of Western...,0,Diet in acne: further evidence for the role of...,<s>[INST] <<SYS>>\nYou are a data annotator tr...
2,MED-5243,Coffee consumption and risk of fractures: a sy...,PURPOSE: The data on the association between c...,3488,PURPOSE: The data on the association between c...,0,Coffee consumption and risk of fractures: a sy...,<s>[INST] <<SYS>>\nYou are a data annotator tr...
3,MED-1676,Berries reduce postprandial insulin responses ...,Starch in white wheat bread (WB) induces high ...,759,Starch in white wheat bread (WB) induces high ...,0,Berries reduce postprandial insulin responses ...,<s>[INST] <<SYS>>\nYou are a data annotator tr...
4,MED-4637,Fibre and bowel transit times.,1. Bowel transit time has been investigated in...,2927,1. Bowel transit time has been investigated in...,0,Fibre and bowel transit times.\n1. Bowel trans...,<s>[INST] <<SYS>>\nYou are a data annotator tr...
...,...,...,...,...,...,...,...,...
96,MED-4565,The clinical significance of nasal irrigation ...,OBJECTIVES/HYPOTHESIS: This study aimed to ass...,2873,OBJECTIVES/HYPOTHESIS: This study aimed to ass...,0,The clinical significance of nasal irrigation ...,<s>[INST] <<SYS>>\nYou are a data annotator tr...
97,MED-1873,Dietary saturated fat intake is negatively ass...,Research finding on the composition of macronu...,929,Research finding on the composition of macronu...,0,Dietary saturated fat intake is negatively ass...,<s>[INST] <<SYS>>\nYou are a data annotator tr...
98,MED-4853,"Effects of a very low-fat, vegan diet in subje...",OBJECTIVE: To demonstrate the effects of a ver...,3122,OBJECTIVE: To demonstrate the effects of a ver...,0,"Effects of a very low-fat, vegan diet in subje...",<s>[INST] <<SYS>>\nYou are a data annotator tr...
99,MED-3496,Pro-inflammatory NF-κB and early growth respon...,The widely used food additive carrageenan (CGN...,2025,The widely used food additive carrageenan (CGN...,0,Pro-inflammatory NF-κB and early growth respon...,<s>[INST] <<SYS>>\nYou are a data annotator tr...


### 3. Synthetic Data Generation

Now we'll use [Nemotron-4-340B-Instruct](https://build.nvidia.com/nvidia/nemotron-4-340b-instruct) from NVIDIA AI Endpoints (www.build.nvidia.com) to generate synthetic data. Make sure you have a valid API key stored as the environment variable NVIDIA_API_KEY, or you can generate one following the link earlier. 

The NVIDIA AI endpoint follows the same schemas as the OpenAI API standard, so we'll go ahead and use the AsyncOpenAI() client in order to asynchronously send many requests to the server. 

In [13]:
texts = kb["prompt"].tolist()
len(texts)

101

In [14]:
from openai import AsyncOpenAI
import asyncio
import nest_asyncio

nest_asyncio.apply()

# If you are using a self-hosted NIM or any other API endpoint, modify base_url and other relevant parameters here.
llm_client = AsyncOpenAI(
    base_url = "https://integrate.api.nvidia.com/v1",
    api_key = os.environ["NVIDIA_API_KEY"]
)

In [15]:
async def generate_response(client, prompt):
    try:
        response = await client.chat.completions.create(
            model="nvidia/nemotron-4-340b-instruct", # specify which model to use
            messages=[{"role": "user", "content": prompt}],
            temperature=0.2,
            top_p=0.7,
            max_tokens=1024
        )

        if hasattr(response, 'choices') and len(response.choices) > 0:
            return response.choices[0].message.content
            
    except Exception as e:
        return f"Error occurred: {str(e)}"
    

async def generate_batch_response(client, all_prompts):
    tasks = [generate_response(client, prompt) for prompt in all_prompts]
    results_list = await asyncio.gather(*tasks)
    return results_list

In [16]:
# test to see that the API endpoint is responding
result = await generate_response(llm_client, texts[0])
print(result)

1. <q>What type of device was used for punctal occlusion surgery in patients with severe dry eye disease and recurrent punctal plug extrusion?</q>
2. <q>How many puncta out of the 70 that underwent thermal cautery recanalized after the surgery?</q>
3. <q>What were the improvements observed in patients with severe dry eye disease three months after punctal occlusion surgery with a high heat-energy-releasing cautery device?</q>



In [17]:
# this could take a while depending on the number of LLM calls
generations = await generate_batch_response(llm_client, texts)

In [18]:
kb["generated_text"] = generations

In [19]:
# It's possible that some requests get dropped for various reasons. Retry them here (WAR)
print("Requests to retry: " + str(len(kb[kb['generated_text'].isna()])))
for idx in kb[kb['generated_text'].isna()].index.tolist(): 
    kb.loc[idx, 'generated_text'] = await generate_response(client, kb.loc[idx, 'prompt'])

Requests to retry: 0


### 4. Parsing Generations

We'll do some simple text parsing to extract the generated queries, then store them as individual entries in the dataset.

In [20]:
def extract_questions_from_generations(kb):
    paragraph_id_question = []
    for row in kb.to_dict(orient='records'):
        paragraph_id = row["paragraph_id"]
        title = row["title"]
        text = row["chunk_text"]
        chunk_id = row["chunk_id"]
        questions = re.findall(r'<q>(.+?)</q>', row["generated_text"])
        print(questions)
        print("-"*200)
        paragraph_id_question += [(paragraph_id, chunk_id, question) for question in questions]
    return pd.DataFrame(paragraph_id_question, columns=["paragraph_id", "chunk_id", "chunk_question"])

In [21]:
extracted_questions = extract_questions_from_generations(kb)

['What type of device was used for punctal occlusion surgery in patients with severe dry eye disease and recurrent punctal plug extrusion?', 'How many puncta out of the 70 that underwent thermal cautery recanalized after the surgery?', 'What were the improvements observed in patients with severe dry eye disease three months after punctal occlusion surgery with a high heat-energy-releasing cautery device?']
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
['What type of diet has been linked to the absence of acne in certain populations, and what are its key characteristics?', 'Which specific components of Western diet have been associated with the pathogenesis and aggravation of acne?', 'How does the consumption of milk and dairy products contribute to the development of acne, and what is the role of the mTORC1 pathway in this process?']

In [22]:
kb = kb.merge(extracted_questions, how="left", left_on=["paragraph_id", "chunk_id"], right_on=["paragraph_id", "chunk_id"])
kb

Unnamed: 0,_id,title,text,paragraph_id,chunk_text,chunk_id,chunk_title_text_concat,prompt,generated_text,chunk_question
0,MED-5212,Surgical punctal occlusion with a high heat-en...,PURPOSE: To report the rate of recanalization ...,3458,PURPOSE: To report the rate of recanalization ...,0,Surgical punctal occlusion with a high heat-en...,<s>[INST] <<SYS>>\nYou are a data annotator tr...,1. <q>What type of device was used for punctal...,What type of device was used for punctal occlu...
1,MED-5212,Surgical punctal occlusion with a high heat-en...,PURPOSE: To report the rate of recanalization ...,3458,PURPOSE: To report the rate of recanalization ...,0,Surgical punctal occlusion with a high heat-en...,<s>[INST] <<SYS>>\nYou are a data annotator tr...,1. <q>What type of device was used for punctal...,How many puncta out of the 70 that underwent t...
2,MED-5212,Surgical punctal occlusion with a high heat-en...,PURPOSE: To report the rate of recanalization ...,3458,PURPOSE: To report the rate of recanalization ...,0,Surgical punctal occlusion with a high heat-en...,<s>[INST] <<SYS>>\nYou are a data annotator tr...,1. <q>What type of device was used for punctal...,What were the improvements observed in patient...
3,MED-2117,Diet in acne: further evidence for the role of...,Recent evidence underlines the role of Western...,1116,Recent evidence underlines the role of Western...,0,Diet in acne: further evidence for the role of...,<s>[INST] <<SYS>>\nYou are a data annotator tr...,1. <q>What type of diet has been linked to the...,What type of diet has been linked to the absen...
4,MED-2117,Diet in acne: further evidence for the role of...,Recent evidence underlines the role of Western...,1116,Recent evidence underlines the role of Western...,0,Diet in acne: further evidence for the role of...,<s>[INST] <<SYS>>\nYou are a data annotator tr...,1. <q>What type of diet has been linked to the...,Which specific components of Western diet have...
...,...,...,...,...,...,...,...,...,...,...
298,MED-3496,Pro-inflammatory NF-κB and early growth respon...,The widely used food additive carrageenan (CGN...,2025,The widely used food additive carrageenan (CGN...,0,Pro-inflammatory NF-κB and early growth respon...,<s>[INST] <<SYS>>\nYou are a data annotator tr...,1. <q>Which food additive has been linked to i...,How does the suppression of NF-κB or EGR-1 imp...
299,MED-3496,Pro-inflammatory NF-κB and early growth respon...,The widely used food additive carrageenan (CGN...,2025,The widely used food additive carrageenan (CGN...,0,Pro-inflammatory NF-κB and early growth respon...,<s>[INST] <<SYS>>\nYou are a data annotator tr...,1. <q>Which food additive has been linked to i...,In the context of carrageenan-induced intestin...
300,MED-3752,Questionnaire survey on use of placebo,Objectives To gauge the frequency and circumst...,2248,Objectives To gauge the frequency and circumst...,0,Questionnaire survey on use of placebo\nObject...,<s>[INST] <<SYS>>\nYou are a data annotator tr...,1. <q>What percentage of the surveyed healthca...,What percentage of the surveyed healthcare pro...
301,MED-3752,Questionnaire survey on use of placebo,Objectives To gauge the frequency and circumst...,2248,Objectives To gauge the frequency and circumst...,0,Questionnaire survey on use of placebo\nObject...,<s>[INST] <<SYS>>\nYou are a data annotator tr...,1. <q>What percentage of the surveyed healthca...,How did the majority of placebo-prescribing pr...


In [23]:
qa_pairs = kb[["chunk_question", "chunk_title_text_concat", "chunk_id", "paragraph_id"]]
qa_pairs.columns = ["question", "positive_chunk", "positive_chunk_id", "paragraph_id"]
qa_pairs

Unnamed: 0,question,positive_chunk,positive_chunk_id,paragraph_id
0,What type of device was used for punctal occlu...,Surgical punctal occlusion with a high heat-en...,0,3458
1,How many puncta out of the 70 that underwent t...,Surgical punctal occlusion with a high heat-en...,0,3458
2,What were the improvements observed in patient...,Surgical punctal occlusion with a high heat-en...,0,3458
3,What type of diet has been linked to the absen...,Diet in acne: further evidence for the role of...,0,1116
4,Which specific components of Western diet have...,Diet in acne: further evidence for the role of...,0,1116
...,...,...,...,...
298,How does the suppression of NF-κB or EGR-1 imp...,Pro-inflammatory NF-κB and early growth respon...,0,2025
299,In the context of carrageenan-induced intestin...,Pro-inflammatory NF-κB and early growth respon...,0,2025
300,What percentage of the surveyed healthcare pro...,Questionnaire survey on use of placebo\nObject...,0,2248
301,How did the majority of placebo-prescribing pr...,Questionnaire survey on use of placebo\nObject...,0,2248


In [24]:
# Optionally save the data here
#GENERATIONS_MODEL_NAME_OR_PATH = "nvidia/nemotron-4-340b-instruct"
#GENERATIONS_SAVE_DIR = "/workspace/files/data"
#GENERATIONS_SAVE_FILENAME =  f"qa_pairs_{GENERATION_MODEL_NAME_OR_PATH}_num_questions_{len(qa_pairs)}_{DOMAIN}"
#GENERATIONS_SAVE_FILENAME = re.sub(r'\W+', '_', GENERATIONS_SAVE_FILENAME)
#GENERATIONS_SAVE_PATH = os.path.join(GENERATIONS_SAVE_DIR, f"{GENERATIONS_SAVE_FILENAME}.csv")

In [25]:
#qa_pairs.to_csv(GENERATIONS_SAVE_PATH, index=None)
#print(f"Generated QA Pairs saved to {GENERATIONS_SAVE_PATH}")

## Mining Hard Negatives

Hard negative mining refers to the creation of negative examples that are 'hard'. Essentially, what this means is that rather than performing random sampling - which would lead to easy negatives - we mine for harder negative examples.

This has an advantage that the negatives would not be obvious to the model during training, and hence would actually be more helpful.

However, hard negative mining has a higher probability of generating false negatives. To avoid this, we set a safety `margin`. This margin is a hyperparameter and you may change it depending on if more false negatives are being generated. For instance, a larger corpus has a higher probability of generating false negatives than a smaller one, as the probability of finding another positive increases. In such cases a lower `margin` value may be more helpful.

#### NV-EmbedQA-E5-V4
To do hard negative mining, we'll need to create embeddings for all of our text chunks using the [NV-EmbedQA-E5-V5](https://build.nvidia.com/nvidia/nv-embedqa-e5-v5) model from www.build.nvidia.com. You can reuse the same NVIDIA_API_KEY as before. This is also the embedding model we will fine-tune in the next part of this tutorial. 

Since the NV-Embedqa-E5-V5 model is quite small, you can also easily host it as self-deployed NIM Docker container following the instructions [here](https://build.nvidia.com/nvidia/nv-embedqa-e5-v5?snippet_tab=Docker). If you already have the model weights downloaded in preparation for fine-tuning, you can also restore the model using NeMo Framework. To do that, simply copy the encode_text() function from Notebook 2 and use it here. 

#### BeIR
BEIR is a heterogeneous benchmark containing diverse IR tasks. It also provides a common and easy framework for evaluation of your NLP-based retrieval models within the benchmark [source](https://github.com/beir-cellar/beir). First we'll do some basic processing so that our synthetic dataset matches the BeIR format. 

In [26]:
passages = OrderedDict()
queries = []
positive_passage_ids = []
for _, row in qa_pairs.iterrows():
    queries.append(row["question"])
    positive_passage_str = row["positive_chunk"]
    if(positive_passage_str in passages):
        positive_passage_id = passages[positive_passage_str]
        positive_passage_ids.append(positive_passage_id)
    else:
        positive_passage_id = len(passages)
        passages[positive_passage_str] = positive_passage_id
        positive_passage_ids.append(positive_passage_id)

In [27]:
queries[12], positive_passage_ids[12]

('What is the relationship between bowel transit time and dietary fiber intake in vegetarians and non-vegetarians?',
 4)

### Generate Embeddings for all Queries and Positive Passages

In [28]:
embedding_client = AsyncOpenAI(
    base_url = "https://integrate.api.nvidia.com/v1",
    api_key = os.environ["NVIDIA_API_KEY"]
)

In [29]:
async def encode_text(client, text, input_type):
    try:
        response = await client.embeddings.create(
            input=[text],
            model="nvidia/nv-embedqa-e5-v5",
            encoding_format="float",
            extra_body={"input_type": input_type, "truncate": "END"}
        )

        if hasattr(response, 'data') and len(response.data) > 0:
            return response.data[0].embedding
            
    except Exception as e:
        return f"Error occurred: {str(e)}"
    

async def batch_encode_text(client, all_texts, input_type):
    tasks = [encode_text(client, text, input_type) for text in all_texts]
    results_list = await asyncio.gather(*tasks)
    return results_list

In [30]:
query_embeddings = await batch_encode_text(embedding_client, [("query: "+query) for query in queries], "query")
passage_embeddings = await batch_encode_text(embedding_client, [("passage: "+passage) for passage in list(passages)], "passage")

### Find Hard Negatives Using Similarity Score

In [31]:
def hard_negative_mining(
        query_embeddings,
        passage_embeddings,
        batch_size,
        margin, 
        num_negs,
        query_positive_paragraph_idxs
):
    hard_negative_idxs = []
    num_batches = int(math.ceil(query_embeddings.shape[0] / batch_size))
    # Split the query embeddings into batches of given batch size
    for current_batch_idx in range(num_batches):
        start = (current_batch_idx)*batch_size
        end = (current_batch_idx+1)*(batch_size)
        batch_query_embeddings = query_embeddings[start:end]
        batch_query_positive_paragraph_idxs = query_positive_paragraph_idxs[start:end]
        
        # Find minimum query-positive_chunk similarity score for each query in a batch
        query_passage_pos_scores = np.matmul(batch_query_embeddings, passage_embeddings.T)

        min_pos_scores = []
        for query_id, row in enumerate(query_passage_pos_scores):
            min_value = float("inf")
            for query_positive_paragraph_idx in query_positive_paragraph_idxs[query_id+start]:
                min_value = min(min_value, row[query_positive_paragraph_idx])
            min_pos_scores.append(min_value)
        min_pos_scores = np.array(min_pos_scores)
            
        # For each query set minimum threshold as margin*minimum_batch_positive_score 
        mining_thresholds = min_pos_scores*margin
        
        # Filter out all chunks belonging to the same paragraph as positive passage OR those manually labelled as positives
        for query_idx, positive_paragraph_idxs in enumerate(batch_query_positive_paragraph_idxs):
            batch_query_idx = query_idx%batch_size
            query_passage_pos_scores[batch_query_idx][positive_paragraph_idxs] = -float("inf")
        
        # Filter out all chunks with score>mining_threshold
        for row_idx in range(query_passage_pos_scores.shape[0]):
            row = query_passage_pos_scores[row_idx]
            row[row>mining_thresholds[row_idx]] = -float("inf")
            
        # For each query get top_k hard negatives from all that remains
        for row in query_passage_pos_scores:
            top_k_hard_negative_idxs = np.argpartition(row, -num_negs)[-num_negs:]
            hard_negative_idxs.append(list(top_k_hard_negative_idxs))
            
    return hard_negative_idxs

In [32]:
# Here we set a margin of 0.95 to prevent false negatives and we mine 5 negative docs (num_negs)
query_embeddings = torch.tensor(query_embeddings).numpy()
passage_embeddings = torch.tensor(passage_embeddings).numpy()

positive_passage_ids_list = [[element] for element in positive_passage_ids]
hard_negative_idxs = hard_negative_mining(query_embeddings=query_embeddings, passage_embeddings=passage_embeddings, query_positive_paragraph_idxs=positive_passage_ids_list,
                    batch_size=32, num_negs=5, margin=0.95)

Use similarity score with the `margin` variable to generate hard negatives. For this example we generate 5 hard negatives, but you can change this number. Ultimately the data will be stored in the following format: 

```
[
    {
        "question": "Query",
        "pos_doc": ["Positive"],
        "neg_doc": ["Negative_1", "Negative_2", ..., "Negative_n"]
    },
    {
        // Next data instance
    },
    ...,
    {
        // Subsequent data instance
    }
]
```

In [33]:
data = []
for query_id, query in enumerate(queries):
    hard_negative_passages = []
    for hard_negative_idx in hard_negative_idxs[query_id]:
        for key, val in passages.items():
            if val == hard_negative_idx:
                hard_negative_passage = key
                hard_negative_passages.append(hard_negative_passage)
    
    for key, val in passages.items():
        if val == positive_passage_ids[query_id]:
            positive_passage = key
            break

    datapoint = {
        "query" : query,
        "pos_doc" : positive_passage,
        "neg_doc" : hard_negative_passages
    }
    data.append(datapoint)

In [34]:
print(len(data))
print(data[0])

303
{'query': 'What type of device was used for punctal occlusion surgery in patients with severe dry eye disease and recurrent punctal plug extrusion?', 'pos_doc': 'Surgical punctal occlusion with a high heat-energy releasing cautery device for severe dry eye with recurrent punctal plug extrusion.\nPURPOSE: To report the rate of recanalization and the efficacy of punctal occlusion surgery with a high heat-energy-releasing cautery device in patients with severe dry eye disease and recurrent punctal plug extrusion. DESIGN: Prospective, interventional case series. METHODS: Seventy puncta from 44 eyes of 28 dry eye patients underwent punctal occlusion with thermal cautery. All patients had a history of recurrent punctal plug extrusion. A high heat-energy-releasing thermal cautery device (Optemp II V; Alcon Japan) was used for punctal occlusion surgery. Symptom scores, best-corrected visual acuity, fluorescein staining score, rose bengal staining score, tear film break-up time, and Schirme

In [35]:
# Save data to JSONL file
OUTPUT_DATA_PATH = "/tmp/data/output_data.jsonl"
output_dir_path = os.path.dirname(OUTPUT_DATA_PATH)
if not os.path.exists(output_dir_path):
    os.mkdir(output_dir_path)
    

print(f"Saving data to: {OUTPUT_DATA_PATH}")

with open(OUTPUT_DATA_PATH, "w") as f:
    for entry in data:
        f.write(json.dumps(entry) + '\n')

Saving data to: /tmp/data/output_data.jsonl


Congratulations, you've now successfully generated a synthetic dataset for Fine-Tuning a text embedding model! In the next notebook you'll use the `.jsonl` file you've just generated to fine-tune NV-EmbedQA-V4 using NeMo Framework. 