In [1]:
import pandas as pd
import numpy as np
import torch
import networkx as nx
import pickle


In [2]:
def get_Nodes(file_path):
    print('reading the nodes from the graph...')
    G = nx.read_gml(file_path)    

    Nodes = list(G.nodes())

    C_Nodes = [v for v in Nodes if v[0]=='C']
    V_Nodes = [v for v in Nodes if v[0]=='V']
    M_Nodes = [v for v in Nodes if v[0]=='M']
    D_Nodes = [v for v in Nodes if v[0]=='D']
    P_Nodes = [v for v in Nodes if v[0]=='P']

    print(f'number of patients = {len(C_Nodes)}')
    print(f'number of visits = {len(V_Nodes)}')
    print(f'number of Medication = {len(M_Nodes)}')
    print(f'number of Diagnoses = {len(D_Nodes)}')
    print(f'number of Procedures = {len(P_Nodes)}')

    return C_Nodes, V_Nodes, M_Nodes, D_Nodes, P_Nodes, G

In [3]:
def read_embedding():
    print('Reading the embedding of the entire MIMIC dataset...')

    embedding_df = pd.read_csv(f'../Data/grouped_emb2.csv')

    embedding_df = embedding_df.rename(columns={'SUBJECT_ID': 'Patient'})#, 'Embedding': 'Reduced_Embedding'})

    # convert the df to a dict of {patient: list of }

    patient_emb_dict = embedding_df.set_index('Patient')['Embedding'].to_dict()

    df_patients = embedding_df['Patient'].unique()

    print(f'number of patients in the embedding file is {len(df_patients)}')
    print('\tReturning the list of patients and their embedding...')
    return df_patients, patient_emb_dict, embedding_df

In [4]:
def proc_string_list(s):
    import re
    # Remove brackets and newline characters, then split based on spaces
    numbers_str = re.sub(r'[\[\]\n]', '', s).split()

    # Convert each element to float, then to int (by rounding)
    return [float(num) for num in numbers_str if num]



In [5]:
file_path = 'results/the_complete_hetero_graph.gml'

C_Nodes, V_Nodes, M_Nodes, D_Nodes, P_Nodes, G = get_Nodes(file_path)

Nodes = C_Nodes + V_Nodes + M_Nodes + D_Nodes + P_Nodes

df_patients, patient_emb_dict, embedding_df = read_embedding()

embedding_df.head(5)


reading the nodes from the graph...
number of patients = 10128
number of visits = 13620
number of Medication = 364
number of Diagnoses = 8
number of Procedures = 88
Reading the embedding of the entire MIMIC dataset...
number of patients in the embedding file is 46145
	Returning the list of patients and their embedding...


