# Biocreative IX Task 1: MedHopQA

![title](IBMC_logo.png)

## Team Info

Team Name: ***Orekhovichi***

- Rustam R. Taktashov, IBMC
- Nadezhda Yu. Bizyukova, IBMC
- Olga A. Tarasova, IBMC
- Alexander V. Dmitriev, IBMC

## Description
This is a demonstration of the tool used for MedHopQA task. Make sure to follow the repository README file to install the required packages and setup the virtual environment

Some of the notebook cells are interactive. If you don't want it, use separate .py files in the repo directory (**duplicates.py** and **decomp_inference.py**)


## Project Pipeline

<img src="Diagram.png" width="400" height="600">

# Setting up a vector store

## Step 1: Get the wikipedia dump

Download the wikipedia dump in **xml.bz2** format. Make sure you have enough memory since the archived file is around 20 gigabytes in size. We used the 20250420 dump, but you can get a more recent one

**Download Link**
https://dumps.wikimedia.org/enwiki/20250420/



## Step 2: Extract the relevant articles

In our case, we used direct category matching from the prepared **medical_categories.txt**. To create the list of categories, you can either: \
a) Parse the wikipedia [wikipedia tree](https://en.wikipedia.org/wiki/Special:CategoryTree) using wikipedia API \
b) Use a [PetScan](https://petscan.wmcloud.org/) tool 

The code block below follows the option a)


In [None]:
import requests
from tqdm import tqdm
import time

def get_all_subcategories(
    root_category="Medicine",
    max_depth=5,
    checkpoint_interval=20,
    output_file="./categories/medical_categories.txt",
    resume=True
):
    """Fetch all subcategories recursively, with checkpoints."""
    base_url = "https://en.wikipedia.org/w/api.php"
    session = requests.Session()
    visited = set()
    queue = [(root_category, 0)]
    all_categories = set()
    checkpoint_count = 0

    # Resume from existing file
    if resume:
        try:
            with open(output_file, "r") as f:
                existing_cats = {line.strip() for line in f}
                visited.update(existing_cats)
                all_categories.update(existing_cats)
                print(f"Resumed with {len(existing_cats)} pre-loaded categories.")
        except FileNotFoundError:
            pass

    with tqdm(desc="Fetching subcategories") as pbar, \
         open(output_file, "a" if resume else "w") as f_out:

        while queue:
            current_cat, depth = queue.pop(0)
            if current_cat in visited or depth > max_depth:
                continue
            visited.add(current_cat)

            params = {
                "action": "query",
                "list": "categorymembers",
                "cmtitle": f"Category:{current_cat}",
                "cmtype": "subcat",
                "cmlimit": "500",
                "format": "json"
            }

            while True:
                try:
                    response = session.get(base_url, params=params).json()
                    if "query" not in response:
                        break

                    for member in response["query"]["categorymembers"]:
                        subcat = member["title"].replace("Category:", "")
                        if subcat not in visited:
                            all_categories.add(subcat)
                            queue.append((subcat, depth + 1))
                            f_out.write(f"{subcat}\n")
                            f_out.flush()  # Ensure immediate write
                            checkpoint_count += 1
                            pbar.update(1)

                            # Print checkpoint every N categories
                            if checkpoint_count % checkpoint_interval == 0:
                                print(f"\nCheckpoint ({checkpoint_count}): {subcat}")

                    # Pagination
                    if "continue" in response:
                        params["cmcontinue"] = response["continue"]["cmcontinue"]
                    else:
                        break

                except Exception as e:
                    print(f"\nError fetching {current_cat}: {e}")
                    time.sleep(5)  # Rate limit protection
                    break

    return all_categories

# Run
health_subcats = get_all_subcategories(
    root_category="Medicine", #Change the category if you want
    max_depth=4, # Be cautious when increasing or decreasing this value
    checkpoint_interval=20,
    output_file="./categories/surgery_categories.txt",
    resume=True
)
print(f"\nDone! Total categories: {len(health_subcats)}")

The code block below is necessary to extract the articles in **JSONL** 

In [None]:
import re
import json
import bz2
from urllib.parse import quote
import mwparserfromhell
import mwxml
from tqdm import tqdm
import os

# ===== CONFIGURATION =====
UNWANTED_SECTIONS = {
    'references', 'notes', 'citations', 'sources',
    'external links', 'bibliography',
    'see also', 'footnotes', 'works cited'
} #You can add or remove the unwanted sections

WIKIPEDIA_BASE_URL = "https://en.wikipedia.org/wiki/"
CHECKPOINT_INTERVAL = 100000  # Checkpoint at every 100000th article
CHECKPOINT_FILE = "checkpoint.json"

def load_health_categories(file_path: str) -> set[str]:
    """Load health categories from file, one per line, with case-insensitive matching."""
    categories = set()
    with open(file_path, 'r', encoding='utf-8') as f:
        for line in f:
            line = line.strip()
            if line:  # Non-empty lines
                normalized = line.lower().strip()
                if normalized:
                    categories.add(normalized)
    return categories

def get_wikipedia_url(title: str) -> str:
    """Generate Wikipedia URL from title."""
    return WIKIPEDIA_BASE_URL + quote(title.replace(" ", "_"))

