In [15]:
import fitz  # PyMuPDF
from sentence_transformers import SentenceTransformer, util
import numpy as np
from transformers import MarianMTModel, MarianTokenizer

### Load Pdf and Convert to Text

In [16]:
def load_pdf_text(file_path):
    """
    Extracts text from a PDF file.
    :param file_path: path to the PDF file
    :return: extracted text as a single string
    """
    text = ""
    with fitz.open(file_path) as pdf:
        for page_num in range(pdf.page_count):
            page = pdf[page_num]
            text += page.get_text()
    return text

# Load each PDF file
pdf_files = ["healthy_diet.pdf", "hiv_testing.pdf", "maternal_peripartum_infection.pdf"]
documents = [load_pdf_text(file) for file in pdf_files]

### Document Embedding

In [17]:
# Load SBERT Model and Embed Documents
model = SentenceTransformer('all-MiniLM-L6-v2')
document_embeddings = model.encode(documents, convert_to_tensor=True)

### Translation Model (Optional)
De to En, En to De

In [18]:
# Load translation models for German-to-English and English-to-German
translator_ge_en = MarianMTModel.from_pretrained('Helsinki-NLP/opus-mt-de-en')
translator_ge_en_tokenizer = MarianTokenizer.from_pretrained('Helsinki-NLP/opus-mt-de-en')
translator_en_ge = MarianMTModel.from_pretrained('Helsinki-NLP/opus-mt-en-de')
translator_en_ge_tokenizer = MarianTokenizer.from_pretrained('Helsinki-NLP/opus-mt-en-de')

# Translation functions
def translate_text(text, model, tokenizer):
    inputs = tokenizer(text, return_tensors="pt", padding=True)
    outputs = model.generate(**inputs)
    return tokenizer.decode(outputs[0], skip_special_tokens=True)




### Retrieve Document

In [19]:
# Retrieval Function Using Cosine Similarity
def retrieve_documents(query, document_embeddings, documents, model, top_n=3, lang="en"):
    
    # If the query is in German, translate it to English
    if lang == "de":
        query = translate_text(query, translator_ge_en, translator_ge_en_tokenizer)
    
    # Encode the query
    query_embedding = model.encode(query, convert_to_tensor=True)
    cosine_scores = util.pytorch_cos_sim(query_embedding, document_embeddings).flatten()
    ranked_indices = cosine_scores.argsort(descending=True).cpu().numpy()[:top_n]

    # Retrieve top documents
    results = []
    for idx in ranked_indices:
        text_snippet = documents[idx][:200]  # Limit to first 200 characters for display
        relevance_score = cosine_scores[idx].item()
        if lang == "de":
            text_snippet = translate_text(text_snippet, translator_en_ge, translator_en_ge_tokenizer)
        results.append((idx, text_snippet, relevance_score))
    
    return results

query = "Was ist eine gesunde Ernährung?"
retrieved_docs = retrieve_documents(query, document_embeddings, documents, model, top_n=3, lang="de")

print("\nRetrieved Documents for Query:", query)
for idx, doc, score in retrieved_docs:
    print(f"Document Index: {idx}")
    print(f"Document Snippet: {doc}\nRelevance Score: {score:.4f}\n")


Retrieved Documents for Query: Was ist eine gesunde Ernährung?
Document Index: 0
Document Snippet: Fiskalpolitik zur Förderung gesunder Ernährung WHO-Leitlinie Fiskalpolitik zur Förderung gesunder Ernährung WHO-Leitlinie Fiskalpolitik zur Förderung gesunder Ernährung: WHO-Leitlinie ISBN 978-92-4-009101-6 (electr)
Relevance Score: 0.4006

Document Index: 1
Document Snippet: Konsolidierte Leitlinien für differenzierte HIV-Testdienste Konsolidierte Leitlinien für differenzierte HIV-Testdienste Konsolidierte Leitlinien für differenzierte HIV-Testdienste ISBN 97
Relevance Score: 0.0057

Document Index: 2
Document Snippet: WHO-Empfehlungen zur Prävention und Behandlung von mütterlichen Peripartuminfektionen WHO-Empfehlungen zur Prävention und Behandlung von mütterlichen Peripartuminfektionen WHO-Bibliothek Cataloguing-in-Pu
Relevance Score: 0.0036



### Evaluation

##### Precision is calculated as the proportion of relevant documents among the retrieved documents.
##### Recall is calculated as the proportion of relevant documents that were successfully retrieved.

In [20]:
# Define Evaluation Metrics (Precision, Recall)
def evaluate_retrieval(queries, relevant_docs, document_embeddings, documents, model, top_n=1, lang="en"):
    precision_scores = []
    recall_scores = []

    for i, query in enumerate(queries):
        # Retrieve documents based on query
        retrieved_docs = retrieve_documents(query, document_embeddings, documents, model, top_n=top_n, lang=lang)
        retrieved_indices = [doc[0] for doc in retrieved_docs] 
        relevant_indices = relevant_docs[i]

        # Calculate Precision and Recall
        true_positives = len(set(retrieved_indices) & set(relevant_indices))
        precision = true_positives / len(retrieved_indices) if retrieved_indices else 0
        recall = true_positives / len(relevant_indices) if relevant_indices else 0

        precision_scores.append(precision)
        recall_scores.append(recall)

    # Average precision and recall over all queries
    avg_precision = np.mean(precision_scores)
    avg_recall = np.mean(recall_scores)
    
    return avg_precision, avg_recall

# Example usage for evaluation
queries = [
    "What are the recommended fiscal policies to promote healthy diets?", 
    "How do subsidies influence the consumption of healthy foods?",
    "What are the key considerations for implementing fiscal policies to promote healthy diets?",
    "What are the research gaps identified in fiscal policies for healthy diets?",
    "What are the WHO's recommendations for HIV self-testing (HIVST)?",
    "How does network-based testing help in HIV diagnosis?",
    "What are the quality assurance measures for HIV testing services?",
    "What are the strategic considerations for retesting individuals for HIV?",
    "Among women in the second or third trimester of pregnancy (P), does routine antibiotic prophylaxis (I), compared with no antibiotic prophylaxis (C), prevent infectious morbidities and improve outcomes (O)?",
    "What are the key considerations for maternal peripartum infection prevention?",
]

relevant_docs = [
    [0],  # Relevant to "healthy_diet.pdf"
    [0],
    [0],
    [0],
    [1],  # Relevant to "hiv_testing.pdf"
    [1],
    [1],
    [1],
    [2],  # Relevant to "maternal_peripartum_infection.pdf"
    [2],
]

# Run Evaluation for German Queries
avg_precision, avg_recall = evaluate_retrieval(queries, relevant_docs, document_embeddings, documents, model, top_n=1, lang="de")
print(f"Average Precision: {avg_precision:.4f}")
print(f"Average Recall: {avg_recall:.4f}")



Average Precision: 1.0000
Average Recall: 1.0000


### Improvements

1. Fine tune the model using labeled dataset while freezing some layers of the transformer to retain the pre-trained knowledge
2. For scalable retrieval use a vector database (FAISS)
3. Converting into a RAG system including LLM to handle follow up queries