In [1]:
import argparse

parser = argparse.ArgumentParser()
parser.add_argument('--dataset_path', type=str, default='dataset_partition')
parser.add_argument('--covid_chestxray_path', type=str, default='data/covid-chestxray-dataset')
parser.add_argument('--chest_xray_pneumonia', type=str, default='data/chest-xray-pneumonia')

args = parser.parse_args("")

In [48]:
import pdb, os, random
import pandas as pd
from sklearn.model_selection import train_test_split
import pickle

# makedir
def make_dir(dirname, rm=False):
    if not os.path.exists(dirname):
        os.makedirs(dirname)
    elif rm:
        print('rm and mkdir ', dirname)
        shutil.rmtree(dirname)
        os.makedirs(dirname)
        
make_dir(args.dataset_path)

## Covid-19 and Normal Chest X-Ray combined

In [40]:
def read_covid_dataset():
    metadata_path = os.path.join(args.covid_chestxray_path, 'metadata.csv')
    images_dir = os.path.join(args.covid_chestxray_path, 'images')
    csv = pd.read_csv(metadata_path)
    # Filter the PA view xray items
    csv = csv[csv['view'] == 'PA']
    # Get the covid-19 cases
    csv = csv[csv['finding'].str.contains('COVID-19')]
    csv['filename'] = csv['filename'].apply(lambda x : os.path.join(images_dir, x))
     # Remove duplicates
    csv = csv.drop_duplicates(subset=['filename'])
    return csv.to_dict(orient='records')

In [35]:
def read_chest_xray_pneumonia():
    def get_files(dirpath):
        filelist = os.listdir(dirpath)
        filelist = [os.path.join(dirpath, f) for f in filelist if os.path.isfile(os.path.join(dirpath, f))]
        return filelist
    
    images_dir_train = os.path.join(args.chest_xray_pneumonia, 'chest_xray/train/NORMAL')
    images_dir_test = os.path.join(args.chest_xray_pneumonia, 'chest_xray/test/NORMAL')
    train_files = get_files(images_dir_train)
    test_files = get_files(images_dir_test)
    return train_files, test_files

In [77]:
def create_dataset():
    positive_all = read_covid_dataset()
    negative_train, negative_test = read_chest_xray_pneumonia()
    positive_train, positive_test = train_test_split(positive_all, train_size=0.8, shuffle=True)
    #print(len(positive_test), len(positive_train))
    #print(len(negative_train), len(negative_test))
    negative_sample_train = random.sample(negative_train, len(positive_train))
    negative_sample_test = random.sample(negative_test, len(positive_test))
    
    train = []
    test = []
    for item in positive_train:
        ditem = {}
        ditem['filename'] = item['filename']
        ditem['label'] = 1
        train.append(ditem)
    
    for item in positive_test:
        ditem = {}
        ditem['filename'] = item['filename']
        ditem['label'] = 1
        test.append(ditem)
    
    for item in negative_sample_train:
        ditem = {}
        ditem['filename'] = item
        ditem['label'] = 0
        train.append(ditem)
    
    for item in negative_sample_test:
        ditem = {}
        ditem['filename'] = item
        ditem['label'] = 0
        test.append(ditem)
        
    return train, test

def test_pkl(trainfname='train.pkl', testfname='test.pkl'):
    train_pkl = os.path.join(args.dataset_path, trainfname)
    test_pkl = os.path.join(args.dataset_path, testfname)
    with open(train_pkl, 'rb') as f:
        l = pickle.load(f)
        random.shuffle(l)
        display(l[:10])
    with open(test_pkl, 'rb') as f:
        l = pickle.load(f)
        random.shuffle(l)
        display(l[:10])

def generate_dataset_files():
    train, test = create_dataset()
    train_pkl = os.path.join(args.dataset_path, 'train.pkl')
    test_pkl = os.path.join(args.dataset_path, 'test.pkl')
    pickle.dump(train, open(train_pkl, 'wb'))
    pickle.dump(test, open(test_pkl, 'wb'))
    

In [56]:
# Uncomment to regenerate the dataset
#generate_dataset_files()
test_pkl()

