In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import sys
NB_DIR = os.getcwd()
ROOT_DIR = os.path.dirname(NB_DIR)

sys.path.append(ROOT_DIR)

In [3]:
from sklearn.model_selection import train_test_split

from core.io_ops import load_pickle
from core.augmentation import cleanse_data
from core.data_model import Patients
from ontology_src import ArtifactPath

disease_data = load_pickle(ArtifactPath.diseases)
patient_data = load_pickle(ArtifactPath.patients)
ontology = load_pickle(ArtifactPath.hpo_definition)
disease_data, patient_data = cleanse_data(disease_data, patient_data)

train_val_patients_list, test_patients_list = train_test_split(
    patient_data.data, random_state=2023
)
train_patients_list, val_patients_list = train_test_split(
    train_val_patients_list, random_state=2023
)
train_patients = Patients(train_patients_list)
val_patients = Patients(val_patients_list)
test_patients = Patients(test_patients_list)

In [4]:
import torch
from core.datasets import (
    StochasticPairwiseDataset,
    collate_for_stochastic_pairwise_eval,
)
from torch.utils.data import DataLoader
from core.networks import Transformer

params = {
    "output_size": 128,
    "hidden_dim": 2048,
    "input_size": 1536,
    "n_layers": 32,
    "nhead": 32,
    "batch_first": False,
}
best_model = Transformer(**params)
best_model.load_state_dict(
    torch.load(
        "/data/tyler_dev/working_dir/sym/symptom_similarity/data/81fcf4f39e57422db4debfeac61b01f0/val_top100_0.584.ckpt"
    )
)
best_model.eval()
best_model = best_model.cuda()

test_dataset = StochasticPairwiseDataset(
    test_patients,
    disease_data,
    max_len=15,
)
test_dataset.validate()

test_dataloader = DataLoader(
    test_dataset,
    batch_size=1,
    num_workers=0,
    collate_fn=collate_for_stochastic_pairwise_eval,
    shuffle=False,
    pin_memory=True,
)



In [5]:
import numpy as np
from collections import defaultdict
from tqdm import tqdm
from core_3asc.metric import topk_recall

cached_vector = {}
whole_disease = test_dataset.disease_tensors
with torch.no_grad():
    for disease_id, tensor in tqdm(whole_disease.items()):
        cached_vector[disease_id] = best_model(tensor.cuda()).squeeze(0)

100%|██████████| 8181/8181 [01:44<00:00, 78.19it/s]


In [54]:
import glob
from core.io_ops import read_json
from core.data_model import HPO, HPOs, Patient, Patients

benchmark_patients_container = list()
for path in glob.glob("/home/heon/repositories/symptom_similarity/data/phenopackets/*"):
    p_data = read_json(path)
    
    patient_id = p_data["id"]
    disease_ids = {disease["term"]["id"] for disease in p_data["diseases"]}
    hpos = ontology[[phenotype["type"]["id"] for phenotype in p_data["phenotypicFeatures"]]]

    benchmark_patients_container.append(
        Patient(
            id=patient_id,
            hpos=hpos,
            disease_ids=disease_ids
        )
    )

benchmark_patients = Patients(benchmark_patients_container)

In [32]:
from core.data_model import Diseases
omim_diseases = Diseases([disease for disease in disease_data if disease.id.startswith("OMIM")])

In [33]:
from benchmark.pheno2disease import Pheno2Disease
pheno2disease = Pheno2Disease()

In [82]:
import numpy as np
def cosine_sim(v1, v2)->float:
    return v1 @ v2 / np.linalg.norm(v1) * np.linalg.norm(v1)

patient_scores = dict()
for patient in tqdm(benchmark_patients):
    label = np.zeros(len(omim_diseases))
    scores_model = np.zeros(len(omim_diseases))
    scores_pheno2disease = np.zeros(len(omim_diseases))
    p_vector = best_model(
        torch.from_numpy(patient.hpos.vector).cuda().float()
    ).squeeze(0).detach().cpu().numpy()
    
    patient_score = dict()
    for i, omim_disease in enumerate(omim_diseases):
        scores_model[i] = cosine_sim(p_vector, cached_vector[omim_disease.id].detach().cpu().numpy())
        scores_pheno2disease[i] = pheno2disease.get_pheno2disease(patient, omim_disease)

        if omim_disease.id in patient.disease_ids:
            label[i] = 1
    
    patient_score["model"] = scores_model
    patient_score["p2d"] = scores_pheno2disease
    patient_score["label"] = label
    patient_scores[patient.id] = patient_score

  0%|          | 0/384 [00:00<?, ?it/s]

 12%|█▏        | 47/384 [02:36<18:41,  3.33s/it]


KeyboardInterrupt: 

In [88]:
from core_3asc.metric import AverageMeter, topk_recall
model_top1 = AverageMeter()
model_top5 = AverageMeter()
model_top10 = AverageMeter()
model_top50 = AverageMeter()
p2d_top1 = AverageMeter()
p2d_top5 = AverageMeter()
p2d_top10 = AverageMeter()
p2d_top50 = AverageMeter()

for id, patient_score in patient_scores.items():
    model_top1.update(topk_recall(patient_score["model"], patient_score["label"], k=1))
    model_top5.update(topk_recall(patient_score["model"], patient_score["label"], k=5))
    model_top10.update(topk_recall(patient_score["model"], patient_score["label"], k=10))
    model_top50.update(topk_recall(patient_score["model"], patient_score["label"], k=50))
    
    p2d_top1.update(topk_recall(patient_score["p2d"], patient_score["label"], k=1))
    p2d_top5.update(topk_recall(patient_score["p2d"], patient_score["label"], k=5))
    p2d_top10.update(topk_recall(patient_score["p2d"], patient_score["label"], k=10))
    p2d_top50.update(topk_recall(patient_score["p2d"], patient_score["label"], k=50))

In [89]:
model_top10.avg, p2d_top10.avg

(0.2553191489361702, 0.3191489361702128)

In [91]:
model_top50.avg, p2d_top50.avg

(0.5319148936170213, 0.6382978723404256)