Import Libraries

In [2]:
import numpy as np
import json
import matplotlib.pyplot as plt
import pandas as pd
from openai import OpenAI
import pandas as pd
import re

In [52]:
def response_zeroshot(model_id,client, clinical_note):
    prompt = f"""
    You are an expert in extracting phenotype terms and their corresponding HPO IDs from clinical notes. Your task is to identify all relevant phenotype terms mentioned in the given clinical note and map them to their correct HPO IDs.

    Output the results in this format:
    <phenotype term> | <HPO ID>
    
    Input Clinical Note:
    {clinical_note}
    
    Output:
    """
    
    # Send the request to GPT-4
    response = client.chat.completions.create(
    model=model_id,
    messages=[
        {
            "role": "system",
            "content": "You are an AI assistant specialized in extracting medical phenotypes, syndromes, diseases and mapping them to HPO IDs."
        },
        {
            "role": "user",
            "content": prompt
        }
    ]
)
      # Extract and print the output
    output = response.choices[0].message.content
    def post_process(output):
        # Split the output into lines
        lines = output.strip().split("\n")
        
        # Extract the HPO terms and IDs
        hpo_names,ids=[],[]
        for line in lines:
            if "|" in line:
                name, hpo_id = line.split("|")
                hpo_names.append(name.strip())
                ids.append(hpo_id.strip())
        return hpo_names,ids
    A,B=post_process(output)
    return A,B    # return names and ids

def response_finetuning(model_id,client,query):
    response = client.chat.completions.create(
            model=model_id,  # Replace with your fine-tuned model ID
            messages=[
                {
                    "role": "system",
                    "content": "You are an assistant specialized in extracting phenotype terms and their corresponding HPO IDs from clinical notes."
                },
                {
                    "role": "user",
                    "content": f"Extract relevant phenotype terms and their HPO IDs from the following clinical note:\n{query}"
                }
            ]
        )
    x= response.choices[0].message.content
    def parse_gpt_output(text):
        text = text.strip()
        terms = text.split(",")  # Split by commas for predictions
        parsed = []
        ids=[]
        for term in terms:
            term = term.strip()
            if "(" in term and ")" in term:
                # Split only on the first "(" to avoid unpacking issues
                name,_ = term.split("(", 1)
                id_ = term.split("(")[1].split(")")[0]
                parsed.append(name.strip().lower())
                ids.append(id_)
        ids=[id.replace('_', ':') for id in ids]
        return parsed,ids
    A,B=parse_gpt_output(x)
    return A,B



def cosine_similarity(v1, v2):
    return np.dot(v1, v2) / (np.linalg.norm(v1) * np.linalg.norm(v2))

def get_top_k_similarities(embeddings, query, k):
    similarities = np.array([cosine_similarity(query, embedding) for embedding in embeddings])
    top_k_indices = similarities.argsort()[-k:][::-1]
    return top_k_indices, similarities[top_k_indices]


In [60]:
h2g=pd.read_csv('2024/phenotype_to_genes.txt',sep='\t')
# Select specific columns from the DataFrame
h2g = h2g[['hpo_id', 'gene_symbol']]
h2g= h2g.drop_duplicates(subset=['hpo_id', 'gene_symbol'])
# create a dictionary with hpo_id as key and it was associated with many gene symbols, so append it as list of strings
hpo_gene_dict = {}
for i in range(len(h2g)):
    if h2g['hpo_id'].iloc[i] in hpo_gene_dict:
        hpo_gene_dict[h2g['hpo_id'].iloc[i]].append(h2g['gene_symbol'].iloc[i])
    else:
        hpo_gene_dict[h2g['hpo_id'].iloc[i]] = [h2g['gene_symbol'].iloc[i]]

print("hp to genes dictionary created")

hpo_dict = {}
with open("2024/hpo_dict.txt") as f:
    for line in f:
        (key, val) = line.split(":H")
        # remove newline character  
        val = val.rstrip()
        hpo_dict[key] = "H"+val

print("hpo dictionary created")

# Load the embeddings

df=pd.read_csv('2024/HPO_embeddings_40k.csv')
db= df.to_numpy()
print("embeddings loaded")

# child parent dict
child_parent_dict = {}
with open("2024/child_parent_dict_merged.json") as f:
    child_parent_dict = json.load(f)

