- Figure 3 - Phen2Disease 에서 썼었던 데이터셋으로 benchmark performance 확인
- https://github.com/ZhuLab-Fudan/Phen2Disease?tab=readme-ov-file 에서 benchmark set 1
- https://zenodo.org/records/3905420 에 위치해있음

In [None]:
import os
import sys
import torch

PUBDIR = os.getcwd()
ROOT_DIR = os.path.dirname(PUBDIR)
DATA_DIR = os.path.join(ROOT_DIR, "data")
sys.path.append(ROOT_DIR)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
# Load ontology
from core.data_model import Patient, Disease, Ontology
from core.data_model import Diseases
from core.io_ops import load_pickle

disease_data = load_pickle(os.path.join(DATA_DIR, "diseases.pickle"))
vectorized_hpo = load_pickle(os.path.join(DATA_DIR, "hpo_definition.vector.pickle"))
ontology = Ontology(vectorized_hpo)
omim_diseases = Diseases([disease for disease in disease_data if disease.id.startswith("OMIM")])

### Load benchmark patient dataset

In [5]:
# download public dataset from zeonodo
! wget https://zenodo.org/records/3905420/files/phenopackets.zip?download=1
! mv phenopackets.zip?download=1 ../data/phenopackets.zip
! unzip -o ../data/phenopackets.zip -d ../data/

--2024-03-30 17:22:04--  https://zenodo.org/records/3905420/files/phenopackets.zip?download=1
Resolving zenodo.org (zenodo.org)... 188.185.79.172, 188.184.103.159, 188.184.98.238, ...
Connecting to zenodo.org (zenodo.org)|188.185.79.172|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 572349 (559K) [application/octet-stream]
Saving to: ‘phenopackets.zip?download=1’


2024-03-30 17:22:07 (425 KB/s) - ‘phenopackets.zip?download=1’ saved [572349/572349]

Archive:  ../data/phenopackets.zip
replace ../data/phenopackets/Naz_Villalba-2016-NLRP3-proband.json? [y]es, [n]o, [A]ll, [N]one, [r]ename: 

#### Phen2Disease Benchmark dataset 가져오기
https://github.com/ZhuLab-Fudan/Phen2Disease?tab=readme-ov-file 에서 benchmark set 1
https://zenodo.org/records/3905420 에 위치해있음

In [None]:
import glob
from core.benchmark import load_phenopacket_patients
from core.data_model import HPO, HPOs, Patient, Patients

benchmark_patients:Patients = load_phenopacket_patients(
    phenopacket_dir=os.path.join(DATA_DIR, "phenopackets"),
    ontology=ontology
)
print(benchmark_patients)

#### Benchmark model 1: Phen2Disease

Download prerequsite file

In [None]:
! pip install gdown
! gdown 1CSYfDj5fG9SsosIDlG-hLAoKp9eMHxjH
! gunzip lin_similarity_matrix.json.gz

In [None]:
from core.benchmark import get_pheno2disease

#### Cache에서 가져오기 모든 데이터: benchmark +disease

In [None]:
from core.datasets import StochasticPairwiseDataset, collate_for_stochastic_pairwise_eval
from torch.utils.data import DataLoader

benchmark_dataset = StochasticPairwiseDataset(
    benchmark_patients,
    disease_data,
    max_len=15,
)
benchmark_dataset.validate()

benchmark_dataloader = DataLoader(
    benchmark_dataset,
    batch_size=1,
    num_workers=0,
    collate_fn=collate_for_stochastic_pairwise_eval,
    shuffle=False,
    pin_memory=True,
)

# disease cache
import numpy as np
from collections import defaultdict
from tqdm import tqdm
from core_3asc.metric import topk_recall

cached_vector = {}
whole_disease = benchmark_dataset.disease_tensors
with torch.no_grad():
    for disease_id, tensor in tqdm(whole_disease.items()):
        cached_vector[disease_id] = best_model(tensor.cuda(3)).squeeze(0)

In [None]:
print(benchmark_dataset)
print(benchmark_dataloader)
p_samples = np.random.choice(range(len(benchmark_dataloader)), 300).tolist()
print(p_samples)
print(disease_data[:])
for i in benchmark_dataset:
    print(i)
    break

for i in omim_diseases:
    print(i)
    break

#### LaRa 가져오기


In [None]:
import torch
from core.datasets import (
    StochasticPairwiseDataset,
    collate_for_stochastic_pairwise_eval,
)
from torch.utils.data import DataLoader
from core.networks import Transformer

params = {
    "output_size": 128,
    "hidden_dim": 2048,
    "input_size": 1536,
    "n_layers": 32,
    "nhead": 32,
    "batch_first": False,
}
best_model = Transformer(**params)
best_model.load_state_dict(
    # torch.load(
    #     "/data/tyler_dev/working_dir/sym/symptom_similarity/data/81fcf4f39e57422db4debfeac61b01f0/val_top100_0.584.ckpt"
    # )
    torch.load(
        os.path.join(ROOT_DIR, "data", "val_top100_0.584.ckpt"),
        map_location='cuda:3'
    )
)
best_model.eval()
best_model = best_model.cuda(3)

