# **Install Needed Libraries**

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
!pip install biopython tqdm

# **Import Libraries used**

In [3]:
import os
import re
import json
import torch
import requests
from tqdm import tqdm #For displaying progress bars
from Bio import Entrez #for interacting with biological databases like PubMed
from collections import defaultdict
from transformers import AutoTokenizer, AutoModel #loading pre-trained models and tokenizers
from sklearn.metrics.pairwise import cosine_similarity #for vector comparisons

# **Constants used in helper functions**

In [4]:
NUM_MENTIONS = 60  # Number of sentences to fetch from PubMed/UMLS
MASK_TOKEN = "[MASK]"
MAX_LEN = 128
ENCODER_MODEL = "emilyalsentzer/Bio_ClinicalBERT"

# **Load Unique Options**
htis function reads a JSON Lines (.jsonl) file (e.g., medqal.jsonl). It iterates through each line, parses the JSON, and extracts all unique options from the "options". The options are normalized by stripping whitespace, converting to lowercase, and replacing spaces with underscores. It returns a sorted list of these unique options

In [5]:
def load_unique_options(filepath):
    unique_options = set()
    with open(filepath, 'r') as f:
        for line in f:
            data = json.loads(line)
            for opt in data["options"].values():
                opt_clean = opt.strip().lower().replace(" ", "_")
                unique_options.add(opt_clean)
    option_list = sorted(unique_options)
    print(f"[Step 1] Loaded {len(option_list)} unique options.")
    print("Example option:", option_list[0])
    return option_list

# **Relation Mapping**

A dictionary that maps abbreviated UMLS relation types (e.g., "CHD", "PAR") to more descriptive, human-readable phrases (e.g., "has child", "has parent").

In [6]:
RELATION_MAPPING = {
    "CHD": "has child",
    "PAR": "has parent",
    "RB": "has a broader relationship",
    "RN": "has a narrower relationship",
    "RO": "has a related relationship",
    "RQ": "can be used for",
    "SY": "has synonym",
    "AQ": "is a conceptual part of",
    "RL": "is a clinical part of",
    "RU": "is a temporal part of", # Note: 'RU' also used for 'has a cause' in some contexts, but 'is a temporal part of' is more common.
    "CO": "is a component of",
    "QB": "has therapeutic class",
    "RT": "has a therapeutic indication",
    "TR": "has a finding site",
    "AL": "is a manifestation of",
    "BR": "has a part",
    "XR": "is a type of",
    "DX": "is a diagnosis of",
    "CI": "is caused by",
    "PM": "has a mechanism of action",
    "HG": "is a gene associated with",
    "LC": "is a chemical constituent of",
    "AD": "is administered via",
    "DI": "is indicated for",
    "TI": "is a treatment for",
    "AO": "is associated with", # Common relation, often generic
    "AT": "has an anatomical site",
    "CC": "is a clinical course of",
    "CD": "has clinical manifestation",
    "CM": "has common finding",
    "DD": "has a differential diagnosis",
    "DP": "is a demographic characteristic of",
    "DR": "is a drug resistance mechanism of",
    "GE": "is a genetic association with",
    "GN": "has a gene",
    "GS": "has a finding site",
    "HT": "has a histopathological finding",
    "IN": "is an indication for",
    "MD": "has a medical device",
    "MH": "is a mental health condition related to",
    "MT": "is a method of treatment for",
    "NG": "is a molecular function of",
    "OP": "is an operative procedure for",
    "OR": "has an organism",
    "PB": "is a pathological process of",
    "PC": "is a physiological process of",
    "PH": "has a pharmacological action",
    "PP": "has a physical finding",
    "PS": "is a pathogenic mechanism of",
    "PX": "is a physical finding in",
    "TC": "is a therapeutic class of",
    "TT": "is a therapeutic target of",
    "VG": "is a genetic variant of",
    "VS": "has a vital sign",
    "XD": "is an imaging finding in",
    "XG": "is a gene variant in",
    "XM": "is a metabolic product of",
    "XO": "is an outcome of",
    "XP": "is a prognosis of",
    "XT": "is a therapeutic intervention for",
    "ZN": "is a zone of",
}

# **Load English Concepts**

This function reads the MRCONSO.RRF file from the specified UMLS meta folder. It parses each line to extract the Concept Unique Identifier (CUI) and its string name. It filters for English (ENG) concepts and returns a dictionary mapping CUIs to their respective English names. This mapping is essential for translating CUIs found in MRREL.RRF into understandable terms

