# Retrieval-Augmented Generation (RAG) Notebook

## Overview
This notebook demonstrates a Retrieval-Augmented Generation (RAG) pipeline for answering user queries. It combines information retrieval using dense embeddings with text generation to generate accurate and contextually relevant answers based on its medical knowledge source (CSV with question-answer pairs).

## Features
- **Text Preprocessing**: Normalizes and cleans input data (questions) before processing.
- **Knowledge Source Embedding**: Embeds questions and answers into vector space using Dense Passage Retrieval (DPR).
- **Similarity Search**: Retrieves the top-k most relevant questions from the knowledge source based on user queries.
- **Answer Generation**: Generates contextually relevant answers using a large language model (Gemma-2-2B).

## Steps
1. **Load Knowledge Source**: Load a CSV containing question-answer pairs.
2. **Text Preprocessing**: Clean and normalize text before encoding.
3. **Embedding and Indexing**: Use DPR encoder to embed questions and answers, then index them for fast retrieval.
4. **Query Retrieval**: Given a user query, retrieve the top-k most relevant question-answer pairs.
5. **Answer Generation**: Use the Gemma-2-2B model to generate an answer from the retrieved contexts.

## Requirements
- Python 3.x
- PyTorch
- HuggingFace Transformers
- FAISS
- Pandas

## How to Use
1. Place your question-answer CSV in the same directory as the notebook, or specify the path in `knowledge_source_path`.
2. Run the notebook to query and retrieve answers from the knowledge source.


In [None]:
import warnings
warnings.filterwarnings('ignore')


In [None]:
from huggingface_hub import login
login("")

The token has not been saved to the git credentials helper. Pass `add_to_git_credential=True` in this function directly or `--add-to-git-credential` if using via `huggingface-cli` if you want to set the git credential as well.
Token is valid (permission: write).
Your token has been saved to /root/.cache/huggingface/token
Login successful


In [None]:
!pip install pandas transformers faiss-cpu sentence-transformers
!pip install openpyxl
!pip install huggingface_hub

Collecting faiss-cpu
  Downloading faiss_cpu-1.9.0.post1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (4.4 kB)
Collecting sentence-transformers
  Downloading sentence_transformers-3.3.1-py3-none-any.whl.metadata (10 kB)