print("child parent dictionary created")

hp to genes dictionary created
hpo dictionary created
embeddings loaded
child parent dictionary created


In [9]:
import numpy as np

def process_embeddings(content, db, hpo_dict, client):
    """
    Process content to compute embeddings, find top 1 similar rows in db, and map to HPO terms and IDs.
    
    Args:
    - content: List of lists of text labels.
    - db: Database of embeddings for similarity comparison.
    - hpo_dict: Dictionary mapping HPO terms to IDs.
    - client: API client for generating embeddings.
    
    Returns:
    - Sep: List of lists of unique HPO terms and IDs, grouped by the number of terms in each sublist of content.
    """
    # Step 1: Calculate total number of terms and initialize embedding array
    total_terms = sum(len(i) for i in content)
    embeddings_20 = np.zeros((total_terms, 1536))

    # Step 2: Generate embeddings for each term
    z = 0
    for i in range(len(content)):
        for term in content[i]:
            response = client.embeddings.create(
                input=term,
                model="text-embedding-3-small"
            )
            embeddings_20[z] = response.data[0].embedding
            z += 1
            if z % 500 == 0:
                print(f"Processed {z} embeddings")

    # Step 3: Find top 1 similar rows in db
    top_1_indices = []
    for i in range(total_terms):
        top_1_indices.append(get_top_k_similarities(db, embeddings_20[i], 1)[0][0])
        if i % 50 == 0:
            print(f"Processed similarity for row {i}")

    # Step 4: Retrieve HPO terms and IDs
    hpo_terms = [
        f"{list(hpo_dict.keys())[index]} {list(hpo_dict.values())[index]}"
        for index in top_1_indices
    ]

    # Step 5: Split HPO terms into sublists based on the number of terms in content
    no_ids = [len(i) for i in content]
    Sep = []
    start = 0
    for num_terms in no_ids:
        Sep.append(list(set(hpo_terms[start:start + num_terms])))
        start += num_terms

    return Sep


### Choose the model:

#### ---------------------------------------Finetuned models---------------------------------------
1) GPT4o-mini-2024-07-18 : ft:gpt-4o-mini-2024-07-18:iisc-bangalore::AYf5TC9S
2) GPT4o-2024-08-06      : ft:gpt-4o-2024-08-06:iisc-bangalore::AZ03ME6y

#### ------------------------------------Base models for zeroshot----------------------------------

3) GPT4o-mini-2024-07-18 : gpt-4o-mini-2024-07-18
4) GPT4o-2024-08-06      : gpt-4o-2024-08-06


In [6]:
# openai api key
client = OpenAI(api_key="XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX")

In [58]:
clinical_text="Branchio-oto-renal dysplasia, often called the BOR syndrome, in its full expression consists of hearing loss of conductive, sensorineural, or mixed type; preauricular pits; auricular deformities; lateral cervical sinuses, cysts, or fistulas; and renal malformations. The condition is inherited in an autosomal dominant mode. The findings in three affected families are described, and pertinent genetic and clinical aspects are discussed. The potential seriousness of the renal and aural malformations stresses the importance of early recognition of this syndrome."
# spans, HP_ids =response_finetuning("ft:gpt-4o-mini-2024-07-18:iisc-bangalore::AYf5TC9S",client,clinical_text)
spans, HP_ids =response_zeroshot("gpt-4o-mini-2024-07-18",client,clinical_text)
print("phenotypes extracted from clinical text")
print(spans)
print("The corresponding HPO IDs are:")
print(HP_ids)


phenotypes extracted from clinical text
['Branchio-oto-renal dysplasia', 'Hearing loss', 'Preauricular pits', 'Auricular deformities', 'Lateral cervical sinuses', 'Cysts', 'Fistulas', 'Renal malformations']
The corresponding HPO IDs are:
['HP:0000404', 'HP:0000365', 'HP:0000249', 'HP:0008761', 'HP:0000480', 'HP:0011462', 'HP:0000289', 'HP:0000118']


Normalization using embeddings

In [10]:
embed_spans=process_embeddings([spans], db, hpo_dict, client)               # pass lists of list of spans for multiple clinical notes
embed_spans

Processed similarity for row 0


