In [154]:
import pandas as pd
import ast
import numpy as np

# Abstract level CT annotation

In [86]:
from abbreviations import schwartz_hearst

In [87]:
pairs = schwartz_hearst.extract_abbreviation_definition_pairs(doc_text='The emergency room (ER) was busy')
pairs

{'ER': 'emergency room'}

In [88]:
annotated_files_path_prefix = "./predictions/"

## Load BERT model annotations

In [89]:
hugging_face_model_name_biolink = "michiyasunaga/BioLinkBERT-base"
model_name_str_biolink = hugging_face_model_name_biolink.split("/")[1]

In [90]:
annotated_files_path = "data/annotated_aact/ct_neuro_test_annotated_BioLinkBERT-base_20240311.csv"
annotated_files_path_second_batch = "data/annotated_aact/ct_neuro_test_annotated_BioLinkBERT-base_20240320.csv"

In [91]:
biolinkbert_col = f'ner_prediction_{model_name_str_biolink}_normalized'
df_first_batch = pd.read_csv(annotated_files_path)[['nct_id', 'text', biolinkbert_col]]
df_second_batch = pd.read_csv(annotated_files_path_second_batch)[['nct_id', 'text', biolinkbert_col]]
df_first_batch.shape, df_second_batch.shape

((32803, 3), (13573, 3))

In [92]:
df = pd.concat([df_first_batch, df_second_batch])
df.shape

(46376, 3)

In [93]:
df.head(10)

Unnamed: 0,nct_id,text,ner_prediction_BioLinkBERT-base_normalized
0,NCT02970292,"A Phase 3, Randomized, Double-Blind, Placebo-C...","[(37, 44, 'CONTROL', 'placebo'), (112, 124, 'D..."
1,NCT03767426,The Effect of Sleep Deprivation and Recovery S...,"[(36, 50, 'OTHER', 'recovery sleep'), (161, 17..."
2,NCT03941067,Effects of Pre-event Massage Over the Neuromus...,"[(11, 28, 'OTHER', 'pre - event massage'), (12..."
3,NCT03542357,The Effect of Sumatriptan and Placebo on CGRP ...,"[(14, 25, 'DRUG', 'sumatriptan'), (30, 37, 'CO..."
4,NCT02776553,A Physical Activity Program in End-stage Liver...,"[(2, 27, 'PHYSICAL', 'physical activity progra..."
5,NCT03034330,A Two-Tier Care Management Program to Empower ...,"[(2, 45, 'BEHAVIOURAL', 'two - tier care manag..."
6,NCT00611559,Immunogenicity and Reactogenicity Study of a N...,"[(81, 94, 'DRUG', 'dtpa - hbv - ipv /'), (94, ..."
7,NCT02746510,Validation of a Clinical Screening Grid for Sy...,"[(44, 67, 'CONDITION', 'syndromic schizophreni..."
8,NCT03656770,Measuring Beliefs and Norms About Persons With...,"[(47, 61, 'CONDITION', 'mental illness'), (152..."
9,NCT00963898,The Clinical Efficacy of the Combination Targe...,"[(41, 67, 'OTHER', 'target controlled infusion..."


In [94]:
# aggregate annotations and sum up how often they were annotated
def extract_summary(annotation_list):
    annotation_list = eval(annotation_list)
    summary = {}
    for annotation in annotation_list:
        _, _, entity_type, entity_name = annotation
        entity_name = entity_name.lower()
        if entity_type not in summary:
            summary[entity_type] = {}
        if entity_name not in summary[entity_type]:
            summary[entity_type][entity_name] = 0
        summary[entity_type][entity_name] += 1
    return summary

In [95]:
# Define a function to extract the unique conditions, drugs, and others from the 'ner_manual_final_annotated_ds' column
def extract_unique_entities_count(annotation_list, abbreviation_definition_pairs):
    unique_conditions = set()
    unique_drugs = set()
    unique_others = set()
    annotation_list = eval(annotation_list)
    for annotation in annotation_list:
        _, _, entity_type, entity_name = annotation
        if entity_name in abbreviation_definition_pairs:
            #print("Skipping entity {} as it is an ABBR".format(entity_name))
            continue
        entity_name = entity_name.lower()
        if entity_type == 'CONDITION':
            unique_conditions.add(entity_name)
        elif entity_type == 'DRUG':
            unique_drugs.add(entity_name)
        elif entity_type == 'OTHER':
            unique_others.add(entity_name)
    return len(unique_conditions), len(unique_drugs), len(unique_others)

def extract_unique_entities(nct_id, annotation_list, abbreviation_definition_pairs, model="linkbert", keep_drug_interventions_only=True):
    unique_conditions = set()
    unique_interventions = set()
    interventions_type = set()
   
    try:
        annotation_list = eval(annotation_list)
    except SyntaxError as e:
        print(nct_id)
        print("Syntax error in eval:", e)
        return "issues processing line"
    
    for annotation in annotation_list:
        _, _, entity_type, entity_name = annotation
        if entity_name.startswith("##"):
            continue ## THERE IS AN ISSUE WITH BIOBERT and BERT
        if (len(entity_name) == 1 or len(entity_name)==2) and model=="biobert":
            continue ## ASSUME TOKENIZER ERROR IN BIOBERT
        # REPLACE ABBREVIATIONS WITH FULL FORM
        if entity_name in abbreviation_definition_pairs:
            #print("Skipping entity {} as it is an ABBR".format(entity_name))
            entity_name = abbreviation_definition_pairs[entity_name] 
            #continue
        if entity_name.upper() in abbreviation_definition_pairs:
            #print("Skipping entity {} as it is an ABBR".format(entity_name))
            entity_name = abbreviation_definition_pairs[entity_name.upper()] 
        entity_name = entity_name.lower()
        if entity_type == 'CONDITION':
            unique_conditions.add(entity_name)
        elif keep_drug_interventions_only and entity_type == 'DRUG':
            unique_interventions.add(entity_name)
            interventions_type.add(entity_type)
        elif not keep_drug_interventions_only:
            unique_interventions.add(entity_name)
            interventions_type.add(entity_type)
        
    return "|".join(list(unique_conditions)), "|".join(list(unique_interventions)), "|".join(list(interventions_type))

# Placeholder function to demonstrate applying the Schwartz-Hearst algorithm (Replace with actual implementation)
def extract_abbreviation_definition_pairs(doc_text):
    pairs = schwartz_hearst.extract_abbreviation_definition_pairs(doc_text=doc_text)
    return pairs

In [96]:
# Add a new column 'BERT failed' and initialize it with 0
df['BERT failed'] = 0

# Replace the rows containing 'Failed NER extraction!' with empty strings and set 'BERT failed' to 1
mask = df['ner_prediction_BioLinkBERT-base_normalized'] == 'Failed NER extraction!'
df.loc[mask, 'ner_prediction_BioLinkBERT-base_normalized'] = ''
df.loc[mask, 'BERT failed'] = 1

