In [5]:
'''
import necessary packages
'''

import os
import sys
import glob
import cv2
import re
import imageio
import itertools

import numpy as np
import matplotlib.pyplot as plt

from scipy import ndimage

from skimage import io, color
from skimage.io import imshow, imsave
from skimage.color import rgb2gray
from skimage.transform import rescale, resize, downscale_local_mean
from skimage.segmentation import clear_border
from skimage.morphology import ball, disk, remove_small_objects
from skimage.morphology import erosion, dilation, closing, reconstruction
from skimage.morphology import binary_erosion, binary_dilation, binary_closing, binary_opening
from skimage.measure import label, regionprops, perimeter
from skimage.filters import roberts, sobel, threshold_otsu


In [1]:
'''
define lung segmentation and image viewer functions
'''

def segment_lungs(im):
    """
    Input: 2D lung CT image in the form of an array
    Output: binary mask for lung CT
    """
    
    ##### BINARIZE IMAGE
    # first anything <0 (as some images are negative)
    im_b = np.where(im<0, 0, im)
    #then based on otsu thresholding
    thresh = threshold_otsu(im_b)
    binary = im_b > thresh

    #invert the image to make the lungs the ROIs
    binary = np.invert(binary)

    ##### GENERATE BORDER OF THE BINARIES
    cleared = clear_border(binary)
    
    ##### LABEL DISTINCT BODIES IN IMAGE
    label_image = label(cleared)

    ##### KEEP TWO LARGEST AREAS
    areas = [r.area for r in regionprops(label_image)]
    areas.sort()
    if len(areas) > 2:
        for region in regionprops(label_image):
            if region.area < areas[-2]:
                for coordinates in region.coords:                
                       label_image[coordinates[0], coordinates[1]] = 0
    binary = label_image > 0
    
    ##### EROSION TO GET RID OF ARTIFACTS
#    selem = disk(4)
#    binary = binary_erosion(binary, selem)

    ##### DILATION TO BRING BACK LUNG INFORMATION
    selem = disk(10)
    binary = binary_dilation(binary, selem)

    ##### FILL SMALL HOLES
    edges = sobel(binary)
    binary = ndimage.binary_fill_holes(edges)

    return binary

'''
import images from directory, convert to grayscale, append to list
'''
def import_images(dir_name, images, image_names):
    for root, dirnames, filenames in os.walk(dir_name):
        for filename in filenames:
            if re.search("\.(jpg|jpeg|png|bmp|tiff)$", filename):
                filepath = os.path.join(root, filename)
                image_names.append(filename)
                
                ## read images and add to list
                image = io.imread(fname=filepath)
                image_gray = rgb2gray(image)
                images.append(image_gray)

'''
establish function to create lung masks for all images in collection
'''

def create_masks(images, image_masks):
    for i in range(len(images)):
        img = images[i]
        img_mask = segment_lungs(img)
        image_masks.append(img_mask)

'''
establish function to save all lung masks to file
'''
def save_masks(dir_out, image_masks, image_names):
    # loop through image_masks, convert from bool to save to file
    for i in range(len(image_masks)):
        img_mask = image_masks[i]
        img_name = image_names[i]

        # define destination filename
        filename='{}{}'.format(dir_out, img_name)

        # convert bool image masks to uint8 to allow writing to file
        img_mask = img_mask.astype(np.uint8)  #convert to an unsigned byte
        img_mask*=255

        # write to file
        cv2.imwrite(filename, img_mask)

'''
create function to display multiple images in a single plot
'''

def show_images(images, cols = 4, titles = None):
    """Display a list of images in a single figure with matplotlib.
    
    Parameters
    ---------
    images: List of np.arrays compatible with plt.imshow.
    
    cols (Default = 1): Number of columns in figure (number of rows is 
                        set to np.ceil(n_images/float(cols))).
    
    titles: List of titles corresponding to each image. Must have
            the same length as titles.
    """
    assert((titles is None)or (len(images) == len(titles)))
    n_images = len(images)
    if titles is None: titles = ['Image (%d)' % i for i in range(1,n_images + 1)]
    fig = plt.figure()
    for n, (image, title) in enumerate(zip(images, titles)):
        a = fig.add_subplot(cols, np.ceil(n_images/float(cols)), n + 1)
        if image.ndim == 2:
            plt.gray()
        plt.imshow(image)
        a.set_title(title, fontsize = 100)
    fig.set_size_inches(np.array(fig.get_size_inches()) * n_images)
    plt.show()
    


In [15]:
'''
import images, create masks, add to lists
'''
# define source directory for image dataset
img_dir = '../../dataset/images'

# establish empty lists for images and image names
images = []
image_names = []
    
# import images and populate lists (image names and arrays)
import_images(img_dir, images, image_names)



In [16]:
'''
create lung masks on all images and add to list
'''
# establish empty list for lung masks
image_masks = []

# create lung masks from images list
create_masks(images, image_masks)



In [19]:
'''
save lung masks to file
'''
# define destination directory for image masks
dir_out = '../../dataset/masks/'

# save lung masks to disk
save_masks(dir_out, image_masks, image_names)


In [None]:
'''
plot 20 images
'''
img_set = images[40:60]
name_set = image_names[40:60]

show_images(img_set, titles = name_set)

In [None]:
'''
plot 20 image masks
'''
mask_set = image_masks[40:60]
name_set = image_names[40:60]

show_images(mask_set, titles = name_set)