[['hearing loss HP:0000365',
  'preauricular pits HP:0004467',
  'autosomal dominant HP:0000006',
  'hepatic cysts HP:0001407',
  'lateral neck mass HP:6000174',
  'anal fistula HP:0010447',
  'sensorineural hearing loss HP:0000407',
  'renal malformation HP:0012210',
  'small cervical vertebrae HP:0004629',
  'dichromacy HP:0011518',
  'auricular malformation HP:0000377',
  'bilateral renal dysplasia HP:0012582',
  'obsolete heterogeneous HP:0001425']]

In [47]:
# extract HP ids after embeddings normalisation
HP_ids_embed=re.findall(r'HP:\d+',str(embed_spans[0]) )
print(HP_ids_embed)

['HP:0000365', 'HP:0004467', 'HP:0000006', 'HP:0001407', 'HP:6000174', 'HP:0010447', 'HP:0000407', 'HP:0012210', 'HP:0004629', 'HP:0011518', 'HP:0000377', 'HP:0012582', 'HP:0001425']


BioMED_NER

In [None]:
from transformers import pipeline
pipe = pipeline("token-classification", model="venkatd/BIOMed_NER", aggregation_strategy='simple')

Hardware accelerator e.g. GPU is available in the environment, but no `device` argument is passed to the `Pipeline` object. Model will be on CPU.


In [27]:
def get_spans(output):
    x = []
    i = 0
    while i < len(output):
        temp = ""
        if i + 1 < len(output) and output[i]['entity_group'] == 'DETAILED_DESCRIPTION' and output[i + 1]['entity_group'] == 'DISEASE_DISORDER':
            temp = output[i]['word'] + " " + output[i + 1]['word']
            x.append(temp)
            i += 2  # Skip the next element as it is already processed
            continue
        if i + 1 < len(output) and output[i]['entity_group'] == 'DETAILED_DESCRIPTION' and output[i + 1]['entity_group'] == 'SIGN_SYMPTOM':
            temp = output[i]['word'] + " " + output[i + 1]['word']
            x.append(temp)
            i += 2  # Skip the next element as it is already processed
            continue
        if i+1< len(output) and output[i]['entity_group'] == 'BIOLOGICAL_STRUCTURE' and output[i+1]['entity_group'] == 'SIGN_SYMPTOM':
            temp = output[i]['word'] + " " + output[i+1]['word']
            x.append(temp)
            i += 2
            continue
        if output[i]['entity_group'] == 'SIGN_SYMPTOM':
            x.append(output[i]['word'])
        if output[i]['entity_group'] == 'DISEASE_DISORDER':
            x.append(output[i]['word'])
        i += 1
    return x

In [28]:
output = pipe(clinical_text)
NER_spans = get_spans(output)
print(NER_spans)

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


['Branchio-oto-renal dysplasia', 'BOR syndrome', 'hearing loss', 'preauricular pits', 'auricular deformities', 'lateral cervical sinuses cysts', 'fistulas', 'renal malformations']


Normalisation using embeddings

In [29]:
embed_spans=process_embeddings([NER_spans], db, hpo_dict, client)               # pass lists of list of spans for multiple clinical notes
embed_spans

Processed similarity for row 0


[['renal malformation HP:0012210',
  'auricular malformation HP:0000377',
  'preauricular pits HP:0004467',
  'compressed lymph-node sinuses HP:0020268',
  'bilateral renal dysplasia HP:0012582',
  'blesovsky syndrome HP:0033658',
  'anal fistula HP:0010447',
  'hearing loss HP:0000365']]

ACMG 81 GENES

In [5]:
genes = [
    "APC", "RET", "BRCA1", "BRCA2", "PALB2", "SDHD", "SDHAF2", "SDHC", "SDHB", 
    "MAX", "TMEM127", "BMPR1A", "SMAD4", "TP53", "MLH1", "MSH2", "MSH6", "PMS2", 
    "MEN1", "MUTYH", "NF2", "FBN1", "TGFBR1", "TGFBR2", "SMAD3", "ACTA2", "MYH11", 
    "PKP2", "DSP", "DSC2", "TMEM43", "DSG2", "RYR2", "CASQ2", "TRDN", "BAG3", 
    "DES", "RBM20", "TNNC1", "TNNT2", "LMNA", "FLNC", "TTN", "CALM1", "CALM2", 
    "CALM3", "COL3A1", "LDLR", "APOB", "PCSK9", "MYH7", "MYBPC3", "TNNI3", 
    "TPM1", "MYL3", "ACTC1", "PRKAG2", "MYL2", "KCNQ1", "KCNH2", "SCN5A", "BTD", 
    "GLA", "OTC", "GAA", "STK11", "HFE", "ACVRL1", "ENG", "RYR1", "CACNA1S", 
    "HNF1A", "RPE65", "ATP7B", "TTR", "PTEN", "RB1", "TSC1", "TSC2", "VHL", "WT1"
]