#### Phen2Disease 모델 가져오기

#### 사내 증상유사도 모델 가져오기 + 계산수식

In [None]:
from omegaconf import OmegaConf
from SemanticSimilarity.calculator import NodeLevelSimilarityCalculator

conf = OmegaConf.load("/data1/benny_dev/symptom_similarity/SemanticSimilarity/config.yaml")
# 원본 알고리즘 이용한 계산
tb_cal = NodeLevelSimilarityCalculator(conf)
tb_cal.set_level()
tb_cal.set_mica_mat()


In [None]:

from SemanticSimilarity.data_model import Phenotype

def calculate_score(p, d):
    node_level = {}
    node_level[p.id] = {}
    p_syms = {Phenotype(id_, name) for id_, name in zip(p.hpos.id2hpo.keys(), p.hpos.name2hpo.keys())}
    d_syms = {Phenotype(id_, name) for id_, name in zip(d.hpos.id2hpo.keys(), d.hpos.name2hpo.keys())}
    score = tb_cal.get_semantic_similarity(p_syms, d_syms)
    return score

#### Phen2Disease 유사도 계산

In [None]:
import json
with open(os.path.join(DATA_DIR, "tyler_backup/lin_similarity_matrix.json"), "r") as f:
    similarity_matrix = json.load(f)

In [None]:
# 식 참고해서 계산
# t in P, t` in D
# s_p = sum_p(max_d(sim(t, t`))*IC(t)) / sum(IC(t))
# s_pd = (sum_p(max_d(sim(t, t`))*IC(t)) + sum_d(max_p(sim(t, t`))*IC(t`))) /  (sum(IC(t)) + sum(IC(t`)))

def get_sym(patient, disease):
    sum_sym = 0.
    sum_ic = 0.
    for p_sym in patient.hpos:
        max_sim = 0.
        for d_sym in disease.hpos:
            try:
                score = similarity_matrix[p_sym.id][d_sym.id] 
            except:
                score = 0 
                
            if score > max_sim:
                max_sim = score

        sum_sym += max_sim * p_sym.ic
        sum_ic += p_sym.ic
    
    return sum_sym, sum_ic


def get_pheno2disease(patient, disease):
    sum_sym_p, sum_ic_p = get_sym(patient, disease)
    sum_sym_d, sum_ic_d = get_sym(disease, patient)
    sym_pd = (sum_sym_p+sum_sym_d) / (sum_ic_p + sum_ic_d)

    pheno2disease = sym_pd + (sum_sym_p/sum_ic_p)

    return pheno2disease


#### LaRa 계산 수식

In [None]:
from core.augmentation import (
    TruncateOrPad,
)

padder = TruncateOrPad(
            15, stochastic=False, weighted_sampling=True
        )

def get_score_from_model(patient, disease):
    with torch.no_grad():
        input_src = padder(
            torch.tensor(patient.hpos.vector, dtype=torch.float32, device="cuda:3"), patient
        )
        target_src = padder(
            torch.tensor(disease.hpos.vector, dtype=torch.float32, device="cuda:3"), disease
        )

        input_vector = best_model(input_src)
        target_vector = best_model(target_src)
        # target_vector = cached_vector[disease.id]

        scores = (
            torch.nn.functional.cosine_similarity(
                input_vector, target_vector
            )
            .squeeze(-1)
            .detach()
            .cpu()
            .numpy()
        )

        del input_src
        del input_vector
        torch.cuda.empty_cache()

        return scores
    

def get_att_weight(disease):
    with torch.no_grad():
        input_src = padder(
            torch.tensor(disease.hpos.vector, dtype=torch.float32).cuda(3), disease
        )
        return best_model.get_att_weight(input_src)

In [None]:
for i in disease_data_sample:
    print(i.vector)
    break

# for i in benchmark_patients:
#     print(i)
#     break

#### 비교 평가


In [None]:
from tqdm import tqdm
from core_3asc.metric import topk_recall

disease_data_sample = [i for i in omim_diseases]
patient_data_sample = [i for i in benchmark_patients]