In [97]:
df.shape

(46376, 4)

In [98]:
# Create a new column 'abbreviation_definition_pairs' using the 'apply' function
df['abbreviation_definition_pairs'] = df['text'].apply(extract_abbreviation_definition_pairs)

# Apply the function to each row and create new columns 'num_unique_conditions', 'num_unique_drugs', and 'num_unique_others'
df[f'unique_conditions_{model_name_str_biolink}_predictions'], df[f'unique_interventions_{model_name_str_biolink}_predictions'], df[f'unique_interventions_type_{model_name_str_biolink}_predictions'] = zip(*df.apply(lambda row: extract_unique_entities(row['nct_id'], row[biolinkbert_col], row['abbreviation_definition_pairs']), axis=1))

#df['num_unique_conditions'], df['num_unique_drugs'], df['num_unique_others'] = zip(*df.apply(lambda row: extract_unique_entities_count(row[col_name_target_annot], row['abbreviation_definition_pairs']), axis=1))


NCT03268187
Syntax error in eval: invalid syntax (<string>, line 0)
NCT04176302
Syntax error in eval: invalid syntax (<string>, line 0)
NCT03810898
Syntax error in eval: invalid syntax (<string>, line 0)
NCT02554487
Syntax error in eval: invalid syntax (<string>, line 0)
NCT06213766
Syntax error in eval: invalid syntax (<string>, line 0)
NCT03442166
Syntax error in eval: invalid syntax (<string>, line 0)
NCT04538521
Syntax error in eval: invalid syntax (<string>, line 0)
NCT03582293
Syntax error in eval: invalid syntax (<string>, line 0)
NCT06140355
Syntax error in eval: invalid syntax (<string>, line 0)
NCT02118610
Syntax error in eval: invalid syntax (<string>, line 0)
NCT05834855
Syntax error in eval: invalid syntax (<string>, line 0)
NCT01201967
Syntax error in eval: invalid syntax (<string>, line 0)
NCT01744548
Syntax error in eval: invalid syntax (<string>, line 0)
NCT05491122
Syntax error in eval: invalid syntax (<string>, line 0)


In [99]:
df.head(2)

Unnamed: 0,nct_id,text,ner_prediction_BioLinkBERT-base_normalized,BERT failed,abbreviation_definition_pairs,unique_conditions_BioLinkBERT-base_predictions,unique_interventions_BioLinkBERT-base_predictions,unique_interventions_type_BioLinkBERT-base_predictions
0,NCT02970292,"A Phase 3, Randomized, Double-Blind, Placebo-C...","[(37, 44, 'CONTROL', 'placebo'), (112, 124, 'D...",0,{},schizophrenia,pimavanserin,DRUG
1,NCT03767426,The Effect of Sleep Deprivation and Recovery S...,"[(36, 50, 'OTHER', 'recovery sleep'), (161, 17...",0,"{'ii': 'information,'}",emotional distress|stress,,


In [100]:
df_unique_labels = df[['nct_id',
                       f'unique_conditions_{model_name_str_biolink}_predictions', f'unique_interventions_{model_name_str_biolink}_predictions', f'unique_interventions_type_{model_name_str_biolink}_predictions']]

In [101]:
df_unique_labels.head(2)

Unnamed: 0,nct_id,unique_conditions_BioLinkBERT-base_predictions,unique_interventions_BioLinkBERT-base_predictions,unique_interventions_type_BioLinkBERT-base_predictions
0,NCT02970292,schizophrenia,pimavanserin,DRUG
1,NCT03767426,emotional distress|stress,,


In [102]:
# Function to remove spaces around ' and -
def remove_spaces_around_apostrophe_and_dash(text):
    text = text.replace(" ' ", "'")  # Remove spaces around '
    text = text.replace("' s", "'s")  # Remove spaces around '
    text = text.replace(" - ", "-")  # Remove spaces around -
    text = text.replace(" / ", "/")  # Remove spaces around /
    text = text.replace("( ", "(")  # Remove spaces around (
    text = text.replace(" )", ")")  # Remove spaces around -
    return text

df_unique_labels[f'unique_conditions_{model_name_str_biolink}_predictions'] = df_unique_labels[f'unique_conditions_{model_name_str_biolink}_predictions'].apply(remove_spaces_around_apostrophe_and_dash)

df_unique_labels[f'unique_interventions_{model_name_str_biolink}_predictions'] = df_unique_labels[f'unique_interventions_{model_name_str_biolink}_predictions'].apply(remove_spaces_around_apostrophe_and_dash)


A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df_unique_labels[f'unique_conditions_{model_name_str_biolink}_predictions'] = df_unique_labels[f'unique_conditions_{model_name_str_biolink}_predictions'].apply(remove_spaces_around_apostrophe_and_dash)
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df_unique_labels[f'unique_interventions_{model_name_str_biolink}_predictions'] = df_unique_labels[f'unique_interventions_{model_name_str_biolink}_predictions'].apply(remove_spaces_around_apostrophe_and_dash)


In [103]:
df_unique_labels.head(2)

Unnamed: 0,nct_id,unique_conditions_BioLinkBERT-base_predictions,unique_interventions_BioLinkBERT-base_predictions,unique_interventions_type_BioLinkBERT-base_predictions
0,NCT02970292,schizophrenia,pimavanserin,DRUG
1,NCT03767426,emotional distress|stress,,


### add AACT

In [104]:
set(df_aact_labels['aact_intervention_types'])

{'Behavioral',
 'Biological',
 'Combination Product',
 'Device',
 'Diagnostic Test',
 'Dietary Supplement',
 'Drug',
 'Genetic',
 'Other',
 'Procedure',
 'Radiation'}

In [105]:
df_aact_labels= pd.read_csv("./data/combined_neuro_trials_with_interventions_20240325.csv")
#df_aact_labels_2 = pd.read_csv("../data/data_aact_sample/aact_neuro_samples_second_batch_202309171159_annotated.csv")
#df_aact_labels = pd.concat([df_aact_labels_1, df_aact_labels_2], ignore_index=True)
df_aact_labels.rename(columns={'Neurological Disease': 'aact_conditions'}, inplace=True)
df_aact_labels.rename(columns={'intervention_name': 'aact_intervention_names'}, inplace=True)
df_aact_labels.rename(columns={'intervention_type': 'aact_intervention_types'}, inplace=True)

# Function to replace values in aact_intervention_names based on aact_intervention_types
def replace_values(row):
    if 'Drug' not in row['aact_intervention_types'] and 'Genetic' not in row['aact_intervention_types'] and 'Biological' not in row['aact_intervention_types'] and 'Dietary Supplement' not in row['aact_intervention_types']:
        return ""
    else:
        return row['aact_intervention_names']