In [7]:
def load_english_concepts(umls_meta_folder):
    """
    Loads English Concept Unique Identifiers (CUIs) and their string names
    from the MRCONSO.RRF file. This is crucial for translating CUIs into
    readable names for sentence formation.

    Args:
        umls_meta_folder (str): The path to the 'META' directory containing RRF files.

    Returns:
        dict: A dictionary mapping CUI to English concept name, or an empty dict if an error occurs.
    """
    mrconso_path = os.path.join(umls_meta_folder, "MRCONSO.RRF")
    cui_to_name = {}
    print(f"Loading English concepts from {mrconso_path}...")
    try:
        with open(mrconso_path, 'r', encoding='utf-8') as f:
            for line in f:
                parts = line.strip().split('|')
                # Based on your MRCONSO.RRF example:
                # CUI is parts[0], LAT (Language) is parts[1], STR (String) is parts[14]
                if len(parts) > 14: # Ensure the line has enough parts
                    cui = parts[0]
                    lat = parts[1]
                    string = parts[14]
                    if lat == "ENG": # Filter for English language concepts
                        cui_to_name[cui] = string
        print(f"Loaded {len(cui_to_name)} English concepts.")
    except FileNotFoundError:
        print(f"Error: MRCONSO.RRF not found at {mrconso_path}. Please check the path.")
        return {}
    except Exception as e:
        print(f"An error occurred while loading MRCONSO.RRF: {e}")
        return {}
    return cui_to_name

# **Extract and Format Triplets**

This is a core function that processes the MRREL.RRF file
It iterates through relational triples (CUI1, relation type, CUI2).

* It iterates through relational triples (CUI1, relation type, CUI2).
* For each triple, it uses the cui_to_name_map to convert CUIs into their English names (subject and object).
* For each triple, it uses the cui_to_name_map to convert CUIs into their English names (subject and object).
* It uses RELATION_MAPPING to translate the relation type into a descriptive phrase.
* It then constructs a sentence in the format: "Subject relation_label object."
* It matches these generated sentences to the medical_options (case-insensitive, checking if the option appears in the subject or object).
* It limits the number of triples collected per option using max_triples_per_option.
* The function returns a dictionary where each medical option is mapped to a list of relevant sentence-style triples.













In [8]:
def extract_and_format_triples(medical_options, umls_meta_folder, cui_to_name_map, max_triples_per_option=50):
    """
    Extract and format triples from MRREL.RRF for the given medical options.

    Returns a dictionary mapping each medical option to a list of sentence-style triples.
    Matches are case-insensitive and can appear as subject or object in the triple.
    """

    if not cui_to_name_map:
        print("❌ No English concepts loaded.")
        return {}

    mrrel_path = os.path.join(umls_meta_folder, "MRREL.RRF")
    print(f"✅ Reading triples from {mrrel_path}")

    results = defaultdict(set)
    option_counts = defaultdict(int)
    options_for_search = [opt.lower().replace("_", " ") for opt in medical_options]

    try:
        with open(mrrel_path, 'r', encoding='utf-8') as f:
            for line in f:
                parts = line.strip().split('|')
                if len(parts) < 5:
                    continue
                cui1, rel, cui2 = parts[0], parts[3], parts[4]

                if cui1 in cui_to_name_map and cui2 in cui_to_name_map:
                    subj = cui_to_name_map[cui1]
                    obj = cui_to_name_map[cui2]
                    rel_label = RELATION_MAPPING.get(rel, rel)
                    sentence = f"{subj} {rel_label} {obj}."

                    subj_lower = subj.lower()
                    obj_lower = obj.lower()

                    for idx, opt in enumerate(options_for_search):
                        orig_option = medical_options[idx]
                        if option_counts[orig_option] >= max_triples_per_option:
                            continue

                        if opt in subj_lower or opt in obj_lower:
                            if sentence not in results[orig_option]:
                                results[orig_option].add(sentence)
                                option_counts[orig_option] += 1
    except FileNotFoundError:
        print("❌ MRREL.RRF file not found.")
        return {}
    except Exception as e:
        print(f"❌ Error while processing: {e}")
        return {}

    final_results = {opt: list(triples) for opt, triples in results.items()}
    return final_results

# **Fetch Pubmed Sentences**

This function queries the PubMed database using Bio.Entrez. It searches for abstracts containing the given term

