In [1]:
import os
import pandas as pd
import numpy as np
import cv2
import matplotlib.pyplot as plt
import random
seed = 0
random.seed(seed)
np.random.seed(seed)


# generate

In [None]:
oct_root = './data/OCT_2017'

splits = ["train", "val", "test"]
classes = ["CNV", "DME", "DRUSEN", "NORMAL"]
dict_OCT_2017 = {}
for split in splits:
    for cur_class in classes:
        dict_OCT_2017[f"{split}_{cur_class}_images"] = os.listdir(os.path.join(oct_root, split, cur_class))
        print(f"{split}_{cur_class}_images", len(dict_OCT_2017[f"{split}_{cur_class}_images"]), len(set([i.split("-")[1] for i in dict_OCT_2017[f"{split}_{cur_class}_images"]])))


In [None]:
train_images_valid = []
train_images_valid_splits = []
train_images_valid_labels = []
train_images_valid_abnormal = []

train_images_labeled = []
train_images_labeled_splits = []
train_images_labeled_labels = []
train_images_labeled_abnormal = []

train_images_unlabeled = []
train_images_unlabeled_splits = []
train_images_unlabeled_labels = []
train_images_unlabeled_abnormal = []
for each_class in classes:
    print(each_class)
    train_patient_ids = list(set([i.split("-")[1] for i in dict_OCT_2017[f"train_{each_class}_images"]]))
    # print(len(train_patient_ids))
    patient_ids_valid = np.random.choice(train_patient_ids, 50, replace=False)
    train_patient_ids = [i for i in train_patient_ids if i not in patient_ids_valid]
    # print(len(train_patient_ids))
    patient_ids_labeled = np.random.choice(train_patient_ids, len(train_patient_ids)//3, replace=False)
    train_patient_ids = [i for i in train_patient_ids if i not in patient_ids_labeled]
    train_patient_ids_unlabeled = train_patient_ids
    # print(len(train_patient_ids))

    cur_train_images_valid = [i for i in dict_OCT_2017[f"train_{each_class}_images"] if i.split("-")[1] in patient_ids_valid]
    # group by patient id
    cur_train_images_valid_grouped = {}
    for i in cur_train_images_valid:
        if i.split("-")[1] not in cur_train_images_valid_grouped:
            cur_train_images_valid_grouped[i.split("-")[1]] = [i]
        else:
            cur_train_images_valid_grouped[i.split("-")[1]].append(i)
    cur_train_images_valid = [np.random.choice(cur_train_images_valid_grouped[i]) for i in cur_train_images_valid_grouped]

    train_images_valid.extend(cur_train_images_valid)
    train_images_valid_splits.extend(['train']*len(cur_train_images_valid))
    train_images_valid_labels.extend([each_class]*len(cur_train_images_valid))
    train_images_valid_abnormal.extend([1]*len(cur_train_images_valid) if each_class != "NORMAL" else [0]*len(cur_train_images_valid))
    print(len(cur_train_images_valid))

    cur_train_images_labeled = [i for i in dict_OCT_2017[f"train_{each_class}_images"] if i.split("-")[1] in patient_ids_labeled]
    train_images_labeled.extend(cur_train_images_labeled)
    train_images_labeled_splits.extend(['train']*len(cur_train_images_labeled))
    train_images_labeled_labels.extend([each_class]*len(cur_train_images_labeled))
    train_images_labeled_abnormal.extend([1]*len(cur_train_images_labeled) if each_class != "NORMAL" else [0]*len(cur_train_images_labeled))

    print(len(cur_train_images_labeled))
    cur_train_images_unlabeled = [i for i in dict_OCT_2017[f"train_{each_class}_images"] if i.split("-")[1] in train_patient_ids_unlabeled]
    train_images_unlabeled.extend(cur_train_images_unlabeled)
    train_images_unlabeled_splits.extend(['train']*len(cur_train_images_unlabeled))
    train_images_unlabeled_labels.extend([each_class]*len(cur_train_images_unlabeled))
    train_images_unlabeled_abnormal.extend([1]*len(cur_train_images_unlabeled) if each_class != "NORMAL" else [0]*len(cur_train_images_unlabeled))
    print(len(cur_train_images_unlabeled))

In [32]:
oct_train_dict_valid = pd.DataFrame({"fnames": train_images_valid, "split": train_images_valid_splits, "labels": train_images_valid_labels, "abnormal": train_images_valid_abnormal})
oct_train_dict_valid.to_csv(os.path.join(oct_root, "BenchReAD", "valid.csv"), index=False)
oct_train_dict_labeled = pd.DataFrame({"fnames": train_images_labeled, "split": train_images_labeled_splits, "labels": train_images_labeled_labels, "abnormal": train_images_labeled_abnormal})
oct_train_dict_labeled.to_csv(os.path.join(oct_root, "BenchReAD", "train_labeled.csv"), index=False)
oct_train_dict_unlabeled = pd.DataFrame({"fnames": train_images_unlabeled, "split": train_images_unlabeled_splits, "labels": train_images_unlabeled_labels, "abnormal": train_images_unlabeled_abnormal})
oct_train_dict_unlabeled.to_csv(os.path.join(oct_root, "BenchReAD", "train_unlabeled.csv"), index=False)

In [None]:
test_images = []
test_images_splits = []
test_images_labels = []
test_images_abnormal = []

for each_class in classes:
    print(each_class)
    cur_valid_images = [i for i in dict_OCT_2017[f"val_{each_class}_images"]]
    cur_test_images = [i for i in dict_OCT_2017[f"test_{each_class}_images"]]
    test_images.extend(cur_valid_images)
    test_images_splits.extend(['val']*len(cur_valid_images))
    test_images.extend(cur_test_images)
    test_images_splits.extend(['test']*len(cur_test_images))
    test_images_labels.extend([each_class]*len(cur_valid_images))
    test_images_labels.extend([each_class]*len(cur_test_images))
    test_images_abnormal.extend([1]*len(cur_valid_images) if each_class != "NORMAL" else [0]*len(cur_valid_images))
    test_images_abnormal.extend([1]*len(cur_test_images) if each_class != "NORMAL" else [0]*len(cur_test_images))
    print(len(test_images))
oct_test_dict = pd.DataFrame({"fnames": test_images, "split": test_images_splits, "labels": test_images_labels, "abnormal": test_images_abnormal})
oct_test_dict.to_csv(os.path.join(oct_root, "BenchReAD", "test.csv"), index=False)
