In [2]:
import os
import sys
import torch

ROOT_DIR = "/data/tyler_dev/working_dir/sym/symptom_similarity"
DATA_DIR = os.path.join(ROOT_DIR, "data")
sys.path.append(ROOT_DIR)

# load model
from core.networks import Transformer

params = {
    "output_size": 128,
    "hidden_dim": 2048,
    "input_size": 1536,
    "n_layers": 32,
    "nhead": 32,
    "batch_first": False,
}
best_model = Transformer(**params)
best_model.load_state_dict(
    torch.load(
        DATA_DIR + "/81fcf4f39e57422db4debfeac61b01f0/val_top100_0.584.ckpt"
    )
)
best_model.eval()
best_model = best_model.cuda()

In [3]:
from core.augmentation import (
    TruncateOrPad,
)

# sequence padder
padder = TruncateOrPad(
            15, stochastic=False, weighted_sampling=True
        )
# define inference function
def get_score_from_model(patient, disease):
    with torch.no_grad():
        input_src = padder(
            torch.tensor(patient.hpos.vector, dtype=torch.float32, device="cuda:0"), patient
        )
        target_src = padder(
            torch.tensor(disease.hpos.vector, dtype=torch.float32, device="cuda:0"), disease
        )

        input_vector = best_model(input_src)
        target_vector = best_model(target_src)
        scores = (
            torch.nn.functional.cosine_similarity(
                input_vector, target_vector
            )
            .squeeze(-1)
            .detach()
            .cpu()
            .numpy()
        )
        return scores

In [5]:
from core.io_ops import load_pickle

disease_data = load_pickle(os.path.join(DATA_DIR, "diseases.pickle"))
patient_data = load_pickle(os.path.join(DATA_DIR, "patients.pickle"))

In [19]:
patient_data[1000]

Patient(id=EPJ22-HCGX, hpos=HP:0012642,HP:0001650,HP:0012494,HP:0100817, disease_id=OMIM:607151)

In [33]:
disease_data["OMIM:607151"]

Disease(id='OMIM:607151', name='Moyamoya disease 2, susceptibility to', hpos=HPOs(N HPO=5), clinical_synopsis=ClinicalSynopsis(names={'Transient ischemic attacks (TIA)', 'Internal carotid artery stenosis/occlusion', 'Headache', 'Autosomal recessive', 'Susceptibility conferred by mutation in the ring finger protein 213 gene (RNF213,', 'Seizures', 'Autosomal dominant', 'Thu, 16 Jul 2020 00:00:00 UTC', 'Intraventricular hemorrhage', 'Hemiparesis', 'Posterior cerebral artery stenosis/occlusion', 'Development of moyamoya collaterals', 'Intracranial hemorrhage', 'Middle cerebral artery stenosis/occlusion', 'Moyamoya disease', 'Visual defects', 'Intellectual impairment', 'Anterior cerebral artery stenosis/occlusion', 'Highest prevalence in East Asian countries', 'Occurs worldwide;', 'Ischemic stroke (childhood)', 'Wed, 07 Oct 2020 00:00:00 UTC'}, vector=array([-0.01448052,  0.01769544,  0.02337512, ..., -0.00306422,
       -0.01872689, -0.02514333])))

In [47]:
disease_data[100]

Disease(id='OMIM:617667', name='Fraser syndrome 3', hpos=HPOs(N HPO=23), clinical_synopsis=ClinicalSynopsis(names={'Hyperechogenic lungs seen on prenatal ultrasound', 'Broad nose', 'Abnormal frontal hairline', 'Autosomal recessive', 'Abnormal anterior chamber', 'Pulmonary hyperplasia', 'Short toes', 'Caused by mutation in the glutamate receptor-interacting protein 1 gene (GRIP1,', 'Based on a report of 2 stillborn male fetuses (last curated September 2017)', 'Hypoplastic bladder', 'Hypoplastic scrotum', 'Wed, 13 Sep 2017 00:00:00 UTC', 'Cryptophthalmos, bilateral', 'Absent kidneys', 'Hypoplastic penis', 'Tue, 25 Aug 2020 00:00:00 UTC', 'Tracheal atresia', 'Cutaneous syndactyly, partial bilateral', 'Malformed larynx', 'Abnormal lung lobation', 'Malformed penis', 'Micrognathia', 'Short fingers', 'Beaked nose', 'Notched alae nasi', 'Abnormally positioned anus', 'Hydrops fetalis', 'Normal retinal pigment epithelium', 'Low-set simple ears'}, vector=array([-0.02173538,  0.01032189,  0.013292

In [46]:
from core.data_model import Patient, Disease, Ontology

# load ontology
ectorized_hpo = load_pickle(DATA_DIR + "/hpo_definition.vector.pickle")
ontology = Ontology(ectorized_hpo)


# patient and disease example
'''
id: EPJ22-HCGX
symptom hpos: HP:0012642,HP:0001650,HP:0012494,HP:0100817

confirmed disease_id: OMIM:607151(MOYAMOYA DISEASE 2; MYMY2)

random_disease_id: OMIM:617667(FRASER SYNDROME 3; FRASRS3)
'''

patient_hpos = ontology(["HP:0012642","HP:0001650","HP:0012494","HP:0100817"])
confirmed_disease_hpos = ontology(['HP:0000007', 'HP:0011834', 'HP:0002326', 'HP:0000006'])
random_disease_hpos = ontology(['HP:0001562', 'HP:0010958', 'HP:0000046', 'HP:0002101', \
                                'HP:0001541', 'HP:0005343', 'HP:0001790', 'HP:0012725', \
                                    'HP:0034217', 'HP:0100682', 'HP:0012300', 'HP:0034198',\
                                          'HP:0003826', 'HP:0000007', 'HP:0000238', 'HP:0000444', \
                                            'HP:0001126', 'HP:0008736', 'HP:0000369', 'HP:0000347',\
                                                  'HP:0020206', 'HP:0001831', 'HP:0000445'])

# get similarities
p_data = Patient("EPJ22-HCGX", patient_hpos, {})
confirmed_d_data = Disease("OMIM:607151", "MOYAMOYA DISEASE 2", confirmed_disease_hpos)
random_d_data = Disease("OMIM:219000", "FRASER SYNDROME 3", random_disease_hpos)

print(f"sym bt patient and confirmd: {get_score_from_model(p_data, confirmed_d_data)}")
print(f"sym bt patient and random: {get_score_from_model(p_data, random_d_data)}")

  ontology = pronto.Ontology(SORUCE_URL["hpo_obo"])


sym bt patient and confirmd: 0.7595720291137695
sym bt patient and random: -0.14743348956108093
