# Rx Copilot

A prototypal exploration of semantic search across unstructured text drug indications from DailyMed to ICD-10 codes and descriptions.

## Baseline: Zero Shot

_Semantic search with embeddings_ using latest [ICD-10-CM codes](https://www.cdc.gov/nchs/icd/icd-10-cm/files.html) published by CDC (Apr 2025), develop embedding space for the descriptions.

Key Assumptions:

- [BiomedBERT](https://huggingface.co/microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract-fulltext) is a sufficient pre-trained model for generating embeddings in semantic space for ICD-10 codes and drug label indications

- [DailyMed SPL Full Release - Part 1](https://dailymed-data.nlm.nih.gov/public-release-files/dm_spl_release_human_rx_part1.zip) is sufficient for diagnostic & drug prediction plan


`#biobert` `#dailymed` `#faiss` `#embeddings` `#cosine-similarity`


> Objectives:
>
> - Observe vectorization applied to natural language inputs
> - Leverage vector similarity used to perform semantic search


### Context and goals


Approach: Pre-trained Embeddings + FAISS

**Pros:**
- Zero-shot Learning: No need for training, leveraging existing knowledge in pre-trained models.
- High Accuracy for Complex Phrases: Especially with biomedical models.
- Scalable: Suitable for massive datasets.
- Real-Time Search: FAISS provides quick nearest-neighbor lookups.
- Handling Synonyms: Embeddings naturally capture semantic similarity.

**Cons:**
- Black Box: Hard to interpret why a specific match occurred.
- Embedding Quality Dependency: Highly reliant on the quality of pre-trained embeddings.
- Domain-Specific Overfitting: Using domain-specific models may fail on generalized queries.
- Cold Start Problem: Poor performance if the query significantly differs from training data.

**Goal**: To illustrate semantic search (i.e. searching by _meaning_) we will select a reputable biomedically rich pre-trained model for embeddings to create a reference set of vectors for ICD-10 code descriptions to query against using natural language queries.

As for QC, we will tune embeddings:

1. Leverage DailyMed drug indications as supervised learning dataset.

2. Implement human-in-the-loop as loss function.


### Imports

In [241]:
%pip install numpy pandas faiss-cpu torch sentence-transformers

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


Note: you may need to restart the kernel to use updated packages.


### Pre-processing pipeline

Leverage a semantically rich biomedical pipeline to transform input text into a dense semantic vector embedding to support similarity search queries:

- Load domain relevant sentence transformer to tokenize and generate embeddings

- Tokenize and generate embeddings a using BioBERT model (preferably trained on PubMed)

In [242]:
import torch
from transformers import AutoTokenizer, AutoModel

# Load a pre-trained BioBERT-based model (e.g., PubMedBERT) from the Hugging Face model hub
pubmed_model_name = "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract"
print(f"Loading model: {pubmed_model_name}")

tokenizer = AutoTokenizer.from_pretrained(pubmed_model_name)
model = AutoModel.from_pretrained(pubmed_model_name)
print(f"Model {pubmed_model_name} loaded successfully!")

# Set the device to run on: GPU, MPS, or CPU
# Each processor has a different memory location in addition to the processing power and different tensor representations within the processing unit
device = torch.device(
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
print(f"Using {device} device")
print(f"Model: {model}")

model = model.to(device)

def generate_embeddings_batch(texts):
    """Generates embeddings for a batch of texts."""
    # cap token length for BERT at 512 - drug indications are strongly varied in length
    inputs = tokenizer(texts, return_tensors="pt", padding=True, truncation=True, max_length=512)
    inputs = {key: val.to(device) for key, val in inputs.items()}
    with torch.no_grad():
        outputs = model(**inputs)
    embeddings = torch.mean(outputs.last_hidden_state, dim=1).cpu().numpy()
    return embeddings

Loading model: microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract
Model microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract loaded successfully!
Using mps device
Model: BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(30522, 768, padding_idx=0)
    (position_embeddings): Embedding(512, 768)
    (token_type_embeddings): Embedding(2, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0-11): 12 x BertLayer(
        (attention): BertAttention(
          (self): BertSdpaSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_f

### 0. Generate Reference Embeddings for ICD-10-CM Codes

ICD-10-CM codes are published as plain text and generate 74,260 vector embeddings [Code-descriptions-April-2025/icd10cm-codes-April-2025.txt](https://ftp.cdc.gov/pub/Health_Statistics/NCHS/Publications/ICD10CM/2025-Update/Code-desciptions-April-2025.zip)

- Download and unzip contents under `assets/downloads` directory.
- Check that `assets/downloads/Code-desciptions-April-2025/icd10cm-codes-April-2025.txt` exists before continuing.
- We will use FAISS as the vector store. Full set of embeddings can be generated under 1 minute when leveraging hardware-acceleration - only need to run once per publication year.

In [243]:
def load_icd10_codes(icd10_path):
    """Loads ICD-10 codes from the given file."""
    icd10_codes = {}
    with open(icd10_path, 'r') as f:
        for line in f:
            line = line.strip().split('  ')
            code = line[0].strip()
            description = line[-1].strip()
            # print(code, '=>', description)
            icd10_codes[code] = description
    return icd10_codes

# load_icd10_codes('assets/downloads/Code-desciptions-April-2025/icd10cm-codes-April-2025.txt')

Generate ICD-10 embeddings

In [244]:
import torch
import pandas as pd
from tqdm import tqdm
import faiss


# Load ICD-10 code, description pairs and generate embeddings with BioBERT
def generate_icd10_embeddings(icd10_df, batch_size=32):
    """Generates BioBERT embeddings for ICD-10 descriptions using batching."""
    icd_embeddings = []

    # Batch processing
    descriptions = icd10_df['description'].tolist()
    codes = icd10_df['code'].tolist()

    for i in tqdm(range(0, len(descriptions), batch_size), desc="Generating Embeddings"):
        batch_texts = descriptions[i:i + batch_size]
        batch_codes = codes[i:i + batch_size]
        batch_embeddings = generate_embeddings_batch(batch_texts)

        for code, description, embedding in zip(batch_codes, batch_texts, batch_embeddings):
            icd_embeddings.append((code, description, embedding))

    return icd_embeddings


# Load ICD-10 codes and descriptions
icd10_codes = load_icd10_codes('assets/downloads/Code-desciptions-April-2025/icd10cm-codes-April-2025.txt')
icd10_codes_df = pd.DataFrame(icd10_codes.items(), columns=['code', 'description'])
print(f"Loaded {len(icd10_codes_df)} ICD-10 codes.")

icd10_embeddings = generate_icd10_embeddings(icd10_codes_df)

print(f"Generated {len(icd10_embeddings)} ICD-10 embeddings.")
print(icd10_embeddings[0])

Loaded 74260 ICD-10 codes.


Generating Embeddings: 100%|██████████| 2321/2321 [00:51<00:00, 45.50it/s]

Generated 74260 ICD-10 embeddings.
('A000', 'Cholera due to Vibrio cholerae 01, biovar cholerae', array([-3.03309172e-01,  9.49256569e-02,  3.90974045e-01, -9.96407941e-02,
       -2.15424329e-01, -2.71187127e-01,  2.79311121e-01, -3.02834451e-01,
        6.73074603e-01, -6.53716251e-02,  4.96145785e-01,  2.19713479e-01,
       -5.35560489e-01, -4.42970365e-01,  4.40391213e-01, -7.81174451e-02,
        1.55149341e-01,  1.45223618e-01, -7.12256730e-01, -2.28009503e-02,
       -1.87848303e-02,  2.71798432e-01,  1.05305515e-01,  5.21313965e-01,
       -2.19086632e-01,  4.31417555e-01,  2.25695446e-01, -3.05367261e-01,
       -7.38111138e-01,  2.96821386e-01, -2.32227057e-01,  4.22214270e-02,
       -3.50495242e-02,  4.24171448e-01,  4.58415300e-01,  2.72942752e-01,
       -8.63884166e-02, -8.85700509e-02,  4.02216196e-01, -3.10418576e-01,
        1.69298276e-01, -9.63555455e-01,  5.66147327e-01, -6.76564753e-01,
       -4.13298517e-01,  1.42448023e-01, -4.20432627e-01,  2.46382326e-01,
  




Create FAISS vector index and store embeddings for subsequent similarity searches from indication query inputs.
  - The choice of distance metric will depend on how the embedding model was trained. The metric we use to perform searches needs to match the metric used by the model to minimize distances between related words.

  - The current model uses the _inner product (IP) distance_ (also known as "cosine distance")


In [265]:
import numpy as np


def softmax(x):
    e_x = np.exp(x - np.max(x)) #subtract max for numerical stability.
    return e_x / e_x.sum()


def get_probability_distribution(query_embedding, index, top_k):
    """
    Retrieves the top_k nearest neighbors in semantic space, and converts the distances to a probability distribution.
    """
    similarities, indices = index.search(query_embedding.reshape(1, -1).astype('float32'), top_k)
    probabilities = softmax(similarities[0])
    return probabilities, indices[0] # remove batch dimension.


def build_faiss_index(embeddings, faiss_index_path):
    """Builds and saves a FAISS index with ICD-10 or Drug Indication embeddings."""
    dimension = 768  # BioBERT embedding dimension
    index = faiss.IndexFlatIP(dimension)

    for code, desc, embedding in embeddings:
        embedding_vector = embedding.astype('float32')
        index.add(embedding_vector.reshape(1, -1))

    # Save the index
    faiss.write_index(index, faiss_index_path)
    print(f"FAISS index saved to: {faiss_index_path}")

    return index


def query_faiss_index(query_text, faiss_index_path, icd10_df, top_k=5):
    """Queries the FAISS index to find the best matching ICD-10 codes."""
    # Load the FAISS index
    index = faiss.read_index(faiss_index_path)

    # Generate the embedding for the input query
    query_embedding = generate_embeddings_batch([query_text]).astype('float32')

    # Search the FAISS index
    probabilities, indices = get_probability_distribution(query_embedding, index, top_k)

    # Get the top-k matching ICD-10 codes
    results = []
    for probability, idx in zip(probabilities, indices):
        code = icd10_df.iloc[idx]['code']
        description = icd10_df.iloc[idx]['description']
        confidence = probability
        results.append((code, description, confidence))

    # Print the results
    print(f"Query: {query_text}")
    print(f"Top {top_k} Matching ICD-10 Codes:")
    for code, desc, conf in results:
        print(f"  - {code}: {desc} (Confidence: {conf:.2f})")

    return results


faiss_index_path = 'icd10_faiss.index'
icd10_faiss_index = build_faiss_index(icd10_embeddings, faiss_index_path)
print()

# example query
medical_condition_query_text = "Chronic renal failure with edema"
query_faiss_index(medical_condition_query_text, faiss_index_path, icd10_codes_df, top_k=10)

FAISS index saved to: icd10_faiss.index

Query: Chronic renal failure with edema
Top 10 Matching ICD-10 Codes:
  - E8352: Hypercalcemia (Confidence: 0.91)
  - F205: Residual schizophrenia (Confidence: 0.05)
  - K9083: Intestinal failure (Confidence: 0.03)
  - M6284: Sarcopenia (Confidence: 0.02)
  - I7091: Generalized atherosclerosis (Confidence: 0.00)
  - O925: Suppressed lactation (Confidence: 0.00)
  - Z931: Gastrostomy status (Confidence: 0.00)
  - K561: Intussusception (Confidence: 0.00)
  - E8351: Hypocalcemia (Confidence: 0.00)
  - J9381: Chronic pneumothorax (Confidence: 0.00)


[('E8352', 'Hypercalcemia', np.float32(0.90616024)),
 ('F205', 'Residual schizophrenia', np.float32(0.04979781)),
 ('K9083', 'Intestinal failure', np.float32(0.026428055)),
 ('M6284', 'Sarcopenia', np.float32(0.015734745)),
 ('I7091', 'Generalized atherosclerosis', np.float32(0.0011176385)),
 ('O925', 'Suppressed lactation', np.float32(0.000429313)),
 ('Z931', 'Gastrostomy status', np.float32(0.000121864396)),
 ('K561', 'Intussusception', np.float32(8.502315e-05)),
 ('E8351', 'Hypocalcemia', np.float32(6.912333e-05)),
 ('J9381', 'Chronic pneumothorax', np.float32(5.6162582e-05))]

### 1. Preprocess Drug Labels (optional)

Load, extract, and format drug indication text from Indications and Usage sections.

__SKIP__: can skip downloading and parsing DailyMed (GBs) and directly use pre-generated `dm_spl_release_human_rx_part1_indications.csv`

In [269]:
import zipfile
import os
import xml.etree.ElementTree as ET
import re
import pandas as pd


def unzip(zip_path, extract_to):
    """Extracts nested zip files recursively."""
    with zipfile.ZipFile(zip_path, 'r') as zip_ref:
        zip_ref.extractall(extract_to)
        extracted_files = zip_ref.namelist()
        for file_name in extracted_files:
            file_path = os.path.join(extract_to, file_name)
            if file_name.endswith('.zip'):
                nested_dir = os.path.join(extract_to, file_name.replace('.zip', ''))
                os.makedirs(nested_dir, exist_ok=True)
                unzip(file_path, nested_dir)


class DrugInfo:
    def __init__(self, xml_file_path, should_print=False):
        print(f"Processing {xml_file_path}") if should_print else None
        self.xml_file_path = xml_file_path
        self.root = None
        self.drug_name = ''
        self.indications = []
        self.synonyms = []
        self.extract_drug_info(xml_file_path, should_print)

    def __repr__(self):
        return f"DrugInfo(drug_name={self.drug_name}, synonyms={self.synonyms}, indications={self.indications})"

    def extract_drug_info(self, xml_file, should_print=False):
        """
        Parse a DailyMed SPL XML file and extract the drug name and indications.

        Args:
            xml_file (str): The path to the XML file.

        Returns:
            Tuple[str, str]: A tuple (drug_name, indications) where:
                - drug_name is the first non-empty candidate from several possible locations.
                - indications is the text from the first <section> whose <title> contains 'indication'.
                If not found, returns "Not Found".
        """
        tree = ET.parse(xml_file)
        root = tree.getroot()

        # Extract the default namespace (if any)
        m = re.match(r'\{(.*)\}', root.tag)
        namespace = m.group(1) if m else ''
        ns = {'ns': namespace} if namespace else {}

        # --- Extract the drug name ---
        # Try several candidate paths for a drug name.
        drug_name = None
        # 1. Check if a top-level <title> exists (child of the root)
        title_elem = root.find('./ns:title', ns)
        if title_elem is not None and title_elem.text:
            drug_name = title_elem.text.strip()

        # 2. Fallback: search for any <title> element in the document
        if not drug_name:
            title_elem = root.find('.//ns:title', ns)
            if title_elem is not None and title_elem.text:
                drug_name = title_elem.text.strip()

        # 3. Fallback: search for manufactured product names in several possible paths
        if not drug_name:
            for path in [
                './/ns:manufacturedProduct/ns:manufacturedMaterial/ns:name',
                './/ns:manufacturedMaterial/ns:name',
                './/ns:consumable/ns:manufacturedProduct/ns:manufacturedMaterial/ns:name'
            ]:
                name_elem = root.find(path, ns)
                if name_elem is not None and name_elem.text:
                    drug_name = name_elem.text.strip()
                    break

        if not drug_name:
            drug_name = "Unknown Drug"

        # --- Extract the indications section ---
        indications = None
        # Iterate over all <section> elements looking for one with a title containing "indication"
        for section in root.findall('.//ns:section', ns):
            sec_title = section.find('./ns:title', ns)
            if sec_title is not None and sec_title.text and 'indication' in sec_title.text.lower():
                # Look first for a nested <text> element to extract detailed text
                text_elem = section.find('.//ns:text', ns)
                if text_elem is not None:
                    indications = " ".join(text_elem.itertext()).strip()
                else:
                    # Fallback: use all the section's text content
                    indications = " ".join(section.itertext()).strip()
                break

        if not indications:
            indications = "Not Found"

        if should_print:
            print(f"Drug Name: {drug_name}")
            print(f"Indications: {indications}")

        self.drug_name = drug_name
        self.indications = indications
        return (drug_name, indications)


def save_drug_data_to_csv(xml_files, count, output_file="drug_data.csv"):
    """Saves the drug data to a CSV file."""
    drug_data = []

    for xml_file in xml_files[:count]:
        drug_info = DrugInfo(xml_file)
        drug_data.append({"Drug Name": drug_info.drug_name, "Indications": drug_info.indications})

    df = pd.DataFrame(drug_data)
    df.to_csv(output_file, index=False)
    print(f"Drug data saved to {output_file}")


In [None]:
# Note: can skip the following 2 cells if the data is already downloaded and extracted.
target_rx = "assets/downloads/dm_spl_release_human_rx_part1"
dest_rx = f'assets/dm_spl_release_human_rx_part1_indications'

# 1. extract test corpus (20s)
unzip(f'{target_rx}.zip', target_rx)

In [270]:
# 2. load all xml files from the test corpus and extract indications
xml_files = []
for root, dirs, files in os.walk(target_rx):
    for file in files:
        if file.endswith('.xml'):
            xml_files.append(os.path.join(root, file))

print(f"Found {len(xml_files)} XML files (drugs) in the test corpus.")
xml_files.sort()

drug_sample_size = 15079  # number of drugs to extract indications for

# 3. Save the extracted drug data to a CSV file
save_drug_data_to_csv(xml_files, count=min(drug_sample_size, len(xml_files)), output_file=f'{dest_rx}.csv')

print(f"Extracted {drug_sample_size} drugs w/ indications from the test corpus.")

Found 15079 XML files (drugs) in the test corpus.
Drug data saved to assets/dm_spl_release_human_rx_part1_indications.csv
Extracted 15079 drugs w/ indications from the test corpus.


### 2. Vectorize Drug Indications

Generate embeddings for drug label indications from DailyMed. In semantic search applications, the input could range from small snippets to entire documents, for now restrict to simple drug indications.

- [x] start with naive implementation and truncate input token length (lossy)
- TODO: introduce a synthesis autoencoder step to further compress lengthy (> 512 tokens) indications
- TODO: leverage an unsupervised clustering algorithm to group semantic synonyms for representative drug "clusters"

Use the same pre-trained model as used for ICD-10 codes (same "meaning space") in order to provide fuzzy similarity search from ailment (ICD-10) to drug (label).

__WARNING__: generating embeddings for the full set of drugs (DailyMed part 1) on a typical circa 2024 laptop, plan to make a pot of coffee (~5min)


In [271]:
# Load Drug Indications
def load_drug_indications(filepath):
    """Load drug indications from a CSV file."""
    df = pd.read_csv(filepath)
    return df


# Generate drug indication embeddings
def generate_drug_indications_embeddings(drug_indications_df, batch_size=32):
    """Generates BioBERT embeddings for drug indications using batching."""
    drug_embeddings = []

    # Batch processing
    indications = drug_indications_df['Indications'].tolist()
    drug_names = drug_indications_df['Drug Name'].tolist()

    for i in tqdm(range(0, len(indications), batch_size), desc="Generating Drug Embeddings"):
        batch_texts = indications[i:i + batch_size]
        batch_drugs = drug_names[i:i + batch_size]
        batch_embeddings = generate_embeddings_batch(batch_texts)

        for drug_name, indication, embedding in zip(batch_drugs, batch_texts, batch_embeddings):
            drug_embeddings.append((drug_name, indication, embedding))

    return drug_embeddings

def generate_drug_indications_embeddings(drug_indications_df, batch_size=32):
    """Generates BioBERT embeddings for drug indications using batching."""
    drug_embeddings = []

    # Batch processing
    indications = drug_indications_df['Indications'].astype(str).tolist()
    drug_names = drug_indications_df['Drug Name'].astype(str).tolist()

    for i in tqdm(range(0, len(indications), batch_size), desc="Generating Drug Embeddings"):
        batch_texts = indications[i:i + batch_size]
        batch_drugs = drug_names[i:i + batch_size]

        # Filter out any non-string or empty texts
        filtered_batch_texts = [text for text in batch_texts if isinstance(text, str) and text.strip()]
        filtered_batch_drugs = [drug for drug, text in zip(batch_drugs, batch_texts) if isinstance(text, str) and text.strip()]

        # Only proceed if the filtered batch is not empty
        if len(filtered_batch_texts) == 0:
            continue

        # Generate embeddings for the filtered batch
        batch_embeddings = generate_embeddings_batch(filtered_batch_texts)

        for drug_name, indication, embedding in zip(filtered_batch_drugs, filtered_batch_texts, batch_embeddings):
            drug_embeddings.append((drug_name, indication, embedding))

    return drug_embeddings


# Load the drug indications file (CSV format)
drug_indications_df = load_drug_indications(f'{dest_rx}.csv')
print(f"Number of non-string indications: {sum(~drug_indications_df['Indications'].apply(lambda x: isinstance(x, str)))}")

print(f"Loaded {len(drug_indications_df)} drug indications.")
print(drug_indications_df.head())
print()

# Generate embeddings for drug indications
drug_embeddings = generate_drug_indications_embeddings(drug_indications_df)

print()
print(f"Generated {len(drug_embeddings)} drug indication embeddings.")
print(drug_embeddings[0])

Number of non-string indications: 1
Loaded 15079 drug indications.
                              Drug Name  \
0                                RENESE   
1                       Mykrox® Tablets   
2                              Tolinase   
3                               HYPAQUE   
4  MEPERIDINE HYDROCHLORIDE Tablets USP   

                                         Indications  
0  Renese is indicated as adjunctive therapy in e...  
1  MYKROX Tablets are indicated for the treatment...  
2  TOLINASE Tablets are indicated as an adjunct t...  
3  HYPAQUE-76 is indicated for excretory urograph...  
4  Meperidine is indicated for the relief of mode...  



Generating Drug Embeddings: 100%|██████████| 472/472 [04:02<00:00,  1.95it/s]


Generated 15079 drug indication embeddings.
('RENESE', 'Renese is indicated as adjunctive therapy in edema associated with congestive heart failure, hepatic cirrhosis, and corticosteroid and estrogen therapy. \n\t\t\t\t\t\t Renese has also been found useful in edema due to various forms of renal dysfunction as: Nephrotic syndrome; Acute glomerulonephritis; and Chronic renal failure. \n\t\t\t\t\t\t Renese is indicated in the management of hypertension either as the sole therapeutic agent or to enhance the effectiveness of other antihypertensive drugs in the more severe forms of hypertension.', array([-1.04940534e-01,  4.42041829e-02, -6.47793338e-02, -9.85202044e-02,
       -1.11142956e-02,  1.23856559e-01, -1.46111101e-01,  9.84481275e-02,
       -5.99351013e-03, -7.81668797e-02,  1.18984148e-01,  1.13775134e-01,
       -3.10066342e-02,  1.83795139e-01,  3.42779934e-01, -1.67320237e-01,
        5.68514541e-02, -3.38297039e-02,  2.94664279e-02,  4.86081466e-03,
        1.99748695e-01, 




Saving the generated embeddings for subsequent semantic search.


In [272]:
faiss_index_path = 'drugs_faiss.index'
drugs_faiss_index = build_faiss_index(drug_embeddings, faiss_index_path)
print(f"Number of vectors in the index: {drugs_faiss_index.ntotal}")

FAISS index saved to: drugs_faiss.index
Number of vectors in the index: 15079


### 3. Implement RX Copilot Structured Queries

RX Copilot probabilistic process flow:

1.	The probability of a specific medical condition is correct - **P(condition)**

2.	The probability of choosing the correct drug given that condition - **P(drug | condition)**

then—in a Bayesian probabilistic sense we are really interested in the joint probability that both are simultaneously true. In other words, we need to capture the fact that the correct drug decision is contingent on the condition being present.


In [273]:
# P(condition and correct drug) = P(condition) x P(drug | condition)
p_icd10_drug = lambda p_icd10, p_drug: p_icd10 * p_drug

Leverage built FAISS indices to construct an e2e NLP clinical flow for HCPs team members:

- Implement a nearest-neighbor search to match indications with ICD-10 codes.
- Display in a table with drug name, indication, ICD-10 code, and confidence score.

In [274]:
from collections import defaultdict


def query_rx_copilot(query_text, icd10_faiss_path, drug_faiss_path, icd10_df, drug_df, top_k=5):
    """Handles a natural language query and returns matched drugs and ICD-10 codes."""

    # Step 1: Generate query embedding
    [query_embedding] = generate_embeddings_batch([query_text]).astype('float32')

    # Step 2: Load FAISS indexes
    icd10_index = faiss.read_index(icd10_faiss_path)
    drug_index = faiss.read_index(drug_faiss_path)

    # Step 3: Search ICD-10 Index and get probability distribution
    icd10_probabilities, icd10_indices = get_probability_distribution(query_embedding, icd10_index, top_k)
    icd10_results = [
        (icd10_df.iloc[idx]['code'], icd10_df.iloc[idx]['description'], prob)
        for prob, idx in zip(icd10_probabilities, icd10_indices)
    ]

    # Step 4: Search Drug Indication Index and get probability distribution
    drug_probabilities, drug_indices = get_probability_distribution(query_embedding, drug_index, top_k)
    drug_results = [
        (drug_df.iloc[idx]['Drug Name'], drug_df.iloc[idx]['Indications'], prob)
        for prob, idx in zip(drug_probabilities, drug_indices)
    ]

    # Step 5: Combine and format results (joint prob confidence)
    response = defaultdict(lambda: [None, None, 0.0])  # Accumulate probabilities for duplicates

    for drug_name, indication, drug_conf in drug_results:
        for icd_code, icd_desc, icd_conf in icd10_results:
            combined_conf = p_icd10_drug(icd_conf, drug_conf)
            key = (drug_name, icd_code)

            # If the pair already exists, accumulate the confidence (since they are duplicate probability events)
            if response[key][2] > 0:
                response[key][2] = response[key][2] + combined_conf
            else:
                response[key] = [indication, icd_desc, combined_conf]

    # Step 6: Sort by combined confidence
    sorted_response = sorted(response.items(), key=lambda x: x[1][2], reverse=True)

    # Step 7: Print and return results
    print(f"Query: {query_text}")
    print()
    print("Top Diagnoses (Rx included):")
    for (drug_name, icd_code), (indication, icd_desc, conf) in sorted_response:
        print(f"ICD-10: {icd_code} ({icd_desc})")
        print(f"Drug: {drug_name}, Indication: {indication}")
        print(f"    -> Indication: {indication}")
        print(f"Copilot Confidence: {conf:.2f}")
        print()

    return response

Run through case study examples to test behavior.

In [284]:
# Run through more examples to verify the pipeline and begin to address QC concerns like multiple diagnoses, confidence intervals, and mechanisms to mitigate misdiagnoses like human-in-the-loop
soap_note_text = "Chronic renal failure with edema"
# soap_note_text = "Congestive heart failure"

query_rx_copilot(
    soap_note_text,
    icd10_faiss_path="icd10_faiss.index",
    drug_faiss_path="drugs_faiss.index",
    icd10_df=icd10_codes_df,
    drug_df=drug_indications_df
)

Query: Chronic renal failure with edema

Top Diagnoses (Rx included):
ICD-10: E8352 (Hypercalcemia)
Drug: METHYLDOPA TABLETS USP, Indication: Hypertension.
    -> Indication: Hypertension.
Copilot Confidence: 0.88

ICD-10: F205 (Residual schizophrenia)
Drug: METHYLDOPA TABLETS USP, Indication: Hypertension.
    -> Indication: Hypertension.
Copilot Confidence: 0.05

ICD-10: E8352 (Hypercalcemia)
Drug: Methyldopa Tablets, USP Rx only, Indication: Hypertension.
    -> Indication: Hypertension.
Copilot Confidence: 0.03

ICD-10: K9083 (Intestinal failure)
Drug: METHYLDOPA TABLETS USP, Indication: Hypertension.
    -> Indication: Hypertension.
Copilot Confidence: 0.03

ICD-10: M6284 (Sarcopenia)
Drug: METHYLDOPA TABLETS USP, Indication: Hypertension.
    -> Indication: Hypertension.
Copilot Confidence: 0.02

ICD-10: F205 (Residual schizophrenia)
Drug: Methyldopa Tablets, USP Rx only, Indication: Hypertension.
    -> Indication: Hypertension.
Copilot Confidence: 0.00

ICD-10: I7091 (Generaliz

defaultdict(<function __main__.query_rx_copilot.<locals>.<lambda>()>,
            {('METHYLDOPA TABLETS USP', 'E8352'): ['Hypertension.',
              'Hypercalcemia',
              np.float32(0.8779762)],
             ('METHYLDOPA TABLETS USP', 'F205'): ['Hypertension.',
              'Residual schizophrenia',
              np.float32(0.04824896)],
             ('METHYLDOPA TABLETS USP', 'K9083'): ['Hypertension.',
              'Intestinal failure',
              np.float32(0.02560607)],
             ('METHYLDOPA TABLETS USP', 'M6284'): ['Hypertension.',
              'Sarcopenia',
              np.float32(0.015245352)],
             ('METHYLDOPA TABLETS USP', 'I7091'): ['Hypertension.',
              'Generalized atherosclerosis',
              np.float32(0.0010828769)],
             ('Methyldopa Tablets, USP Rx only', 'E8352'): ['Hypertension.',
              'Hypercalcemia',
              np.float32(0.028874736)],
             ('Methyldopa Tablets, USP Rx only', 'F205'): ['Hypert