# Apply the custom function to replace values in aact_intervention_names
df_aact_labels['aact_intervention_names'] = df_aact_labels.apply(replace_values, axis=1)
df_aact_labels['aact_intervention_names'] = df_aact_labels['aact_intervention_names'].str.replace('|Placebo|', '')
df_aact_labels['aact_intervention_names'] = df_aact_labels['aact_intervention_names'].str.replace('|Placebo', '')
df_aact_labels['aact_intervention_names'] = df_aact_labels['aact_intervention_names'].str.replace('Placebo|', '')

df_aact_labels.head(10)

Unnamed: 0,nct_id,aact_conditions,Disease Class,brief_title,study_official_title,brief_summary_description,start_date,completion_date,phase,study_type,overall_status,country_name,aact_intervention_names,aact_intervention_types
0,NCT03171649,Stroke,Central Nervous System Diseases,Reaching Training Based on Robotic Hybrid Assi...,Reaching Training Based on Robotic Hybrid Assi...,Stroke is the third most common cause of death...,2016-11-10,2018-05-31,Not Applicable,Interventional,Unknown status,Germany,,Device
1,NCT03171649,Stroke,Central Nervous System Diseases,Reaching Training Based on Robotic Hybrid Assi...,Reaching Training Based on Robotic Hybrid Assi...,Stroke is the third most common cause of death...,2016-11-10,2018-05-31,Not Applicable,Interventional,Unknown status,Italy,,Other
2,NCT03171649,Stroke,Central Nervous System Diseases,Reaching Training Based on Robotic Hybrid Assi...,Reaching Training Based on Robotic Hybrid Assi...,Stroke is the third most common cause of death...,2016-11-10,2018-05-31,Not Applicable,Interventional,Unknown status,Italy,,Other
3,NCT03171649,Stroke,Central Nervous System Diseases,Reaching Training Based on Robotic Hybrid Assi...,Reaching Training Based on Robotic Hybrid Assi...,Stroke is the third most common cause of death...,2016-11-10,2018-05-31,Not Applicable,Interventional,Unknown status,Italy,,Other
4,NCT03171649,Stroke,Central Nervous System Diseases,Reaching Training Based on Robotic Hybrid Assi...,Reaching Training Based on Robotic Hybrid Assi...,Stroke is the third most common cause of death...,2016-11-10,2018-05-31,Not Applicable,Interventional,Unknown status,Germany,,Other
5,NCT03171649,Stroke,Central Nervous System Diseases,Reaching Training Based on Robotic Hybrid Assi...,Reaching Training Based on Robotic Hybrid Assi...,Stroke is the third most common cause of death...,2016-11-10,2018-05-31,Not Applicable,Interventional,Unknown status,Germany,,Other
6,NCT03171649,Stroke,Central Nervous System Diseases,Reaching Training Based on Robotic Hybrid Assi...,Reaching Training Based on Robotic Hybrid Assi...,Stroke is the third most common cause of death...,2016-11-10,2018-05-31,Not Applicable,Interventional,Unknown status,Germany,,Other
7,NCT02397031,Borderline Personality Disorder,Psychiatry and Psychology Category,Mindfulness and Interpersonal Effectiveness Sk...,"Randomized, Active-controlled, Clinical Trial ...",The purpose of the study was to determine whet...,2011-09-30,2014-04-30,Not Applicable,Interventional,Completed,,,Behavioral
8,NCT02397031,Borderline Personality Disorder,Psychiatry and Psychology Category,Mindfulness and Interpersonal Effectiveness Sk...,"Randomized, Active-controlled, Clinical Trial ...",The purpose of the study was to determine whet...,2011-09-30,2014-04-30,Not Applicable,Interventional,Completed,,,Behavioral
9,NCT02397031,Borderline Personality Disorder,Psychiatry and Psychology Category,Mindfulness and Interpersonal Effectiveness Sk...,"Randomized, Active-controlled, Clinical Trial ...",The purpose of the study was to determine whet...,2011-09-30,2014-04-30,Not Applicable,Interventional,Completed,,,Behavioral


In [106]:
df_unique_labels_with_aact = pd.merge(df_unique_labels, df_aact_labels, on='nct_id', how='left')
df_unique_labels_with_aact.head()

Unnamed: 0,nct_id,unique_conditions_BioLinkBERT-base_predictions,unique_interventions_BioLinkBERT-base_predictions,unique_interventions_type_BioLinkBERT-base_predictions,aact_conditions,Disease Class,brief_title,study_official_title,brief_summary_description,start_date,completion_date,phase,study_type,overall_status,country_name,aact_intervention_names,aact_intervention_types
0,NCT02970292,schizophrenia,pimavanserin,DRUG,Schizophrenia,Diseases of the nervous system,Efficacy and Safety of Adjunctive Pimavanserin...,"A Phase 3, Randomized, Double-Blind, Placebo-C...",To evaluate the efficacy and safety of adjunct...,2016-10-26,2019-06-25,Phase 3,Interventional,Completed,Spain,Pimavanserin,Drug
1,NCT02970292,schizophrenia,pimavanserin,DRUG,Schizophrenia,Diseases of the nervous system,Efficacy and Safety of Adjunctive Pimavanserin...,"A Phase 3, Randomized, Double-Blind, Placebo-C...",To evaluate the efficacy and safety of adjunct...,2016-10-26,2019-06-25,Phase 3,Interventional,Completed,Spain,Pimavanserin,Drug
2,NCT02970292,schizophrenia,pimavanserin,DRUG,Schizophrenia,Diseases of the nervous system,Efficacy and Safety of Adjunctive Pimavanserin...,"A Phase 3, Randomized, Double-Blind, Placebo-C...",To evaluate the efficacy and safety of adjunct...,2016-10-26,2019-06-25,Phase 3,Interventional,Completed,Spain,Pimavanserin,Drug
3,NCT02970292,schizophrenia,pimavanserin,DRUG,Schizophrenia,Diseases of the nervous system,Efficacy and Safety of Adjunctive Pimavanserin...,"A Phase 3, Randomized, Double-Blind, Placebo-C...",To evaluate the efficacy and safety of adjunct...,2016-10-26,2019-06-25,Phase 3,Interventional,Completed,Ukraine,Pimavanserin,Drug
4,NCT02970292,schizophrenia,pimavanserin,DRUG,Schizophrenia,Diseases of the nervous system,Efficacy and Safety of Adjunctive Pimavanserin...,"A Phase 3, Randomized, Double-Blind, Placebo-C...",To evaluate the efficacy and safety of adjunct...,2016-10-26,2019-06-25,Phase 3,Interventional,Completed,Ukraine,Pimavanserin,Drug


## Conditions

In [107]:
conditions_db = pd.read_csv("./data/neuro_diseases_terminology/diseases_dictionary_mesh_icd_2024.csv")