In [9]:
def fetch_pubmed_sentences(term, max_results=NUM_MENTIONS):
    query = term.replace("_", " ")
    handle = Entrez.esearch(db="pubmed", term=query, retmax=max_results)
    record = Entrez.read(handle)
    ids = record["IdList"]
    print(f"[PubMed] {term}: {len(ids)} abstracts found.")
    if not ids:
        return []
    handle = Entrez.efetch(db="pubmed", id=','.join(ids), rettype="abstract", retmode="text")
    abstracts = handle.read().split('\n\n')
    sentences = []
    for ab in abstracts:
        sents = re.split(r'(?<=[.!?]) +', ab)
        for s in sents:
            if query in s.lower():
                sentences.append(s.strip())
            if len(sentences) >= max_results:
                return sentences
    return sentences

# **Mask Sentences**

This function takes an option and a list of sentences. It replaces all occurrences of the option within each sentence with a MASK_TOKEN [MASK]. This is useful for generating masked language model embeddings where the context around the masked term is learned

In [10]:
def mask_sentences(option, sentences):
    masked = []
    raw = []
    for s in sentences:
        if option.replace("_", " ") in s.lower():
            pattern = re.compile(re.escape(option.replace("_", " ")), re.IGNORECASE)
            masked_sent = pattern.sub(MASK_TOKEN, s)
            masked.append(masked_sent)
            raw.append(s)
    if masked:
        print(f"[Step 3] Example masked sentence for '{option}':\n  Raw: {raw[0]}\n  Masked: {masked[0]}")
    else:
        print(f"[Step 3] No valid masked sentences found for '{option}'.")

    return masked

# **Extract Mask Embeddings**

this function generates embeddings for the MASK_TOKEN. For each sentence, it identifies the position of the mask token and extracts its hidden state representation from the model's output. These mask embeddings represent the contextual meaning of the masked option

In [11]:
def extract_mask_embeddings(sentences, tokenizer, model, device):
    mask_vecs = []
    for sent in sentences:
        inputs = tokenizer(sent, return_tensors="pt", truncation=True, max_length=MAX_LEN).to(device)
        with torch.no_grad():
            outputs = model(**inputs).last_hidden_state
        mask_index = (inputs["input_ids"] == tokenizer.mask_token_id).nonzero(as_tuple=True)[1]
        if len(mask_index) == 0:
            continue
        mask_vec = outputs[0, mask_index[0]]
        mask_vecs.append(mask_vec)
    return mask_vecs

# **Average Embedding Function**
helper function that takes a list of PyTorch embeddings and returns their mean

In [12]:
def average_embedding(embeddings):
    if not embeddings:
        return None
    stacked = torch.stack(embeddings)
    return stacked.mean(dim=0)

#**Top K Closest Sentences**

This function finds the k sentences most similar to a given avg_vector

In [13]:
def top_custom_closest_sentences(avg_vector, sentences, tokenizer, model, device):
    embeddings = []
    valid_sentences = []

    print(f"\n🔢 Total input sentences: {len(sentences)}")
    print("🧠 Encoding sentences and extracting CLS embeddings...")

    # Step 1: Encode sentences and collect CLS embeddings
    for idx, sent in enumerate(sentences):
        inputs = tokenizer(sent, return_tensors="pt", truncation=True, max_length=MAX_LEN).to(device)
        with torch.no_grad():
            outputs = model(**inputs).last_hidden_state
        cls_vec = outputs[:, 0, :].squeeze(0)  # Use CLS token as sentence embedding
        embeddings.append(cls_vec)
        valid_sentences.append(sent)
        print(f"✅ Encoded sentence {idx + 1}: {sent[:60]}...")  # Truncate long sentences

    if not embeddings:
        print("❌ No valid sentence embeddings found.")
        return []

    # Step 2: Compute cosine similarities
    all_embs = torch.stack(embeddings)
    avg_vector = avg_vector.unsqueeze(0)
    similarities = cosine_similarity(avg_vector.cpu().numpy(), all_embs.cpu().numpy())[0]

    print("\n📊 Cosine similarities with avg_vector:")
    for i, (sim, sent) in enumerate(zip(similarities, valid_sentences)):
        print(f"{i + 1:2d}) Similarity: {sim:.4f} | Sentence: {sent[:60]}...")

    # Step 3: Get sorted indices (descending order of similarity)
    sorted_indices = similarities.argsort()[::-1]

    print("\n🏅 Sorted sentence indices by similarity:")
    print(sorted_indices[:20])  # Print top 20 for inspection

    # Step 4: Build custom positions: [0, 1, 10, 11, 20, 21, 30, 31, 40, 41]
    custom_positions = []
    for i in range(0, NUM_MENTIONS, 10):
        custom_positions.extend([i, i + 1])
    print(f"\n🔢 Custom positions to extract: {custom_positions}")

    # Step 5: Select the corresponding indices (avoid going out of bounds)
    selected_indices = [sorted_indices[i] for i in custom_positions if i < len(sorted_indices)]
    print(f"✅ Selected top indices: {selected_indices}")

    # Step 6: Return the selected sentences
    return [valid_sentences[i] for i in selected_indices]

