In [None]:
"""
Author: Eunmi Joo

Balance patient-wise dataset

Load RSNA, AISD patient-wise dataset
    - data/patientwise/patient_data_rsna.pickle: rsna dataset including hemorrhage pos(8882 patients), neg(12862 patients)
    - data/patientwise/patient_data_isch.pickle: aisd dataset, ischemic pos(394 patients)

Save separated Hemorrhage pos, neg dataset size in 400, 800, 1200
    - data/patientwise/patient_data_hemo_400.pickle, patient_data_neg_400.pickle
    - data/patientwise/patient_data_hemo_800.pickle, patient_data_neg_800.pickle
    - data/patientwise/patient_data_hemo_1200.pickle, patient_data_neg_1200.pickle
"""

In [2]:
import pickle
import torch
from tqdm.notebook import tqdm
import warnings
warnings.filterwarnings("ignore")
device = "cuda" if torch.cuda.is_available() else "cpu"

In [5]:
# load patient-wise data - hemorrhage, ischemic
root = "/ceph/inestp02/stroke_classifier/data/patientwise/"
hemo_path = "patient_data_rsna.pickle"
isch_path = "patient_data_isch.pickle"

with open(root+hemo_path, "rb") as f:
    patient_data_hemo = pickle.load(f)

with open(root+isch_path, "rb") as f:
    patient_data_isch = pickle.load(f)

In [6]:
patient_data_hemo_pos = dict()
patient_data_hemo_neg = dict()
for patient in tqdm(patient_data_hemo.keys()):
    if patient_data_hemo[patient]['patient_label'] == 1:
        patient_data_hemo_pos[patient] = patient_data_hemo[patient]
    else:
        patient_data_hemo_neg[patient] = patient_data_hemo[patient]

  0%|          | 0/21744 [00:00<?, ?it/s]

In [9]:
print(len(patient_data_hemo_pos.keys()))
print(len(patient_data_hemo_neg.keys()))

patient_data_hemo_pos.keys()

8882
12862


dict_keys(['ID_8042a129_ID_6f7a65817e', 'ID_f309d4c5_ID_cc6347e8d8', 'ID_76f9d249_ID_2ddaee6098', 'ID_17b42890_ID_91cae68878', 'ID_b39240b2_ID_942c005f0c', 'ID_098cadb4_ID_78028bc150', 'ID_96bb7e1d_ID_4b827bde69', 'ID_a2a225e1_ID_f423458ad8', 'ID_2cb66c95_ID_03e7a73ba7', 'ID_41de45c7_ID_3aae78a409', 'ID_7bc83188_ID_7878c102a3', 'ID_1cd2f5f1_ID_3257817c9f', 'ID_740cb910_ID_607672ce6c', 'ID_22bc2cd9_ID_328309234a', 'ID_d8fcc7c7_ID_998e2540f5', 'ID_a9ee6816_ID_4f4963a2d0', 'ID_3d9d3a41_ID_e9d4bdd5cb', 'ID_6377f72a_ID_e16fa0dc81', 'ID_524827d2_ID_4eb57ee89f', 'ID_5f5207a2_ID_043ce30aa0', 'ID_83902425_ID_a5ba853417', 'ID_8cd94fca_ID_3bf1d20dba', 'ID_ea57fe03_ID_e5a90bc477', 'ID_853843b9_ID_9a979e157a', 'ID_98b79de2_ID_df6067a299', 'ID_7e2e48bf_ID_8ac11863ce', 'ID_4ce3f2c5_ID_63f98d02f9', 'ID_db1b9b68_ID_f014ee3046', 'ID_3ed0d2cd_ID_07d1256292', 'ID_01f8f5b1_ID_aca5d39899', 'ID_9dc842f4_ID_aa7659f277', 'ID_216e7cc9_ID_1a80f7c47d', 'ID_2db7ee14_ID_f27e38fdfd', 'ID_a7056d12_ID_421ea624cd', 'ID

In [10]:
print(len(patient_data_hemo_pos["ID_8042a129_ID_6f7a65817e"]['pos2']))
print(len(patient_data_hemo_pos["ID_f309d4c5_ID_cc6347e8d8"]['pos2']))

32
48


In [40]:
import random
patient_list_hemo = list(patient_data_hemo_pos.keys())
patient_list_neg = list(patient_data_hemo_neg.keys())
size = 1200 # 400, 800, 1200
patient_list_hemo = random.sample(patient_list_hemo, size)
patient_list_neg = random.sample(patient_list_neg, size)

patient_data_hemo_pos_save = dict()
patient_data_hemo_neg_save = dict()
for patient in tqdm(patient_list_hemo):
    patient_data_hemo_pos_save[patient] = patient_data_hemo[patient]

for patient in tqdm(patient_list_neg):
    patient_data_hemo_neg_save[patient] = patient_data_hemo[patient]
hemo_path = f"patient_data_hemo_{size}.pickle"
neg_path = f"patient_data_neg_{size}.pickle"

with open(root+hemo_path,"wb") as f:
    pickle.dump(patient_data_hemo_pos_save, f, pickle.HIGHEST_PROTOCOL)
with open(root+neg_path,"wb") as f:
    pickle.dump(patient_data_hemo_neg_save, f, pickle.HIGHEST_PROTOCOL)

  0%|          | 0/1200 [00:00<?, ?it/s]

  0%|          | 0/1200 [00:00<?, ?it/s]

In [41]:
hemo_path = f"patient_data_hemo_{size}.pickle"
with open(root+hemo_path, "rb") as f:
    patient_data_hemo_1200 = pickle.load(f)