### Import required modules

In [1]:
import os
import csv
import pickle
from time import sleep
import numpy as np
import pydicom as dcm
import nibabel as nib
from itkwidgets import view
from skimage import measure
from skimage import segmentation

### Functions for loading cases, images, and masks

In [2]:
def get_metadata(filename):
    cases = []
    with(open(filename, 'r') as csvf):
        csv_reader = csv.reader(csvf)
        header = next(csv_reader)
        for row in csv_reader:
            cases.append({header[i]:row[i] for i in range(len(row))})
    return cases[86:]

def read_dicom(case: dict):
    filename = "Dataset" + case['File Location'][1:].replace('\\', '/')
    dcms = os.listdir(filename)
    dcms.sort()
    first_image = dcm.read_file(f"{filename}/{dcms[0]}")
    first_pixs = first_image.pixel_array
    volume = np.empty((first_pixs.shape[0], first_pixs.shape[1], len(dcms)))
    for idx, im in enumerate(dcms):
        pixels = dcm.read_file(f"{filename}/{im}").pixel_array
        volume[:,:,idx] = pixels.transpose()
    return volume

def read_mask(case: dict):
    case_name = case['File Location'][1:].split('\\')[2]
    filename = "Dataset/MED_ABD_LYMPH_MASKS/" + case_name + "/" + case_name + "_mask.nii.gz"
    mask = nib.load(filename).get_fdata()
    return mask

### Functions for performing processing on images and masks

In [43]:
def man_binary_thresh(arr, lo, hi):
    thresh = np.where((arr >= lo) & (arr <= hi), 1, 0)
    return thresh

def connected_components_2d(area, lo, hi, return_num=False):
    thresh = man_binary_thresh(area, lo, hi)
    comps, num = measure.label(thresh, return_num=True)
    return comps, num if return_num else comps

def sort_conn_comps(conn_comps, num_conn_comps):
    counts = [(idx, np.count_nonzero(conn_comps==idx)) for idx in range(num_conn_comps)]
    return sorted(counts, key = lambda i: i[1], reverse=True)
        
def find_average_pos(bin_image):
    x, y, z = np.nonzero(bin_image)
    return (np.mean(x), np.mean(y), np.mean(z))

def find_lungs(vol):
    # Binary threshold image
    bin_im = man_binary_thresh(vol, -900, -700)
    # Calculate average position (assumes this will be in a slice containing lung)
    average_pos = find_average_pos(bin_im)
    # Find connected components of the average position slice
    conn_comps, num_comps = connected_components_2d(np.squeeze(vol[:,:,int(average_pos[2])]), -900, -700, return_num=True)
    # Find the second biggest label which should correspond to lung
    lung_label = sort_conn_comps(conn_comps, num_comps)[1][0]
    flood_seed = list(zip(*np.nonzero(conn_comps==lung_label)))[0]
    # Flood fill volume using this label
    flood = segmentation.flood(bin_im, (*flood_seed, int(average_pos[2])))
    # Return corresponding elements in original volume
    lungs_vol = np.where(flood, vol, 0)
    return lungs_vol, (*flood_seed, int(average_pos[2]))

def find_nonzero_extent(arr, lungs=False, buffer_percent=0.0):
    inds = np.nonzero(arr)
    extent = [[min(inds[idx]), max(inds[idx])] for idx in range(3)]
    if lungs:
        if abs(extent[0][0] - arr.shape[0]//2) < arr.shape[0]//8\
        and abs(extent[0][1] - arr.shape[0]//2) < arr.shape[0]//8:
            return extent # Do nothing, lungs weren't found
        elif abs(extent[0][0] - arr.shape[0]//2) < arr.shape[0]//8:
            extent[0][0] = arr.shape[0] - extent[0][1]
        elif abs(extent[0][1] - arr.shape[0]//2) < arr.shape[0]//8:
            extent[0][1] = arr.shape[0] - extent[0][0]
    if buffer_percent:
        for idx, pair in enumerate(extent):
            span = pair[1] - pair[0]
            extent[idx] = [pair[0]-int(buffer_percent*span/100.0), pair[1]+int(buffer_percent*span/100.0)]
    return extent

### Load cases

In [4]:
cases = get_metadata('Dataset/metadata.csv')

### See if bounding prism containing lungs includes masked lymph nodes

In [44]:
def test_prisms(cases, indices=[]):
    failures, extents = {}, {}
    if not indices:
        for idx, case in enumerate(cases):
            vol = read_dicom(case)
            mask = read_mask(case)
            lungs, pos = find_lungs(vol[:,:,vol.shape[2]//2:])
            lungs_ext = find_nonzero_extent(lungs, lungs=True, buffer_percent=10.0)
            mask_ext = find_nonzero_extent(mask[:,:,mask.shape[2]//2:])
            for pair in zip(lungs_ext, mask_ext):
                if pair[1][1] > pair[0][1] or pair[1][0] < pair[0][0]:
                    print(f"\nFAIL WITH CASE {idx}")
                    failures[idx] = {"lungs": lungs_ext, "mask": mask_ext}
                    break
            extents[idx] = {"lungs": lungs_ext, "mask": mask_ext}

            print(f"Done with {idx} of {len(cases)-1}", end="\r")
            sleep(0) # This makes carriage return work. I dunno either.
    else:
        for idx in indices:
            vol = read_dicom(cases[idx])
            mask = read_mask(cases[idx])
            lungs, pos = find_lungs(vol[:,:,vol.shape[2]//2:])
            lungs_ext = find_nonzero_extent(lungs, lungs=True, buffer_percent=10.0)
            mask_ext = find_nonzero_extent(mask[:,:,mask.shape[2]//2:])
            for pair in zip(lungs_ext, mask_ext):
                if pair[1][1] > pair[0][1] or pair[1][0] < pair[0][0]:
                    print(f"\nFAIL WITH CASE {idx}")
                    failures[idx] = {"lungs": lungs_ext, "mask": mask_ext}
                    break
            extents[idx] = {"lungs": lungs_ext, "mask": mask_ext}
    return failures, extents

In [45]:
failures, extents = test_prisms(cases)

Done with 87 of 87

### Write failures and extents to file for later evaluation

In [46]:
with(open('lung_failures', 'wb') as f):
    pickle.dump(failures, f)

with(open('extents', 'wb') as f):
    pickle.dump(extents, f)

In [47]:
with(open('lung_failures', 'rb') as f):
    read_failures = pickle.load(f)
    
with(open('extents', 'rb') as f):
    read_extents = pickle.load(f)

### View a fiew examples

In [48]:
def overlay_prism_on_vol(case_idx):
    vol = read_dicom(cases[case_idx])
    ext = read_extents[case_idx]['lungs']
    prism = np.zeros_like(vol)
    prism[ext[0][0]:ext[0][1],
          ext[1][0]:ext[1][1],
          ext[2][0]+vol.shape[2]//2:ext[2][1]+vol.shape[2]//2] = 1
    return vol, prism

def overlay_prism_on_mask(case_idx):
    mask = read_mask(cases[case_idx])
    ext = read_extents[case_idx]['lungs']
    prism = np.zeros_like(mask)
    prism[ext[0][0]:ext[0][1],
          ext[1][0]:ext[1][1],
          ext[2][0]+mask.shape[2]//2:ext[2][1]+mask.shape[2]//2] = 1
    return mask, prism

In [53]:
view(*overlay_prism_on_vol(3))

Viewer(geometries=[], gradient_opacity=0.22, interpolation=False, point_sets=[], rendered_image=<itk.itkImageP…