# **Save Vectors and Senteces into a JSON File**

This utility function saves processed data to a JSON file

In [14]:
def save_vectors_and_sentences_json(filename, data_dict):
    # Convert numpy arrays to lists for JSON serialization
    serializable_data = {
        key: {
            "avg_vector": value["avg_vector"].tolist(),
            "top_sentences": value["top_sentences"]
        }
        for key, value in data_dict.items()
    }
    with open(filename, "w") as f:
        json.dump(serializable_data, f, indent=2)
    print(f"[Saved] Data written to {filename}")

In [16]:
def main():

    # Set device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # Load tokenizer and model
    tokenizer = AutoTokenizer.from_pretrained(ENCODER_MODEL)
    model = AutoModel.from_pretrained(ENCODER_MODEL).to(device)
    model.eval()

    umls_meta_folder_path = "/content/drive/MyDrive/concepts" # Make sure this path is correct in your environment

    # --- Step 1: Load options from medqa
    print("Step 1: Loading options from medqa...")
    medical_options_from_medqa = load_unique_options("/content/medqal.jsonl")
    print(f"Loaded {len(medical_options_from_medqa)} unique options.")

    # --- Step 2: Load English concepts ---
    print("Step 2: Loading English concepts...")
    Entrez.email = "your.email@example.com" # Replace with your actual email for PubMed access
    cui_to_name = load_english_concepts(umls_meta_folder_path)
    print(f"Loaded {len(cui_to_name)} English concepts.")

    # --- Step 3: Extract and format triples ---
    all_results_for_json = {}

    if cui_to_name: # Only proceed if concepts were loaded successfully
        print("Step 3: Extracting and formatting triples from UMLS...")
        extracted_triples = extract_and_format_triples(medical_options_from_medqa, umls_meta_folder_path, cui_to_name, max_triples_per_option=NUM_MENTIONS)
        print(f"Extracted triples for {len(extracted_triples)} options from UMLS.")

        for option in tqdm(medical_options_from_medqa, desc="Processing options"):
            all_sentences_for_option = set()

            # Add triples from UMLS first
            if option in extracted_triples:
                for triple in extracted_triples[option]:
                    all_sentences_for_option.add(triple)

            # If less than NUM_MENTIONS triples, fetch from PubMed
            if len(all_sentences_for_option) < NUM_MENTIONS:
                pubmed_needed = NUM_MENTIONS - len(all_sentences_for_option)
                print(f"  [Info] For '{option}': {len(all_sentences_for_option)} triples found from UMLS. Fetching {pubmed_needed} from PubMed.")
                pubmed_sentences = fetch_pubmed_sentences(option, max_results=pubmed_needed)
                for sent in pubmed_sentences:
                    all_sentences_for_option.add(sent)

            list_all_sentences = list(all_sentences_for_option)
            print(f"  [Info] Total sentences for '{option}': {len(list_all_sentences)}")

            if not list_all_sentences:
                print(f"  [Warning] No sentences found for option: {option}. Skipping.")
                continue

            # Mask sentences
            masked_sentences = mask_sentences(option, list_all_sentences)

            if not masked_sentences:
                print(f"  [Warning] No masked sentences generated for option: {option}. Skipping embedding generation.")
                continue

            # Extract mask embeddings
            mask_embeddings = extract_mask_embeddings(masked_sentences, tokenizer, model, device)

            if not mask_embeddings:
                print(f"  [Warning] No mask embeddings generated for option: {option}. Skipping.")
                continue

            # Average embeddings
            avg_mask_embedding = average_embedding(mask_embeddings)

            if avg_mask_embedding is None:
                print(f"  [Warning] Average embedding could not be computed for option: {option}. Skipping.")
                continue

            # Find top K closest sentences
            top_sentences = top_custom_closest_sentences(avg_mask_embedding, list_all_sentences, tokenizer, model, device)

            all_results_for_json[option] = {
                "avg_vector": avg_mask_embedding.cpu().numpy(),
                "top_sentences": top_sentences
            }
            print(f"  [Done] Processed '{option}'. Top 10 sentences retrieved.")
    else:
        print("Skipping triple extraction as English concepts were not loaded.")

    # Save results to JSON
    output_filename = "optionll_embeddings_and_sentences.json"
    save_vectors_and_sentences_json(output_filename, all_results_for_json)
    print("Main function execution completed.")

In [None]:
if __name__ == "__main__":
    main()