In [None]:
import pickle
from pyhealth.tokenizer import Tokenizer
task_dataset = pickle.load(open('retain_patient_time_mimic3_0.05/mimic3_box_dataset_0.05.pkl', 'rb'))

ccs9 = pickle.load(open('ccs9.pkl','rb'))
label_tokenizer = Tokenizer(tokens=ccs9)
ccs_idx=sorted(ccs9)

In [None]:
import torch
def batch_to_multihot(label, num_labels: int) -> torch.tensor:

    multihot = torch.zeros((len(label), num_labels))
    for i, l in enumerate(label):
        multihot[i, l] = 1
    return multihot

labels=[['99'],['12']]

labels_index = label_tokenizer.batch_encode_2d(
    labels, padding=False, truncation=False
)
num_labels = label_tokenizer.get_vocabulary_size()
labels = batch_to_multihot(labels_index, num_labels)



In [None]:
import pickle
import numpy as np

import pickle


retain_result = pickle.load(open(f'retain_patient_time_mimic3_0.05/patient_time_m3_0.05_y_results.pkl', 'rb'))

y_true = retain_result['y_true']  # shape: (N, num_labels)
y_prob = retain_result['y_prob']  # shape: (N, top_k indices)

retain_y_prob_code=[]
for i, (row1, row2) in enumerate(zip(y_true, y_prob)):
    t = np.where(row1 == 1)[0].tolist()  # ground truth labels index

    t = [ccs_idx[j] for j in t]
    p = np.argsort(-row2)
    p = [ccs_idx[j] for j in p]
    retain_y_prob_code.append(p)

In [None]:
from torch.utils.data import Dataset
class MMDataset(Dataset):
    def __init__(self, dataset):
        self.sequence_dataset = []
        for visit in dataset.samples:
            self.sequence_dataset.append(visit)


    def __len__(self):
        return len(self.sequence_dataset)

    def __getitem__(self, idx):
        sequence_data = self.sequence_dataset[idx]
        return sequence_data, idx
def custom_collate_fn(batch):
    sequence_data_list = [item[0] for item in batch]
    graph_data_list = [item[1] for item in batch]

    sequence_data_batch = {key: [d[key] for d in sequence_data_list if d[key]!=[]] for key in sequence_data_list[0]}

    graph_data_batch = graph_data_list

    return sequence_data_batch, graph_data_batch



In [None]:
import torch
from torch.utils.data import Subset
mdataset = MMDataset(task_dataset)
indices = torch.load('retain_patient_time_mimic3_0.05/'+'/testset.pt')
testset = Subset(mdataset, indices)

In [None]:

ccs_sorted_pred_ccs_list = pickle.load(open('retain_patient_time_mimic3_0.05/ccs_ccs.pkl', 'rb'))
visit_sorted_ccs_list=pickle.load(open('retain_patient_time_mimic3_0.05/visit_ccs.pkl', 'rb'))


In [None]:
from torch.utils.data import DataLoader
import numpy as np
diag_test_data=[]

dataloader = DataLoader(testset, batch_size=1, shuffle=False)
y_true = retain_result['y_true']  
y_prob = retain_result['y_prob'] 
num=0
for i, (data, row1, row2) in enumerate(zip(dataloader, y_true, y_prob)):

    t = np.where(row1 == 1)[0].tolist()  

    t = [ccs_idx[j] for j in t]
    p = np.argsort(-row2)
    p = [ccs_idx[j] for j in p]


    temp=data[0].copy()
    temp['ccs_answer']=t
    temp['retain_top']=p
    temp['sorted_ccs']=visit_sorted_ccs_list[num]

    num+=1
    diag_test_data.append(temp)


In [None]:

from datetime import datetime

unsorted_count = 0
print(len(diag_test_data))
for item in diag_test_data:
    times = [datetime.strptime(t[0], '%Y-%m-%d %H:%M') for t in item['adm_time']]
    if times != sorted(times):
        unsorted_count += 1


In [None]:
import pandas as pd


df = pd.read_csv("CCS9.csv", encoding="utf-8")


ccs_to_name = dict(zip(df["code"], df["name"]))

print(ccs_to_name[101])  

In [None]:
import pandas as pd

df = pd.read_csv("ICD9CM.csv", encoding="utf-8")