In [108]:
conditions_db[conditions_db['MeSH Common name']=='Depression']

Unnamed: 0.1,Unnamed: 0,ICD Node URI,ICD Parent URI,Mesh ID,MeSH Tree Number,ICD Title,MeSH Common name,MeSH Disease Class,ICD Disease Class,MeSH Synonyms
8141,8141,,,,,,Depression,Psychiatry and Psychology Category,,Depressive Symptoms | Depressive Symptom | Sym...


In [109]:
def add_variant(canonical_name, variant, drug_variant_to_canonical):
    #print(drug_variant_to_canonical)
    if variant not in drug_variant_to_canonical:
        drug_variant_to_canonical[variant] = set()
    drug_variant_to_canonical[variant].add(canonical_name)
    return drug_variant_to_canonical

In [110]:
def generate_conditions_lookup_dictionary(df):
    synonyms_dict = {}
    
    for index, row in df.iterrows():
        icd_title = row['ICD Title']
        mesh_name = row['MeSH Common name']
        if pd.notna(row['MeSH Synonyms']):
            synonyms_list = row['MeSH Synonyms'].split('|')
            for synonym in synonyms_list:
                synonym = synonym.strip().lower()
                mesh_name = mesh_name.lower()
                synonyms_dict = add_variant(mesh_name, synonym, synonyms_dict)
        elif pd.notna(row['ICD Title']):
            icd_title = icd_title.lower()
            synonyms_dict = add_variant(icd_title, icd_title, synonyms_dict) 
        elif pd.notna(row['MeSH Common name']):
            mesh_nam = mesh_name.lower()
            synonyms_dict = add_variant(mesh_name, mesh_name, synonyms_dict)
    
    return synonyms_dict

In [111]:
synonyms_dict = generate_conditions_lookup_dictionary(conditions_db)

In [112]:
synonyms_dict.get("depressive symptoms")

{'depression'}

In [113]:
synonyms_dict.get("depression")

In [114]:
import re
df_conditions = df_unique_labels_with_aact[["nct_id", f'unique_conditions_{model_name_str_biolink}_predictions', "aact_conditions", "Disease Class"]] #, "aact_conditions"
df_conditions.head(-5)

Unnamed: 0,nct_id,unique_conditions_BioLinkBERT-base_predictions,aact_conditions,Disease Class
0,NCT02970292,schizophrenia,Schizophrenia,Diseases of the nervous system
1,NCT02970292,schizophrenia,Schizophrenia,Diseases of the nervous system
2,NCT02970292,schizophrenia,Schizophrenia,Diseases of the nervous system
3,NCT02970292,schizophrenia,Schizophrenia,Diseases of the nervous system
4,NCT02970292,schizophrenia,Schizophrenia,Diseases of the nervous system
...,...,...,...,...
503993,NCT01036581,addiction|drug addiction,Nicotine Dependence,Diseases Category
503994,NCT01036581,addiction|drug addiction,Nicotine Dependence,Diseases Category
503995,NCT01036581,addiction|drug addiction,Drug Abuse,Diseases Category
503996,NCT01036581,addiction|drug addiction,Drug Abuse,Diseases Category


### Normalize representations

In [115]:
def lookup_canonical(conditions_list, synonyms_dict):
        canonical_list = []
        for condition in conditions_list.split('|'):
            condition = condition.lower().strip()
            if condition == "none" or condition == "" or condition == "none.":
                continue # generated by gpt if no condition was found
            if condition in synonyms_dict:
                canonical_list.extend(synonyms_dict[condition])
            else:
                canonical_list.append(condition)
        return '|'.join(canonical_list)

In [116]:
def process_dataframe(df, synonyms_dict):
    df.loc[:, f'canonical_{model_name_str_biolink}_conditions'] = df[f'unique_conditions_{model_name_str_biolink}_predictions'].apply(lookup_canonical, synonyms_dict=synonyms_dict)
    df.loc[:, 'canonical_aact_conditions'] = df['aact_conditions'].apply(lookup_canonical, synonyms_dict=synonyms_dict)
    return df

In [117]:
df_conditions_mapped = process_dataframe(df_conditions, synonyms_dict)

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df.loc[:, f'canonical_{model_name_str_biolink}_conditions'] = df[f'unique_conditions_{model_name_str_biolink}_predictions'].apply(lookup_canonical, synonyms_dict=synonyms_dict)
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df.loc[:, 'canonical_aact_conditions'] = df['aact_conditions'].apply(lookup_canonical, synonyms_dict=synonyms_dict)


In [118]:
df_conditions_mapped.head(-5)

Unnamed: 0,nct_id,unique_conditions_BioLinkBERT-base_predictions,aact_conditions,Disease Class,canonical_BioLinkBERT-base_conditions,canonical_aact_conditions
0,NCT02970292,schizophrenia,Schizophrenia,Diseases of the nervous system,schizophrenia,schizophrenia
1,NCT02970292,schizophrenia,Schizophrenia,Diseases of the nervous system,schizophrenia,schizophrenia
2,NCT02970292,schizophrenia,Schizophrenia,Diseases of the nervous system,schizophrenia,schizophrenia
3,NCT02970292,schizophrenia,Schizophrenia,Diseases of the nervous system,schizophrenia,schizophrenia
4,NCT02970292,schizophrenia,Schizophrenia,Diseases of the nervous system,schizophrenia,schizophrenia
...,...,...,...,...,...,...
503993,NCT01036581,addiction|drug addiction,Nicotine Dependence,Diseases Category,addiction|substance-related disorders,tobacco use disorder|nicotine dependence
503994,NCT01036581,addiction|drug addiction,Nicotine Dependence,Diseases Category,addiction|substance-related disorders,tobacco use disorder|nicotine dependence
503995,NCT01036581,addiction|drug addiction,Drug Abuse,Diseases Category,addiction|substance-related disorders,substance-related disorders
503996,NCT01036581,addiction|drug addiction,Drug Abuse,Diseases Category,addiction|substance-related disorders,substance-related disorders


## Drugs

In [119]:
import csv
import re

### Normalize representations

In [120]:
path_prefix = "./data"

In [121]:
variant_regex = re.compile(r'^[A-Za-z0-9,]+[ -]?[A-Za-z0-9\-]+(?:[ -][A-Z])?$')
drug_variant_to_canonical = {}
drug_canonical_to_data = {}

def add_variant(canonical_name, variant):
    if variant not in drug_variant_to_canonical:
        drug_variant_to_canonical[variant] = set()
    drug_variant_to_canonical[variant].add(canonical_name)