def clean_wikitext(wikitext: str) -> str:
    if not wikitext:
        return ""

    try:
        parsed = mwparserfromhell.parse(wikitext)
        
        nodes_to_remove = []
        for node in parsed.nodes:
            if isinstance(node, mwparserfromhell.nodes.template.Template):
                # Only remove infoboxes and similar templates, not all templates
                if node.name.lower().strip().startswith(('infobox', 'cite', 'reflist')):
                    nodes_to_remove.append(node)
            elif isinstance(node, mwparserfromhell.nodes.tag.Tag) and node.tag.lower() == 'ref':
                nodes_to_remove.append(node)
            elif isinstance(node, mwparserfromhell.nodes.wikilink.Wikilink):
                if node.title.lower().startswith(('file:', 'image:')):
                    nodes_to_remove.append(node)
        
        for node in nodes_to_remove:
            parsed.remove(node)
        
        text = parsed.strip_code()
        
        # Clean up
        text = re.sub(r'\[\[([^|\]]*?\|)?([^\]]*?)\]\]', r'\2', text)  # Simplify links
        text = re.sub(r'\{\{.*?\}\}', '', text)  # Remove templates
        text = re.sub(r'<[^>]+>', '', text)  # Remove HTML tags
        text = re.sub(r'\s+', ' ', text)  # Normalize whitespace
        text = re.sub(r'(\s\.)+', '.', text)  # Fix spaced dots
        
        # Remove references section if present
        lines = []
        skip = False
        for line in text.split('\n'):
            line = line.strip()
            if not line:
                continue
                
            # Check for section headers to skip unwanted sections
            if line.startswith('=='):
                section_match = re.match(r'=+\s*(.*?)\s*=+', line)
                if section_match:
                    current_section = section_match.group(1).lower()
                    skip = current_section in UNWANTED_SECTIONS
                    continue
            
            if not skip:
                lines.append(line)
        
        return '\n'.join(lines).strip()
    
    except Exception as e:
        # If parsing initial parsing fails
        text = re.sub(r'\{\{(Infobox|infobox)[^\}]*?\}\}', '', wikitext, flags=re.IGNORECASE | re.DOTALL)
        text = re.sub(r'\[\[(File|Image):.*?\]\]', '', text, flags=re.IGNORECASE)
        text = re.sub(r'<ref[^>]*>.*?</ref>', '', text, flags=re.DOTALL)
        text = re.sub(r'<[^>]+>', '', text)
        text = re.sub(r'\[\[([^|\]]*?\|)?([^\]]*?)\]\]', r'\2', text)
        text = re.sub(r'\{\{.*?\}\}', '', text)
        text = re.sub(r'\[[^\]]+\]', '', text)
        text = re.sub(r'\s+', ' ', text)
        return text.strip()

def save_checkpoint(data):
    """Save checkpoint data to disk."""
    temp_file = f"{CHECKPOINT_FILE}.tmp"
    with open(temp_file, 'w', encoding='utf-8') as f:
        json.dump(data, f, ensure_ascii=False)
    if os.path.exists(CHECKPOINT_FILE):
        os.remove(CHECKPOINT_FILE)
    os.rename(temp_file, CHECKPOINT_FILE)

def load_checkpoint():
    """Load checkpoint data from disk if exists."""
    if os.path.exists(CHECKPOINT_FILE):
        with open(CHECKPOINT_FILE, 'r', encoding='utf-8') as f:
            return json.load(f)
    return None

def process_dump(dump_path: str, output_file: str, health_categories: set[str]):
    """Process Wikipedia dump with category matching."""
    checkpoint = load_checkpoint()
    if checkpoint:
        processed_count = checkpoint['processed_count']
        health_articles_found = checkpoint['health_articles_found']
        print(f"Resuming from checkpoint - previously processed {processed_count} articles")
    else:
        processed_count = 0
        health_articles_found = 0

    output_dir = os.path.dirname(output_file) or "."
    os.makedirs(output_dir, exist_ok=True)

    with (bz2.open(dump_path) as dump_file,
          open(output_file, 'a' if checkpoint else 'w', encoding='utf-8') as out_file):
        
        dump = mwxml.Dump.from_file(dump_file)
        
        for page in tqdm(dump, initial=processed_count, desc="Processing articles"):
            processed_count += 1
            if not isinstance(page, mwxml.Page):
                continue

            try:
                revision = next(iter(page))
                text = revision.text or ""
                categories = set()
                for match in re.finditer(r"\[\[Category:(.*?)(?:\|.*?)?\]\]", text, re.IGNORECASE):
                    category = match.group(1).strip()
                    if category:  # Non-empty categories
                        categories.add(category.lower())
                
                # Debug output
                if processed_count < 10:
                    print(f"\nDebug - Article: {page.title}")
                    print(f"Categories found: {categories}")
                    print(f"Health categories sample: {list(health_categories)[:5]}")
                
                # Check for any overlap between article categories and health categories
                if categories & health_categories:
                    clean_text = clean_wikitext(text)
                    if len(clean_text.split()) >= 50:
                        json.dump({
                            "title": page.title,
                            "url": get_wikipedia_url(page.title),
                            "text": clean_text,
                            "categories": list(categories),
                        }, out_file, ensure_ascii=False)
                        out_file.write("\n")
                        health_articles_found += 1
                        
                        # Debug output for found articles
                        if health_articles_found < 5:
                            print(f"\nFound medical article: {page.title}")
                            print(f"Matching categories: {categories & health_categories}")

                if processed_count % CHECKPOINT_INTERVAL == 0:
                    save_checkpoint({
                        'processed_count': processed_count,
                        'health_articles_found': health_articles_found
                    })
                    print(f"\nCheckpoint saved - Processed: {processed_count}, Found: {health_articles_found}")
                    out_file.flush()

            except Exception as e:
                continue

    if os.path.exists(CHECKPOINT_FILE):
        os.remove(CHECKPOINT_FILE)

    return health_articles_found

def main():
    WIKIPEDIA_DUMP_PATH = "enwiki-[YYYY-MM-DD]-pages-articles-multistream.xml.bz2" #change the [YYYY-MM-DD] for the dump you downloaded
    HEALTH_CATEGORIES_FILE = "medical_categories.txt"
    OUTPUT_FILE = "medical_articles.jsonl"

    print(f"Loading health categories from {HEALTH_CATEGORIES_FILE}...")
    health_categories = load_health_categories(HEALTH_CATEGORIES_FILE)
    
    print(f"Loaded {len(health_categories)} unique health categories (normalized to lowercase).")
    print("Sample of first 5 categories:")
    print("\n".join(sorted(health_categories)[:5]))

    print(f"\nProcessing Wikipedia dump to {OUTPUT_FILE}...")
    total_health_articles = process_dump(WIKIPEDIA_DUMP_PATH, OUTPUT_FILE, health_categories)

    print(f"\nDone! Found {total_health_articles} health articles in total.")
    print(f"All articles saved to {OUTPUT_FILE}")