Downloading faiss_cpu-1.9.0.post1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (27.5 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m27.5/27.5 MB[0m [31m65.0 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[?25hDownloading sentence_transformers-3.3.1-py3-none-any.whl (268 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m268.8/268.8 kB[0m [31m20.1 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: faiss-cpu, sentence-transformers
Successfully installed faiss-cpu-1.9.0.post1 sentence-transformers-3.3.1


In [None]:
import re

import pandas as pd
import faiss
import torch
from transformers import (
    DPRContextEncoder,
    DPRQuestionEncoder,
    DPRContextEncoderTokenizer,
    DPRQuestionEncoderTokenizer
)

# Check if a GPU is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Preprocessing Function
def preprocess_text(text):
    return text.strip().lower()

# Load the knowledge source
def load_knowledge_source(file_path):
    df = pd.read_csv(file_path)
    questions_answers = []
    for _, row in df.iterrows():
        questions_answers.append({
            "Question": preprocess_text(row['Question']),  # Apply preprocessing
            "Answer": row['Answer']
        })
    return questions_answers

# Embed and index the knowledge source
def embed_and_index_knowledge(questions_answers):
    context_encoder = DPRContextEncoder.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base").to(device)
    context_tokenizer = DPRContextEncoderTokenizer.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base")

    embeddings = []
    max_length = 256

    for qa in questions_answers:
        text = preprocess_text(f"Question: {qa['Question']}")  # Only include the question
        inputs = context_tokenizer(
            text,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=max_length
        ).to(device)

        with torch.no_grad():
            embedding = context_encoder(**inputs).pooler_output
            normalized_embedding = torch.nn.functional.normalize(embedding, p=2, dim=1)
            embeddings.append(normalized_embedding.cpu())

    embeddings = torch.cat(embeddings).numpy()
    dim = embeddings.shape[1]
    index = faiss.IndexFlatIP(dim)  # Ensure Inner Product similarity
    index.add(embeddings)

    return context_encoder, context_tokenizer, index

# Retrieve the top-k results
def retrieve_top_k(question_encoder, question_tokenizer, index, query, questions_answers, k=2):
    query = preprocess_text(query)  # Preprocess query

    # Tokenize the query
    inputs = question_tokenizer(
        query,
        return_tensors="pt",
        padding=True,
        truncation=True,
        max_length=256
    ).to(device)

    with torch.no_grad():
        query_embedding = question_encoder(**inputs).pooler_output
        normalized_query_embedding = torch.nn.functional.normalize(query_embedding, p=2, dim=1).cpu().numpy()

    # Retrieve the top k most relevant questions
    distances, indices = index.search(normalized_query_embedding, k)

    # Get the corresponding questions and answers
    retrieved_data = [
        {"Question": questions_answers[i]['Question'], "Answer": questions_answers[i]['Answer']}
        for i in indices[0]
    ]
    return retrieved_data, distances[0]


def extract_diseases_from_query(query):
    """Extract disease options from the query."""
    # Match lines starting with numbers followed by ')'
    pattern = r"\d\)\s([^\n]+)"
    diseases = re.findall(pattern, query)
    return diseases

def generate_sub_queries(diseases):
    """Generate sub-queries for each disease."""
    sub_queries = [f"What are the symptoms of {disease}?" for disease in diseases]
    return sub_queries

def retrieve_for_sub_queries(question_encoder, question_tokenizer, index, sub_queries, questions_answers):
    """Retrieve answers for each sub-query."""
    results = {}
    for sub_query in sub_queries:
        retrieved_data, distances = retrieve_top_k(
            question_encoder,
            question_tokenizer,
            index,
            sub_query,
            questions_answers
        )
        results[sub_query] = retrieved_data
    return results

def main():
    # Load the question encoder model and tokenizer
    question_encoder = DPRQuestionEncoder.from_pretrained(
        "facebook/dpr-question_encoder-single-nq-base"
    ).to(device)
    question_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained(
        "facebook/dpr-question_encoder-single-nq-base"
    )

    # Load the knowledge source (CSV with question-answer pairs)
    knowledge_source_path = "knowledge-source.csv"
    questions_answers = load_knowledge_source(knowledge_source_path)

    # Embed and index the questions
    context_encoder, context_tokenizer, index = embed_and_index_knowledge(questions_answers)

    # Example query
    query = """A patient is presenting with the following symptoms: Back pain, Ache all over, Neck pain. Based on these symptoms, which of the following diseases is the most likely diagnosis?
1) Fibromyalgia
2) Spondylitis
3) Polycystic ovarian syndrome (PCOS)
4) Breast infection (mastitis)"""

    print(f"\nQuery: {query}")

    # Extract diseases from the query
    diseases = extract_diseases_from_query(query)
    print(f"\nDiseases Identified: {diseases}")

    # Generate sub-queries
    sub_queries = generate_sub_queries(diseases)
    print(f"\nGenerated Sub-Queries: {sub_queries}")

    # Retrieve information for each sub-query
    retrieved_results = retrieve_for_sub_queries(
        question_encoder,
        question_tokenizer,
        index,
        sub_queries,
        questions_answers
    )

    # Display the results
    print("\nRetrieved Information for Each Disease:")
    for sub_query, data in retrieved_results.items():
        print(f"\nSub-Query: {sub_query}")
        for i, retrieved in enumerate(data, 1):
            print(f"  {i}. Question: {retrieved['Question']}, Answer: {retrieved['Answer'][:200]}...")
        print("-" * 80)

if __name__ == "__main__":
    main()


Some weights of the model checkpoint at facebook/dpr-question_encoder-single-nq-base were not used when initializing DPRQuestionEncoder: ['question_encoder.bert_model.pooler.dense.bias', 'question_encoder.bert_model.pooler.dense.weight']
- This IS expected if you are initializing DPRQuestionEncoder from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DPRQuestionEncoder from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of the model checkpoint at facebook/dpr-ctx_encoder-single-nq-base were not used when initializing DPRContextEncoder: ['ctx_encoder.bert_model.pooler.dense.bias', 'ctx_encoder.bert_model.pooler.dense.weight']
- This IS expected if you are initializing DPRContextEncoder from the


Query: A patient is presenting with the following symptoms: Back pain, Ache all over, Neck pain. Based on these symptoms, which of the following diseases is the most likely diagnosis?
1) Fibromyalgia
2) Spondylitis
3) Polycystic ovarian syndrome (PCOS)
4) Breast infection (mastitis)

Diseases Identified: ['Fibromyalgia', 'Spondylitis', 'Polycystic ovarian syndrome (PCOS)', 'Breast infection (mastitis)']

Generated Sub-Queries: ['What are the symptoms of Fibromyalgia?', 'What are the symptoms of Spondylitis?', 'What are the symptoms of Polycystic ovarian syndrome (PCOS)?', 'What are the symptoms of Breast infection (mastitis)?']

Retrieved Information for Each Disease:

Sub-Query: What are the symptoms of Fibromyalgia?
  1. Question: what is fibromyalgia?, Answer: Fibromyalgia (FM or FMS) is characterised by chronic widespread pain and allodynia (a heightened and painful response to pressure). Its exact cause is unknown but is believed to involve psychological,...
  2. Question: what a

In [None]:
import re
import pandas as pd
import faiss
import torch
from transformers import (
    DPRContextEncoder,
    DPRQuestionEncoder,
    DPRContextEncoderTokenizer,
    DPRQuestionEncoderTokenizer,
    AutoTokenizer,
    AutoModelForCausalLM
)

# Check if a GPU is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Preprocessing Function
def preprocess_text(text):
    return text.strip().lower()

# Load the knowledge source
def load_knowledge_source(file_path):
    df = pd.read_csv(file_path)
    questions_answers = []
    for _, row in df.iterrows():
        questions_answers.append({
            "Question": preprocess_text(row['Question']),  # Apply preprocessing
            "Answer": row['Answer']
        })
    return questions_answers

# Embed and index the knowledge source
def embed_and_index_knowledge(questions_answers):
    context_encoder = DPRContextEncoder.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base").to(device)
    context_tokenizer = DPRContextEncoderTokenizer.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base")

    embeddings = []
    max_length = 256

    for qa in questions_answers:
        text = preprocess_text(f"Question: {qa['Question']}")  # Only include the question
        inputs = context_tokenizer(
            text,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=max_length
        ).to(device)

        with torch.no_grad():
            embedding = context_encoder(**inputs).pooler_output
            normalized_embedding = torch.nn.functional.normalize(embedding, p=2, dim=1)
            embeddings.append(normalized_embedding.cpu())

    embeddings = torch.cat(embeddings).numpy()
    dim = embeddings.shape[1]
    index = faiss.IndexFlatIP(dim)  # Ensure Inner Product similarity
    index.add(embeddings)

    return context_encoder, context_tokenizer, index

# Retrieve the top-k results
def retrieve_top_k(question_encoder, question_tokenizer, index, query, questions_answers, k=2):
    query = preprocess_text(query)  # Preprocess query

    # Tokenize the query
    inputs = question_tokenizer(
        query,
        return_tensors="pt",
        padding=True,
        truncation=True,
        max_length=256
    ).to(device)

    with torch.no_grad():
        query_embedding = question_encoder(**inputs).pooler_output
        normalized_query_embedding = torch.nn.functional.normalize(query_embedding, p=2, dim=1).cpu().numpy()

    # Retrieve the top k most relevant questions
    distances, indices = index.search(normalized_query_embedding, k)

    # Get the corresponding questions and answers
    retrieved_data = [
        {"Question": questions_answers[i]['Question'], "Answer": questions_answers[i]['Answer']}
        for i in indices[0]
    ]
    return retrieved_data, distances[0]

# Extract diseases from query
def extract_diseases_from_query(query):
    pattern = r"\d\)\s([^\n]+)"
    diseases = re.findall(pattern, query)
    return diseases

# Generate sub-queries for diseases
def generate_sub_queries(diseases):
    sub_queries = [f"What are the symptoms of {disease}?" for disease in diseases]
    return sub_queries

# Retrieve information for sub-queries
def retrieve_for_sub_queries(question_encoder, question_tokenizer, index, sub_queries, questions_answers):
    results = {}
    for sub_query in sub_queries:
        retrieved_data, distances = retrieve_top_k(
            question_encoder,
            question_tokenizer,
            index,
            sub_query,
            questions_answers
        )
        results[sub_query] = retrieved_data
    return results

# Generate an answer based on retrieved contexts
def generate_answer_with_model(main_query, retrieved_results):
    tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b-it", token='')
    model = AutoModelForCausalLM.from_pretrained("google/gemma-2-2b-it", token='').to(device)

    # Combine retrieved contexts
    context_summary = "\n".join([
        f"{i+1}. {retrieved['Answer']}" for i, retrieved in enumerate(sum(retrieved_results.values(), []))
    ])

    prompt = f"""
Below are some relevant documents that might help answer the question.
If the answer can be found in these documents, please provide it.

Retrieved Documents:
{context_summary}

Question: {main_query}

Answer: """

    input_ids = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=1024).to(device)

    outputs = model.generate(
        input_ids=input_ids["input_ids"],
        max_new_tokens=150,
        temperature=0.7,
        top_p=0.95,
        do_sample=True,
        num_return_sequences=1
    )

    generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
    answer_parts = generated_text.split("Answer:")
    if len(answer_parts) > 1:
        return answer_parts[-1].strip()
    return generated_text.strip()

# Main function
def main():
    question_encoder = DPRQuestionEncoder.from_pretrained("facebook/dpr-question_encoder-single-nq-base").to(device)
    question_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained("facebook/dpr-question_encoder-single-nq-base")

    knowledge_source_path = "knowledge-source.csv"
    questions_answers = load_knowledge_source(knowledge_source_path)
    context_encoder, context_tokenizer, index = embed_and_index_knowledge(questions_answers)

    query = """A patient is presenting with the following symptoms: Back pain, Ache all over, Neck pain. Based on these symptoms, which of the following diseases is the most likely diagnosis?
1) Fibromyalgia
2) Chalazion
3) Polycystic ovarian syndrome (PCOS)
4) Breast infection (mastitis)"""

    diseases = extract_diseases_from_query(query)
    sub_queries = generate_sub_queries(diseases)
    retrieved_results = retrieve_for_sub_queries(
        question_encoder,
        question_tokenizer,
        index,
        sub_queries,
        questions_answers
    )

    answer = generate_answer_with_model(query, retrieved_results)
    print(f"\nQuery: {query}")
    print(f"\nAnswer: {answer}")

if __name__ == "__main__":
    main()


Some weights of the model checkpoint at facebook/dpr-question_encoder-single-nq-base were not used when initializing DPRQuestionEncoder: ['question_encoder.bert_model.pooler.dense.bias', 'question_encoder.bert_model.pooler.dense.weight']
- This IS expected if you are initializing DPRQuestionEncoder from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DPRQuestionEncoder from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of the model checkpoint at facebook/dpr-ctx_encoder-single-nq-base were not used when initializing DPRContextEncoder: ['ctx_encoder.bert_model.pooler.dense.bias', 'ctx_encoder.bert_model.pooler.dense.weight']
- This IS expected if you are initializing DPRContextEncoder from the

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]


Query: A patient is presenting with the following symptoms: Back pain, Ache all over, Neck pain. Based on these symptoms, which of the following diseases is the most likely diagnosis?
1) Fibromyalgia
2) Chalazion
3) Polycystic ovarian syndrome (PCOS)
4) Breast infection (mastitis)

