In [5]:
import pandas as pd
from tqdm.notebook import tqdm

In [6]:
def ct_dict2pd(study: dict, missing_val=None) -> pd.Series:
    """ETL process to convert a CT study in a JSON format to the format
    exemplified in the trial2vec demo data.
    See: https://pypi.org/project/Trial2Vec/

    Parameters
    ----------
    study : dict
        as provided through the clinicaltrials.gov API

    missing_val: (default:None)
        How to encode missing values

    Returns
    -------
    pd.Series
        fields :
        - nct_id
        - description
        - study_type
        - title
        - intervention name
        - disease
        - keyword
        - outcome measure
        - (selection) criteria
        - references
        - overall status
    """
    missing_val = None

    ct_protocol = study.get("protocolSection", {})

    nct_id = ct_protocol.get("identificationModule", {}).get("nctId", missing_val)

    description = ct_protocol.get("descriptionModule", {}).get(
        "briefSummary", missing_val
    )

    study_type = ct_protocol.get("designModule", {}).get("studyType", missing_val)

    title = ct_protocol["identificationModule"].get(
        "officialTitle",
        ct_protocol["identificationModule"].get("briefTitle", missing_val),
    )

    # Intervention name
    if study_type == "OBSERVATIONAL":
        intervention_name = study_type
    else:
        interventions = ct_protocol.get("armsInterventionsModule", {}).get(
            "interventions", []
        )
        intervention_name = ", ".join(
            set(i.get("name", "").split(":")[-1] for i in interventions)
        )

    disease = ", ".join(
        sorted(ct_protocol.get("conditionsModule", {}).get("conditions", []))
    )

    keyword = (
        ", ".join(sorted(ct_protocol.get("conditionsModule", {}).get("keywords", [])))
        if study_type != "OBSERVATIONAL"
        else missing_val
    )

    # Outcome measurement
    if study_type == "OBSERVATIONAL":
        try:
            design_info = ct_protocol["designModule"]["designInfo"]
            outcome_measure = design_info.get("observationalModel", study_type)
            outcome_measure += (
                "-" + design_info.get("timePerspective", "")
                if "timePerspective" in design_info
                else ""
            )
        except KeyError:
            outcome_measure = study_type
    else:
        primary_outcomes = ct_protocol.get("outcomesModule", {}).get(
            "primaryOutcomes", []
        )
        outcome_measure = ", ".join(set(i.get("measure", "") for i in primary_outcomes))

    # Selection criteria
    try:
        criteria = ct_protocol.get("eligibilityModule", {}).get(
            "eligibilityCriteria", ""
        )
        criteria = criteria.replace("\n* ", "~").replace("\n", "~").replace("~~", "~")
    except:
        try:
            eligibility = ct_protocol.get("eligibilityModule", {})
            criteria = ", ".join(
                [": ".join([k, str(v)]) for k, v in eligibility.items()]
            )
        except:
            criteria = missing_val

    # References
    try:
        references = ct_protocol.get("referencesModule", {}).get("references", [])
        reference = ", ".join(
            r.get("citation", "").split(".")[1].lstrip(" ")
            for r in references
            if "citation" in r
        )
    except KeyError:
        reference = missing_val

    overall_status = (
        ct_protocol.get("statusModule", {}).get("overallStatus", "").lower()
    )

    return (
        pd.Series(
            {
                "nct_id": nct_id,
                "description": description,
                "title": title,
                "intervention_name": intervention_name,
                "disease": disease,
                "keyword": keyword,
                "outcome_measure": outcome_measure,
                "criteria": criteria,
                "reference": reference,
                "overall_status": overall_status,
            }
        )
        .to_frame()
        .transpose()
    )

In [7]:
from src.utils.utils import connect_to_mongoDB

import os
from dotenv import load_dotenv

load_dotenv(".env")

MONGODB_USER = os.getenv("MONGODB_USER")
MONGODB_PWD = os.getenv("MONGODB_PWD")

client = connect_to_mongoDB(MONGODB_USER, MONGODB_PWD)
db = client["ctGov"]
collection = db["heart_failure"]
studies = collection.find({})

Pinged your deployment. You successfully connected to MongoDB!


In [8]:
study_pd = pd.DataFrame()

i = 0
for study in studies:
    tmp = ct_dict2pd(study)

    study_pd = pd.concat([study_pd, tmp])
    i += 1
    if i > 10:
        break

In [9]:
study_pd.head()