[{'filename': 'data/covid-chestxray-dataset/images/01E392EE-69F9-4E33-BFCE-E5C968654078.jpeg',
  'label': 1},
 {'filename': 'data/chest-xray-pneumonia/chest_xray/train/NORMAL/IM-0676-0001.jpeg',
  'label': 0},
 {'filename': 'data/covid-chestxray-dataset/images/F2DE909F-E19C-4900-92F5-8F435B031AC6.jpeg',
  'label': 1},
 {'filename': 'data/covid-chestxray-dataset/images/ryct.2020200034.fig5-day4.jpeg',
  'label': 1},
 {'filename': 'data/chest-xray-pneumonia/chest_xray/train/NORMAL/NORMAL2-IM-0855-0001.jpeg',
  'label': 0},
 {'filename': 'data/chest-xray-pneumonia/chest_xray/train/NORMAL/IM-0497-0001-0002.jpeg',
  'label': 0},
 {'filename': 'data/covid-chestxray-dataset/images/93FE0BB1-022D-4F24-9727-987A07975FFB.jpeg',
  'label': 1},
 {'filename': 'data/covid-chestxray-dataset/images/ciaa199.pdf-001-b.png',
  'label': 1},
 {'filename': 'data/covid-chestxray-dataset/images/kjr-21-e24-g002-l-a.jpg',
  'label': 1},
 {'filename': 'data/covid-chestxray-dataset/images/radiol.2020200490.fig3.jp

[{'filename': 'data/covid-chestxray-dataset/images/all14238-fig-0001-m-b.jpg',
  'label': 1},
 {'filename': 'data/chest-xray-pneumonia/chest_xray/test/NORMAL/NORMAL2-IM-0300-0001.jpeg',
  'label': 0},
 {'filename': 'data/covid-chestxray-dataset/images/31BA3780-2323-493F-8AED-62081B9C383B.jpeg',
  'label': 1},
 {'filename': 'data/chest-xray-pneumonia/chest_xray/test/NORMAL/NORMAL2-IM-0290-0001.jpeg',
  'label': 0},
 {'filename': 'data/covid-chestxray-dataset/images/1-s2.0-S1684118220300682-main.pdf-003-b2.png',
  'label': 1},
 {'filename': 'data/chest-xray-pneumonia/chest_xray/test/NORMAL/NORMAL2-IM-0238-0001.jpeg',
  'label': 0},
 {'filename': 'data/chest-xray-pneumonia/chest_xray/test/NORMAL/IM-0037-0001.jpeg',
  'label': 0},
 {'filename': 'data/chest-xray-pneumonia/chest_xray/test/NORMAL/NORMAL2-IM-0288-0001.jpeg',
  'label': 0},
 {'filename': 'data/covid-chestxray-dataset/images/auntminnie-a-2020_01_28_23_51_6665_2020_01_28_Vietnam_coronavirus.jpeg',
  'label': 1},
 {'filename': 'da

## Multilabel dataset generation

In [69]:
def read_chest_xray_pneumonia_full():
    def get_files(dirpath):
        filelist = os.listdir(dirpath)
        filelist = [os.path.join(dirpath, f) for f in filelist if os.path.isfile(os.path.join(dirpath, f))]
        return filelist
    
    images_dir_train_normal = os.path.join(args.chest_xray_pneumonia, 'chest_xray/train/NORMAL')
    images_dir_test_normal = os.path.join(args.chest_xray_pneumonia, 'chest_xray/test/NORMAL')
    images_dir_train_pneumonia = os.path.join(args.chest_xray_pneumonia, 'chest_xray/train/PNEUMONIA')
    images_dir_test_pneumonia = os.path.join(args.chest_xray_pneumonia, 'chest_xray/test/PNEUMONIA')
    train_files_normal = get_files(images_dir_train_normal)
    test_files_normal = get_files(images_dir_test_normal)
    train_files_pneumonia = get_files(images_dir_train_pneumonia)
    test_files_pneumonia = get_files(images_dir_test_pneumonia)
    
    return train_files_normal, test_files_normal, train_files_pneumonia, test_files_pneumonia

In [74]:
def create_multilabel_dataset():
    covid_all = read_covid_dataset()
    train_files_normal, test_files_normal, train_files_pneumonia, test_files_pneumonia = read_chest_xray_pneumonia_full()
    
    covid_train, covid_test = train_test_split(covid_all, train_size=0.8, shuffle=True)
    normal_sample_train = random.sample(train_files_normal, 2*len(covid_train))  # Check for total samples
    normal_sample_test = random.sample(test_files_normal, 2*len(covid_test))
    
    # Taking twice the number of covid samples for viral and bacterial pneumonia
    # This might help in identifying covid specific features
    random.shuffle(train_files_normal)
    random.shuffle(test_files_normal)
    random.shuffle(train_files_pneumonia)
    random.shuffle(test_files_pneumonia)
    
    train = []
    test = []
    keys = ['covid', 'normal', 'viral', 'bacterial']
    dummy_label = dict.fromkeys(keys, 0)
    
    def add_data_row(input_list, output_list, positive_keys, index=False):
        for item in input_list:
            ditem = {}
            ditem['filename'] = item['filename'] if index else item
            ditem['label'] = dummy_label.copy()
            for key in positive_keys:
                ditem['label'][key] = 1
            output_list.append(ditem)
    
    add_data_row(covid_train, train, ['covid', 'viral'], index=True)
    add_data_row(covid_test, test, ['covid', 'viral'], index=True)
    add_data_row(normal_sample_train, train, ['normal'])
    add_data_row(normal_sample_test, test, ['normal'])
    add_data_row([x for x in train_files_pneumonia if 'virus' in x][:2*len(covid_train)],
                train, ['viral'])
    add_data_row([x for x in train_files_pneumonia if 'bacteria' in x][:2*len(covid_train)],
                train, ['bacterial'])
    add_data_row([x for x in test_files_pneumonia if 'virus' in x][:2*len(covid_test)],
                test, ['viral'])
    add_data_row([x for x in test_files_pneumonia if 'bacteria' in x][:2*len(covid_test)],
                test, ['bacterial'])
        
    train_pkl = os.path.join(args.dataset_path, 'train_multilabel.pkl')
    test_pkl = os.path.join(args.dataset_path, 'test_multilabel.pkl')
    pickle.dump(train, open(train_pkl, 'wb'))
    pickle.dump(test, open(test_pkl, 'wb'))

In [79]:
create_multilabel_dataset()
test_pkl(trainfname='train_multilabel.pkl', testfname='test_multilabel.pkl')

[{'filename': 'data/chest-xray-pneumonia/chest_xray/train/NORMAL/IM-0560-0001.jpeg',
  'label': {'covid': 0, 'normal': 1, 'viral': 0, 'bacterial': 0}},
 {'filename': 'data/chest-xray-pneumonia/chest_xray/train/NORMAL/NORMAL2-IM-0530-0001.jpeg',
  'label': {'covid': 0, 'normal': 1, 'viral': 0, 'bacterial': 0}},
 {'filename': 'data/covid-chestxray-dataset/images/6CB4EFC6-68FA-4CD5-940C-BEFA8DAFE9A7.jpeg',
  'label': {'covid': 1, 'normal': 0, 'viral': 1, 'bacterial': 0}},
 {'filename': 'data/chest-xray-pneumonia/chest_xray/train/PNEUMONIA/person569_bacteria_2362.jpeg',
  'label': {'covid': 0, 'normal': 0, 'viral': 0, 'bacterial': 1}},
 {'filename': 'data/covid-chestxray-dataset/images/8FDE8DBA-CFBD-4B4C-B1A4-6F36A93B7E87.jpeg',
  'label': {'covid': 1, 'normal': 0, 'viral': 1, 'bacterial': 0}},
 {'filename': 'data/covid-chestxray-dataset/images/SARS-10.1148rg.242035193-g04mr34g09a-Fig9a-day17.jpeg',
  'label': {'covid': 1, 'normal': 0, 'viral': 1, 'bacterial': 0}},
 {'filename': 'data/ches

[{'filename': 'data/chest-xray-pneumonia/chest_xray/test/PNEUMONIA/person91_bacteria_445.jpeg',
  'label': {'covid': 0, 'normal': 0, 'viral': 0, 'bacterial': 1}},
 {'filename': 'data/chest-xray-pneumonia/chest_xray/test/PNEUMONIA/person1668_virus_2882.jpeg',
  'label': {'covid': 0, 'normal': 0, 'viral': 1, 'bacterial': 0}},
 {'filename': 'data/chest-xray-pneumonia/chest_xray/test/PNEUMONIA/person1680_virus_2897.jpeg',
  'label': {'covid': 0, 'normal': 0, 'viral': 1, 'bacterial': 0}},
 {'filename': 'data/chest-xray-pneumonia/chest_xray/test/NORMAL/IM-0105-0001.jpeg',
  'label': {'covid': 0, 'normal': 1, 'viral': 0, 'bacterial': 0}},
 {'filename': 'data/chest-xray-pneumonia/chest_xray/test/NORMAL/NORMAL2-IM-0348-0001.jpeg',
  'label': {'covid': 0, 'normal': 1, 'viral': 0, 'bacterial': 0}},
 {'filename': 'data/covid-chestxray-dataset/images/nejmc2001573_f1b.jpeg',
  'label': {'covid': 1, 'normal': 0, 'viral': 1, 'bacterial': 0}},
 {'filename': 'data/chest-xray-pneumonia/chest_xray/test/PN