Answer: **1) Fibromyalgia**

Explanation:**
The provided text mentions Fibromyalgia as a condition characterized by widespread pain, and allodynia (pain response to pressure). The patient's symptoms of back pain, ache all over, and neck pain are all consistent with the description of fibromyalgia. 


Let me know if you have other questions.


In [None]:
import re
import pandas as pd
import faiss
import torch
from transformers import (
    DPRContextEncoder,
    DPRQuestionEncoder,
    DPRContextEncoderTokenizer,
    DPRQuestionEncoderTokenizer,
    AutoTokenizer,
    AutoModelForCausalLM
)

# Check if a GPU is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Preprocessing Function
def preprocess_text(text):
    return text.strip().lower()

# Load the knowledge source
def load_knowledge_source(file_path):
    df = pd.read_csv(file_path)
    questions_answers = []
    for _, row in df.iterrows():
        questions_answers.append({
            "Question": preprocess_text(row['Question']),  # Apply preprocessing
            "Answer": row['Answer']
        })
    return questions_answers

# Embed and index the knowledge source
def embed_and_index_knowledge(questions_answers):
    context_encoder = DPRContextEncoder.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base").to(device)
    context_tokenizer = DPRContextEncoderTokenizer.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base")

    embeddings = []
    max_length = 256

    for qa in questions_answers:
        text = preprocess_text(f"Question: {qa['Question']}")  # Only include the question
        inputs = context_tokenizer(
            text,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=max_length
        ).to(device)

        with torch.no_grad():
            embedding = context_encoder(**inputs).pooler_output
            normalized_embedding = torch.nn.functional.normalize(embedding, p=2, dim=1)
            embeddings.append(normalized_embedding.cpu())

    embeddings = torch.cat(embeddings).numpy()
    dim = embeddings.shape[1]
    index = faiss.IndexFlatIP(dim)  # Ensure Inner Product similarity
    index.add(embeddings)

    return context_encoder, context_tokenizer, index