Unnamed: 0.1,Unnamed: 0,Patient,Embedding
0,0,2,[ 1.95153690e-02 5.16873449e-01 -2.77101226e-...
1,1,3,[ 5.07365704e+00 8.60597286e+00 -8.62598368e+...
2,2,4,[ 2.78797361e+00 2.62662276e-01 -3.01898561e+...
3,3,5,[-3.46906440e-03 3.76553104e-01 -6.26694559e-...
4,4,6,[ 3.62095121e-01 6.58498007e-01 -3.55797152e+...


In [6]:
def mark_all_nodes(C_Nodes, V_Nodes, M_Nodes, D_Nodes, P_Nodes):
    print('Assingning embedding to the other nodes.')
    print('\t The GNN model shall identify the other nodes as well, right?.')

    nodes_by_type = {'C':0, 'V':1, 'M': 2, 'D': 3, 'P':4}

    nodes_list = {'C': C_Nodes, 'V': V_Nodes, 'M': M_Nodes, 'D': D_Nodes, 'P': P_Nodes}

    emb = {}
    for n in Nodes:
        t = nodes_by_type[n[0]]
        emb[n] = [t, nodes_list[n[0]].index(n)]

    print('\tDone --> Assingning embedding to the other nodes.')
    return emb

In [7]:
int_Patients     = [int(p[2:]) for p in C_Nodes]
patient_list     = [p for p in int_Patients if p in df_patients]  
ignored_patients = [p for p in int_Patients if p not in df_patients] 
print(f'Number of Ignored patients are : {len(ignored_patients)}')

# --------------------------------------------------------------------

emb = mark_all_nodes(C_Nodes, V_Nodes, M_Nodes, D_Nodes, P_Nodes)


Number of Ignored patients are : 64
Assingning embedding to the other nodes.
	 The GNN model shall identify the other nodes as well, right?.
	Done --> Assingning embedding to the other nodes.


In [8]:
emb

{'C_115': [0, 0],
 'C_123': [0, 1],
 'C_124': [0, 2],
 'C_132': [0, 3],
 'C_79': [0, 4],
 'C_81': [0, 5],
 'C_92': [0, 6],
 'C_94': [0, 7],
 'C_96': [0, 8],
 'C_3': [0, 9],
 'C_6': [0, 10],
 'C_12': [0, 11],
 'C_18': [0, 12],
 'C_21': [0, 13],
 'C_25': [0, 14],
 'C_32': [0, 15],
 'C_41': [0, 16],
 'C_63': [0, 17],
 'C_247': [0, 18],
 'C_253': [0, 19],
 'C_255': [0, 20],
 'C_177': [0, 21],
 'C_184': [0, 22],
 'C_188': [0, 23],
 'C_198': [0, 24],
 'C_302': [0, 25],
 'C_304': [0, 26],
 'C_307': [0, 27],
 'C_309': [0, 28],
 'C_319': [0, 29],
 'C_140': [0, 30],
 'C_164': [0, 31],
 'C_165': [0, 32],
 'C_173': [0, 33],
 'C_174': [0, 34],
 'C_146': [0, 35],
 'C_149': [0, 36],
 'C_156': [0, 37],
 'C_107': [0, 38],
 'C_200': [0, 39],
 'C_205': [0, 40],
 'C_214': [0, 41],
 'C_221': [0, 42],
 'C_223': [0, 43],
 'C_226': [0, 44],
 'C_228': [0, 45],
 'C_234': [0, 46],
 'C_235': [0, 47],
 'C_239': [0, 48],
 'C_243': [0, 49],
 'C_263': [0, 50],
 'C_268': [0, 51],
 'C_438': [0, 52],
 'C_464': [0, 53],


In [9]:
# --------------------------------------------------------------------
print('Delete the ignored patient from Nodes and emb')
for r in ignored_patients:
    p = f'C_{r}'
    ind = Nodes.index(p)
    Nodes.remove(p)
    emb.pop(p)



Delete the ignored patient from Nodes and emb


In [10]:
w = len(proc_string_list(patient_emb_dict[list(patient_emb_dict.keys())[0]]))
print(w)

temp_emb = {}
for v in Nodes:
    if v[0]=='C':
        p = int(v[2:])
        temp = proc_string_list(patient_emb_dict[p]) + emb[v]
    else:
        temp = [0] * w + emb[v]
    temp_emb[v] = temp

x = torch.tensor([temp_emb[v] for v in temp_emb], dtype=torch.float)

torch.save(x, 'results/X.pt')

x.shape

768


torch.Size([24144, 770])

In [11]:
with open('results/Nodes.pkl', 'wb') as file:
    pickle.dump(Nodes, file)

In [12]:
nodes_to_delete = [f'C_{i}' for i in ignored_patients]

print(f'Number of nodes before removal: {len(G.nodes())}')

for nd in nodes_to_delete:
    if nd in G.nodes():
        G.remove_node(nd)

print(f'Number of nodes after removal: {len(G.nodes())}')

print('Saving the new graph...')

# Now, G has the ultimate structure. We should save it, right??
nx.write_gml(G, 'results/the_complete_hetero_graph2.gml')


Number of nodes before removal: 24208
Number of nodes after removal: 24144
Saving the new graph...
