Chest X-Ray dataset

**About the data**:
- The dataset is made up of images and segmentated mask from two diffrent sources MCU and CHN.
- Some images don't have their corresponding masks (Mostly in shenzhen).






Link to dataset  [here](https://www.kaggle.com/datasets/kmader/pulmonary-chest-xray-abnormalities).

Published in Quant Imaging Med Surg. 2014 Dec; 4(6): 475–477. https://www.ncbi.nlm.nih.gov/pmc/articles/PMC4256233/

In [None]:
import numpy as np 
import pandas as pd
from tqdm import tqdm
import os
%matplotlib inline
import matplotlib.pyplot as plt
imgDir='../input/Lung Segmentation/CXR_png/'
mskDir='../input/Lung Segmentation/masks/'
clinicalReadings='../input/Lung Segmentation/ClinicalReadings/'


In [None]:
#inspect dataset, find number of masks, images and clinical readings
numImgs=len(os.listdir(imgDir))
numMsks=len(os.listdir(mskDir))
numReadings=len(os.listdir(clinicalReadings))
print(f'Images:{numImgs} Msks:{numMsks} Readings:{numReadings}')
imgNames=os.listdir(imgDir)
mskNames=os.listdir(mskDir)
readings=os.listdir(clinicalReadings)

In [None]:
#inspect names of files for inconsistencies look at num2see names
randPermImgs=np.random.RandomState(2500).permutation(len(imgNames)).reshape(-1,1)
randPermMsks=np.random.RandomState(2500).permutation(len(mskNames)).reshape(-1,1)

num2see=20 
selImgNames=[imgNames[randPermImgs[i][0]] for i in range(num2see)]
selMskNames=[mskNames[randPermMsks[i][0]] for i in range(num2see)]

In [None]:
selImgNames

In [None]:
selMskNames 
# something odd, masknames in MCUCXR doesnt end in _mask 
#704 masks, but 800 images and readings, which images are missing masks? 
#need to find corresponding images and filter

In [None]:
def getImgsMsksCond(imgNames,mskNames,cond):
    """ 
    returns images and masks for given dataset name (cond)
    inputs: imgNames (string : filenames), mskNames(string:filenames), cond(string:CHN/MCU) 
    """
    splitImgs=[imgNames[i] for i in range(len(imgNames)) if imgNames[i][:3]==cond]
    splitMsks=[mskNames[i] for i in range(len(mskNames)) if mskNames[i][:3]==cond]
    return splitImgs,splitMsks

splitCHNImgs,splitCHNMsks=getImgsMsksCond(imgNames,mskNames,'CHN')
splitMCUImgs,splitMCUMsks=getImgsMsksCond(imgNames,mskNames,'MCU')
print(f'CHNImgs:{len(splitCHNImgs)}, CHNMsks:{len(splitCHNMsks)}, MCUImgs:{len(splitMCUImgs)}, MCUMsks:{len(splitMCUMsks)}')
#some CHNImgs dont have masks

In [None]:
#extract images with mask
CHNImgIds=[splitCHNImgs[i].split('.')[0].split('_')[1] for i in range(len(splitCHNImgs))]
CHNMskIds=[splitCHNMsks[i].split('.')[0].split('_')[1] for i in range(len(splitCHNMsks))]
CHNImgsWithMask=[splitCHNImgs[i] for i in range(len(CHNImgIds)) if CHNImgIds[i] in CHNMskIds]
print(f'{len(CHNImgsWithMask)} Image with mask {len(CHNMskIds)} mask Images')

In [None]:
def splitImgsByLabel(imgs):
    labels=[imgs[i].split('.')[0].split('_')[2] for i in range(len(imgs))]
    print('Label 0')
    print(np.sum(np.array(labels).reshape(-1,1)=='0'))
    print('Label 1')
    print(np.sum(np.array(labels).reshape(-1,1)=='1'))
    ImgsLabels0=[imgs[i] for i in range(len(imgs)) if labels[i]=='0']
    ImgsLabels1=[imgs[i] for i in range(len(imgs)) if labels[i]=='1']
    return(ImgsLabels0,ImgsLabels1)
print('CHNImgs')
CHNImgsLabels0,CHNImgsLabels1=splitImgsByLabel(CHNImgsWithMask)
print('CHNMasks')
CHNMskLabels0,CHNMskLabels1=splitImgsByLabel(splitCHNMsks)
print('MCUImgs')
MCUImgsLabels0,MCUImgsLabels1=splitImgsByLabel(splitMCUImgs)
print('MCUMsks')
MCUMsksLabels0,MCUMsksLabels1=splitImgsByLabel(splitMCUMsks)


In [None]:
def getSortEnums(myArr):
    """ 
    get sorted enumerate by ID of myArr
    inputs: myArr:array of str (filenames)
    returns: sorted list of enumerated IDs
    """
    Ids=list(map(lambda x:x.split('.')[0].split('_')[1],myArr))
    enumIds=enumerate(Ids)
    enumIdSort=list(sorted(enumIds,key=lambda x:x[1]))
    return(enumIdSort)
def sortImgByMskIds(Imgs,Msks):
    """ 
    sort images and masks by IDs
    inputs: Imgs(Image file list(str)), msks(Mask file list(str))
    returns: sorted list of image files and mask files 
    """
    enumIdSortImgs=getSortEnums(Imgs)
    ImgsSort=[Imgs[enumIdSortImgs[i][0]] for i in range(len(enumIdSortImgs))]
    enumIdMskSort=getSortEnums(Msks)
    MsksSort=[Msks[enumIdMskSort[i][0]] for i in range(len(enumIdMskSort))]
    return(ImgsSort,MsksSort)

def showOverlay(splitImgs,splitMsks,ind=5):
    """
    display overlay of image and mask
    inputs: 
    splitImgs: Image file list (splitImgs) (str)
    splitMsks: Mask file list (splitMsks) (str)
    ind: integer index defaults to 5
    """
    ImgRead=plt.imread(imgDir+'/'+splitImgs[ind])
    if(len(ImgRead.shape)>2):
        if(ImgRead.shape[2]==3):
            ImgRead=ImgRead[:,:,0]
    MskRead=plt.imread(mskDir+'/'+splitMsks[ind])
    OverlaySample=np.concatenate([np.expand_dims(ImgRead,axis=2),
                                     0*np.expand_dims(ImgRead,axis=2),
                                    np.expand_dims(MskRead,axis=2)],axis=2)
    fig,ax=plt.subplots(nrows=1,ncols=3,figsize=(21,7))
   
    blankImg=np.sum(ImgRead,axis=1)
    mskInd=np.where(blankImg==0)[0]
    if(len(mskInd)>0):
        ImgRead=ImgRead[:mskInd[0],:]
        MskRead=MskRead[:mskInd[0],:]
        OverlaySample=OverlaySample[:mskInd[0],:,:]
    ax[2].imshow(OverlaySample[:,300:])
    ax[2].set_title('Overlay',fontsize=14,fontweight='bold')
    ax[1].imshow(MskRead[:,300:])
    ax[1].set_title('Mask',fontsize=14,fontweight='bold')
    ax[0].imshow(ImgRead[:,300:])
    ax[0].set_title('Image',fontsize=14,fontweight='bold')

ind=7
splitMCUImgs,splitMCUMsks=sortImgByMskIds(splitMCUImgs,splitMCUMsks)
CHNImgsWithMask,splitCHNMsks=sortImgByMskIds(CHNImgsWithMask,splitCHNMsks)
showOverlay(splitMCUImgs,splitMCUMsks,ind)
plt.suptitle(f'MCU Imgs Msk overlay Index:{ind}',fontsize=16,fontweight='bold')
plt.figure()
showOverlay(CHNImgsWithMask,splitCHNMsks,ind)
plt.suptitle(f'CHN Imgs Msk overlay Index:{ind}',fontsize=16,fontweight='bold')



In [None]:
def sortReadingsIds(readings,cond):
    splitReadings=[readings[i] for i in range(len(readings)) if readings[i][:3]==cond]
    enumSortIds=getSortEnums(splitReadings)
    splitReadings=[splitReadings[enumSortIds[i][0]] for i in range(len(enumSortIds))]
    return(splitReadings)

splitReadingsCHN=sortReadingsIds(readings,'CHN')
splitReadingsMCU=sortReadingsIds(readings,'MCU')
CHNReadingsWithMask=[splitReadingsCHN[i] for i in range(len(CHNImgIds)) if CHNImgIds[i] in CHNMskIds]
print(f'Original readings (CHN) {len(splitReadingsCHN)} Original readings (MCU) {len(splitReadingsMCU)}')
print(f'Readings with mask (CHN) {len(CHNReadingsWithMask)}')

In [None]:
def prettify(ax):
    ax.set_xticklabels(np.round(ax.get_xticks(),1),fontsize=12,fontweight='bold')
    ax.set_yticklabels(np.round(ax.get_yticks(),1),fontsize=12,fontweight='bold')
    ax.spines['top'].set_linewidth(3.0)
    ax.spines['bottom'].set_linewidth(3.0)
    ax.spines['left'].set_linewidth(3.0)
    ax.spines['right'].set_linewidth(3.0)

    
def analyseReadings(readingList):
    lenLines=[]
    for i in range(len(readingList)):
        with open(clinicalReadings+readingList[i],'r') as handle:
            dl=handle.readlines()
        lenLines.append(len(dl))
    return(lenLines)

def vizFindingDist(ax,dSet,dSetName='CHN'):
    ax.hist(dSet,20,linewidth=3.0,edgecolor='k')
    ax.set_title(f'{dSetName} reading lines',fontsize=16,fontweight='bold')
    ax.set_ylabel('Number of patients',fontsize=14,fontweight='bold')
    ax.set_xlabel('Number of lines \n in finding',fontsize=14,fontweight='bold')
    prettify(ax)
fig,ax=plt.subplots(nrows=1,ncols=2,figsize=(14,7))
lenLinesCHN=analyseReadings(CHNReadingsWithMask)
lenLinesMCU=analyseReadings(splitReadingsMCU)
vizFindingDist(ax[0],lenLinesCHN,dSetName='CHN')
vizFindingDist(ax[1],lenLinesCHN,dSetName='MCU')

In [None]:
genderListCHN=[]
ageListCHN=[]
findingsListCHN=[]
for i in range(len(CHNReadingsWithMask)):
    with open(clinicalReadings+CHNReadingsWithMask[i],'r') as handle:
        dl=handle.readlines()
        # there is inconsistent use of tabs, end line characters and spaces, removing those for consistency
        stripped_string=dl[0].replace(' ','').strip('\t\n').strip('\t').strip('\n').lower()
        genderListCHN.append(stripped_string[0])
        if('female' in stripped_string):
            ageListCHN.append(stripped_string.strip('female').strip(','))
        elif('femal' in stripped_string): # some have typos
            ageListCHN.append(stripped_string.strip('femal').strip(','))
        else:
            ageListCHN.append(stripped_string.strip('male').strip(','))
        findingsListCHN.append(dl[-1].strip('\t').strip('\n').rstrip(' '))


In [None]:
#MCU list is much more consistent, easier to deal with, not many typos
genderListMCU=[]
ageListMCU=[]
findingMCU=[]
for i in range(len(splitReadingsMCU)):
    with open(clinicalReadings+splitReadingsMCU[i],'r') as handle:
        dl=handle.readlines()
        genderListMCU.append(dl[0].strip('\n').replace(' ','').lower().split(':')[1])
        ageListMCU.append(dl[1].strip('\n').replace(' ','').lower().split(':')[1])
        findingMCU.append(dl[2].strip('\n'))

In [None]:
# find overall gender distribution 
def genderDistOverall(ax,genderList,dSetName='CHN'):
    ax.hist(genderList,edgecolor='k',linewidth=3.0)
    ax.set_title(f'Gender distribution {dSetName}',fontsize=16,fontweight='bold')
    ax.set_xlabel('Gender',fontsize=14,fontweight='bold')
    ax.set_ylabel('Number of patients',fontsize=14,fontweight='bold')
    prettify(ax)
    
    gUnique=np.unique(np.array(genderList))
    if(len(gUnique)>2): #MCU has m, f, o
        gUnique=np.append(np.sort(gUnique[:2])[::-1],gUnique[2])
    else:
        gUnique=np.sort(gUnique[:2])[::-1]
    ax.set_xticklabels(gUnique)
fig,ax=plt.subplots(nrows=1,ncols=2,figsize=(14,7))

genderDistOverall(ax[0],genderListCHN,dSetName='CHN')
genderDistOverall(ax[1],genderListMCU,dSetName='MCU')



In [None]:
#find gender distribution between normal and TB for both MCU and CHN
def getGenderDistribNormalTB(genderList,findingList,findingCond='normal'):
    genderDistribNormal=[genderList[i] for i in range(len(findingList)) if findingList[i]==findingCond]
    genderDistribTB=[genderList[i] for i in range(len(findingList)) if findingList[i]!=findingCond]
    return(genderDistribNormal,genderDistribTB)


#genderDistribNormalMCU=[genderListMCU[i] for i in range(len(findingMCU)) if findingMCU[i]=='normal']
#genderDistribTBMCU=[genderListMCU[i] for i in range(len(findingMCU)) if findingMCU[i]!='normal']

def vizGenderDistribWithDisease(ax,genderListWithDis,dSetName='CHN'):
    ax.hist(genderListWithDis,edgecolor='k',linewidth=3.0)
    ax.set_title(f'Gender distribution normal {dSetName}',fontsize=16,fontweight='bold')
    ax.set_ylabel('Number of patients',fontsize=14,fontweight='bold')
    ax.set_xlabel('Gender',fontsize=14,fontweight='bold')
    prettify(ax)
    gUnique=np.unique(np.array(genderListWithDis))
    if(len(gUnique)>2): #MCU has m, f, o
        gUnique=np.append(np.sort(gUnique[:2])[::-1],gUnique[2])
    else:
        gUnique=np.sort(gUnique[:2])[::-1]
    ax.set_xticklabels(gUnique)

genderDistribNormalCHN,genderDistribTBCHN=getGenderDistribNormalTB(genderListCHN,findingsListCHN,findingCond='normal')
genderDistribNormalMCU,genderDistribTBMCU=getGenderDistribNormalTB(genderListMCU,findingMCU,findingCond='normal')
    
fig,ax=plt.subplots(nrows=2,ncols=2,figsize=(14,14))
vizGenderDistribWithDisease(ax[0,0],genderDistribNormalCHN,dSetName='CHN')
vizGenderDistribWithDisease(ax[0,1],genderDistribTBCHN,dSetName='CHN')

vizGenderDistribWithDisease(ax[1,0],genderDistribNormalMCU,dSetName='MCU')
vizGenderDistribWithDisease(ax[1,1],genderDistribTBMCU,dSetName='MCU')




In [None]:
def getOutliersInDset(ageList,dSetName='CHN'):
    #exclude outliers in dataset, age not in years there are at least 2 in months and 1 in days
    #would be unethical to use these data in experiments
    if(dSetName=='CHN'):
        indStr='yr'
    else:
        indStr='y'

    ageListNotInYrs=[ageList[i] for i in range(len(ageList)) if indStr not in ageList[i]]
    ageListInYrs=[ageList[i] for i in range(len(ageList)) if indStr in ageList[i]]
    return(ageListNotInYrs,ageListInYrs)
ageListNotInYrsCHN,ageListInYrsCHN=getOutliersInDset(ageListCHN,dSetName='CHN')
ageListNotInYrsMCU,ageListInYrsMCU=getOutliersInDset(ageListMCU,dSetName='MCU')

#ageListNotInYrsMCU=[ageListMCU[i] for i in range(len(ageListMCU)) if 'y' not in ageListMCU[i]]
#ageListInYrsMCU=[ageListMCU[i] for i in range(len(ageListMCU)) if 'y' in ageListMCU[i]]

print(f'{len(ageListNotInYrsCHN)} outliers in CHN')
print(f'{len(ageListNotInYrsMCU)} outliers in MCU')

In [None]:
def vizAgeDist(ax,dSet,dSetName='CHN'):
    ax.hist(dSet,20,edgecolor='k',linewidth=3.0)
    ax.set_title(f'Age distribution {dSetName}\nExcluding outliers',fontsize=16,fontweight='bold')
    ax.set_ylabel('Number of patients',fontsize=14,fontweight='bold')
    ax.set_xlabel('Age (Yrs)',fontsize=14,fontweight='bold')
    prettify(ax)
    
def getAgeListNum(ageListInYrsCHN,ageListInYrsMCU):
    ageListInYrsCHN=[int(ageListInYrsCHN[i].split('y')[0]) for i in range(len(ageListInYrsCHN))]
    ageListInYrsMCU=[int(ageListInYrsMCU[i].split('y')[0]) for i in range(len(ageListInYrsMCU))]
    fig,ax=plt.subplots(nrows=1,ncols=2,figsize=(14,7))
    vizAgeDist(ax[0],ageListInYrsCHN,'CHN')
    vizAgeDist(ax[1],ageListInYrsMCU,'MCU')
    return(ageListInYrsCHN,ageListInYrsMCU)

ageListInYrsCHN,ageListInYrsMCU=getAgeListNum(ageListInYrsCHN,ageListInYrsMCU)
    

In [None]:
#exclude those from genderlist whose age is not in years (there are at least 2 16 months, and 1 64 days)
genderListCHNAgeInYrs=[genderListCHN[i] for i in range(len(genderListCHN)) if  'yr' in ageListCHN[i]]

def getAgeGenderDist(ageListInYrs,genderListAgeInYrs):
    ageListM=[ageListInYrs[i] for i in range(len(ageListInYrs)) if genderListAgeInYrs[i]=='m']
    ageListF=[ageListInYrs[i] for i in range(len(ageListInYrs)) if genderListAgeInYrs[i]=='f']
    return(ageListM,ageListF)

def vizAgeGenderDist(ax,dSet,dSetName='CHN',gender='M'):
    ax.hist(dSet,edgecolor='k',linewidth=3.0)
    ax.set_ylabel('Number of patients',fontsize=14,fontweight='bold')
    ax.set_xlabel('Age (Yrs)',fontsize=14,fontweight='bold')
    ax.set_title(f'Age distribution ({gender}){dSetName}',fontsize=16,fontweight='bold')
    prettify(ax)

ageListCHNM,ageListCHNF=getAgeGenderDist(ageListInYrsCHN,genderListCHNAgeInYrs)
ageListMCUM,ageListMCUF=getAgeGenderDist(ageListInYrsMCU,genderListMCU)

fig,ax=plt.subplots(nrows=2,ncols=2,figsize=(14,14))
vizAgeGenderDist(ax[0,0],ageListCHNM,dSetName='CHN',gender='M')
vizAgeGenderDist(ax[0,1],ageListCHNF,dSetName='CHN',gender='F')
vizAgeGenderDist(ax[1,0],ageListMCUM,dSetName='MCU',gender='M')
vizAgeGenderDist(ax[1,1],ageListMCUF,dSetName='MCU',gender='F')



In [None]:
#excluding outlier findings
findingsCHNAgeInYrs=[findingsListCHN[i] for i in range(len(findingsListCHN)) if  'yr' in ageListCHN[i]]
def getAgeGenderFindingDist(ageListInYrs,genderListAgeInYrs,findingsAgeInYrs):
    ageListMNormal=[ageListInYrs[i] for i in range(len(ageListInYrs)) if genderListAgeInYrs[i]=='m' and findingsAgeInYrs[i]=='normal']
    ageListFNormal=[ageListInYrs[i] for i in range(len(ageListInYrs)) if genderListAgeInYrs[i]=='f' and findingsAgeInYrs[i]=='normal']
    ageListMTB=[ageListInYrs[i] for i in range(len(ageListInYrs)) if genderListAgeInYrs[i]=='m' and findingsAgeInYrs[i]!='normal']
    ageListFTB=[ageListInYrs[i] for i in range(len(ageListInYrs)) if genderListAgeInYrs[i]=='f' and findingsAgeInYrs[i]!='normal']
    return(ageListMNormal,ageListFNormal,ageListMTB,ageListFTB)


def vizAgeGenderFindingDist(ax,ageListMNormal,ageListFNormal,ageListMTB,ageListFTB,dSetName='CHN'):
    fig,ax=plt.subplots(nrows=2,ncols=2,figsize=(14,14))

    ax[0,0].hist(ageListMNormal,edgecolor='k',linewidth=3.0)
    ax[0,0].set_title('Age distribution (M) Normal',fontsize=16,fontweight='bold')
    ax[0,0].set_ylabel('Number of patients',fontsize=14,fontweight='bold')
    prettify(ax[0,0])

    ax[0,1].hist(ageListMTB,edgecolor='k',linewidth=3.0)
    ax[0,1].set_title('Age distribution (M) TB',fontsize=16,fontweight='bold')
    prettify(ax[0,1])

    ax[1,0].hist(ageListFNormal,edgecolor='k',linewidth=3.0)
    ax[1,0].set_title('Age distribution (F) Normal',fontsize=16,fontweight='bold')
    ax[1,0].set_xlabel('Age (Yrs)',fontsize=14,fontweight='bold')
    ax[1,0].set_ylabel('Number of patients',fontsize=14,fontweight='bold')
    prettify(ax[1,0])

    ax[1,1].hist(ageListFTB,edgecolor='k',linewidth=3.0)
    ax[1,1].set_title('Age distribution (F) TB',fontsize=16,fontweight='bold')
    ax[1,1].set_xlabel('Age (Yrs)',fontsize=14,fontweight='bold')
    prettify(ax[1,1])
    plt.suptitle(f'Distribution of age and gender with and without TB {dSetName}',fontsize=18,fontweight='bold')

ageListCHNMNormal,ageListCHNFNormal,ageListCHNMTB,ageListCHNFTB=getAgeGenderFindingDist(ageListInYrsCHN,genderListCHNAgeInYrs,findingsCHNAgeInYrs)
vizAgeGenderFindingDist(ax,ageListCHNMNormal,ageListCHNFNormal,ageListCHNMTB,ageListCHNFTB,dSetName='CHN')

In [None]:
ageListMCUMNormal,ageListMCUFNormal,ageListMCUMTB,ageListMCUFTB=getAgeGenderFindingDist(ageListInYrsMCU,genderListMCU,findingMCU)
vizAgeGenderFindingDist(ax,ageListMCUMNormal,ageListMCUFNormal,ageListMCUMTB,ageListMCUFTB,dSetName='MCU')