# Retrieve the top-k results
def retrieve_top_k(question_encoder, question_tokenizer, index, query, questions_answers, k=2):
    query = preprocess_text(query)  # Preprocess query

    # Tokenize the query
    inputs = question_tokenizer(
        query,
        return_tensors="pt",
        padding=True,
        truncation=True,
        max_length=256
    ).to(device)

    with torch.no_grad():
        query_embedding = question_encoder(**inputs).pooler_output
        normalized_query_embedding = torch.nn.functional.normalize(query_embedding, p=2, dim=1).cpu().numpy()

    # Retrieve the top k most relevant questions
    distances, indices = index.search(normalized_query_embedding, k)

    # Get the corresponding questions and answers
    retrieved_data = [
        {"Question": questions_answers[i]['Question'], "Answer": questions_answers[i]['Answer']}
        for i in indices[0]
    ]
    return retrieved_data, distances[0]

# Extract diseases from query
def extract_diseases_from_query(query):
    pattern = r"\d\)\s([^\n]+)"
    diseases = re.findall(pattern, query)
    return diseases

# Generate sub-queries for diseases
def generate_sub_queries(diseases):
    sub_queries = [f"What are the symptoms of {disease}?" for disease in diseases]
    return sub_queries

# Retrieve information for sub-queries
def retrieve_for_sub_queries(question_encoder, question_tokenizer, index, sub_queries, questions_answers):
    results = {}
    for sub_query in sub_queries:
        retrieved_data, distances = retrieve_top_k(
            question_encoder,
            question_tokenizer,
            index,
            sub_query,
            questions_answers
        )
        results[sub_query] = retrieved_data
    return results

# Generate an answer based on retrieved contexts
def generate_answer_with_model(main_query, retrieved_results):
    tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b-it", token='')
    model = AutoModelForCausalLM.from_pretrained("google/gemma-2-2b-it", token='').to(device)

    # Combine retrieved contexts
    context_summary = "\n".join([
        f"{i+1}. {retrieved['Answer']}" for i, retrieved in enumerate(sum(retrieved_results.values(), []))
    ])

    prompt = f"""
Below are some relevant documents that might help answer the question.
If the answer can be found in these documents, please provide it.

Retrieved Documents:
{context_summary}

Question: {main_query}

Answer: """

    input_ids = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=1024).to(device)

    outputs = model.generate(
        input_ids=input_ids["input_ids"],
        max_new_tokens=150,
        temperature=0.7,
        top_p=0.95,
        do_sample=True,
        num_return_sequences=1
    )

    generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
    answer_parts = generated_text.split("Answer:")
    if len(answer_parts) > 1:
        return answer_parts[-1].strip()
    return generated_text.strip()

# Main function
def main():
    question_encoder = DPRQuestionEncoder.from_pretrained("facebook/dpr-question_encoder-single-nq-base").to(device)
    question_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained("facebook/dpr-question_encoder-single-nq-base")

    knowledge_source_path = "knowledge-source.csv"
    questions_answers = load_knowledge_source(knowledge_source_path)
    context_encoder, context_tokenizer, index = embed_and_index_knowledge(questions_answers)

    query = """A patient is presenting with the following symptoms: Shortness of breath, Sharp chest pain, Palpitations, Dizziness. Based on these symptoms, which of the following diseases is the most likely diagnosis?
1) Astigmatism
2) Pinworm infection
3) Cerebral palsy
4) Atrial fibrillation """

    diseases = extract_diseases_from_query(query)
    sub_queries = generate_sub_queries(diseases)
    retrieved_results = retrieve_for_sub_queries(
        question_encoder,
        question_tokenizer,
        index,
        sub_queries,
        questions_answers
    )

    answer = generate_answer_with_model(query, retrieved_results)
    print(f"\nQuery: {query}")
    print(f"\nAnswer: {answer}")

if __name__ == "__main__":
    main()


Some weights of the model checkpoint at facebook/dpr-question_encoder-single-nq-base were not used when initializing DPRQuestionEncoder: ['question_encoder.bert_model.pooler.dense.bias', 'question_encoder.bert_model.pooler.dense.weight']
- This IS expected if you are initializing DPRQuestionEncoder from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DPRQuestionEncoder from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of the model checkpoint at facebook/dpr-ctx_encoder-single-nq-base were not used when initializing DPRContextEncoder: ['ctx_encoder.bert_model.pooler.dense.bias', 'ctx_encoder.bert_model.pooler.dense.weight']
- This IS expected if you are initializing DPRContextEncoder from the

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]


Query: A patient is presenting with the following symptoms: Shortness of breath, Sharp chest pain, Palpitations, Dizziness. Based on these symptoms, which of the following diseases is the most likely diagnosis?
1) Astigmatism
2) Pinworm infection
3) Cerebral palsy
4) Atrial fibrillation 

Answer: **4) Atrial fibrillation** 

Explanation:** 

The provided symptoms of shortness of breath, sharp chest pain, palpitations, and dizziness are all consistent with atrial fibrillation. 

Here's why:

* **Shortness of breath:** This can be a sign of fluid buildup in the lungs, a potential complication of atrial fibrillation.
* **Sharp chest pain:**  Could be caused by the heart struggling to pump effectively due to atrial fibrillation. 
* **Palpitations:**  A common symptom of atrial fibrillation due to the irregular heart rhythm.
* **Dizziness:** Could be caused by the rapid, irregular heartbeat and blood flow changes. 


Let me know if you have any other questions.


In [None]:
import re
import pandas as pd
import faiss
import torch
from transformers import (
    DPRContextEncoder,
    DPRQuestionEncoder,
    DPRContextEncoderTokenizer,
    DPRQuestionEncoderTokenizer,
    AutoTokenizer,
    AutoModelForCausalLM
)

# Check if a GPU is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Preprocessing Function
def preprocess_text(text):
    return text.strip().lower()

# Load the knowledge source
def load_knowledge_source(file_path):
    df = pd.read_csv(file_path)
    questions_answers = []
    for _, row in df.iterrows():
        questions_answers.append({
            "Question": preprocess_text(row['Question']),  # Apply preprocessing
            "Answer": row['Answer']
        })
    return questions_answers