In [14]:
embed_spans[0]

['hearing loss HP:0000365',
 'preauricular pits HP:0004467',
 'autosomal dominant HP:0000006',
 'hepatic cysts HP:0001407',
 'lateral neck mass HP:6000174',
 'anal fistula HP:0010447',
 'sensorineural hearing loss HP:0000407',
 'renal malformation HP:0012210',
 'small cervical vertebrae HP:0004629',
 'dichromacy HP:0011518',
 'auricular malformation HP:0000377',
 'bilateral renal dysplasia HP:0012582',
 'obsolete heterogeneous HP:0001425']

### Genes extracted after embedding normalisation

In [44]:
# get the gene symbols associated with the extracted HP ids
genes_associated = []
for hp_id in HP_ids_embed:
    if hp_id in hpo_gene_dict:
        genes_associated.extend(hpo_gene_dict[hp_id])
genes_associated = list(set(genes_associated))
print("Genes extracted after embedding normalisation")
print("-------------------------------------------------------------------------------------------")
print(f"The total number of genes associated with the extracted HPO IDs are : {len(genes_associated)}")
print("-------------------------------------------------------------------------------------------")
print(genes_associated)

Genes extracted after embedding normalisation
-------------------------------------------------------------------------------------------
The total number of genes associated with the extracted HPO IDs are : 3544
-------------------------------------------------------------------------------------------
['KCNA2', 'TRIM44', 'BLVRA', 'DEPDC5', 'SI', 'GJA8', 'XPR1', 'OFD1', 'YME1L1', 'EGR2', 'FBXO11', 'NDUFS4', 'TMEM107', 'POMGNT2', 'PDZD8', 'FRMD4A', 'CREB3L1', 'RNF220', 'LEMD2', 'MYL9', 'SARDH', 'UPF3B', 'STX11', 'IRF3', 'GFI1B', 'MAPT', 'POLR1D', 'AUTS2', 'SDHB', 'MINAR2', 'PER2', 'STX1A', 'GCH1', 'NAA10', 'ATP2C1', 'GATAD1', 'ABCC6', 'DISC2', 'SPTBN1', 'ARHGAP31', 'NUP214', 'SPINT2', 'BLTP1', 'RYR2', 'UGDH', 'MET', 'COL2A1', 'NDUFB10', 'PRSS1', 'NRIP1', 'PANX1', 'CHD2', 'ZNF408', 'SMOC1', 'B9D2', 'C1GALT1C1', 'ILDR1', 'CSF3R', 'MYO3A', 'HNF1B', 'VAMP1', 'COL8A2', 'HBG1', 'KRT17', 'TMIE', 'COG3', 'PRG4', 'MCTP2', 'NAF1', 'CACNB4', 'PMVK', 'DDR2', 'CEP152', 'ZNF644', 'TNNC1', 'BEST1', '

### Genes extacted after embedding normalisation + additional of parent ids

In [37]:
# Adding parent IDs to the extracted HPO IDs
HP_ids_embed_parents = []
for hp_id in HP_ids_embed:
    HP_ids_embed_parents.append(hp_id)
    if hp_id in child_parent_dict:
        HP_ids_embed_parents.append(child_parent_dict[hp_id])
print(f"HP ids after addition of parent ids: {HP_ids_embed_parents}")

HP ids after addition of parent ids: ['HP:0000365', 'HP:0000364', 'HP:0004467', 'HP:0100277', 'HP:0000006', 'HP:0001407', 'HP:6000174', 'HP:0010447', 'HP:0000407', 'HP:0012210', 'HP:0000077', 'HP:0004629', 'HP:0008479', 'HP:0011518', 'HP:0007641', 'HP:0000377', 'HP:0012582', 'HP:0000110', 'HP:0001425']


In [45]:
# get the gene symbols associated with the added parent HP ids
genes_associated_parents = []
for hp_id in HP_ids_embed_parents:
    if hp_id in hpo_gene_dict:
        genes_associated_parents.extend(hpo_gene_dict[hp_id])
genes_associated_parents = list(set(genes_associated_parents))
print("Genes extracted after adding parent IDs")
print("-------------------------------------------------------------------------------------------------------------")
print(f"The total number of genes associated with the extracted HPO IDs and their parent IDs are : {len(genes_associated_parents)}")
print("-------------------------------------------------------------------------------------------------------------")
print(genes_associated_parents)

Genes extracted after adding parent IDs
-------------------------------------------------------------------------------------------------------------
The total number of genes associated with the extracted HPO IDs and their parent IDs are : 3658
-------------------------------------------------------------------------------------------------------------
['KCNA2', 'MFSD8', 'SPARC', 'TRIM44', 'BLVRA', 'ALDOA', 'DEPDC5', 'SI', 'GJA8', 'XPR1', 'OFD1', 'YME1L1', 'EGR2', 'FBXO11', 'NDUFS4', 'TMEM107', 'POMGNT2', 'PDZD8', 'FRMD4A', 'CREB3L1', 'RNF220', 'LEMD2', 'MYL9', 'SARDH', 'UPF3B', 'STX11', 'IRF3', 'GFI1B', 'MAPT', 'POLR1D', 'AUTS2', 'SDHB', 'MINAR2', 'PER2', 'STX1A', 'GCH1', 'NAA10', 'ATP2C1', 'GATAD1', 'ABCC6', 'DISC2', 'SPTBN1', 'ARHGAP31', 'NUP214', 'SPINT2', 'BLTP1', 'RYR2', 'UGDH', 'MET', 'COL2A1', 'NDUFB10', 'PRSS1', 'NRIP1', 'PANX1', 'CHD2', 'ZNF408', 'SMOC1', 'B9D2', 'C1GALT1C1', 'ILDR1', 'CSF3R', 'MYO3A', 'HNF1B', 'VAMP1', 'COL8A2', 'GPR35', 'HBG1', 'KRT17', 'TMIE', 'COG3', 'PR

### Genes extacted after embedding normalisation + additional of parent ids + ACMG 81 genes addition

In [46]:
# Adding the ACMG 81 genes to the list of genes
final_genes = genes_associated_parents + genes
final_genes = list(set(final_genes))
print("Final list of genes after 3 optimizations: embeddings normalisation+parent ids+ACMG 81 genes")
print("------------------------------------------------------------------------------------------------------------------------")
print(f"The total number of genes associated with the extracted HPO IDs, their parent IDs, and the ACMG 81 genes are : {len(final_genes)}")
print("------------------------------------------------------------------------------------------------------------------------")
print(final_genes)


Final list of genes after 3 optimizations: embeddings normalisation+parent ids+ACMG 81 genes
------------------------------------------------------------------------------------------------------------------------
The total number of genes associated with the extracted HPO IDs, their parent IDs, and the ACMG 81 genes are : 3659
------------------------------------------------------------------------------------------------------------------------
['KCNA2', 'MFSD8', 'SPARC', 'TRIM44', 'BLVRA', 'ALDOA', 'DEPDC5', 'SI', 'GJA8', 'XPR1', 'EGR2', 'OFD1', 'YME1L1', 'TMEM107', 'FBXO11', 'NDUFS4', 'POMGNT2', 'FRMD4A', 'PDZD8', 'CREB3L1', 'RNF220', 'LEMD2', 'MYL9', 'SARDH', 'UPF3B', 'STX11', 'IRF3', 'GFI1B', 'MAPT', 'POLR1D', 'AUTS2', 'SDHB', 'MINAR2', 'PER2', 'STX1A', 'GCH1', 'NAA10', 'ATP2C1', 'GATAD1', 'ABCC6', 'DISC2', 'SPTBN1', 'ARHGAP31', 'NUP214', 'SPINT2', 'BLTP1', 'RYR2', 'UGDH', 'MET', 'COL2A1', 'NDUFB10', 'PRSS1', 'NRIP1', 'PANX1', 'CHD2', 'ZNF408', 'SMOC1', 'B9D2', 'C1GALT1C1', 'ILDR