def add_drug(id, synonyms):
    synonyms = [s.strip() for s in synonyms]

    #TODO: add using an exclusion list as a parameter option to the function
    #if re.sub("[- ].+", "", synonyms[0].upper()) in exclusions:
    #    return
    if not variant_regex.match(synonyms[0]):
        return
    if synonyms[0] not in drug_canonical_to_data:
        drug_canonical_to_data[synonyms[0]] = {"name": synonyms[0], "synonyms": set()}
    if id.startswith("a"):
        drug_canonical_to_data[synonyms[0]]["medline_plus_id"] = id
    elif id.startswith("https://www.nhs.uk"):
        drug_canonical_to_data[synonyms[0]]["nhs_url"] = id
    elif id.startswith("https://en.wikipedia"):
        drug_canonical_to_data[synonyms[0]]["wikipedia_url"] = id
    elif id.startswith("DB"):
        drug_canonical_to_data[synonyms[0]]["drugbank_id"] = id
    else:
        drug_canonical_to_data[synonyms[0]]["mesh_id"] = id
    for variant in synonyms:
        #if re.sub(" .+", "", variant.upper()) in exclusions:
        #    return
        if variant_regex.match(variant):
            drug_canonical_to_data[synonyms[0]]["synonyms"].add(variant)
            add_variant(synonyms[0], variant.lower())
            #add_variant(synonyms[0], variant)
            #add_variant(synonyms[0], variant.upper())
            #if variant.lower() in words_to_allow_lower_case:    

with open(path_prefix + "/drug_names_terminology/drugs_dictionary_medlineplus.csv", 'r', encoding="utf-8") as csvfile:
    spamreader = csv.reader(csvfile, delimiter=',')
    headers = None
    for row in spamreader:
        if not headers:
            headers = row
            continue
        id = row[0]
        name = row[1]
        synonyms = row[2].split(r"|")
        name = re.sub(
            " (Injection|Oral Inhalation|Transdermal|Ophthalmic|Topical|Vaginal Cream|Nasal Spray|Transdermal Patch|Rectal)",
            "", name)
        name = name.lower()
        if name == "abobotulinumtoxina":
            print(row[1], synonyms)

        add_drug(id, [name] + synonyms)

AbobotulinumtoxinA Injection ['Dysport', 'BoNT-A']


In [122]:

with open(path_prefix + "/drug_names_terminology/drugs_dictionary_nhs.csv", 'r', encoding="utf-8") as csvfile:
    spamreader = csv.reader(csvfile, delimiter=',')
    headers = None
    for row in spamreader:
        if not headers:
            headers = row
            continue
        id = row[0]
        name = row[1]
        synonyms = row[2].split(r"|")
        name = name.lower()
        add_drug(id, [name] + synonyms)


with open(path_prefix + "/drug_names_terminology/drugs_dictionary_wikipedia.csv", 'r', encoding="utf-8") as csvfile:
    spamreader = csv.reader(csvfile, delimiter=',')
    headers = None
    for row in spamreader:
        if not headers:
            headers = row
            continue
        id = row[0]
        name = row[1]
        synonyms = row[2].split(r"|")
        name = name.lower()
        add_drug(id, [name] + synonyms)
        
with open(path_prefix + "/drug_names_terminology/drugs_dictionary_mesh.csv", 'r', encoding="utf-8") as csvfile:
    spamreader = csv.reader(csvfile, delimiter=',')
    headers = None
    for row in spamreader:
        if not headers:
            headers = row
            continue
        id = row[0]
        name = row[1]
        synonyms = row[2].split(r"\|")
        name = name.lower()
        add_drug(id, [name] + synonyms)

# adding for the full db with product names included as synonyms
# TODO: create a parametrized function from this, not hard-coded inline as it is
is_new_format = False
if is_new_format:
    with open(path_prefix + "/drug_names_terminology/drugdb_full_database_parsed.csv", 'r', encoding="utf-8") as csvfile:
        spamreader = csv.reader(csvfile, delimiter=';')
        headers = None
        for row in spamreader:
            #print(row)
            if not headers:
                headers = row
                continue
            id = row[0]
            name = row[1]
            synonyms = row[4].split(r"|")
            products = row[5].split(r"|")
            syn_prod = synonyms + products
            name = name.lower()
            add_drug(id, [name] + syn_prod)
# no product names considered
else:
    with open(path_prefix + "/drug_names_terminology/drugbank vocabulary.csv", 'r', encoding="utf-8") as csvfile:
        spamreader = csv.reader(csvfile, delimiter=',')
        headers = None
        for row in spamreader:
            if not headers:
                headers = row
                continue
            id = row[0]
            name = row[2]
            synonyms = row[5].split(r"|")
            name = name.lower()
            add_drug(id, [name] + synonyms)

In [139]:
drug_variant_to_canonical.get("exelon")

{'rivastigmine', 'rivastigmine patch'}

In [124]:
df_interventions = df_unique_labels_with_aact[["nct_id",  f'unique_interventions_{model_name_str_biolink}_predictions', 'aact_intervention_names', 'aact_intervention_types']] #'aact_intervention_names'


In [125]:
df_interventions.head(2)

Unnamed: 0,nct_id,unique_interventions_BioLinkBERT-base_predictions,aact_intervention_names,aact_intervention_types
0,NCT02970292,pimavanserin,Pimavanserin,Drug
1,NCT02970292,pimavanserin,Pimavanserin,Drug


In [126]:
def process_dataframe_interventions(df, synonyms_dict):
    df.loc[:, f'canonical_{model_name_str_biolink}_interventions'] = df[f'unique_interventions_{model_name_str_biolink}_predictions'].apply(lookup_canonical, synonyms_dict=synonyms_dict)
    df.loc[:, 'canonical_aact_interventions'] = df['aact_intervention_names'].apply(lookup_canonical, synonyms_dict=synonyms_dict)
    return df

In [127]:
df_interventions_mapped = process_dataframe_interventions(df_interventions, drug_variant_to_canonical)

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df.loc[:, f'canonical_{model_name_str_biolink}_interventions'] = df[f'unique_interventions_{model_name_str_biolink}_predictions'].apply(lookup_canonical, synonyms_dict=synonyms_dict)
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df.loc[:, 'canonical_aact_interventions'] = df['aact_intervention_names'].apply(lookup_canonical, synonyms_dict=synonyms_dict)


In [128]:
df_interventions_mapped

Unnamed: 0,nct_id,unique_interventions_BioLinkBERT-base_predictions,aact_intervention_names,aact_intervention_types,canonical_BioLinkBERT-base_interventions,canonical_aact_interventions
0,NCT02970292,pimavanserin,Pimavanserin,Drug,pimavanserin,pimavanserin
1,NCT02970292,pimavanserin,Pimavanserin,Drug,pimavanserin,pimavanserin
2,NCT02970292,pimavanserin,Pimavanserin,Drug,pimavanserin,pimavanserin
3,NCT02970292,pimavanserin,Pimavanserin,Drug,pimavanserin,pimavanserin
4,NCT02970292,pimavanserin,Pimavanserin,Drug,pimavanserin,pimavanserin
...,...,...,...,...,...,...
503998,NCT01036581,,,Device,,
503999,NCT01036581,,,Device,,
504000,NCT01036581,,,Device,,
504001,NCT01036581,,,Device,,