if __name__ == "__main__":
    main()

We're not done yet, since the output file may contain duplicates. 

To remove the duplicates, you can either launch the cell below or launch **duplicates.py** from the directory


In [None]:
import json
from collections import defaultdict
from typing import Optional
from IPython.display import display
import ipywidgets as widgets

def find_duplicates(filename: str, check_by: str = "title") -> dict:
    """
    Check for duplicate articles in a JSONL file.
    
    Args:
        filename: Path to the JSONL file
        check_by: Field to check for duplicates ("title", "id", or "content")
                  or "full" to compare entire JSON content
    
    Returns:
        Dictionary with duplicate keys and their line numbers
    """
    duplicates = defaultdict(list)
    total_articles = 0

    with open(filename, 'r', encoding='utf-8') as f:
        for line_number, line in enumerate(f, 1):
            total_articles += 1
            try:
                data = json.loads(line)
                
                if check_by == "full":
                    key = json.dumps(data, sort_keys=True)
                else:
                    key = data.get(check_by, None)
                    if key is None:
                        print(f"Warning: Missing '{check_by}' field in line {line_number}")
                        continue
                
                duplicates[key].append(line_number)
                
            except json.JSONDecodeError:
                print(f"! Invalid JSON on line {line_number}")
                continue

    # Filter non-duplicates
    duplicates = {k: v for k, v in duplicates.items() if len(v) > 1}
    
    print(f"\nAnalyzed {total_articles} articles")
    print(f"Found {len(duplicates)} sets of duplicates\n")
    
    for i, (key, lines) in enumerate(duplicates.items(), 1):
        print(f"Duplicate set #{i}:")
        print(f"Key: {key[:100]}{'...' if len(str(key)) > 100 else ''}")
        print(f"Appears on lines: {', '.join(map(str, lines))}\n")
    
    return duplicates

def remove_duplicates(
    input_file: str,
    output_file: Optional[str] = None,
    check_by: str = "title",
    keep: str = "first"
) -> None:
    """
    Remove duplicate entries from a JSONL file.
    
    Args:
        input_file: Path to the input JSONL file
        output_file: Path to save the deduplicated file (None to modify in-place)
        check_by: Field to check for duplicates ("title", "id", "content", or "full")
        keep: Which duplicate to keep ("first" or "last")
    """
    if output_file is None:
        output_file = input_file + ".tmp"
    
    seen_keys = set()
    removed_count = 0
    total_count = 0
    
    with open(input_file, 'r', encoding='utf-8') as infile, \
         open(output_file, 'w', encoding='utf-8') as outfile:
        
        lines = list(infile)  # Read all lines to allow keeping last occurrence
        if keep == "last":
            lines = reversed(lines)
        
        for line in lines:
            total_count += 1
            try:
                data = json.loads(line)
                
                if check_by == "full":
                    key = json.dumps(data, sort_keys=True)
                else:
                    key = data.get(check_by, None)
                    if key is None:
                        # Keep entries with missing key field
                        outfile.write(line)
                        continue
                
                if key not in seen_keys:
                    seen_keys.add(key)
                    outfile.write(line)
                else:
                    removed_count += 1
                    
            except json.JSONDecodeError:
                print(f"! Invalid JSON on line {total_count}, keeping as-is")
                outfile.write(line)
    
    # If we were working with reversed lines, reverse back
    if keep == "last":
        with open(output_file, 'r+', encoding='utf-8') as f:
            lines = f.readlines()
            f.seek(0)
            f.writelines(reversed(lines))
            f.truncate()
    
    # Replace original file if no output file was specified
    if output_file.endswith('.tmp'):
        import os
        os.replace(output_file, input_file)
    
    print(f"\nProcessed {total_count} entries")
    print(f"Removed {removed_count} duplicates")
    print(f"Kept {total_count - removed_count} unique entries")

def run_in_notebook():
    """Interactive version for Jupyter Notebook"""
    # Create widgets
    file_upload = widgets.FileUpload(description="Upload JSONL file", multiple=False)
    check_by = widgets.Dropdown(
        options=['title', 'id', 'content', 'full'],
        value='title',
        description='Check by:'
    )
    action = widgets.RadioButtons(
        options=['Find duplicates', 'Remove duplicates'],
        description='Action:'
    )
    keep = widgets.Dropdown(
        options=['first', 'last'],
        value='first',
        description='Keep:'
    )
    output_file = widgets.Text(
        value='',
        placeholder='output.jsonl (leave blank to overwrite)',
        description='Output file:'
    )
    run_button = widgets.Button(description="Run")
    output = widgets.Output()
    
    # Only show keep widget when removing duplicates
    def update_widgets(change):
        keep.layout.visibility = 'visible' if action.value == 'Remove duplicates' else 'hidden'
        output_file.layout.visibility = 'visible' if action.value == 'Remove duplicates' else 'hidden'
    
    action.observe(update_widgets, names='value')
    update_widgets(None)
    
    def on_run_button_clicked(b):
        with output:
            output.clear_output()
            if not file_upload.value:
                print("Please upload a file first")
                return
                
            # Save uploaded file
            filename = next(iter(file_upload.value))
            with open(filename, 'wb') as f:
                f.write(file_upload.value[filename]['content'])
            
            if action.value == 'Find duplicates':
                find_duplicates(filename, check_by=check_by.value)
            else:
                out_file = output_file.value if output_file.value else None
                remove_duplicates(
                    filename,
                    output_file=out_file,
                    check_by=check_by.value,
                    keep=keep.value
                )
    
    run_button.on_click(on_run_button_clicked)
    
    # Display the interface
    display(widgets.VBox([
        file_upload,
        check_by,
        action,
        keep,
        output_file,
        run_button,
        output
    ]))