result = []
for p in tqdm(patient_data_sample):

    label = np.zeros((len(disease_data_sample), ))
    scores_base = np.zeros((len(disease_data_sample)),)
    scores_pd = np.zeros((len(disease_data_sample)), )
    scores_model = np.zeros((len(disease_data_sample)), )

    # p_vector = best_model(
    #     torch.from_numpy(patient.hpos.vector).cuda(3).float()
    # ).squeeze(0).detach().cpu().numpy()

    for i, d in enumerate(disease_data_sample):
        if d.id in p.disease_ids:
            label[i] = 1

        scores_base[i] = calculate_score(p, d)
        scores_pd[i] = get_pheno2disease(p, d)
        scores_model[i] = get_score_from_model(p, d)
    
    result.append({
        "p_id": p.id, 
        "scores_base": scores_base, 
        "scores_pd": scores_pd, 
        "scores_model": scores_model, 

        "top_1_base": topk_recall(scores_base, label, k=1),
        "top_1_pd": topk_recall(scores_pd, label, k=1),
        "top_1_model": topk_recall(scores_model, label, k=1),
        
        "top_5_base": topk_recall(scores_base, label, k=5),
        "top_5_pd": topk_recall(scores_pd, label, k=5),
        "top_5_model": topk_recall(scores_model, label, k=5),

        "top_10_base": topk_recall(scores_base, label, k=10),
        "top_10_pd": topk_recall(scores_pd, label, k=10),
        "top_10_model": topk_recall(scores_model, label, k=10),

        "top_15_base": topk_recall(scores_base, label, k=15),
        "top_15_pd": topk_recall(scores_pd, label, k=15),
        "top_15_model": topk_recall(scores_model, label, k=15),

        "top_20_base": topk_recall(scores_base, label, k=20),
        "top_20_pd": topk_recall(scores_pd, label, k=20),
        "top_20_model": topk_recall(scores_model, label, k=20),

        "top_30_base": topk_recall(scores_base, label, k=30),
        "top_30_pd": topk_recall(scores_pd, label, k=30),
        "top_30_model": topk_recall(scores_model, label, k=30),

        "top_40_base": topk_recall(scores_base, label, k=40),
        "top_40_pd": topk_recall(scores_pd, label, k=40),
        "top_40_model": topk_recall(scores_model, label, k=40),

        "top_50_base": topk_recall(scores_base, label, k=50),
        "top_50_pd": topk_recall(scores_pd, label, k=50),
        "top_50_model": topk_recall(scores_model, label, k=50),

        "top_75_base": topk_recall(scores_base, label, k=75),
        "top_75_pd": topk_recall(scores_pd, label, k=75),
        "top_75_model": topk_recall(scores_model, label, k=75),

        "top_100_base": topk_recall(scores_base, label, k=100),
        "top_100_pd": topk_recall(scores_pd, label, k=100),
        "top_100_model": topk_recall(scores_model, label, k=100),
        
    })

In [None]:
base_data = {f"top{i}":None for i in [1,5,10,15,20,30,40,50,75,100]}
print(base_data)

In [None]:
print(result)

In [None]:
import pandas as pd
result_df = pd.DataFrame(result)
result_df = result_df.set_index("p_id")
result_df = result_df[list(set(result_df.columns) - {'p_id', 'scores_base', 'scores_pd', 'scores_model'})]

print(result_df)

data = (result_df.sum(0) / len(result_df)).to_dict()
print(data)

print(scores_base)

# Initialize empty dictionaries for 'base', 'pd', and 'model' data
# base_data = {'top1': None, 'top10': None, 'top50': None, 'top100': None}
# pd_data = {'top1': None, 'top10': None, 'top50': None, 'top100': None}
# model_data = {'top1': None, 'top10': None, 'top50': None, 'top100': None}

base_data = {f"top{i}": None for i in [1,5,10,15,20,30,40,50,75,100]}
pd_data = {f"top{i}": None for i in [1,5,10,15,20,30,40,50,75,100]}
model_data = {f"top{i}": None for i in [1,5,10,15,20,30,40,50,75,100]}

# Organize the data according to 'base', 'pd', and 'model' categories
for key, value in data.items():
    if 'base' in key:
        base_data[f"top{key.split('_')[1]}"] = value
    elif 'pd' in key:
        pd_data[f"top{key.split('_')[1]}"] = value
    elif 'model' in key:
        model_data[f"top{key.split('_')[1]}"] = value

# Create a DataFrame with 'base', 'pd', and 'model' as rows and 'top1', 'top10', 'top50', 'top100' as columns
df = pd.DataFrame([base_data, pd_data, model_data], index=['baseline', 'Pheno2Disease', 'LLM-based'])
df.index.name = 'Method'

# Display the DataFrame
print(df)

In [None]:
# Figure 3

import matplotlib.pyplot as plt
import pandas as pd

fig_df = pd.DataFrame([base_data, pd_data, model_data], index=['Resnik-based IC', 'Pheno2Disease', 'LaRa'])
fig_df.index.name = 'Method'


# Example DataFrame
# Plotting the recall curve
plt.figure(figsize=(6, 6))
for index, row in fig_df.iterrows():
    plt.plot(list(row.index), list(row.values), marker='o', label=index)

plt.title('Real world dataset: rare disease patient data')
plt.xlabel('Top-k')
plt.ylabel('Top-k Recall')
plt.xticks(rotation=45)
plt.legend()
plt.grid(False)
plt.tight_layout()
plt.show()