Unnamed: 0,nct_id,description,title,intervention_name,disease,keyword,outcome_measure,criteria,reference,overall_status
0,NCT00000475,To assess the impact of a multidisciplinary tr...,Prevention of Early Readmission in Elderly Con...,health education,"Cardiovascular Diseases, Heart Diseases, Heart...",,,"Men and women, ages 70 or older, with document...",A multidisciplinary intervention to prevent th...,completed
0,NCT00000476,"To determine if digitalis had a beneficial, ha...",Digitalis Investigation Group (DIG),digitalis,"Arrhythmia, Cardiovascular Diseases, Heart Dis...",,,"Men and women with clinical heart failure, sin...",Protecting patient's rights: the DIG study exp...,completed
0,NCT00000560,To determine if addition of a beta-blocker to ...,Beta-Blocker Evaluation in Survival Trial (BEST),adrenergic beta antagonists,"Cardiovascular Diseases, Heart Diseases, Heart...",,,"Men and women, ages 18 and over. Patients had ...","The BEST Steering Committee, A trial of the be...",completed
0,NCT00000607,"To conduct a randomized, unblinded clinical tr...","The REMATCH Trial: Rationale, Design, and End ...","Optimal medical therapy, Left ventricular assi...","Cardiovascular Diseases, Heart Diseases, Heart...",,Survival rate in LVAD group,Inclusion Criteria~1. Men and women with Class...,"The REMATCH trial: rationale, design, and end ...",completed
0,NCT00000609,To compare conventional treatment of congestiv...,Sudden Cardiac Death in Heart Failure Trial (S...,"amiodarone, defibrillators, implantable","Arrhythmia, Cardiovascular Diseases, Death, Su...",,,Patients with New York Heart Association class...,Amiodarone or an implantable cardioverter-defi...,completed


In [10]:
from trial2vec import Trial2Vec

model = Trial2Vec(device="cpu")
model.from_pretrained()

Load pretrained Trial2Vec model from ./trial_search/pretrained_trial2vec
load predictor config file from ./trial_search/pretrained_trial2vec\model_config.json


In [21]:
from sentence_transformers import SentenceTransformer
clinicalBERT = SentenceTransformer("emilyalsentzer/Bio_ClinicalBERT")

No sentence-transformers model found with name emilyalsentzer/Bio_ClinicalBERT. Creating a new one with MEAN pooling.


In [15]:
# test_data = {'x': df} # contains trial documents

emb = model.encode({"x": study_pd})  # make inference
# emb.to_csv("./data/ct.trial2vec_embedding.csv")

# # # or just find the pre-encoded trial documents
# emb2 = [model[nct_id] for test_data['x']['nct_id']]

  _ = df.applymap(str)
Encoding: 100%|██████████| 1/1 [00:05<00:00,  5.34s/it]


In [32]:
emb

{'NCT00000475': array([ 0.1387456 ,  0.03155407, -0.00298959, -0.12880342,  0.06189645,
        -0.02382167, -0.10314611, -0.12430413,  0.15415034, -0.07071654,
         0.13423762,  0.01993478, -0.06761336,  0.00990986,  0.05189272,
        -0.17783299, -0.0703306 ,  0.03008412,  0.07903063, -0.07783525,
         0.04445762,  0.01213957,  0.0002565 , -0.02367849,  0.16438083,
        -0.04542102,  0.04863609,  0.04252839, -0.00506172, -0.06073992,
        -0.04101189, -0.09020843, -0.12913786,  0.06918579, -0.01727096,
        -0.07175722,  0.0029251 , -0.04499697, -0.03253841, -0.0316362 ,
         0.02847278, -0.19580099,  0.03814256,  0.03822841,  0.11539964,
         0.02902259,  0.16552168,  0.0430192 ,  0.05023267,  0.08685552,
        -0.14534307,  0.23547299,  0.04868193,  0.14032781,  0.02263975,
        -0.11445466, -0.00221165,  0.11863811,  0.15749435,  0.0155226 ,
         0.12800963,  0.02377683,  0.04912435, -0.04896666, -0.10938001,
        -0.06061439,  0.02112046,  0

In [33]:
emb["NCT00000475"]

array([ 0.1387456 ,  0.03155407, -0.00298959, -0.12880342,  0.06189645,
       -0.02382167, -0.10314611, -0.12430413,  0.15415034, -0.07071654,
        0.13423762,  0.01993478, -0.06761336,  0.00990986,  0.05189272,
       -0.17783299, -0.0703306 ,  0.03008412,  0.07903063, -0.07783525,
        0.04445762,  0.01213957,  0.0002565 , -0.02367849,  0.16438083,
       -0.04542102,  0.04863609,  0.04252839, -0.00506172, -0.06073992,
       -0.04101189, -0.09020843, -0.12913786,  0.06918579, -0.01727096,
       -0.07175722,  0.0029251 , -0.04499697, -0.03253841, -0.0316362 ,
        0.02847278, -0.19580099,  0.03814256,  0.03822841,  0.11539964,
        0.02902259,  0.16552168,  0.0430192 ,  0.05023267,  0.08685552,
       -0.14534307,  0.23547299,  0.04868193,  0.14032781,  0.02263975,
       -0.11445466, -0.00221165,  0.11863811,  0.15749435,  0.0155226 ,
        0.12800963,  0.02377683,  0.04912435, -0.04896666, -0.10938001,
       -0.06061439,  0.02112046,  0.06741887,  0.04994828, -0.09

In [22]:
inputs = ['I am a sentence', 'I am another sentence abcdefg xyz']

In [28]:
len(model.sentence_vector(inputs)[0])

128

In [27]:
len(clinicalBERT.encode(inputs)[0])

768

In [29]:
clinicalBERT.encode(inputs)[0]

array([ 9.14228782e-02,  2.07971454e-01, -2.09350541e-01,  2.93758541e-01,
        2.56028503e-01, -4.56299752e-01,  4.40826088e-01, -5.57930805e-02,
        1.48435131e-01, -2.25606099e-01, -2.16735959e-01,  8.60819146e-02,
       -4.45429444e-01,  2.43752941e-01, -6.33598924e-01, -1.32210806e-01,
       -8.17837566e-02, -2.39041135e-01, -5.20994961e-01, -1.11547470e-01,
        5.90904765e-02,  3.00880432e-01, -1.83278918e-02, -3.54681462e-02,
       -1.32419556e-01, -1.23944938e-01,  8.07736099e-01,  2.95291573e-01,
        1.07702434e-01,  4.47419614e-01, -3.70611757e-04, -1.70179442e-01,
        9.98176113e-02, -2.11752459e-01, -4.63927239e-01,  1.33983390e-02,
       -4.06245619e-01,  3.67949247e-01, -6.01417184e-01, -2.21887946e-01,
        2.33610436e-01, -1.53267071e-01,  5.18833399e-01,  3.71078588e-02,
        2.59550124e-01, -2.50122279e-01, -2.51298308e-01, -3.61252546e-01,
       -4.25588675e-02,  2.58032888e-01,  1.49974719e-01, -6.38096333e-02,
        2.32220992e-01,  

In [20]:
pd.DataFrame(emb)

Unnamed: 0,NCT00000475,NCT00000476,NCT00000560,NCT00000607,NCT00000609,NCT00000619,NCT00001313,NCT00001402,NCT00001628,NCT00001629,NCT00004562
0,0.138746,0.040624,0.104108,0.165003,0.112361,0.135371,0.081229,0.053268,0.120190,0.113337,0.025053
1,0.031554,0.063589,0.011923,-0.012164,0.000821,0.006450,0.178092,-0.027048,0.060667,0.044872,-0.066249
2,-0.002990,-0.010372,-0.126291,-0.047869,-0.099737,-0.081933,-0.050669,-0.000881,-0.075458,-0.065103,-0.130493
3,-0.128803,-0.118481,-0.137736,-0.069651,-0.135369,-0.113227,-0.103267,-0.066385,0.047523,0.014866,-0.096861
4,0.061896,-0.123055,-0.057879,-0.069621,-0.043407,-0.008484,-0.083458,-0.102128,0.118483,0.081735,0.077726
...,...,...,...,...,...,...,...,...,...,...,...
123,-0.023395,-0.068702,0.003924,-0.006702,-0.074611,-0.035129,-0.079416,0.035715,-0.021810,-0.018815,0.024305
124,-0.078355,-0.145258,-0.186262,-0.084890,-0.086567,-0.070208,-0.130423,-0.107883,-0.218984,-0.208242,-0.077465
125,0.094270,0.123368,0.080032,-0.003876,0.025504,-0.025423,0.065408,-0.075339,-0.009425,0.019551,0.073762
126,0.021166,-0.013623,0.045496,-0.017552,0.064643,-0.002614,-0.014949,-0.059166,0.045631,0.081300,0.026403


In [18]:
pd.DataFrame(emb, columns=["nctId", "trial2vec"])

Unnamed: 0,nctId,trial2vec


In [None]:
for disease in ["heart_failure", "asthma"]:
    collection = db[disease]
    # Set trial2vec to default null
    collection.update_many({}, {"$set": {"trial2vec": []}})
    # Load embedding from file
    emb = pd.read_csv(f"./data/ct.trial2vec_embedding.{disease}.csv", index_col=0)

    for study in emb.columns:
        collection.update_one(
            {"_id": study}, {"$set": {"trial2vec": list(emb[study].values)}}
        )

In [None]:
from biobert_embedding.embedding import BiobertEmbedding

text = "Breast cancers with HER2 amplification have a higher risk of CNS metastasis and poorer prognosis."
# Class Initialization (You can set default 'model_path=None' as your finetuned BERT model path while Initialization)
biobert = BiobertEmbedding()

word_embeddings = biobert.word_vector(text)
sentence_embedding = biobert.sentence_vector(text)

print("Text Tokens: ", biobert.tokens)
# Text Tokens:  ['breast', 'cancers', 'with', 'her2', 'amplification', 'have', 'a', 'higher', 'risk', 'of', 'cns', 'metastasis', 'and', 'poorer', 'prognosis', '.']

print(
    "Shape of Word Embeddings: %d x %d"
    % (len(word_embeddings), len(word_embeddings[0]))
)
# Shape of Word Embeddings: 16 x 768

print("Shape of Sentence Embedding = ", len(sentence_embedding))
# Shape of Sentence Embedding =  768