run_in_notebook()

## Step 3: Create vector stores

### Chroma Vector Store

We used a Langchain Chroma class to create vector stores. Make sure to use the chromadb version dependency which comes with installing lanchain-chroma.

You can change the embedding function as you see fit. 


In [None]:
from langchain_community.vectorstores import Chroma
from langchain_huggingface import HuggingFaceEmbeddings
import json
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.docstore.document import Document
from tqdm import tqdm
import torch
import os
import gc
from uuid import uuid4
from datetime import datetime
import time
import psutil
from typing import List

# ===== Configuration =====
CHUNK_SIZE = 512 # In characters
CHUNK_OVERLAP = 32 # In characters
EMBEDDING_BATCH_SIZE = 10 # In chunks
MAX_ARTICLE_LENGTH = 50000 # In characters
MAX_CHUNKS_PER_ARTICLE = 100
CHECKPOINT_INTERVAL = 50 # In chunks
LOG_FILE = "processing_log.txt"
PERSIST_DIRECTORY = "./chroma_store"
EMBEDDING_MODEL_PATH = "abhinand/MedEmbed-small-v0.1" # Choose any embedding function you want from HuggingFace
CHECKPOINT_FILE = "processing_checkpoint.txt"  # File to store the last processed line
CHUNKED_INSERT_SIZE = 1000  # Number of documents to insert at once

# ===== Memory Monitoring =====
def memory_safe():
    """Check if we have sufficient memory to continue"""
    mem = psutil.virtual_memory()
    if mem.available < 1 * 1024**3:  # Less than 1GB available
        gc.collect()
        if device == "cuda":
            torch.cuda.empty_cache()
        time.sleep(5)
        mem = psutil.virtual_memory()
        return mem.available >= 1.5 * 1024**3
    return True

# ===== GPU Configuration =====
device = "cuda" if torch.cuda.is_available() else "cpu"
if device == "cuda":
    torch.backends.cudnn.benchmark = True
    torch.cuda.empty_cache()
    os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:512"

# ===== Logging and Checkpointing =====
def log(message):
    with open(LOG_FILE, 'a') as f:
        f.write(f"{time.ctime()}: {message}\n")
    print(message)

def save_checkpoint(line_number):
    """Save the current line number to resume from later"""
    with open(CHECKPOINT_FILE, 'w') as f:
        f.write(str(line_number))

def load_checkpoint():
    """Load the last processed line number, returns 0 if no checkpoint exists"""
    try:
        with open(CHECKPOINT_FILE, 'r') as f:
            return int(f.read().strip())
    except (FileNotFoundError, ValueError):
        return 0

# ===== ChromaDB Setup =====
def get_embeddings():
    return HuggingFaceEmbeddings(
        model_name=EMBEDDING_MODEL_PATH,
        model_kwargs={"device": device, "trust_remote_code": True},
        encode_kwargs={
            "batch_size": EMBEDDING_BATCH_SIZE,
            "normalize_embeddings": True # You can change the encode_kwargs, specifically, the distance metrics. Refer to https://python.langchain.com/api_reference/chroma/index.html#module-langchain_chroma
        }
    )

def split_list(splits, chunk_size):
    """Split a list into smaller chunks for memory-safe processing"""
    for i in range(0, len(splits), chunk_size):
        yield splits[i:i + chunk_size]

# ===== Text Processing =====
text_splitter = RecursiveCharacterTextSplitter(
    chunk_size=CHUNK_SIZE,
    chunk_overlap=CHUNK_OVERLAP,
    separators=["\n\n", "\n", ". ", " ", ""]
)

def process_article(article: dict) -> List[Document]:
    """Process a single article and return Document objects"""
    if not memory_safe():
        raise MemoryError("Insufficient memory before processing article")
        
    try:
        text = article.get("text", "")[:MAX_ARTICLE_LENGTH]
        if not text.strip():
            return []
            
        chunks = text_splitter.split_text(text)[:MAX_CHUNKS_PER_ARTICLE]
        if not chunks:
            return []
            
        # Prepare documents with metadata
        documents = []
        for i, chunk in enumerate(chunks):
            categories = article.get("categories", [])
            categories_str = "|".join(str(cat) for cat in categories)[:1000]
            documents.append(Document(
                page_content=chunk,
                metadata={
                    "title": article.get("title", "Untitled")[:200],
                    "url": article.get("url", ""),
                    "categories": categories_str
                }
            ))
            
        return documents
            
    except Exception as e:
        log(f"Article processing error: {str(e)}")
        return []