# Embed and index the knowledge source
def embed_and_index_knowledge(questions_answers):
    context_encoder = DPRContextEncoder.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base").to(device)
    context_tokenizer = DPRContextEncoderTokenizer.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base")

    embeddings = []
    max_length = 256

    for qa in questions_answers:
        text = preprocess_text(f"Question: {qa['Question']}")  # Only include the question
        inputs = context_tokenizer(
            text,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=max_length
        ).to(device)

        with torch.no_grad():
            embedding = context_encoder(**inputs).pooler_output
            normalized_embedding = torch.nn.functional.normalize(embedding, p=2, dim=1)
            embeddings.append(normalized_embedding.cpu())

    embeddings = torch.cat(embeddings).numpy()
    dim = embeddings.shape[1]
    index = faiss.IndexFlatIP(dim)  # Ensure Inner Product similarity
    index.add(embeddings)

    return context_encoder, context_tokenizer, index

# Retrieve the top-k results
def retrieve_top_k(question_encoder, question_tokenizer, index, query, questions_answers, k=2):
    query = preprocess_text(query)  # Preprocess query

    # Tokenize the query
    inputs = question_tokenizer(
        query,
        return_tensors="pt",
        padding=True,
        truncation=True,
        max_length=256
    ).to(device)

    with torch.no_grad():
        query_embedding = question_encoder(**inputs).pooler_output
        normalized_query_embedding = torch.nn.functional.normalize(query_embedding, p=2, dim=1).cpu().numpy()

    # Retrieve the top k most relevant questions
    distances, indices = index.search(normalized_query_embedding, k)

    # Get the corresponding questions and answers
    retrieved_data = [
        {"Question": questions_answers[i]['Question'], "Answer": questions_answers[i]['Answer']}
        for i in indices[0]
    ]
    return retrieved_data, distances[0]

# Generate an answer using a language model
def generate_answer_with_model(question, retrieved_contexts):
    """
    Generate an answer based on the retrieved contexts and the question.

    Args:
        question (str): The original query.
        retrieved_contexts (list): Retrieved documents to use for answering the question.

    Returns:
        str: The generated answer.
    """
    # Initialize the language model and tokenizer
    tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b-it", token='')
    model = AutoModelForCausalLM.from_pretrained("google/gemma-2-2b-it", token='').to(device)

    # Construct the input prompt with retrieved contexts
    prompt = f"""
Below are some relevant documents that might help answer the question.
If the answer can be found in these documents, please provide it.

Retrieved Documents:
1. {retrieved_contexts[0]['Answer']}
2. {retrieved_contexts[1]['Answer']}

Question: {question}

Answer: """

    # Tokenize the input prompt
    input_ids = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=1024).to(device)

    # Generate the answer
    outputs = model.generate(
        input_ids=input_ids["input_ids"],
        max_new_tokens=150,
        temperature=0.7,
        top_p=0.95,
        do_sample=True,
        num_return_sequences=1
    )

    # Decode the output
    generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)

    # Extract the generated answer (after "Answer:")
    answer_parts = generated_text.split("Answer:")
    if len(answer_parts) > 1:
        return answer_parts[-1].strip()
    return generated_text.strip()

def main():
    # Load the question encoder model and tokenizer
    question_encoder = DPRQuestionEncoder.from_pretrained(
        "facebook/dpr-question_encoder-single-nq-base"
    ).to(device)
    question_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained(
        "facebook/dpr-question_encoder-single-nq-base"
    )

    # Load the knowledge source (CSV with question-answer pairs)
    knowledge_source_path = "knowledge-source.csv"
    questions_answers = load_knowledge_source(knowledge_source_path)

    # Embed and index the questions
    context_encoder, context_tokenizer, index = embed_and_index_knowledge(questions_answers)

    # Example query
    query = "help me. what is Trichiasis?"

    # Retrieve information for the query
    retrieved_data, distances = retrieve_top_k(
        question_encoder,
        question_tokenizer,
        index,
        query,
        questions_answers
    )

    # Generate an answer using the retrieved data
    answer = generate_answer_with_model(query, retrieved_data[:2])  # Use top 2 retrieved results

    # Display the results
    print(f"\nQuery: {query}")
    print(f"\nGenerated Answer:\n{answer}")

if __name__ == "__main__":
    main()


Some weights of the model checkpoint at facebook/dpr-question_encoder-single-nq-base were not used when initializing DPRQuestionEncoder: ['question_encoder.bert_model.pooler.dense.bias', 'question_encoder.bert_model.pooler.dense.weight']
- This IS expected if you are initializing DPRQuestionEncoder from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DPRQuestionEncoder from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of the model checkpoint at facebook/dpr-ctx_encoder-single-nq-base were not used when initializing DPRContextEncoder: ['ctx_encoder.bert_model.pooler.dense.bias', 'ctx_encoder.bert_model.pooler.dense.weight']
- This IS expected if you are initializing DPRContextEncoder from the

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]


Query: help me. what is Trichiasis?

Generated Answer:
Trichiasis is a medical term for abnormally positioned eyelashes that grow back toward the eye, touching the cornea or conjunctiva. This can be caused by infection, inflammation, autoimmune conditions, congenital defects, eyelid agenesis and trauma such as burns or eyelid injury. 


**Explanation:**

The answer is directly provided in the document.


In [None]:
import re
import pandas as pd
import faiss
import torch
from transformers import (
    DPRContextEncoder,
    DPRQuestionEncoder,
    DPRContextEncoderTokenizer,
    DPRQuestionEncoderTokenizer,
    AutoTokenizer,
    AutoModelForCausalLM
)

# Check if a GPU is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Preprocessing Function
def preprocess_text(text):
    return text.strip().lower()

# Load the knowledge source
def load_knowledge_source(file_path):
    df = pd.read_csv(file_path)
    questions_answers = []
    for _, row in df.iterrows():
        questions_answers.append({
            "Question": preprocess_text(row['Question']),  # Apply preprocessing
            "Answer": row['Answer']
        })
    return questions_answers