df = df[~df["code"].astype(str).str.contains("-")]

df["code"] = df["code"].astype(str).str.replace(".", "", regex=False)
df["parent_code"] = df["parent_code"].astype(str).str.replace(".", "", regex=False)

icd_to_info = dict(zip(
    df["code"],
    df[["name", "parent_code"]].to_dict(orient="records")
))



In [None]:
import pandas as pd

df = pd.read_csv("ICD9CM_to_CCSCM.csv", encoding="utf-8")
df["ICD9CM"] = df["ICD9CM"].astype(str).str.replace(".", "", regex=False)

ICD9CM_to_CCSCM = dict(zip(df["ICD9CM"], df["CCSCM"]))


In [None]:
from collections import defaultdict

def get_diag_ccs_ontology_prior_ccs(hist_diag, sorted_ccs):
    ccs_icd_dict = defaultdict(list)
    for icd_code in hist_diag:
        ccs_code = ICD9CM_to_CCSCM[icd_code]
        ccs_icd_dict[ccs_code].append(icd_code)
        if str(ccs_code) not in sorted_ccs:
            print("warning!!!!!")
    sorted_hist_ccs = [int(ccs) for ccs in sorted_ccs if int(ccs) in ccs_icd_dict]

    sorted_dict = {}
    for ccs in sorted_hist_ccs:
        if ccs in ccs_icd_dict:
            icd_list = ccs_icd_dict[ccs]
            sorted_icds = icd_list
            sorted_dict[ccs] = sorted_icds
    ccs_icd_dict=sorted_dict


    final_diag_str="{"
    for ccs_code,icd_list in ccs_icd_dict.items():
        temp_str="[("
        for icd in icd_list:
            name=icd_to_info[icd]['name']
            temp_str+=f'''"{name}", '''
        temp_str=temp_str[:-2]
        temp_str+=") BELONGS TO "
        ccs_name=ccs_to_name[int(ccs_code)]
        temp_str+= f'''"{ccs_name}"'''
        temp_str+="], "
        final_diag_str+=temp_str
    final_diag_str=final_diag_str[:-2]
    final_diag_str+="}"
    return final_diag_str


In [None]:
import inflect

def number_to_capitalized_ordinal(n):
    p = inflect.engine()
    return p.ordinal(p.number_to_words(n)).capitalize()




In [None]:
def ccs_codes_to_names(ccs_list):
    name_list = []
    for code in ccs_list:
        name = ccs_to_name.get(int(code), f"Unknown CCS: {code}")
        name_list.append(f'''"{name}"''')
    return name_list

def icd_codes_to_names(icd_list):

    name_list = []
    for code in icd_list:
        name = icd_to_info.get(code, {"name": f"Unknown ICD: {code}"}).get("name")
        name_list.append(f'''"{name}"''')
    return name_list


In [None]:

from datetime import datetime