def process_file(file_path: str):
    """Main processing loop with robust error handling and checkpointing"""
    embeddings = get_embeddings()
    total_chunks = 0
    processed_articles = 0
    accumulated_docs = []
    
    # Load the last checkpoint
    start_line = load_checkpoint()
    if start_line > 0:
        log(f"Resuming from line {start_line}")
    
    try:
        with open(file_path, 'r') as f:
            for _ in range(start_line):
                next(f)
            
            pbar = tqdm(total=os.path.getsize(file_path), desc="Processing articles", unit='B', unit_scale=True)
            pbar.update(f.tell())
            
            for line in f:
                try:
                    if not memory_safe():
                        raise MemoryError("Insufficient system memory")
                        
                    article = json.loads(line)
                    documents = process_article(article)
                    if documents:
                        accumulated_docs.extend(documents)
                        total_chunks += len(documents)
                    processed_articles += 1
                    
                    current_line = start_line + processed_articles
                    save_checkpoint(current_line)
                    
                    pbar.update(len(line.encode('utf-8')))
                    
                    # Process in chunks
                    if len(accumulated_docs) >= CHUNKED_INSERT_SIZE:
                        for docs_chunk in split_list(accumulated_docs, CHUNKED_INSERT_SIZE):
                            _ = Chroma.from_documents(
                                documents=docs_chunk,
                                embedding=embeddings,
                                persist_directory=PERSIST_DIRECTORY
                            )
                            torch.cuda.empty_cache()
                            gc.collect()
                        accumulated_docs = []
                    
                    if processed_articles % CHECKPOINT_INTERVAL == 0:
                        log(f"Checkpoint: Processed {current_line} lines, {total_chunks} chunks total")
                        gc.collect()
                        if device == "cuda":
                            torch.cuda.empty_cache()
                            
                except json.JSONDecodeError:
                    log(f"JSON decode error at line {current_line}")
                    continue
                except MemoryError as e:
                    log(f"Memory error at line {current_line}: {str(e)}")
                    time.sleep(30)  # Longer sleep for memory issues
                    # Rewind the file pointer to retry the same line
                    f.seek(pbar.n - len(line.encode('utf-8')))
                    continue
                except Exception as e:
                    log(f"Error at line {current_line}: {str(e)}")
                    continue
                    
        # Process any remaining documents
        if accumulated_docs:
            for docs_chunk in split_list(accumulated_docs, CHUNKED_INSERT_SIZE):
                _ = Chroma.from_documents(
                    documents=docs_chunk,
                    embedding=embeddings,
                    persist_directory=PERSIST_DIRECTORY
                )
                torch.cuda.empty_cache()
                gc.collect()
                    
    except Exception as e:
        log(f"Fatal error: {str(e)}")
        raise  # Re-raise to exit the program
    finally:
        pbar.close()
        log(f"Processing completed up to line {start_line + processed_articles} with {total_chunks} chunks")

if __name__ == "__main__":
    log(f"Starting processing on {device}")
    if device == "cuda":
        log(f"GPU: {torch.cuda.get_device_name(0)}")
        log(f"Available VRAM: {torch.cuda.get_device_properties(0).total_memory/1024**3:.2f} GB")
    
    start_time = time.time()
    
    try:
        process_file("medical_articles.jsonl")
        # If we complete successfully, remove the checkpoint file
        if os.path.exists(CHECKPOINT_FILE):
            os.remove(CHECKPOINT_FILE)
        log("\nProcessing completed successfully!")
    except Exception as e:
        log(f"\nProcessing interrupted due to: {str(e)}")
        log("The program can be restarted to resume from the last checkpoint")
    
    log(f"Time elapsed: {(time.time() - start_time)/60:.2f} minutes")

### BM25S Store
We use custom BM25S implementation and we also do store disk persistance.

In [None]:
from custom_langchain.retrievers import BM25SRetriever 
from langchain_community.document_loaders import JSONLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from pprint import pprint

def metadata_func(record: str, metadata: str) -> str:
    metadata['title'] = record.get('title')
    metadata['url'] = record.get('url')
    metadata['categories'] = record.get('categories')
    return metadata

loader = JSONLoader(
    file_path="medical_articles.jsonl",
    jq_schema=".",
    content_key="text",
    metadata_func=metadata_func,
    json_lines=True
)

docs = loader.load()

for doc in docs:
    if 'source' in doc.metadata and 'seq_num' in doc.metadata:
        del doc.metadata['source']
        del doc.metadata['seq_num']

text_splitter = RecursiveCharacterTextSplitter(
    chunk_size=512, # In characters
    chunk_overlap=32, # In characters
    separators=["\n\n", "\n", ". ", " ", ""]
)

split_docs = text_splitter.split_documents(docs)

# Extract just the page_content from each Document to get a list of strings
text_contents = [doc.page_content for doc in split_docs]
metadata = [doc.metadata for doc in split_docs]

retriever = BM25SRetriever.from_texts(text_contents, metadata, k=2, persist_directory='bm25s_store')

Done, we now have two stores. Next, during retrieval and generation steps, we can use 2 separate (dense and sparse) retrievers

## Step 4: Retrieve



You can either launch the cell below or launch **decomp_inference.py** from the repo directory. 

In [None]:
from langchain_ollama import ChatOllama
from pydantic import BaseModel
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.document_loaders import CSVLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.embeddings import SentenceTransformerEmbeddings
from langchain.prompts import PromptTemplate
from langchain.chains.combine_documents.stuff import StuffDocumentsChain
from langchain.chains import RetrievalQA
from langchain.chains import LLMChain
from langchain.chains import RetrievalQAWithSourcesChain
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain.callbacks.manager import CallbackManager
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from langchain.callbacks import StdOutCallbackHandler
from langchain.retrievers import ContextualCompressionRetriever
from langchain.retrievers.document_compressors import FlashrankRerank
from custom_langchain.retrievers import BM25SRetriever 
from langchain.retrievers import EnsembleRetriever
from chromadb.utils.embedding_functions import SentenceTransformerEmbeddingFunction
from langchain_core.prompts import ChatPromptTemplate
from langchain_ollama import ChatOllama
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import List
from langchain_chroma import Chroma
from langchain_core.documents import Document
from typing import Dict, Any
import gc
import warnings
from tqdm import tqdm
import torch
import pandas as pd
import csv
import os
from time import sleep
from IPython.display import display
import ipywidgets as widgets

# ==== CONFIG ====
device = torch.device("cuda")
warnings.filterwarnings("ignore")
callbacks = CallbackManager([StreamingStdOutCallbackHandler()])
handler = StdOutCallbackHandler()

# CSV output configuration
CHECKPOINT_INTERVAL = 2
RETRY_DELAY = 5 
MAX_RETRIES = 3
OUTPUT_CSV_PATH = "./qa_results.csv"

# ==== DIRECTORIES ====
core_embeddings_model = HuggingFaceEmbeddings(
    model_name="abhinand/MedEmbed-small-v0.1", # Embedding function should be the same as the embedding function used to create vector store
    model_kwargs={
        'device': "cpu", 
        'trust_remote_code': True,
    },
    encode_kwargs={
        'normalize_embeddings': True,
        'batch_size': 8 # Distance metric should be the same as the one used in vector store creation
    }
)