# Embed and index the knowledge source
def embed_and_index_knowledge(questions_answers):
    context_encoder = DPRContextEncoder.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base").to(device)
    context_tokenizer = DPRContextEncoderTokenizer.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base")

    embeddings = []
    max_length = 256

    for qa in questions_answers:
        text = preprocess_text(f"Question: {qa['Question']}")  # Only include the question
        inputs = context_tokenizer(
            text,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=max_length
        ).to(device)

        with torch.no_grad():
            embedding = context_encoder(**inputs).pooler_output
            normalized_embedding = torch.nn.functional.normalize(embedding, p=2, dim=1)
            embeddings.append(normalized_embedding.cpu())

    embeddings = torch.cat(embeddings).numpy()
    dim = embeddings.shape[1]
    index = faiss.IndexFlatIP(dim)  # Ensure Inner Product similarity
    index.add(embeddings)

    return context_encoder, context_tokenizer, index

# Retrieve the top-k results
def retrieve_top_k(question_encoder, question_tokenizer, index, query, questions_answers, k=2):
    query = preprocess_text(query)  # Preprocess query

    # Tokenize the query
    inputs = question_tokenizer(
        query,
        return_tensors="pt",
        padding=True,
        truncation=True,
        max_length=256
    ).to(device)

    with torch.no_grad():
        query_embedding = question_encoder(**inputs).pooler_output
        normalized_query_embedding = torch.nn.functional.normalize(query_embedding, p=2, dim=1).cpu().numpy()

    # Retrieve the top k most relevant questions
    distances, indices = index.search(normalized_query_embedding, k)

    # Get the corresponding questions and answers
    retrieved_data = [
        {"Question": questions_answers[i]['Question'], "Answer": questions_answers[i]['Answer']}
        for i in indices[0]
    ]
    return retrieved_data, distances[0]

# Generate an answer using a language model
def generate_answer_with_model(question, retrieved_contexts):
    """
    Generate an answer based on the retrieved contexts and the question.

    Args:
        question (str): The original query.
        retrieved_contexts (list): Retrieved documents to use for answering the question.

    Returns:
        str: The generated answer.
    """
    # Initialize the language model and tokenizer
    tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b-it", token='')
    model = AutoModelForCausalLM.from_pretrained("google/gemma-2-2b-it", token='').to(device)

    # Construct the input prompt with retrieved contexts
    prompt = f"""
Below are some relevant documents that might help answer the question.
If the answer can be found in these documents, please provide it.


Documents:
{retrieved_contexts[0]['Answer']}
{retrieved_contexts[1]['Answer']}

Question: {question}

Answer:"""

    # Tokenize the input prompt
    input_ids = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=1024).to(device)

    # Generate the answer
    outputs = model.generate(
        input_ids=input_ids["input_ids"],
        max_new_tokens=150,
        temperature=0.5,  # Lower temperature for more deterministic output
        top_p=0.9,
        do_sample=True,
        num_return_sequences=1
    )

    # Decode the output
    generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)

    # Extract the generated answer (after "Answer:")
    answer_parts = generated_text.split("Answer:")
    if len(answer_parts) > 1:
        return answer_parts[-1].strip()
    return generated_text.strip()


def main():
    # Load the question encoder model and tokenizer
    question_encoder = DPRQuestionEncoder.from_pretrained(
        "facebook/dpr-question_encoder-single-nq-base"
    ).to(device)
    question_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained(
        "facebook/dpr-question_encoder-single-nq-base"
    )

    # Load the knowledge source (CSV with question-answer pairs)
    knowledge_source_path = "knowledge-source.csv"
    questions_answers = load_knowledge_source(knowledge_source_path)

    # Embed and index the questions
    context_encoder, context_tokenizer, index = embed_and_index_knowledge(questions_answers)

    # Example query
    query = "What are some of the common medications for Myoclonus?"

    # Retrieve information for the query
    retrieved_data, distances = retrieve_top_k(
        question_encoder,
        question_tokenizer,
        index,
        query,
        questions_answers
    )

    # Generate an answer using the retrieved data
    answer = generate_answer_with_model(query, retrieved_data[:1])  # Use top 2 retrieved results

    # Display the results
    print(f"\nQuery: {query}")
    print(f"\nGenerated Answer:\n{answer}")

if __name__ == "__main__":
    main()


Some weights of the model checkpoint at facebook/dpr-question_encoder-single-nq-base were not used when initializing DPRQuestionEncoder: ['question_encoder.bert_model.pooler.dense.bias', 'question_encoder.bert_model.pooler.dense.weight']
- This IS expected if you are initializing DPRQuestionEncoder from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DPRQuestionEncoder from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of the model checkpoint at facebook/dpr-ctx_encoder-single-nq-base were not used when initializing DPRContextEncoder: ['ctx_encoder.bert_model.pooler.dense.bias', 'ctx_encoder.bert_model.pooler.dense.weight']
- This IS expected if you are initializing DPRContextEncoder from the

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]


Query: What are some of the common medications for Myoclonus?

Generated Answer:
**The following medications are commonly used to treat myoclonus:**

* **Clonazepam**
* **Levetiracetam**
* **Lamotrigine**
* **Oxazepam**
* **Gabapentin** 
* **Ethosuximide** 


**Explanation:**

Myoclonus is a sudden, involuntary muscle contraction.  These medications are generally used to treat various types of epilepsy, but they can also be effective in managing myoclonus.


**Please note:** This is not an exhaustive list, and the best treatment for myoclonus will vary depending on the underlying cause and individual patient factors. Always consult with a healthcare professional for diagnosis and treatment.


In [None]:
import re
import pandas as pd
import faiss
import torch
from transformers import (
    DPRContextEncoder,
    DPRQuestionEncoder,
    DPRContextEncoderTokenizer,
    DPRQuestionEncoderTokenizer,
    AutoTokenizer,
    AutoModelForCausalLM
)

# Check if a GPU is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Preprocessing Function
def preprocess_text(text):
    return text.strip().lower()