total_filter_path_length = 0
total_path_length = 0
result = []
all_filter=[]
for item in diag_test_data:


    candidate_ccs = item['retain_top']

    sorted_ccs = item['sorted_ccs']

    hist_ccs_icd_dict=defaultdict(list)
    hist_ccs=[]
    final_hist_diag=[]
    for num,cond_hist in enumerate(item['cond_hist']):
        
        for t in cond_hist:
            final_hist_diag.append(t[0])
            ccs_code = ICD9CM_to_CCSCM[t[0]]
            hist_ccs.append(int(ccs_code))
            hist_ccs_icd_dict[ccs_code].append(t[0])
    hist_ccs=list(set(hist_ccs))
    

    candidate_ccs_str = ""

    filter_candidate_ccs=[]
    temp_num=0
    for index, code in enumerate(candidate_ccs):
        name = ccs_to_name[int(code)]

        filter_candidate_ccs.append(code)
        candidate_ccs_str += f'''"{name}", '''
    candidate_ccs_str = candidate_ccs_str[:-2]


    ccs_answer_idx = []
    
    all_filter.append(temp_num)
    
    hist_diag_str=""
    cur_time=item['adm_time'][-1][0]

    for num,cond_hist in enumerate(item['cond_hist']):
        hist_diag = []
        for t in cond_hist:
            hist_diag.append(t[0])
        hist_diag = list(set(hist_diag))
        if len(hist_diag)!=0:
            hist_time=item['adm_time'][num][0]

            time_1 = hist_time
            time_2 = cur_time
            
            time_1 = datetime.strptime(time_1, '%Y-%m-%d %H:%M')
            time_2 = datetime.strptime(time_2, '%Y-%m-%d %H:%M')
            
            time_difference = time_2 - time_1
            
            years = time_difference.days // 365
            months = (time_difference.days % 365) // 30
            days = (time_difference.days % 365) % 30
            parts = []
            if years:
                parts.append(f"{years} years")
            if months:
                parts.append(f"{months} months")
            if days:
                parts.append(f"{days} days")
            
            if parts:
                time_str = " ".join(parts) + " ago"
            else:
                time_str = "today"
            hist_diag_str+=number_to_capitalized_ordinal(num+1)
            hist_diag_str+=f" Visit({time_str}): "
            hist_diag_str+=get_diag_ccs_ontology_prior_ccs(hist_diag,sorted_ccs)

            hist_diag_str+="\n\n"

    instruction=("You are a medical diagnosis expert. Your task is to re-rank the provided candidate diseases (CCS categories) based on:\n"
                 +"- The patient’s past diagnosis history"
                 +f"\n- The Model Evidence\n"
                 +"\n- Your medical knowledge\n\n---\n\n"
                 +"The Model Evidence:\n- The deep model provides evidence by associating each candidate CCS with one or more historical diagnosis.\n- These associations reflect statistically learned medical co-occurrence patterns and should be treated as supportive clues, not as output content.\n"
                
                  f"\n\nYour task:\n- Re-rank the candidate CCS categories from most to least likely."
                 +"\nDirectly provide the reordered list of disease names in descending order of likelihood. \nOutput format:\nAnswer: <CCS name 1>, <CCS name 2>, ...\n\n---\n")

    input_="Patient history summary:\n"+hist_diag_str+"\n\n"
    input_+="The Model Evidence:\n"
    path_temp=""
    real_path_idx=0
    sorted_hist_ccs=[ccs for ccs in sorted_ccs if int(ccs) in hist_ccs]

    for path_idx,ccs_code in enumerate(hist_ccs):

        if str(ccs_code) not in ccs_sorted_pred_ccs_list:
            continue
        pred_ccs_code_list= ccs_sorted_pred_ccs_list[str(ccs_code)][0:1]
        target_code=None
        for pred_ccs_code in pred_ccs_code_list:
            if int(pred_ccs_code) in hist_ccs:
                continue
            if str(pred_ccs_code) not in candidate_ccs:
                continue
            target_code=pred_ccs_code
            break
        if target_code is None:
            continue
        total_filter_path_length+=1
        icd_code_list=hist_ccs_icd_dict[int(ccs_code)]
        icd_name_temp=""

        for icd_code in icd_code_list:
            icd_name=f'''"{icd_to_info.get(str(icd_code)).get("name")}" '''
            icd_name_temp+=icd_name
        icd_name_temp=icd_name_temp[:-1]
        icd_name=icd_name_temp
        hist_ccs_name=f'''"{ccs_to_name[int(ccs_code)]}"'''
        pred_ccs_name=f'''"{ccs_to_name[int(target_code)]}"'''
        path_temp+=f"{real_path_idx+1}."+"{"+f"{icd_name}"+"} → "+"{"+f"{hist_ccs_name}"+"} → "+"{"+f"{pred_ccs_name}"+"}\n"
        real_path_idx+=1
    input_+=path_temp

    input_+="\nCandidate diseases:\n{"+f"{candidate_ccs_str}"+"}\n\n"

    answer_ccs=item['ccs_answer']
    hist_ccs_str=[str(idx) for idx in hist_ccs]
    new_ccs= set(answer_ccs)-set(hist_ccs_str)

    
    
    temp={'visit_id':item['visit_id'],'patient_id':item['patient_id'],"instruction":instruction,
          "input":input_,"ccs_answer":item['ccs_answer'],"new_ccs_answer":list(new_ccs),"candidate_ccs":candidate_ccs}
    result.append(temp)



In [None]:
import json
with open(f"dataset.json", "w", encoding="utf-8") as f:

    json.dump(result, f, ensure_ascii=False, indent=2)