In [None]:

from __future__ import print_function, division
import SimpleITK as sitk
import numpy as np
import csv
from glob import glob
import pandas as pd
import os
import cv2
import matplotlib.pyplot as plt
try:
    from tqdm import tqdm # long waits are not fun
except:
    print('TQDM does make much nicer wait bars...')
    tqdm = lambda x: x
    

In [None]:
def load_scan(img_file):
    itk_img = sitk.ReadImage(img_file) 
    img_array = sitk.GetArrayFromImage(itk_img) # indexes are z,y,x (notice the ordering
    origin = np.array(itk_img.GetOrigin())      # x,y,z  Origin in world coordinates (mm)
    spacing = np.array(itk_img.GetSpacing())    # spacing of voxels in world coor. (mm)
    return img_array, origin, spacing

def worldToVoxelCoord(worldCoords, offset, EleSpacing):
    stretchedVoxelCoords = np.absolute(worldCoords - offset)
    voxelCoords = stretchedVoxelCoords / EleSpacing
    return voxelCoords


def make_mask(center,diam,z,width,height,spacing,origin):
    '''
        Center : centers of circles px -- list of coordinates x,y,z
        diam : diameters of circles px -- diameter
        widthXheight : pixel dim of image
        spacing = mm/px conversion rate np array x,y,z
        origin = x,y,z mm np.array
        z = z position of slice in world coordinates mm
    '''
    mask = np.zeros([height,width]) # 0's everywhere except nodule swapping x,y to match img
    #convert to nodule space from world coordinates

    # Defining the voxel range in which the nodule falls
    v_center = worldToVoxelCoord(center, origin, spacing)

    v_diam = int(diam/spacing[0]+5) # ensure that the entire nodule is included
    v_xmin = np.max([0,int(v_center[0]-v_diam)-5])
    v_xmax = np.min([width-1,int(v_center[0]+v_diam)+5])
    v_ymin = np.max([0,int(v_center[1]-v_diam)-5]) 
    v_ymax = np.min([height-1,int(v_center[1]+v_diam)+5])

    v_xrange = range(v_xmin,v_xmax+1)
    v_yrange = range(v_ymin,v_ymax+1)

    # Fill in 1 within sphere around nodule
    for v_x in v_xrange:
        for v_y in v_yrange:
            p_x = spacing[0]*v_x + origin[0]
            p_y = spacing[1]*v_y + origin[1]
            if np.linalg.norm(center-np.array([p_x,p_y,z]))<=diam:
                mask[int((p_y-origin[1])/spacing[1]),int((p_x-origin[0])/spacing[0])] = 1.0
    return(mask)


def show_nodules(imgs, masks):
    # Show the three slices in a row
    fig, axs = plt.subplots(1, 3, figsize=(15, 15))
    for i in range(3):
        axs[i].imshow(imgs[i], cmap="gray")
        axs[i].contour(masks[i], colors="r", linewidths=1)
        axs[i].axis("off")
    plt.show()



def get_nodules_array(imgs, masks):
    # Create an empty numpy array with shape (3, height, width, 3) to store the images and masks
    nodules_array = np.zeros((3, *imgs[0].shape, 3))

    # Loop through the three slices
    for i in range(3):
        # Add the grayscale image to the first channel
        nodules_array[i, :, :, 0] = imgs[i]

        # Add the mask to the second channel
        nodules_array[i, :, :, 1] = masks[i]

        # Add the contour to the third channel
        contour = np.zeros_like(masks[i], dtype=np.uint8)
        cv2.drawContours(contour, cv2.findContours(masks[i].copy(), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)[-2], -1, (255, 0, 0), 1)
        nodules_array[i, :, :, 2] = contour

    return nodules_array




In [None]:


luna_path = "D/Desktop/LUNA16-Data/"
luna_subset_path = "D:/Desktop/LUNA16-Data/Data/"
output_path ="D:/Desktop/LUNA16-Data/Output/"
file_list = glob(os.path.join(luna_subset_path, "*.mhd"))

# Helper function to get rows in data frame associated with each file
def get_filename(file_list, case):
    for f in file_list:
        if case in f:
            return f

# The locations of the nodes
df_node = pd.read_csv("D:/Desktop/LUNA16-Data/annotations.csv")
df_node["file"] = df_node["seriesuid"].map(lambda file_name: get_filename(file_list, file_name))
df_node = df_node.dropna()

for fcount, img_file in enumerate(tqdm(file_list)):
    mini_df = df_node[df_node["file"]==img_file] #get all nodules associate with file

    if mini_df.shape[0]>0: # some files may not have a nodule--skipping those 
        # load the data once
        img_array, origin, spacing = load_scan((img_file))

        num_z, height, width = img_array.shape        #heightXwidth constitute the transverse plane

        for node_idx, cur_row in mini_df.iterrows():   
            node_x = cur_row["coordX"]
            node_y = cur_row["coordY"]
            node_z = cur_row["coordZ"]
            diam = cur_row["diameter_mm"]
            # just keep 3 slices
            imgs = np.ndarray([3,height,width],dtype=np.float32)
            masks = np.ndarray([3,height,width],dtype=np.uint8)
            
            center = np.array([node_x, node_y, node_z])   # nodule center
            v_center = ((center - origin) / spacing).astype(int)  # nodule center in voxel space (still x,y,z ordering)
            for i, i_z in enumerate(np.arange(int(v_center[2])-1, int(v_center[2])+2).clip(0, num_z-1)): # clip prevents going out of bounds in Z
                mask = make_mask(center, diam, i_z*spacing[2]+origin[2],
                                 width, height, spacing, origin)
                masks[i] = mask
                imgs[i] = img_array[i_z]
                
            nodules_array = get_nodules_array(imgs, masks)
            np.save(os.path.join(output_path, "nodules_%04d_%04d.npy" % (fcount, node_idx)), nodules_array)
            np.save(os.path.join(output_path,"images_%04d_%04d.npy" % (fcount, node_idx)),imgs)
            np.save(os.path.join(output_path,"masks_%04d_%04d.npy" % (fcount, node_idx)),masks)

            

In [None]:

working_path = "D:/Desktop/LUNA16-Data/Output/"

# Get the list of all npy files in the directory
files = [f for f in os.listdir(working_path) if f.endswith('.npy')]

# Loop through each file and display the images and masks
fileNum = 0
for file in files:
    if fileNum >= 10:
        break
    fileNum = fileNum + 1
    if 'images' in file:
        num = file.split('_')[1] + '_' + file.split('_')[2].split('.')[0]
        imgs = np.load(working_path + 'images_' + num + '.npy')
        labelnodules = np.load(working_path + 'nodules_' + num + '.npy')
        lungmask = np.load(working_path + 'masks_' + num + '.npy')

        for i in range(len(imgs)):
            fig, ax = plt.subplots(2, 2, figsize=(10, 10))
            ax[0,0].imshow(imgs[i], cmap='gray')
            ax[0,1].imshow(labelnodules[i, :, :, 0], cmap='gray')
            ax[0,1].contour(labelnodules[i, :, :, 2], colors="r", linewidths=1)
            ax[1,0].imshow(lungmask[i], cmap='gray')
            ax[1,1].imshow(imgs[i] * lungmask[i] , cmap='gray')
            plt.show()