# Load the knowledge source
def load_knowledge_source(file_path):
    df = pd.read_csv(file_path)
    questions_answers = []
    for _, row in df.iterrows():
        questions_answers.append({
            "Question": preprocess_text(row['Question']),
            "Answer": row['Answer']
        })
    return questions_answers

# Embed and index the knowledge source
def embed_and_index_knowledge(questions_answers):
    context_encoder = DPRContextEncoder.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base").to(device)
    context_tokenizer = DPRContextEncoderTokenizer.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base")

    embeddings = []
    max_length = 256

    for qa in questions_answers:
        text = preprocess_text(f"Question: {qa['Question']}")
        inputs = context_tokenizer(
            text,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=max_length
        ).to(device)

        with torch.no_grad():
            embedding = context_encoder(**inputs).pooler_output
            normalized_embedding = torch.nn.functional.normalize(embedding, p=2, dim=1)
            embeddings.append(normalized_embedding.cpu())

    embeddings = torch.cat(embeddings).numpy()
    dim = embeddings.shape[1]
    index = faiss.IndexFlatIP(dim)
    index.add(embeddings)

    return context_encoder, context_tokenizer, index

# Retrieve the top-k results
def retrieve_top_k(question_encoder, question_tokenizer, index, query, questions_answers, k=2):
    query = preprocess_text(query)

    inputs = question_tokenizer(
        query,
        return_tensors="pt",
        padding=True,
        truncation=True,
        max_length=256
    ).to(device)

    with torch.no_grad():
        query_embedding = question_encoder(**inputs).pooler_output
        normalized_query_embedding = torch.nn.functional.normalize(query_embedding, p=2, dim=1).cpu().numpy()

    distances, indices = index.search(normalized_query_embedding, k)

    retrieved_data = [
        {"Question": questions_answers[i]['Question'], "Answer": questions_answers[i]['Answer']}
        for i in indices[0]
    ]
    return retrieved_data, distances[0]

# Extract the disease from the query
def extract_disease_from_query(query):
    pattern = r"which person is more likely to have (\w+)\??"
    match = re.search(pattern, query, re.IGNORECASE)
    return match.group(1).strip() if match else None

# Generate sub-query for disease symptoms
def generate_disease_symptoms_query(disease):
    return f"What are the symptoms of {disease}"

# Generate answer using the model
def generate_answer_with_model(main_query, symptoms_context):
    tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b-it", token='')
    model = AutoModelForCausalLM.from_pretrained("google/gemma-2-2b-it", token='').to(device)

    prompt = f"""
Below is the relevant context to answer the question.

Retrieved Context:
{symptoms_context}

Question: {main_query}

Answer: """

    input_ids = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=1024).to(device)

    outputs = model.generate(
        input_ids=input_ids["input_ids"],
        max_new_tokens=150,
        temperature=0.7,
        top_p=0.95,
        do_sample=True,
        num_return_sequences=1
    )

    generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
    answer_parts = generated_text.split("Answer:")
    if len(answer_parts) > 1:
        return answer_parts[-1].strip()
    return generated_text.strip()

# Main function
def main():
    question_encoder = DPRQuestionEncoder.from_pretrained("facebook/dpr-question_encoder-single-nq-base").to(device)
    question_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained("facebook/dpr-question_encoder-single-nq-base")

    knowledge_source_path = "knowledge-source.csv"
    questions_answers = load_knowledge_source(knowledge_source_path)
    context_encoder, context_tokenizer, index = embed_and_index_knowledge(questions_answers)

    query = "Sofia has symptoms such as focal weakness, vaginal pain and pelvic pain. John has symptoms such as pain in eye, diminished vision and spots in vision. Which person is more likely to have Aphakia?"

    disease = extract_disease_from_query(query)
    if not disease:
        print("No disease found in query.")
        return

    disease_query = generate_disease_symptoms_query(disease)
    retrieved_data, _ = retrieve_top_k(question_encoder, question_tokenizer, index, disease_query, questions_answers)

    if not retrieved_data:
        print(f"No information found for disease: {disease}")
        return

    symptoms_context = retrieved_data[0]['Answer']
    answer = generate_answer_with_model(query, symptoms_context)

    print(f"Disease: {disease}")
    print(f"Symptoms of {disease}: {symptoms_context}")
    print(f"Answer: {answer}")

if __name__ == "__main__":
    main()


Some weights of the model checkpoint at facebook/dpr-question_encoder-single-nq-base were not used when initializing DPRQuestionEncoder: ['question_encoder.bert_model.pooler.dense.bias', 'question_encoder.bert_model.pooler.dense.weight']
- This IS expected if you are initializing DPRQuestionEncoder from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DPRQuestionEncoder from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of the model checkpoint at facebook/dpr-ctx_encoder-single-nq-base were not used when initializing DPRContextEncoder: ['ctx_encoder.bert_model.pooler.dense.bias', 'ctx_encoder.bert_model.pooler.dense.weight']
- This IS expected if you are initializing DPRContextEncoder from the

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Disease: Aphakia
Symptoms of Aphakia: Diminished vision, Symptoms of eye, Pain in eye, Eye redness, Itchiness of eye, Spots or clouds in vision, Eye burns or stings, Mass on eyelid
Answer: **John is more likely to have Aphakia.**

**Explanation:**

Aphakia is a condition where the natural lens of the eye is absent or has become significantly impaired. The symptoms you've described are consistent with this condition. 

* **John's symptoms:** Pain in the eye, diminished vision, and spots or clouds in vision all point to a problem with the eye's structure or function.
* **Sofia's symptoms:** While her symptoms are concerning, they are more associated with issues like eye infections, inflammation, or other eye conditions.


**Important Note:**  This is an educated guess based on the provided information. It is crucial to remember that **diagnosing medical conditions requires a professional medical


In [None]:
import re
import pandas as pd
import faiss
import torch
from transformers import (
    DPRContextEncoder,
    DPRQuestionEncoder,
    DPRContextEncoderTokenizer,
    DPRQuestionEncoderTokenizer,
    AutoTokenizer,
    AutoModelForCausalLM
)

# Check if a GPU is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Preprocessing Function
def preprocess_text(text):
    return text.strip().lower()