In [129]:
df_interventions_mapped[df_interventions_mapped['nct_id'] == 'NCT00000173']

Unnamed: 0,nct_id,unique_interventions_BioLinkBERT-base_predictions,aact_intervention_names,aact_intervention_types,canonical_BioLinkBERT-base_interventions,canonical_aact_interventions
197999,NCT00000173,donepezil hcl|aricept|vitamin e|donepezil|alph...,Donepezil,Drug,donepezil hcl|donepezil|vitamin e|donepezil|vi...,donepezil
198000,NCT00000173,donepezil hcl|aricept|vitamin e|donepezil|alph...,Donepezil,Drug,donepezil hcl|donepezil|vitamin e|donepezil|vi...,donepezil
198001,NCT00000173,donepezil hcl|aricept|vitamin e|donepezil|alph...,Donepezil,Drug,donepezil hcl|donepezil|vitamin e|donepezil|vi...,donepezil
198002,NCT00000173,donepezil hcl|aricept|vitamin e|donepezil|alph...,Donepezil,Drug,donepezil hcl|donepezil|vitamin e|donepezil|vi...,donepezil
198003,NCT00000173,donepezil hcl|aricept|vitamin e|donepezil|alph...,Donepezil,Drug,donepezil hcl|donepezil|vitamin e|donepezil|vi...,donepezil
198004,NCT00000173,donepezil hcl|aricept|vitamin e|donepezil|alph...,Donepezil,Drug,donepezil hcl|donepezil|vitamin e|donepezil|vi...,donepezil
198005,NCT00000173,donepezil hcl|aricept|vitamin e|donepezil|alph...,Vitamin E,Drug,donepezil hcl|donepezil|vitamin e|donepezil|vi...,vitamin e
198006,NCT00000173,donepezil hcl|aricept|vitamin e|donepezil|alph...,Vitamin E,Drug,donepezil hcl|donepezil|vitamin e|donepezil|vi...,vitamin e
198007,NCT00000173,donepezil hcl|aricept|vitamin e|donepezil|alph...,Vitamin E,Drug,donepezil hcl|donepezil|vitamin e|donepezil|vi...,vitamin e
198008,NCT00000173,donepezil hcl|aricept|vitamin e|donepezil|alph...,Vitamin E,Drug,donepezil hcl|donepezil|vitamin e|donepezil|vi...,vitamin e


## Drugs and conditions

In [130]:
# Perform an inner join on the 'nct_id' column
merged_df = pd.merge(df_interventions_mapped, df_conditions_mapped, on='nct_id', how='left')
# Remove all duplicates, keeping only rows that are unique across all columns
merged_df = merged_df.drop_duplicates()

merged_df_canonical = merged_df[['nct_id','canonical_BioLinkBERT-base_interventions','canonical_aact_interventions', 'aact_intervention_types', 'canonical_BioLinkBERT-base_conditions', 'canonical_aact_conditions',"Disease Class"]]
# Display the merged DataFrame
merged_df_canonical.head(10)

Unnamed: 0,nct_id,canonical_BioLinkBERT-base_interventions,canonical_aact_interventions,aact_intervention_types,canonical_BioLinkBERT-base_conditions,canonical_aact_conditions,Disease Class
0,NCT02970292,pimavanserin,pimavanserin,Drug,schizophrenia,schizophrenia,Diseases of the nervous system
2178,NCT02970292,pimavanserin,placebo,Drug,schizophrenia,schizophrenia,Diseases of the nervous system
4356,NCT03767426,,,Behavioral,emotional distress|stress,sleep deprivation,Sleep Wake Disorders
4359,NCT03767426,,,Behavioral,emotional distress|stress,sleep,unknown
4420,NCT03941067,,,Other,,neuromuscular diseases,Neuromuscular Diseases
4429,NCT03542357,sumatriptan|acalcitonine gene related peptide|...,calcitonin gene related peptide,Drug,migraine disorders|migraine|headache|migraine ...,migraine disorders|migraine,Diseases of the nervous system|Central Nervous...
4456,NCT03542357,sumatriptan|acalcitonine gene related peptide|...,sumatriptan 50 mg,Drug,migraine disorders|migraine|headache|migraine ...,migraine disorders|migraine,Diseases of the nervous system|Central Nervous...
4483,NCT03542357,sumatriptan|acalcitonine gene related peptide|...,placebo oral tablet,Drug,migraine disorders|migraine|headache|migraine ...,migraine disorders|migraine,Diseases of the nervous system|Central Nervous...
4510,NCT02776553,,,Other,sarcopenia|esld|end-stage liver disease|liver ...,end-stage liver disease (esld),unknown
4511,NCT02776553,,,Other,sarcopenia|esld|end-stage liver disease|liver ...,liver transplant,unknown


In [131]:
merged_df.shape

(133632, 11)

In [132]:
len(set(merged_df['nct_id']))

46376

In [133]:
# Remove rows where both 'canonical_BioLinkBERT-base_interventions' and 'canonical_aact_interventions' are empty
filtered_df = merged_df_canonical.loc[~((merged_df_canonical['canonical_BioLinkBERT-base_interventions'].isna() | merged_df_canonical['canonical_BioLinkBERT-base_interventions'].eq('')) & (merged_df_canonical['canonical_aact_interventions'].isna() | merged_df_canonical['canonical_aact_interventions'].eq('')))]
filtered_df = filtered_df[~filtered_df['canonical_aact_interventions'].str.contains('placebo', na=False)]

filtered_df.head(10)

Unnamed: 0,nct_id,canonical_BioLinkBERT-base_interventions,canonical_aact_interventions,aact_intervention_types,canonical_BioLinkBERT-base_conditions,canonical_aact_conditions,Disease Class
0,NCT02970292,pimavanserin,pimavanserin,Drug,schizophrenia,schizophrenia,Diseases of the nervous system
4429,NCT03542357,sumatriptan|acalcitonine gene related peptide|...,calcitonin gene related peptide,Drug,migraine disorders|migraine|headache|migraine ...,migraine disorders|migraine,Diseases of the nervous system|Central Nervous...
4456,NCT03542357,sumatriptan|acalcitonine gene related peptide|...,sumatriptan 50 mg,Drug,migraine disorders|migraine|headache|migraine ...,migraine disorders|migraine,Diseases of the nervous system|Central Nervous...
5122,NCT00611559,hib vaccine|hib|dtpa-hbv-ipv|dtpa-hbv-ipv /,infanrix™ penta,Biological,,poliomyelitis,Neuromuscular Diseases
5125,NCT00611559,hib vaccine|hib|dtpa-hbv-ipv|dtpa-hbv-ipv /,infanrix™ penta,Biological,,acellular pertussis,unknown
5126,NCT00611559,hib vaccine|hib|dtpa-hbv-ipv|dtpa-hbv-ipv /,infanrix™ penta,Biological,,tetanus,unknown
5127,NCT00611559,hib vaccine|hib|dtpa-hbv-ipv|dtpa-hbv-ipv /,infanrix™ penta,Biological,,diphtheria,unknown
5128,NCT00611559,hib vaccine|hib|dtpa-hbv-ipv|dtpa-hbv-ipv /,infanrix™ penta,Biological,,hepatitis b,unknown
5220,NCT00611559,hib vaccine|hib|dtpa-hbv-ipv|dtpa-hbv-ipv /,infanrix™ hexa,Biological,,poliomyelitis,Neuromuscular Diseases
5223,NCT00611559,hib vaccine|hib|dtpa-hbv-ipv|dtpa-hbv-ipv /,infanrix™ hexa,Biological,,acellular pertussis,unknown


