In [2]:
import pandas as pd
import os, shutil
import torch
import numpy as np
import matplotlib
import sys; sys.path.append("..")
import torch_geometric; print(torch_geometric.__version__)
import torch_geometric.transforms as T
import utils.constant as constant

from tqdm.notebook import tqdm, tqdm_notebook
from torch_geometric.data import HeteroData
from torch_geometric.loader import DataLoader
from torch_geometric.utils import negative_sampling
from dataset.hgs import MyOwnDataset

2.3.0


# Design

Heterogeneous Graph:

- Nodes: 
  - hadm_ids(can merge with patients by subject_id)
  - labitems
- Node Features:
  - hadm_ids: 
    - ADMISSION_TYPE
    - ADMISSION_LOCATION
    - DISCHARGE_LOCATION
    - INSURANCE
    - LANGUAGE
    - RELIGION
    - MARITAL_STATUS
    - ETHNICITY
    - TOP-10 (or TOP-?) diagnoses_icd record, 10 demensions vecter with ICD-9 code (should be rearranged by frequency?) respective
  - labitems:
    - FLUID
    - CATEGORY
  - NDCs(medication):
    - DRUG_TYPE
    - FORM_UNIT_DISP
- Edges:
  - hadm_id *did* labitem
  - hadm_id *took* NDC