persist_directory = "./chroma_store"
vectordb = Chroma(persist_directory=persist_directory, embedding_function=core_embeddings_model)

llm = ChatOllama(
    base_url = "http://localhost:11434",
    model="thewindmom/llama3-med42-8b", 
    timeout=300, 
    temperature = 0.0,
    disable_streaming = True,
    num_ctx=8192
)

# ==== PROMPT TEMPLATES ====
prompt_template = """Context information is below.
---------------------
{context}
---------------------
You are an expert in medicine, molecular biology and biochemistry. Answer the question below based strictly on the context above and using common sense. 
If the answer cannot be found in the context, say "I couldn't find a definitive answer in my sources."
For complex questions, break them down into logical sub-questions.
Query: {question}
Answer: """

QA_CHAIN_PROMPT = PromptTemplate.from_template(prompt_template)

DECOMPOSITION_PROMPT = ChatPromptTemplate.from_template("""
Break down this medical question into simple, factual sub-questions that can be answered independently from medical literature and be used to answer the main question.
Each sub-question should:
1. Be answerable with a specific fact or short answer
2. Build logically toward answering the main question
3. Use clear medical terminology. Don't turn "diagnostic procedure" into "test"
4. Duplicate the initial question in the numbered list of sub-questions
5. Don't make the questions redundant
6. Always make no more than 4 subquestions
7. NEVER loose context (e.g. never make "Which chromosome does this disorder primarily affect?" for "Which gene is associated with a rare genetic disorder characterized by bilateral congenital hearing loss and brain malformations?", since you will loose "brain malformations" with "bilateral congenital hearing loss")
8. NEVER make the sub-quesitons more complex than the initial question

Output ONLY the sub-questions as a numbered list, nothing else.

Question: {question}
Sub-questions:
""")

COMPOSITION_PROMPT = ChatPromptTemplate.from_template("""
Combine these answers to sub-questions into a coherent final answer.
Be precise and cite sources when available using the provided source references.

Sub-question Answers:
{intermediate_answers}

Source References:
{source_references}

Original question: {main_question}

ALWAYS Extract a concise SINGLE and SHORT final answer (1-3 word long) following these rules:
1. Use complete formal names (e.g., "Diabetes Mellitus, type 2")
2. For chromosome questions: format as "Chromosome X"
3. For syndromes named after people: use full name (e.g., "Carpenter's syndrome")
4. For yes/no: answer only "Yes" or "No"
5. For true/false: answer only "TRUE" or "FALSE"
6. Pick a single answer without writing synonyms (e.g., either write "PID" or "Pelvic Inflammatory Disease", NOT "PID (Pelvic Inflammatory Disease)")
7. Don't overthink or overengineer the answers (e.g., write "Cardiology", NOT "Adult Congenital Cardiology")
8. Make sure the answers are not recursive (e.g. for question "What is the primary cause of the physical changes observed in males with Klinefelter syndrome during puberty?" the answer should be not "Klinefelter Syndrome", but the cause of it, like "X Chromosome")
9. For drug questions: don't mention drug form (e.g., "Calamine" but NOT "Calamine lotion")
11. Answers like "Indirectly" or "Probably" are forbidden.
10. If you don't know the answer, just write "N/A" and nothing else
11. For protein questions: just give a protein name (e.g., "Myelin", but not "Myelin protein")

Structure your final answer as follows:

SHORT ANSWER: [your 1-3 word answer here]

DETAILED ANSWER:
[your detailed explanation here, citing sources like [1], [2] where appropriate]

SOURCE REFERENCES:
{source_references}
""")

# ==== HELPER FUNCTIONS ====
def format_source_references(sources: list[dict]) -> str:
    """Format sources with numbering for citation in the answer"""
    source_refs = []
    for idx, source in enumerate(sources, 1):
        if 'wikipedia.org' in source['source_url']:
            source_refs.append(
                f"[{idx}] {source['source_title']} - {source['source_url']}"
            )
    return "\n".join(source_refs) if source_refs else ""

def improve_retrieval(query: str, is_subquestion: bool = False, main_question: str = "") -> str:
    """Optimize the query for better retrieval while avoiding circular references"""
    if is_subquestion and main_question:
        optimization_prompt = ChatPromptTemplate.from_template("""
You are a medical question optimizer. Optimize this specific medical sub-question for document retrieval but never answer them (e.g. "Which medical specialty is likely involved in the diagnosis and treatment of angiolipomas?" is not "Dermatology involvement in angiolipoma diagnosis and management" but "Angliolipoma medical specialty")
Focus only on the specific aspect asked about. Don't incorporate the main question.
NEVER lose context (e.g. "Which medical subspecialty primarily focuses on the diagnosis and management of skin lesions?" is not "Dermatology" but "Dermatology skin lesions medical specialty")

Don't explain your reasoning.

Sub-question: {query}
Optimized:""")
    else:
        optimization_prompt = ChatPromptTemplate.from_template("""
Optimize this medical question for better document retrieval. 
Don't explain your reasoning.

Original: {query}
Optimized:""")
    
    optimizer = optimization_prompt | llm | StrOutputParser()
    return optimizer.invoke({"query": query})

def format_document(doc: Document) -> str:
    """Custom document formatter that safely handles metadata"""
    base_content = f"Content: {doc.page_content}\n"
    
    if not hasattr(doc, 'metadata') or not doc.metadata:
        return base_content
    
    metadata = doc.metadata
    title = metadata.get('title', 'N/A')
    url = metadata.get('url', 'N/A')
    
    return f"Content: {doc.page_content[:500]}\nSource: {doc.metadata.get('title','')}\n"