# Load the knowledge source
def load_knowledge_source(file_path):
    df = pd.read_csv(file_path)
    questions_answers = []
    for _, row in df.iterrows():
        questions_answers.append({
            "Question": preprocess_text(row['Question']),
            "Answer": row['Answer']
        })
    return questions_answers

# Embed and index the knowledge source
def embed_and_index_knowledge(questions_answers):
    context_encoder = DPRContextEncoder.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base").to(device)
    context_tokenizer = DPRContextEncoderTokenizer.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base")

    embeddings = []
    max_length = 256

    for qa in questions_answers:
        text = preprocess_text(f"Question: {qa['Question']}")
        inputs = context_tokenizer(
            text,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=max_length
        ).to(device)

        with torch.no_grad():
            embedding = context_encoder(**inputs).pooler_output
            normalized_embedding = torch.nn.functional.normalize(embedding, p=2, dim=1)
            embeddings.append(normalized_embedding.cpu())

    embeddings = torch.cat(embeddings).numpy()
    dim = embeddings.shape[1]
    index = faiss.IndexFlatIP(dim)
    index.add(embeddings)

    return context_encoder, context_tokenizer, index

# Retrieve the top-k results
def retrieve_top_k(question_encoder, question_tokenizer, index, query, questions_answers, k=2):
    query = preprocess_text(query)

    inputs = question_tokenizer(
        query,
        return_tensors="pt",
        padding=True,
        truncation=True,
        max_length=256
    ).to(device)

    with torch.no_grad():
        query_embedding = question_encoder(**inputs).pooler_output
        normalized_query_embedding = torch.nn.functional.normalize(query_embedding, p=2, dim=1).cpu().numpy()

    distances, indices = index.search(normalized_query_embedding, k)

    retrieved_data = [
        {"Question": questions_answers[i]['Question'], "Answer": questions_answers[i]['Answer']}
        for i in indices[0]
    ]
    return retrieved_data, distances[0]

# Extract the disease from the query
def extract_disease_from_query(query):
    pattern = r"which person is more likely to have ([\w\s]+)\??"
    match = re.search(pattern, query, re.IGNORECASE)
    return match.group(1).strip() if match else None

# Generate sub-query for disease symptoms
def generate_disease_symptoms_query(disease):
    return f"What are the symptoms of {disease}"

# Generate answer using the model
def generate_answer_with_model(main_query, symptoms_context):
    tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b-it", token='')
    model = AutoModelForCausalLM.from_pretrained("google/gemma-2-2b-it", token='').to(device)

    prompt = f"""
Below is the relevant context to answer the question.

Retrieved Context:
{symptoms_context}

Question: {main_query}

Answer: """

    input_ids = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=1024).to(device)

    outputs = model.generate(
        input_ids=input_ids["input_ids"],
        max_new_tokens=100,
        temperature=0.7,
        top_p=0.95,
        do_sample=True,
        num_return_sequences=1
    )

    generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
    answer_parts = generated_text.split("Answer:")
    if len(answer_parts) > 1:
        return answer_parts[-1].strip()
    return generated_text.strip()

# Main function
def main():
    question_encoder = DPRQuestionEncoder.from_pretrained("facebook/dpr-question_encoder-single-nq-base").to(device)
    question_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained("facebook/dpr-question_encoder-single-nq-base")

    knowledge_source_path = "knowledge-source.csv"
    questions_answers = load_knowledge_source(knowledge_source_path)
    context_encoder, context_tokenizer, index = embed_and_index_knowledge(questions_answers)

    query = "Charlotte is a judge and has symptoms such as Knee pain, back pain, ankle pain, finger swelling, wrist swelling and problems with movement. Sally is unemployed and has symptoms such as vomiting, feeling ill, fever, fluid retention, headache and fainting. Which person is more likely to have Juvenile rheumatoid arthritis?"

    disease = extract_disease_from_query(query)
    if not disease:
        print("No disease found in query.")
        return

    disease_query = generate_disease_symptoms_query(disease)
    retrieved_data, _ = retrieve_top_k(question_encoder, question_tokenizer, index, disease_query, questions_answers)

    if not retrieved_data:
        print(f"No information found for disease: {disease}")
        return

    symptoms_context = retrieved_data[0]['Answer']
    answer = generate_answer_with_model(query, symptoms_context)

    print(f"Disease: {disease}")
    print(f"Symptoms of {disease}: {symptoms_context}")
    print(f"Answer: {answer}")

if __name__ == "__main__":
    main()


Some weights of the model checkpoint at facebook/dpr-question_encoder-single-nq-base were not used when initializing DPRQuestionEncoder: ['question_encoder.bert_model.pooler.dense.bias', 'question_encoder.bert_model.pooler.dense.weight']
- This IS expected if you are initializing DPRQuestionEncoder from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DPRQuestionEncoder from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of the model checkpoint at facebook/dpr-ctx_encoder-single-nq-base were not used when initializing DPRContextEncoder: ['ctx_encoder.bert_model.pooler.dense.bias', 'ctx_encoder.bert_model.pooler.dense.weight']
- This IS expected if you are initializing DPRContextEncoder from the

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Disease: Juvenile rheumatoid arthritis
Symptoms of Juvenile rheumatoid arthritis: Juvenile idiopathic arthritis (JIA), also known as juvenile rheumatoid arthritis (JRA), is the most common form of arthritis in children and adolescents. (Juvenile in this context refers to an onset before age 16, idiopathic refers to a condition with no defined cause, and arthritis is the inflammation of the synovium of a joint.) 
Answer: **Charlotte** is more likely to have Juvenile rheumatoid arthritis. 

**Reasoning:**

* **Charlotte's symptoms:** The symptoms Charlotte experiences (knee pain, back pain, ankle pain, finger swelling, wrist swelling, and problems with movement) are common in Juvenile idiopathic arthritis (JIA).
* **Sally's symptoms:** Sally's symptoms (vomiting, feeling ill, fever, fluid retention, headache, and fainting) are more consistent with systemic illnesses or infections
