In [4]:
import numpy as np
import matplotlib.pyplot as plt
from skimage.segmentation import clear_border
from skimage.measure import label, regionprops
from scipy.ndimage import measurements, center_of_mass, binary_dilation, zoom, binary_fill_holes
import nibabel as nib
import os
import argparse
parser = argparse.ArgumentParser(description='Perform lung segmentation on NIfTI images.')
parser.add_argument('input_path', type=str, help='Input directory containing NIfTI images')
parser.add_argument('output_path', type=str, help='Output directory to save segmented images')
parser.add_argument('-t', '--threshold', type=float, default=-320, help='Threshold value for segmentation')
parser.add_argument('-l', '--labels', type=int, default=3, help='Number of top labels to retain')
parser.add_argument('-d', '--dilation_iterations', type=int, default=5, help='Number of dilation iterations')


args = parser.parse_args()

input_path = args.input_path
output_path = args.output_path
threshold = args.threshold
labels = args.labels
dilation_iterations = args.dilation_iterations
# put smallest dimension in axis 0
def get_axes_maps(tensor):
    """
    This function takes a tensor and returns a mapping of axes and the reverse mapping.
    """
    # Create a mapping of axes
    axes_map = {i: (len(tensor.shape) - 1 - i) for i in range(len(tensor.shape))}
    # Create the reverse mapping
    reverse_map = {v: k for k, v in axes_map.items()}
    return axes_map, reverse_map

def transpose_tensor(tensor, axes_map):
    """
    This function takes a tensor and an axes map, and returns the transposed tensor.
    """
    # Get the list of axes for transposition
    axes = [axes_map[i] for i in range(len(tensor.shape))]
    # Perform the transposition
    return np.transpose(tensor, axes=axes)

def mask_and_label(img, threshold):
    mask = img < threshold
    mask = np.vectorize(clear_border, signature='(m,n)->(m,n)')(mask)
    mask_labeled = np.vectorize(label, signature='(m,n)->(m,n)')(mask)
    return mask_labeled

def keep_top_n_labels(slice, n = 5):
    new_slice = np.zeros_like(slice)
    rps = regionprops(slice)
    areas = [rp.area for rp in rps]
    
    sorted_indices = np.argsort(areas)[::-1]
    for index in sorted_indices[:n]:
        new_slice[tuple(rps[index].coords.T)] = index + 1 # tuple(rps[index].coords.T) converts the coordinates to the right format
    return new_slice

def remove_trachea(slice, threshold = 0.007):
    new_slice = slice.copy()
    labels = label(slice, connectivity=1, background=0)
    rps = regionprops(labels)
    labels_areas = np.array([rp.area for rp in rps])
    total_area_of_slice = slice.shape[0] * slice.shape[1]
    indices_to_remove = np.where(labels_areas / total_area_of_slice < threshold)[0]
    for index in indices_to_remove:
        new_slice[tuple(rps[index].coords.T)] = 0
    return new_slice

def lung_segmentation(input_path, output_path):
    ct = nib.load(input_path)
    img = ct.get_fdata()
    # need to put it in the right orientation
    axes_map, reverse_map = get_axes_maps(img)
    img = transpose_tensor(img, axes_map)

    mask_labeled = mask_and_label(img, threshold)
    mask_labeled = keep_top_n_labels(mask_labeled, labels)

    mask = mask_labeled != 0
    #fill holes in mask
    mask = np.vectorize(binary_fill_holes, signature='(m,n)->(m,n)')(mask)
    mask = np.vectorize(remove_trachea, signature='(m,n)->(m,n)')(mask)
    mask = binary_dilation(mask, iterations=dilation_iterations)
    mask = transpose_tensor(mask, reverse_map)
    mask_nii = nib.Nifti1Image(mask, ct.affine, ct.header)
    nib.save(mask_nii, output_path)

files = os.listdir(input_path)
for file in files:
    input_file = os.path.join(input_path, file)
    output_file = os.path.join(output_path, file)
    # see if output file already exists
    if os.path.exists(output_file):
        # delete it
        print(f"deleting {output_file}")
        os.remove(output_file)

    print(f"processing {file}")
    #lung_segmentation(input_file, output_file)
    try:
        lung_segmentation(input_file, output_file)
    except Exception as e:
        print(f"error processing {file}: {e}")
        continue




processing LUNG1-001_0000.nii.gz
processing LUNG1-002_0000.nii.gz
processing LUNG1-004_0000.nii.gz
processing LUNG1-005_0000.nii.gz
processing LUNG1-006_0000.nii.gz


EOFError: Compressed file ended before the end-of-stream marker was reached