def extract_short_answer(long_answer: str) -> str:
    """Extract the short answer from the long answer"""
    # Look for the SHORT ANSWER: pattern
    if "SHORT ANSWER:" in long_answer:
        return long_answer.split("SHORT ANSWER:")[1].split("\n")[0].strip()
    # Fallback to first line that meets criteria
    for line in long_answer.split('\n'):
        line = line.strip()
        if line and not line.startswith('[') and len(line.split()) <= 3:
            return line
    return "N/A"

def write_to_csv(qidx: str, question: str, short_answer: str, long_answer: str, file_path: str):
    """Write or append results to CSV file"""
    file_exists = os.path.isfile(file_path)
    
    with open(file_path, mode='a', newline='', encoding='utf-8') as csvfile:
        fieldnames = ['QIDX', 'Question', 'Short Answer', 'Long Answer']
        writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
        
        if not file_exists:
            writer.writeheader()
        
        writer.writerow({
            'QIDX': qidx,
            'Question': question,
            'Short Answer': short_answer,
            'Long Answer': long_answer
        })

# ==== CHAIN SETUP ====
class CustomStuffDocumentsChain(StuffDocumentsChain):
    def _get_inputs(self, docs, **kwargs):
        doc_strings = [format_document(doc) for doc in docs]
        return {**{self.document_variable_name: "\n\n".join(doc_strings)}, **kwargs}

llm_chain = LLMChain(llm=llm, prompt=QA_CHAIN_PROMPT, callbacks=None, verbose=False)

vectorstore_retriever = vectordb.as_retriever(
    search_type="similarity", # You can use different types of search, refer to https://python.langchain.com/api_reference/chroma/index.html#module-langchain_chroma
    search_kwargs={
        "k": 4 # You can use different numbers of top-k retrieved contexts
    }
)

combine_documents_chain = CustomStuffDocumentsChain(
    llm_chain=llm_chain,
    document_variable_name="context",
    callbacks=callbacks,
)

keyword_retriever = BM25SRetriever.from_persisted_directory("bm25s_store", k=4)

ensemble_retriever = EnsembleRetriever(
    retrievers=[vectorstore_retriever, keyword_retriever],
    weights=[0.6, 0.4] # You can use different weights for each retriever
)

qa = RetrievalQA(
    combine_documents_chain=combine_documents_chain,
    retriever=ensemble_retriever,
    verbose=False,
    return_source_documents=True
)

# ==== MAIN PROCESSING FUNCTIONS ====
def decompose_question(question: str) -> List[str]:
    """Decomposes a complex question into sub-questions while keeping the original question"""
    decomposition_chain = DECOMPOSITION_PROMPT | llm | StrOutputParser()
    result = decomposition_chain.invoke({"question": question})
    
    sub_questions = []
    for line in result.split('\n'):
        line = line.strip()
        if line and line[0].isdigit() and '. ' in line:
            sub_questions.append(line.split('. ', 1)[1])
    
    if question not in sub_questions:
        sub_questions.insert(0, question)
    
    return sub_questions

def format_for_json(docs: list[Document]) -> list[dict]:
    formatted = []
    for doc in docs:
        metadata = getattr(doc, 'metadata', {})
        formatted.append({
            "content": doc.page_content,
            "source_title": metadata.get('title', 'N/A'),
            "source_url": metadata.get('url', 'N/A')
        })
    return formatted

def process_question(question: str, qidx: str = "0"):
    try:
        # Parallelize sub-question processing
        from concurrent.futures import ThreadPoolExecutor
        
        sub_questions = decompose_question(question)
        print(f"\nDecomposed into sub-questions:")
        for i, q in enumerate(sub_questions, 1):
            print(f"{i}. {q}")
            
        intermediate_answers = {}
        all_sources = []
        source_map = {} 
        
        # Process questions in parallel
        def process_subquestion(q, i):
            is_subquestion = i > 0
            optimized_query = improve_retrieval(q, is_subquestion, sub_questions[0])
            print(f"Optimized query: {optimized_query}")
            
            qa_result = qa.invoke({"query": optimized_query[:512]})
            return {
                "idx": i,
                "question": q,
                "optimized_query": optimized_query,
                "answer": qa_result['result'],
                "sources": format_for_json(qa_result['source_documents'][:2])  # Reduced from 3
            }
        
        with ThreadPoolExecutor(max_workers=3) as executor:
            futures = [executor.submit(process_subquestion, q, i) 
                      for i, q in enumerate(sub_questions)]
            for future in futures:
                result = future.result()
                i = result["idx"]
                intermediate_answers[f"q{i}"] = {
                    "question": result["question"],
                    "optimized_query": result["optimized_query"],
                    "answer": result["answer"],
                    "sources": result["sources"]
                }
                source_map[f"q{i}"] = result["sources"]
                all_sources.extend(result["sources"])
        
        # Prepare source references
        source_references = format_source_references(all_sources[:10])
        
        # Generate final answer
        composition_chain = COMPOSITION_PROMPT | llm | StrOutputParser()
        final_answer = composition_chain.invoke({
            "intermediate_answers": "\n".join(
                f"Q{i}: {v['question']}\nA: {v['answer']}" 
                for i, v in enumerate(intermediate_answers.values(), 1)
            ),
            "source_references": source_references,
            "main_question": question
        })
        
        # Ensure sources are included even if LLM didn't add them
        if "SOURCE REFERENCES:" not in final_answer:
            final_answer += f"\n\nSOURCE REFERENCES:\n{source_references}"
        
        # Extract short answer
        short_answer = extract_short_answer(final_answer)
        
        # Write to CSV
        write_to_csv(qidx, question, short_answer, final_answer, OUTPUT_CSV_PATH)
        
        return {
            "Short_Answer": short_answer,
            "Long_Answer": final_answer,
            "Sources": all_sources[:10],
            "Intermediate_steps": intermediate_answers
        }
            
    except Exception as e:
        print(f"Error during QA processing: {str(e)}")
        write_to_csv(qidx, question, f"ERROR: {str(e)}", "", OUTPUT_CSV_PATH)
        return {
            "Short_Answer": f"ERROR: {str(e)}",
            "Long_Answer": "",
            "Sources": [],
            "Intermediate_steps": {}
        }