In [134]:
len(set(filtered_df['nct_id'])), filtered_df.shape

(19607, (57475, 7))

In [135]:
# Function to join unique values
def join_unique(values):
    return '|'.join(set(values))

# Group by 'nct_id' and other relevant columns, then join interventions and conditions with '|', ensuring uniqueness
grouped_df = filtered_df.groupby(['nct_id', 'canonical_BioLinkBERT-base_interventions', 'canonical_BioLinkBERT-base_conditions'], as_index=False).agg({
    'canonical_aact_interventions': join_unique,
    'aact_intervention_types': join_unique,
    'canonical_aact_conditions': join_unique,
    'Disease Class': join_unique
})


In [136]:
# making sure there are no duplicate entities
grouped_df['canonical_BioLinkBERT-base_interventions'] = grouped_df['canonical_BioLinkBERT-base_interventions'].apply(lambda x: '|'.join(sorted(set(x.split('|')))))
grouped_df['canonical_BioLinkBERT-base_conditions'] = grouped_df['canonical_BioLinkBERT-base_conditions'].apply(lambda x: '|'.join(sorted(set(x.split('|')))))
grouped_df = grouped_df[['nct_id', 'canonical_BioLinkBERT-base_interventions', 'canonical_aact_interventions', 'aact_intervention_types', 'canonical_BioLinkBERT-base_conditions', 'canonical_aact_conditions', 'Disease Class']]
grouped_df['canonical_aact_interventions'] = grouped_df['canonical_aact_interventions'].str.lstrip('|')

grouped_df.head(10)

Unnamed: 0,nct_id,canonical_BioLinkBERT-base_interventions,canonical_aact_interventions,aact_intervention_types,canonical_BioLinkBERT-base_conditions,canonical_aact_conditions,Disease Class
0,NCT00000117,intravenous immunoglobulin|ivig,immunoglobulin,Drug,multiple sclerosis|optic neuritis,optic neuritis,Cranial Nerve Diseases
1,NCT00000146,corticosteroid,methylprednisolone|prednisone,Drug,multiple sclerosis|optic neuritis,multiple sclerosis|optic neuritis,Cranial Nerve Diseases|Demyelinating Diseases
2,NCT00000147,corticosteroid,methylprednisolone|prednisone,Drug,multiple sclerosis|optic neuritis,multiple sclerosis|optic neuritis,Cranial Nerve Diseases|Demyelinating Diseases
3,NCT00000151,acetylsalicylic acid|aspirin,aspirin|acetylsalicylic acid,Drug|Procedure,blindness|diabetes mellitus|diabetic retinopat...,diabetic retinopathy|blindness,Neurologic Manifestations|unknown
4,NCT00000170,atropine,atropine|,Drug|Device,amblyopia|anisometropia|moderate amblyopia|str...,amblyopia,Neurologic Manifestations
5,NCT00000171,melatonin,melatonin,Drug,alzheimer disease|sleep disturbances,alzheimer disease|dyssomnias,Sleep Wake Disorders|Neurodegenerative Diseases
6,NCT00000172,galantamine,galantamine,Drug,alzheimer disease,alzheimer disease,Neurodegenerative Diseases
7,NCT00000173,donepezil|donepezil hcl|vitamin e,donepezil|vitamin e,Drug,alzheimer disease|dementia|mild cognitive impa...,alzheimer disease,Neurodegenerative Diseases
8,NCT00000174,rivastigmine|rivastigmine patch,rivastigmine,Drug,alzheimer disease|dementia|mild cognitive impa...,cognition disorders|alzheimer disease,Neurodegenerative Diseases|Psychiatry and Psyc...
9,NCT00000175,estrogen|testosterone,estrogen|testosterone,Drug,,cognition disorders|mood disorders,Psychiatry and Psychology Category


In [137]:
len(grouped_df)

19607

In [138]:
grouped_df.to_csv(f'data/annotated_aact/normalized_annotations_unique_{len(grouped_df)}.csv')

In [83]:
len(set(grouped_df['nct_id'])), grouped_df.shape

(18644, (18644, 7))

In [199]:
# Splitting 'canonical_BioLinkBERT-base_interventions' and 'canonical_BioLinkBERT-base_conditions' into separate rows
filtered_df['canonical_BioLinkBERT-base_interventions'] = filtered_df['canonical_BioLinkBERT-base_interventions'].str.split('|')
filtered_df['canonical_BioLinkBERT-base_conditions'] = filtered_df['canonical_BioLinkBERT-base_conditions'].str.split('|')
filtered_df['canonical_aact_conditions'] = filtered_df['canonical_aact_conditions'].str.split('|')

# Exploding both columns to create new rows for each value
df_exploded = filtered_df.explode('canonical_BioLinkBERT-base_interventions')
df_exploded = df_exploded.explode('canonical_BioLinkBERT-base_conditions')
df_exploded = df_exploded.explode('canonical_aact_conditions')
df_exploded = df_exploded.drop_duplicates()
# Resetting the index to have a continuous index after exploding
df_exploded.reset_index(drop=True, inplace=True)

In [200]:
df_exploded.head()

Unnamed: 0,nct_id,canonical_BioLinkBERT-base_interventions,canonical_aact_interventions,aact_intervention_types,canonical_BioLinkBERT-base_conditions,canonical_aact_conditions
0,NCT02970292,pimavanserin,pimavanserin,Drug,schizophrenia,schizophrenia
1,NCT03542357,acalcitonine gene related peptide,calcitonin gene related peptide,Drug,migraine,migraine
2,NCT03542357,acalcitonine gene related peptide,calcitonin gene related peptide,Drug,migraine,migraine disorders
3,NCT03542357,acalcitonine gene related peptide,calcitonin gene related peptide,Drug,migraine disorders,migraine
4,NCT03542357,acalcitonine gene related peptide,calcitonin gene related peptide,Drug,migraine disorders,migraine disorders