- Edge_features:
  - hadm_id *did* labitem
    - z-score (for value type labitems)
    - category (for non-value type labitems)
    - ~~FLAG~~ (Don't use this, because many entry were found with obvious abnormal value but without abnormal flag!)
  - hadm_id *took* NDC
    - DRUG_TYPE
    - PROD_STRENGTH
    - DOSE_VAL_RX
    - DOSE_UNIT_RX
    - FORM_VAL_DISP
    - FORM_UNIT_DISP
    - ROUTE
- Labels: 
  - edge-level prediction task
  - predict weather edges exist at $T_n$ base on graph from $T_{n-1}$ to $T_0$
- Timesteps:
  - half a day as interval? or full a day? 
    - **Finally, we chose full a day (24 hours)**
  - automatically omit vacant day (without labevent entry), skip to next non-vacant day as next timestep
- Temporal graph shape: 
  - Dynamic
  - edge change over time

Note: 
In practice, a patient won't have *diagnoses_icd* when he/she first enters a hospital, and the *diagnoses_icd* records don't have timestamp. 
So, the connection between *hadm_ids* and *diagnoses_icd* is a final result after all *lebitems* being executed. 
Hence, here raises 3 choices, 
- letting *diagnoses_icd* acts as **node** which exists connections between *hadm_ids*
- letting *diagnoses_icd* acts as **node_feature** of *hadm_ids* (Being choosen)
- **don't use it**

In [3]:
path_dataset = constant.PATH_MIMIC_III_ETL_OUTPUT
# list_csvgz = [f for f in os.listdir(path_dataset) if f[-6:]=="csv.gz"]

df_admissions    = pd.read_csv(os.path.join(path_dataset, "ADMISSIONS_NEW.csv.gz"))
df_labitems      = pd.read_csv(os.path.join(path_dataset, "D_LABITEMS_NEW.csv.gz"))
df_labevents     = pd.read_csv(os.path.join(path_dataset, "LABEVENTS_PREPROCESSED.csv.gz"))
df_prescriptions = pd.read_csv(os.path.join(path_dataset, "PRESCRIPTIONS_PREPROCESSED.csv.gz"))
df_drug_ndc_feat = pd.read_csv(os.path.join(path_dataset, "DRUGS_NDC_FEAT.csv.gz"))

  has_raised = await self.run_ast_nodes(code_ast.body, cell_name,


# Pre-check

In [4]:
print(df_labitems["FLUID"].isnull().any(), df_labitems["CATEGORY"].isnull().any())
list_selected_admission_columns = ['ADMISSION_TYPE', 
                                   'ADMISSION_LOCATION', 
                                   'DISCHARGE_LOCATION', 
                                   'INSURANCE', 
                                   'LANGUAGE', 
                                   'RELIGION', 
                                   'MARITAL_STATUS', 
                                   'ETHNICITY']
df_admissions[list_selected_admission_columns].isnull().any()

False False


ADMISSION_TYPE        False
ADMISSION_LOCATION    False
DISCHARGE_LOCATION    False
INSURANCE             False
LANGUAGE              False
RELIGION              False
MARITAL_STATUS        False
ETHNICITY             False
dtype: bool

In [9]:
df_admissions[list_selected_admission_columns].describe()

Unnamed: 0,ADMISSION_TYPE,ADMISSION_LOCATION,DISCHARGE_LOCATION,INSURANCE,LANGUAGE,RELIGION,MARITAL_STATUS,ETHNICITY
count,58976.0,58976.0,58976.0,58976.0,58976.0,58976.0,58976.0,58976.0
mean,1.462612,2.149145,2.90703,1.711222,0.992573,2.825166,1.530487,2.190773
std,0.803654,1.152766,2.162587,0.841009,2.667782,2.217806,1.185236,3.136717
min,1.0,1.0,1.0,1.0,0.0,0.0,0.0,1.0
25%,1.0,1.0,1.0,1.0,0.0,1.0,1.0,1.0
50%,1.0,2.0,2.0,2.0,1.0,2.0,1.0,1.0
75%,2.0,3.0,4.0,2.0,1.0,4.0,2.0,2.0
max,4.0,9.0,17.0,5.0,75.0,20.0,7.0,41.0


In [4]:
print(df_labevents["CATAGORY"].isnull().any())
print(df_labevents["VALUENUM_Z-SCORED"].isnull().any())
print(df_labevents["TIMESTEP"].isnull().any())

False
False
False


In [8]:
df_labevents[['CATAGORY', 'VALUENUM_Z-SCORED', 'TIMESTEP']].describe()

Unnamed: 0,CATAGORY,VALUENUM_Z-SCORED,TIMESTEP
count,15309100.0,15309100.0,15309100.0
mean,0.261844,270108.2,8.990683
std,8.049515,24870570.0,13.43537
min,0.0,-635000000.0,0.0
25%,0.0,-0.9556186,1.0
50%,0.0,0.0,5.0
75%,0.0,1.28508,11.0
max,1351.0,7730000000.0,291.0


In [6]:
df_prescriptions[["DRUG_TYPE", 
                  "PROD_STRENGTH", 
                  "DOSE_VAL_RX", 
                  "DOSE_UNIT_RX", 
                  "FORM_VAL_DISP", 
                  "FORM_UNIT_DISP", 
                  "ROUTE"]].describe()

Unnamed: 0,DRUG_TYPE,PROD_STRENGTH,DOSE_VAL_RX,DOSE_UNIT_RX,FORM_VAL_DISP,FORM_UNIT_DISP,ROUTE
count,10189840.0,10189840.0,10189840.0,10189840.0,10189840.0,10189840.0,10189840.0
mean,1.135368,1.138353,1.737589,3.141555,1.748559,3.854,3.089786
std,0.3464677,0.4612439,4.069737,4.108374,4.048352,3.77013,3.495601
min,1.0,1.0,0.0,0.0,1.0,0.0,1.0
25%,1.0,1.0,1.0,1.0,1.0,1.0,1.0
50%,1.0,1.0,1.0,2.0,1.0,3.0,2.0
75%,1.0,1.0,1.0,3.0,1.0,5.0,4.0
max,3.0,13.0,389.0,89.0,385.0,52.0,74.0


In [4]:
df_drug_ndc_feat.describe()

Unnamed: 0.1,Unnamed: 0,NDC,DRUG_TYPE_MAIN_Proportion,DRUG_TYPE_BASE_Proportion,DRUG_TYPE_ADDITIVE_Proportion,FORM_UNIT_DISP_Freq_1,FORM_UNIT_DISP_Freq_2,FORM_UNIT_DISP_Freq_3,FORM_UNIT_DISP_Freq_4,FORM_UNIT_DISP_Freq_5,rxnorm_id
count,4294.0,4294.0,4294.0,4294.0,4294.0,4294.0,4294.0,4294.0,4294.0,4294.0,4101.0
mean,2146.5,18723950000.0,0.997196,0.002804,0.0,4.638333,0.539823,0.090824,0.018398,0.003959,652941.3
std,1239.715357,26255640000.0,0.030354,0.030354,0.0,5.546855,3.060029,1.45981,0.838669,0.259429,482806.3
min,0.0,1.0,0.510132,0.0,0.0,1.0,0.0,0.0,0.0,0.0,93006.0
25%,1073.25,74231650.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,212219.0
50%,2146.5,409402100.0,1.0,0.0,0.0,3.0,0.0,0.0,0.0,0.0,544393.0
75%,3219.75,51079030000.0,1.0,0.0,0.0,6.0,0.0,0.0,0.0,0.0,902317.0
max,4293.0,87701090000.0,1.0,0.489868,0.0,50.0,46.0,43.0,52.0,17.0,1992303.0


In [10]:
df_labitems[["FLUID", "CATEGORY"]].describe()

Unnamed: 0,FLUID,CATEGORY
count,753.0,753.0
mean,2.945551,1.618858
std,2.742196,0.874895
min,1.0,1.0
25%,1.0,1.0
50%,2.0,1.0
75%,4.0,2.0
max,16.0,6.0


# Preprocessing

In [6]:
def get_list_total_patient_num(*list_df_single_edges_type: list):
    if len(list_df_single_edges_type) <= 1:
        return len(list_df_single_edges_type[0].SUBJECT_ID.unique())
    else:
        list_set_subjid_single_edges_type = [set(list(df_single_edges_type.SUBJECT_ID.unique()))
                                             for df_single_edges_type in list_df_single_edges_type]
        return len(set.intersection(*list_set_subjid_single_edges_type))
    
get_list_total_patient_num(df_labevents, df_prescriptions)

38817

## Split of train-val-test by `HADM_ID`

In [14]:
def get_list_total_hadmid(*list_df_single_edges_type: list):
    r"""
    Get the interset of hadmid from df(s) which record(s) the edge connection.
    """
    if len(list_df_single_edges_type) <= 1:
        return list(list_df_single_edges_type[0].HADM_ID.unique())
    else:
        list_set_hadmid_single_edges_type = [set(list(df_single_edges_type.HADM_ID.unique()))
                                             for df_single_edges_type in list_df_single_edges_type]
        return list(set.intersection(*list_set_hadmid_single_edges_type))

list_total_hadmid = get_list_total_hadmid(df_labevents, df_prescriptions)
len(list_total_hadmid)

49491

In [15]:
def get_train_val_test_hadmid_list(list_total_hadmid, split_ratio: float, shuffle: bool):
    np.random.shuffle(list_total_hadmid) if shuffle else None
    
    length = len(list_total_hadmid)
    list_train_hadmid = list_total_hadmid[0:int(length*split_ratio)]
    list_val_hadmid = list_total_hadmid[int(length*split_ratio):]
    
    return list(list_train_hadmid), list(list_val_hadmid)

list_train_hadmid, list_val_hadmid = get_train_val_test_hadmid_list(list_total_hadmid, 0.8, shuffle=True)
len(list_train_hadmid), len(list_val_hadmid)

(39592, 9899)

In [20]:
# def batches_spliter(list_hadmid: list, 
#                     df_admissions, 
#                     df_labevents, 
#                     df_prescriptions,
#                     batch_size: int):
#     r"""split dfs into many batches by `HADM_ID`."""
#     idx = 0
#     length = len(list_hadmid)
#     batches_hadmids = []
#     while (idx+batch_size) <= length:
#         batches_hadmids += [list_hadmid[idx:idx+batch_size]]
#         idx += batch_size
#     if idx < length:
#         batches_hadmids += [list_hadmid[idx:]]
    
    
#     list_df_admissions_single_batch    = []
#     list_df_labevents_single_batch     = []
#     list_df_prescriptions_single_batch = []
#     for batch_hadmids in tqdm(batches_hadmids):
#         list_df_admissions_single_batch.append(
#             df_admissions[df_admissions.HADM_ID.isin(batch_hadmids)].copy()
#         )
#         list_df_labevents_single_batch.append(
#             df_labevents[df_labevents.HADM_ID.isin(batch_hadmids)].copy()
#         )
#         list_df_prescriptions_single_batch.append(
#             df_prescriptions[df_prescriptions.HADM_ID.isin(batch_hadmids)].copy()
#         )
    
#     return list_df_admissions_single_batch, \
#            list_df_labevents_single_batch, \
#            list_df_prescriptions_single_batch

In [19]:
def batches_spliter(list_hadmid: list, batch_size: int, *dfs: pd.DataFrame):
    r"""
    Split the df(s) into many batches by the `HADM_ID` batch.

    Note: the order of returned list_dfs is consistent with the order of dfs passed in.
    """
    idx = 0
    length = len(list_hadmid)
    batches_hadmids = []
    while (idx + batch_size) <= length:
        batches_hadmids += [list_hadmid[idx:idx + batch_size]]
        idx += batch_size
    if idx < length:
        batches_hadmids += [list_hadmid[idx:]]

    list_list_dfs = [[] for _ in range(len(dfs))]
    for batch_hadmids in tqdm(batches_hadmids):
        for i, df in enumerate(dfs):
            list_list_dfs[i].append(
                df[df.HADM_ID.isin(batch_hadmids)].copy()
            )

    return list_list_dfs

# TODO: try 64? 256? 512? 1024? ...
batch_size = 128
list_df_admissions_single_batch_train, list_df_labevents_single_batch_train, list_df_prescriptions_single_batch_train = batches_spliter(list_train_hadmid, batch_size, df_admissions, df_labevents, df_prescriptions)
list_df_admissions_single_batch_val,   list_df_labevents_single_batch_val,   list_df_prescriptions_single_batch_val   = batches_spliter(list_val_hadmid,   batch_size, df_admissions, df_labevents, df_prescriptions)

  0%|          | 0/310 [00:00<?, ?it/s]

  0%|          | 0/78 [00:00<?, ?it/s]

# Constructing dynamic graph

In [13]:
def construct_dynamic_hetero_graph(df_admissions_curr, 
                                   df_labitems_curr, 
                                   df_labevents_curr, 
                                   df_drug_ndc_feat_curr, 
                                   df_prescriptions_curr):
    
    ############################################### Nodes #######################################################
    ## admission
    df_admissions_curr.sort_values(by='HADM_ID', inplace=True)
    list_selected_admission_columns = ['ADMISSION_TYPE', 
                                       'ADMISSION_LOCATION', 
                                       'DISCHARGE_LOCATION', 
                                       'INSURANCE', 
                                       'LANGUAGE', 
                                       'RELIGION', 
                                       'MARITAL_STATUS', 
                                       'ETHNICITY']
    nodes_feature_admission_curr = torch.from_numpy(df_admissions_curr[list_selected_admission_columns].values)
    
    ## labitems
    df_labitems_curr.sort_values(by='ITEMID', inplace=True)
    list_selected_labitems_columns = ['FLUID', 'CATEGORY']
    nodes_feature_labitems = torch.from_numpy(df_labitems_curr[list_selected_labitems_columns].values)
    
    ## drug_ndc
    df_drug_ndc_feat_curr.sort_values(by='NDC', inplace=True)
    list_selected_drug_ndc_columns = ["DRUG_TYPE_MAIN_Proportion", 
                                      "DRUG_TYPE_BASE_Proportion", 
                                      "DRUG_TYPE_ADDITIVE_Proportion", 
                                      "FORM_UNIT_DISP_Freq_1", 
                                      "FORM_UNIT_DISP_Freq_2", 
                                      "FORM_UNIT_DISP_Freq_3", 
                                      "FORM_UNIT_DISP_Freq_4", 
                                      "FORM_UNIT_DISP_Freq_5"]
    nodes_feature_drug_ndc = torch.from_numpy(df_drug_ndc_feat_curr[list_selected_drug_ndc_columns].values)


    ############################################### Edges #######################################################
    df_labevents_curr.sort_values(by=["HADM_ID", "ITEMID"], inplace=True)
    df_prescriptions_curr.sort_values(by=["HADM_ID", "NDC"], inplace=True)
    
    ## Edge indexes
    ### Create a mapping from unique hadm_id indices to range [0, num_hadm_nodes):
    unique_hadm_id = df_admissions_curr.HADM_ID.sort_values().unique()
    unique_hadm_id = pd.DataFrame(data={
        'HADM_ID': unique_hadm_id,
        'mappedID': pd.RangeIndex(len(unique_hadm_id)),
    })
    ### Create a mapping from unique ITEMID indices to range [0, num_labitem_nodes):
    unique_item_id = df_labitems_curr.ITEMID.sort_values().unique()
    unique_item_id = pd.DataFrame(data={
        'ITEMID': unique_item_id,
        'mappedID': pd.RangeIndex(len(unique_item_id)),
    })
    ### Create a mapping from unique NDC indices to range [0, num_hadm_nodes):
    unique_ndc_id = df_drug_ndc_feat_curr.NDC.sort_values().unique()
    unique_ndc_id = pd.DataFrame(data={
        'NDC': unique_ndc_id,
        'mappedID': pd.RangeIndex(len(unique_ndc_id)),
    })
    
    ### Perform merge to obtain the edges from HADM_ID and ITEMID:
    #### FOR `df_labevents_curr`
    ratings_hadm_id = pd.merge(df_labevents_curr['HADM_ID'], unique_hadm_id, left_on='HADM_ID', right_on='HADM_ID', how='left')
    ratings_item_id = pd.merge(df_labevents_curr['ITEMID'],  unique_item_id, left_on='ITEMID',  right_on='ITEMID',  how='left')
    #### FOR `df_prescriptions_curr`
    ratings_hadm_id_drug = pd.merge(df_prescriptions_curr['HADM_ID'], unique_hadm_id, left_on='HADM_ID', right_on='HADM_ID', how='left')
    ratings_ndc_id       = pd.merge(df_prescriptions_curr['NDC'],     unique_ndc_id,  left_on='NDC',     right_on='NDC',     how='left')

    ratings_hadm_id = torch.from_numpy(ratings_hadm_id['mappedID'].values)
    ratings_item_id = torch.from_numpy(ratings_item_id['mappedID'].values)
    ratings_hadm_id_drug = torch.from_numpy(ratings_hadm_id_drug['mappedID'].values)
    ratings_ndc_id       = torch.from_numpy(ratings_ndc_id['mappedID'].values)
    
    edge_index_hadm_to_item = torch.stack([ratings_hadm_id, ratings_item_id], dim=0)
    edge_index_hadm_to_ndc  = torch.stack([ratings_hadm_id_drug, ratings_ndc_id], dim=0)
    
    ## Edge features
    ### FOR `df_labevents_curr`
    list_selected_labevents_columns = ['CATAGORY', 'VALUENUM_Z-SCORED']
    edges_feature_labevents = torch.from_numpy(df_labevents_curr[list_selected_labevents_columns].values)
    ### FOR `df_prescriptions_curr`
    list_selected_prescriptions_columns = ["DRUG_TYPE", 
                                           "PROD_STRENGTH", 
                                           "DOSE_VAL_RX", 
                                           "DOSE_UNIT_RX", 
                                           "FORM_VAL_DISP", 
                                           "FORM_UNIT_DISP", 
                                           "ROUTE"]
    edges_feature_prescriptions = torch.from_numpy(df_prescriptions_curr[list_selected_prescriptions_columns].values)
    
    ## Timesteps:
    edges_timestep = torch.from_numpy(df_labevents_curr['TIMESTEP'].values)
    edges_timestep_prescriptions = torch.from_numpy(df_prescriptions_curr['TIMESTEP'].values)
    
    ############################################## assemble #####################################################
    data = HeteroData()
    
    ## Node
    ### node indices
    data["admission"].node_id = torch.arange(len(unique_hadm_id))  
    data["labitem"].node_id   = torch.arange(len(unique_item_id))
    data["drug"].node_id      = torch.arange(len(unique_ndc_id))
    ### node features:
    data["admission"].x = nodes_feature_admission_curr 
    data["labitem"].x   = nodes_feature_labitems
    data["drug"].x      = nodes_feature_drug_ndc

    ## edge:
    data["admission", "did", "labitem"].edge_index = edge_index_hadm_to_item
    data["admission", "did", "labitem"].x          = edges_feature_labevents
    data["admission", "did", "labitem"].timestep   = edges_timestep
    
    data["admission", "took", "drug"].edge_index = edge_index_hadm_to_ndc
    data["admission", "took", "drug"].x          = edges_feature_prescriptions
    data["admission", "took", "drug"].timestep   = edges_timestep_prescriptions
    

    ############################################# debug NaN #####################################################
    assert not data["admission"].node_id.isnan().any()
    assert not data["admission"].x.isnan().any()
    
    assert not data["labitem"].node_id.isnan().any()
    assert not data["labitem"].x.isnan().any()
    
    assert not data["drug"].node_id.isnan().any()
    assert not data["drug"].x.isnan().any()
    
    assert     data["admission", "did", "labitem"].edge_index.shape[-1] > 0
    assert not data["admission", "did", "labitem"].edge_index.isnan().any()
    assert not data["admission", "did", "labitem"].x.isnan().any()
    
    assert     data["admission", "took", "drug"].edge_index.shape[-1] > 0
    assert not data["admission", "took", "drug"].edge_index.isnan().any()
    assert not data["admission", "took", "drug"].x.isnan().any()
    
    return data

In [14]:
train_hgs = [construct_dynamic_hetero_graph(df_admissions_single_batch, 
                                            df_labitems, 
                                            df_labevents_single_batch, 
                                            df_drug_ndc_feat, 
                                            df_prescriptions_single_batch) \
             for df_admissions_single_batch, df_labevents_single_batch, df_prescriptions_single_batch in tqdm(
    zip(list_df_admissions_single_batch_train, 
        list_df_labevents_single_batch_train, 
        list_df_prescriptions_single_batch_train)
)]

val_hgs   = [construct_dynamic_hetero_graph(df_admissions_single_batch, 
                                            df_labitems, 
                                            df_labevents_single_batch, 
                                            df_drug_ndc_feat, 
                                            df_prescriptions_single_batch) \
             for df_admissions_single_batch, df_labevents_single_batch, df_prescriptions_single_batch in tqdm(
    zip(list_df_admissions_single_batch_val, 
        list_df_labevents_single_batch_val, 
        list_df_prescriptions_single_batch_val)
)]

# path_hgs = r"/data/data2/041/datasets/mimic-iii-hgs"
# path_hgs = r"/data/data2/041/datasets/mimic-iii-hgs-new"
path_hgs = constant.PATH_MIMIC_III_HGS_OUTPUT
path_hgs_curr = os.path.join(path_hgs, f'batch_size_{batch_size}')

if os.path.isdir(path_hgs_curr):
    shutil.rmtree(path_hgs_curr)
    
os.mkdir(path_hgs_curr)
os.mkdir(os.path.join(path_hgs_curr, "train"))
for idx, train_hg in enumerate(train_hgs):
    torch.save(train_hg, f'{os.path.join(os.path.join(path_hgs_curr, "train"), str(idx))}.pt')

os.mkdir(os.path.join(path_hgs_curr, "val"))
for idx, train_hg in enumerate(val_hgs):
    torch.save(train_hg, f'{os.path.join(os.path.join(path_hgs_curr, "val"), str(idx))}.pt')

0it [00:00, ?it/s]

0it [00:00, ?it/s]

# Sub-Graph by `timestep`

In [66]:
def get_subgraph_by_timestep(hg, timestep):
    mask4item = torch.BoolTensor(hg["admission", "did", "labitem"].timestep == timestep)
    eidx4item = hg["admission", "did", "labitem"].edge_index[:, mask4item]
    ex4item   = hg["admission", "did", "labitem"].x[mask4item, :]
    
    mask4drug = torch.BoolTensor(hg["admission", "took", "drug"].timestep == timestep)
    eidx4drug = hg["admission", "took", "drug"].edge_index[:, mask4drug]
    ex4drug   = hg["admission", "took", "drug"].x[mask4drug, :]
    
    sub_hg = HeteroData()
    
    # Nodes
    sub_hg["admission"].node_id = hg["admission"].node_id.clone()
    sub_hg["admission"].x       = hg["admission"].x.clone()
    
    sub_hg["labitem"].node_id = hg["labitem"].node_id.clone()
    sub_hg["labitem"].x       = hg["labitem"].x.clone()
    
    sub_hg["drug"].node_id = hg["drug"].node_id.clone()
    sub_hg["drug"].x       = hg["drug"].x.clone()
    
    # Edges
    sub_hg["admission", "did", "labitem"].edge_index = eidx4item.clone()
    sub_hg["admission", "did", "labitem"].x          = ex4item.clone()
    
    sub_hg["admission", "took", "drug"].edge_index = eidx4drug.clone()
    sub_hg["admission", "took", "drug"].x          = ex4drug.clone()
    
    # TODO: labels? edge_index of next timestep
    assert timestep < torch.max(hg["admission", "did", "labitem"].timestep), "last timestep has not labels!"
    assert timestep < torch.max(hg["admission", "took", "drug"].timestep),   "last timestep has not labels!"
    
    mask_next_t4item = torch.BoolTensor(hg["admission", "did", "labitem"].timestep == (timestep+1))
    sub_hg.labels4item_pos_index = hg["admission", "did", "labitem"].edge_index[:, mask_next_t4item].clone()
    sub_hg.labels4item_neg_index = negative_sampling(sub_hg.labels4item_pos_index, 
                                                     num_neg_samples=sub_hg.labels4item_pos_index.shape[1]*2, 
                                                     num_nodes=(sub_hg["admission"].node_id.shape[0],
                                                                sub_hg["labitem"].node_id.shape[0]))
    sub_hg.lables4item_index = torch.cat((sub_hg.labels4item_pos_index, sub_hg.labels4item_neg_index), dim=1)
    sub_hg.lables4item = torch.cat((torch.ones(sub_hg.labels4item_pos_index.shape[1]), 
                                   torch.zeros(sub_hg.labels4item_neg_index.shape[1])), dim=0)
    index4item_shuffle = torch.randperm(sub_hg.lables4item_index.shape[1])
    sub_hg.lables4item_index = sub_hg.lables4item_index[:, index4item_shuffle]
    sub_hg.lables4item = sub_hg.lables4item[index4item_shuffle]
    
    
    mask_next_t4drug = torch.BoolTensor(hg["admission", "took", "drug"].timestep == (timestep+1))
    sub_hg.labels4drug_pos_index = hg["admission", "took", "drug"].edge_index[:, mask_next_t4drug].clone()
    sub_hg.labels4drug_neg_index = negative_sampling(sub_hg.labels4drug_pos_index,
                                                     num_neg_samples=sub_hg.labels4drug_pos_index.shape[1]*2,
                                                     num_nodes=(sub_hg["admission"].node_id.shape[0], 
                                                                sub_hg["drug"].node_id.shape[0]))
    sub_hg.labels4drug_index = torch.cat((sub_hg.labels4drug_pos_index, sub_hg.labels4drug_neg_index), dim=1)
    sub_hg.labels4drug = torch.cat((torch.ones(sub_hg.labels4drug_pos_index.shape[1]), 
                                   torch.zeros(sub_hg.labels4drug_neg_index.shape[1])), dim=0)
    index4drug_shuffle = torch.randperm(sub_hg.labels4drug_index.shape[1])
    sub_hg.labels4drug_index = sub_hg.labels4drug_index[:, index4drug_shuffle]
    sub_hg.labels4drug = sub_hg.labels4drug[index4drug_shuffle]
    
    # We also need to make sure to add the reverse edges from labitems to admission
    # in order to let a GNN be able to pass messages in both directions.
    # We can leverage the `T.ToUndirected()` transform for this from PyG:
    sub_hg = T.ToUndirected()(sub_hg)
    
    return sub_hg

# subgraph_temp = get_subgraph_by_timestep(train_hgs[0], timestep=20)
# subgraph_temp

In [67]:
# train_set = MyOwnDataset(root_path=rf"/data/data2/041/datasets/mimic-iii-hgs/batch_size_{batch_size}", usage="train")
# train_set = MyOwnDataset(root_path=os.path.join(constant.PATH_MIMIC_III_HGS_OUTPUT, rf"batch_size_{batch_size}"), usage="train")
batch_size = 128
# val_set = MyOwnDataset(root_path=rf"/data/data2/041/datasets/mimic-iii-hgs-new/batch_size_{batch_size}", usage="val")
val_set = MyOwnDataset(root_path=os.path.join(constant.PATH_MIMIC_III_HGS_OUTPUT, rf"batch_size_{batch_size}"), usage="val")

max_timestep = 20
curr_subgraphs = [get_subgraph_by_timestep(val_set[0], timestep=t) for t in range(20)]
curr_subgraphs[19]

HeteroData(
  labels4item_pos_index=[2, 379],
  labels4item_neg_index=[2, 758],
  lables4item_index=[2, 1137],
  lables4item=[1137],
  labels4drug_pos_index=[2, 389],
  labels4drug_neg_index=[2, 778],
  labels4drug_index=[2, 1167],
  labels4drug=[1167],
  [1madmission[0m={
    node_id=[128],
    x=[128, 8]
  },
  [1mlabitem[0m={
    node_id=[753],
    x=[753, 2]
  },
  [1mdrug[0m={
    node_id=[4294],
    x=[4294, 8]
  },
  [1m(admission, did, labitem)[0m={
    edge_index=[2, 410],
    x=[410, 2]
  },
  [1m(admission, took, drug)[0m={
    edge_index=[2, 410],
    x=[410, 7]
  },
  [1m(labitem, rev_did, admission)[0m={
    edge_index=[2, 410],
    x=[410, 2]
  },
  [1m(drug, rev_took, admission)[0m={
    edge_index=[2, 410],
    x=[410, 7]
  }
)

In [68]:
existing_edge_indices = torch.index_select(
    curr_subgraphs[19].labels4drug_index, dim=1, index=torch.nonzero(curr_subgraphs[19].labels4drug).flatten()
)
existing_edge_indices[0]

tensor([105,  72, 120,  13,  48,  39,  42,   7,  93, 120, 120,  72,  72,  88,
        120,  49, 124,  29,  39, 124,  60, 104,  29,  51,   7,  48,  13, 112,
        105, 112,  51, 105,  49, 124,  13, 105,  42,  88,  29,  51,   7, 120,
         60, 105, 124, 124,  48,  60,  13,  88,  13,  42,  48,  49, 120, 112,
         49,  51,  48, 120,  39,  48, 105, 120, 104, 120,  72,  48, 112,  72,
         93, 105,  29, 120, 105,  72,  51,  48, 112,   7, 105,  72, 120, 124,
         88,  51, 124,  51, 124,  51, 105, 112, 104, 112, 104,  60,  13,  49,
        120,  72,   7,  51, 124,  49,  48,  60,  29,  93,  60, 120, 105, 105,
        120,  93,  51,  39, 120,  42,  39,  93, 104,  51,  51,  13, 112, 105,
         42,  88,  60, 104,  39,  48, 105, 105, 124, 124,  72, 112, 124,  48,
         29,   7, 105, 105, 120, 105,  48,  72, 104, 105,  48, 124,  48,  42,
         48,  60,   7, 124,  48, 120,  48,   7,  93,  72,  72, 120, 124, 105,
        120,  60,  72,  51,  39,  39,   7,   7,  49, 105,  49,  

In [69]:
existing_edge_indices[0].unique()

tensor([  7,  13,  29,  39,  42,  48,  49,  51,  60,  72,  88,  93, 104, 105,
        112, 120, 124])

In [70]:
curr_admi = 13
indices_curr_hadm = torch.nonzero(existing_edge_indices[0] == curr_admi).flatten()
indices_curr_hadm

tensor([  3,  26,  34,  48,  50,  96, 123, 206, 214, 223, 254, 276, 305, 308,
        364])

In [71]:
torch.index_select(existing_edge_indices, dim=1, index=indices_curr_hadm)[1].sort().values

tensor([   2,  338,  771, 1157, 1636, 1768, 1941, 2282, 2446, 2466, 2758, 2767,
        2779, 4120, 4204])