# ==== INTERACTIVE NOTEBOOK INTERFACE ====
def notebook_interface():
    """Create an interactive interface for Jupyter Notebook"""
    # Create widgets
    mode = widgets.RadioButtons(
        options=['Single Question', 'Batch from CSV'],
        description='Mode:',
        disabled=False
    )
    
    question_input = widgets.Textarea(
        value='',
        placeholder='Enter your medical question here...',
        description='Question:',
        disabled=False,
        layout={'width': '80%', 'height': '100px'}
    )
    
    file_upload = widgets.FileUpload(
        description='Upload CSV',
        multiple=False,
        accept='.csv',
        disabled=False
    )
    
    output_csv = widgets.Text(
        value=OUTPUT_CSV_PATH,
        placeholder='output.csv',
        description='Output CSV:',
        disabled=False
    )
    
    run_button = widgets.Button(description="Run")
    output_area = widgets.Output()
    
    # Show/hide widgets based on mode
    def update_widgets(change):
        if change['new'] == 'Single Question':
            question_input.layout.display = 'flex'
            file_upload.layout.display = 'none'
        else:
            question_input.layout.display = 'none'
            file_upload.layout.display = 'flex'
    
    mode.observe(update_widgets, names='value')
    update_widgets({'new': mode.value})
    
    # Handle button click
    def on_run_button_clicked(b):
        with output_area:
            output_area.clear_output()
            
            if mode.value == 'Single Question':
                if not question_input.value.strip():
                    print("Please enter a question")
                    return
                
                print("Processing question...")
                result = process_question(question_input.value.strip(), "1")
                
                print("\n" + "="*80)
                print("FINAL ANSWER:")
                print("-"*80)
                print(result["Long_Answer"])
                
                print("\n" + "="*80)
                print("INTERMEDIATE STEPS:")
                print("-"*80)
                for step in result.get("Intermediate_steps", {}).values():
                    print(f"\nQ: {step['question']}")
                    print(f"A: {step['answer']}")
                
                print("\n" + "="*80)
                print("SOURCE ATTRIBUTION:")
                print("-"*80)
                for i, source in enumerate(result.get("Sources", [])[:5], 1):
                    print(f"\n[{i}] {source.get('source_title', 'N/A')}")
                    print(f"URL: {source.get('source_url', 'N/A')}")
                    print(f"Content: {source.get('content', '')[:200]}...")
            
            else:  # Batch mode
                if not file_upload.value:
                    print("Please upload a CSV file")
                    return
                
                # Save uploaded file
                filename = next(iter(file_upload.value))
                with open(filename, 'wb') as f:
                    f.write(file_upload.value[filename]['content'])
                
                # Process questions
                global OUTPUT_CSV_PATH
                OUTPUT_CSV_PATH = output_csv.value if output_csv.value else OUTPUT_CSV_PATH
                
                try:
                    df = pd.read_csv(filename)
                    if len(df) == 0:
                        print("Input CSV is empty")
                        return
                except Exception as e:
                    print(f"Error reading input CSV: {str(e)}")
                    return

                # Initialize processed questions tracking
                processed_qidx = set()
                if os.path.exists(OUTPUT_CSV_PATH):
                    try:
                        existing_df = pd.read_csv(OUTPUT_CSV_PATH, usecols=['QIDX'])
                        processed_qidx = set(existing_df['QIDX'].astype(str).unique())
                        print(f"Resuming processing with {len(processed_qidx)} already completed questions")
                    except Exception as e:
                        print(f"Warning: Could not read existing output file - {str(e)}")

                # Create list of unprocessed questions
                unprocessed = []
                for _, row in df.iterrows():
                    qidx = str(row['QIDX']) if 'QIDX' in row else str(row.name)
                    if qidx not in processed_qidx:
                        unprocessed.append((qidx, row['Question']))

                total_to_process = len(unprocessed)
                if total_to_process == 0:
                    print("No new questions to process")
                    return

                print(f"Starting sequential processing of {total_to_process} questions")
                
                # Process in batches (but sequentially within each batch)
                batch_size = 4
                with tqdm(total=total_to_process, desc="Processing questions") as pbar:
                    for batch_start in range(0, total_to_process, batch_size):
                        batch = unprocessed[batch_start:batch_start + batch_size]
                        batch_results = []
                        
                        for qidx, question in batch:
                            try:
                                result = process_question(question, qidx)
                                batch_results.append({
                                    'QIDX': qidx,
                                    'Question': question,
                                    'Short Answer': result.get("Short_Answer", "N/A"),
                                    'Long Answer': result.get("Long_Answer", "")
                                })
                                pbar.update(1)
                            except Exception as e:
                                print(f"\nError processing question {qidx}: {str(e)}")
                                batch_results.append({
                                    'QIDX': qidx,
                                    'Question': question,
                                    'Short Answer': f"ERROR: {str(e)}",
                                    'Long Answer': ""
                                })
                                pbar.update(1)

                        # Save batch results
                        if batch_results:
                            pd.DataFrame(batch_results).to_csv(
                                OUTPUT_CSV_PATH,
                                mode='a',
                                header=not os.path.exists(OUTPUT_CSV_PATH),
                                index=False
                            )
                        
                        # Memory management
                        gc.collect()
                        if torch.cuda.is_available():
                            torch.cuda.empty_cache()

                print(f"\nProcessing complete. Results saved to {OUTPUT_CSV_PATH}")
    
    run_button.on_click(on_run_button_clicked)
    
    # Display the interface
    display(widgets.VBox([
        mode,
        question_input,
        file_upload,
        widgets.HBox([output_csv, run_button]),
        output_area
    ]))


notebook_interface()