## Entity normalization tests

### Load SNOMED terminology
Note: 
- SNOMED CT is the most comprehensive, multilingual clinical healthcare terminology in the world. See https://www.snomed.org/five-step-briefing
- SNOMED CT is a terminology that can cross-map to other international terminologies, classifications and code systems. Maps are associations between particular concepts or terms in one system and concepts or terms in another system that have the same (or similar) meaning. See https://www.snomed.org/maps
- source from: https://www.nlm.nih.gov/healthit/snomedct/international.html
- we use the snapshot release: A snapshot release is a release type in which the release files contain only the most recent version of every component and reference set member released, as at the release date. See https://confluence.ihtsdotools.org/display/DOCRELFMT/3.2+Release+Types

In [144]:
import networkx as nx
from tqdm import tqdm
from src.Snomed import Snomed
# code from https://github.com/cambridgeltl/sapbert/tree/main

In [146]:
release_id = '20240401'
SNOMED_PATH = './data/snomed/SnomedCT_InternationalRF2_PRODUCTION_20240401T120000Z' # you need to download your own SNOMED distribution
snomed = Snomed(SNOMED_PATH, release_id=release_id)
snomed.load_snomed()

In [147]:
snomed_sf_id_pairs = []

for snomed_id in tqdm(snomed.graph.nodes):
    
    node_descs = snomed.index_definition[snomed_id]
    for d in node_descs:
        snomed_sf_id_pairs.append((d, snomed_id))

print(len(snomed_sf_id_pairs))

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 366908/366908 [00:00<00:00, 2027081.37it/s]

971409





In [148]:
snomed_sf_id_pairs[:10]


[('Neoplasm of anterior aspect of epiglottis', '126813005'),
 ('Neoplasm of anterior aspect of epiglottis (disorder)', '126813005'),
 ('Neoplasm of junctional region of epiglottis', '126814004'),
 ('Neoplasm of junctional region of epiglottis (disorder)', '126814004'),
 ('Neoplasm of lateral wall of oropharynx', '126815003'),
 ('Neoplasm of lateral wall of oropharynx (disorder)', '126815003'),
 ('Neoplasm of posterior wall of oropharynx', '126816002'),
 ('Neoplasm of posterior wall of oropharynx (disorder)', '126816002'),
 ('Tumour of posterior wall of oropharynx', '126816002'),
 ('Tumor of posterior wall of oropharynx', '126816002')]

In [149]:
snomed_sf_id_pairs_100k = snomed_sf_id_pairs[:100000] # for simplicity

all_names = [p[0] for p in snomed_sf_id_pairs_100k]
all_ids = [p[1] for p in snomed_sf_id_pairs_100k]

In [150]:
all_names[:10]


['Neoplasm of anterior aspect of epiglottis',
 'Neoplasm of anterior aspect of epiglottis (disorder)',
 'Neoplasm of junctional region of epiglottis',
 'Neoplasm of junctional region of epiglottis (disorder)',
 'Neoplasm of lateral wall of oropharynx',
 'Neoplasm of lateral wall of oropharynx (disorder)',
 'Neoplasm of posterior wall of oropharynx',
 'Neoplasm of posterior wall of oropharynx (disorder)',
 'Tumour of posterior wall of oropharynx',
 'Tumor of posterior wall of oropharynx']

In [151]:
all_ids[:10]


['126813005',
 '126813005',
 '126814004',
 '126814004',
 '126815003',
 '126815003',
 '126816002',
 '126816002',
 '126816002',
 '126816002']

### Load SAPBert

In [161]:
from transformers import AutoTokenizer, AutoModel 
import torch
from scipy.spatial.distance import cdist


tokenizer = AutoTokenizer.from_pretrained("cambridgeltl/SapBERT-from-PubMedBERT-fulltext")  
model = AutoModel.from_pretrained("cambridgeltl/SapBERT-from-PubMedBERT-fulltext")

#### encode snomed labels

In [155]:
bs = 128
all_reps = []
for i in tqdm(np.arange(0, len(all_names), bs)):
    toks = tokenizer.batch_encode_plus(all_names[i:i+bs], 
                                       padding="max_length", 
                                       max_length=25, 
                                       truncation=True,
                                       return_tensors="pt")
    #toks_cuda = {}
    #for k,v in toks.items():
    #    toks_cuda[k] = v.cuda(1)
    #output = model(**toks_cuda)
    
    output = model(**toks)
    cls_rep = output[0][:,0,:]
    
    all_reps.append(cls_rep.cpu().detach().numpy())
all_reps_emb = np.concatenate(all_reps, axis=0)

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 782/782 [11:54<00:00,  1.09it/s]


In [156]:
print (all_reps_emb.shape)


(100000, 768)


#### encode query

In [246]:
query = "early-stage schizophrenia"
query_toks = tokenizer.batch_encode_plus([query], 
                                       padding="max_length", 
                                       max_length=25, 
                                       truncation=True,
                                       return_tensors="pt")

In [247]:
query_output = model(**query_toks)
query_cls_rep = query_output[0][:,0,:]

In [248]:
query_cls_rep.shape


torch.Size([1, 768])

#### find query nearest neighbour

In [249]:
dist = cdist(query_cls_rep.cpu().detach().numpy(), all_reps_emb)
nn_index = np.argmin(dist)
print ("predicted label:", snomed_sf_id_pairs_100k[nn_index])

predicted label: ('Chronic undifferentiated schizophrenia', '29599000')


In [250]:
snomed['29599000'], snomed.predecessors('29599000')

({'desc': 'Chronic undifferentiated schizophrenia'}, ['111484002', '83746006'])

In [236]:
snomed['111484002'], snomed.predecessors('111484002')

({'desc': 'Undifferentiated schizophrenia'}, ['58214004'])

In [237]:
 snomed['83746006'], snomed.predecessors('83746006')

({'desc': 'Chronic schizophrenia'}, ['128293007', '58214004'])

In [223]:
snomed['128293007'], snomed['58214004']

({'desc': 'Chronic mental illness'}, {'desc': 'Schizophrenia'})

In [227]:
snomed.predecessors('58214004')

['69322001']

In [230]:
snomed['69322001'], snomed.predecessors('69322001')

({'desc': 'Psychotic disorder'}, ['74732009'])

In [231]:
snomed['74732009'], snomed.predecessors('74732009')

({'desc': 'Mental disorder'}, ['64572001'])

In [232]:
snomed['64572001'], snomed.predecessors('64572001')

({'desc': 'Disease'}, ['404684003'])

In [233]:
snomed['404684003'], snomed.predecessors('404684003')

({'desc': 'Clinical finding (finding)'}, ['138875005'])

In [234]:
snomed['138875005'], snomed.predecessors('138875005')

({'desc': 'SNOMED CT Concept'}, [])