In [None]:
# 3D object detection for multiple DICOM files
import SimpleITK as sitk 
import numpy as np 
from numpy import asarray 
from PIL import Image
from PIL import ImageOps
import matplotlib
import matplotlib.pyplot as plt 
import cv2 
import sys 
import scipy.ndimage
import pydicom # for reading dicom files 
import os # for doing directory operations 
import time 
import datetime 
import io 
import imgcompare as ic

dirInput = 'C:\\Users\\Administrator\\OneDrive\\Thesis\\Coding\\Leon\\Heatmap_object_detection\\MMWHS\\' #*
dirOutput = 'C:\\Users\\Administrator\\OneDrive\\Thesis\\Coding\\Leon\\Heatmap_object_detection\\cropped\\' #*
dirRef = 'C:\\Users\\Administrator\\OneDrive\\Thesis\\Coding\\Leon\\Heatmap_object_detection\\reference_imgs\\' #*
dirMeta = 'C:\\Users\\Administrator\\OneDrive\\Thesis\\Coding\\Leon\\Heatmap_object_detection\\cropped\\metadata\\' #*
voxelsize = [0.16, 0.16, 0.16] # 0.4 to 0.16 --> resize 512 to 128 
#*denoise = (5,5,5)
#*posterizeLevels = 8
#*createAtlas = False

In [None]:
def readDicom(path):
    slices = [pydicom.read_file(path + '/' + s) for s in os.listdir(path)]
    if slices[0].ImageOrientationPatient != [-1, 0, 0, 0, -1, 0]:
        print('WARNING - Unkown image orientation '+str(slices[0].ImageOrientationPatient))
    if slices[0].PixelSpacing[0] != slices[0].PixelSpacing[1]:
        sys.exit('ERROR - pixel height and depth do not match '+str(slices[0].PixelSpacing))
    if slices[0].PixelSpacing[0] != slices[0].SliceThickness:
        sys.exit('ERROR - pixelspacing and slicedepth do not match '
                 +str(slices[0].PixelSpacing+' - '+str(slices[0].SliceThickness)))
#*    if slices[0].PatientPosition == 'HFP':
#*        if 'diastole' in path.lower():
#*            atlas_a = Image.open(dirRef+'diastole 63 A2.jpg').convert('L') 
#*            atlas_s = Image.open(dirRef+'diastole 67 S2.jpg').convert('L')
#*        elif 'systole' in path.lower():
#*            atlas_a = Image.open(dirRef+'systole 63 A2.jpg').convert('L') 
#*            atlas_s = Image.open(dirRef+'systole 67 S2.jpg').convert('L') 
    ratio = voxelsize[0]/slices[0].SliceThickness

    slices = slices[::-1] # to comply with orientation in viewers such as ITK-Snap or 3D Slicer
    scan = np.stack([s.pixel_array for s in slices])
    scan = scan.astype(np.int16)
    scan = np.flip(scan) # to comply with orientation in viewers such as ITK-Snap or 3D Slicer
    
    # Save DICOM metadata for each ROI so it can be added to the registration and segmentation output
    save_metadata(slices, path)
    return scan, slices, ratio

def save_metadata(slices, path):
    line = (dicom+';'
    +path.replace("\\","/")+';'
    +str(slices[0].PatientName)+';'
    +str(slices[0].PatientComments)+';'
    +str(slices[0].ScanOptions)+';'
    +str(slices[0].SeriesNumber)+';'
    +str(slices[0].ImageComments)+';'
    +str(slices[0].StudyComments)+';'
    +str(slices[0].StudyDescription)+';'
    +str(slices[0].PatientID)+';'
    +str(slices[0].PatientBirthDate)+';'
    +str(slices[0].PatientSex)+';'
    +str(slices[0].PatientAge)+';'
    +str(slices[0].PatientSize)+';'
    +str(slices[0].PatientWeight)+';'
    +str(slices[0].StudyDate)+';'
    +str(slices[0].SeriesDate)+';'
    +str(slices[0].pixel_array.shape)+';'
    +str(slices[0].Rows)+';'
    +str(slices[0].Columns)+';'
    +str(len(slices))+';'
    +str(slices[0].PixelSpacing)+';'
    +str(slices[0].SliceThickness)+';'
    +str(slices[0].WindowCenter)+';'
    +str(slices[0].WindowWidth)+';'
    +str(slices[0].RescaleIntercept)+';'
    +str(slices[0].RescaleSlope)+';'
    +str(slices[0].LossyImageCompression)+';'
    +str(slices[0].PatientPosition)+';'
    +str(slices[0].ImagePositionPatient)+';'
    +str(slices[0].ImageOrientationPatient)+';'
    +str(slices[0].FrameOfReferenceUID)+';'
    +str(slices[0].Laterality)+';'
    +str(slices[0].PositionReferenceIndicator)+';'
    +str(slices[0].SliceLocation)+';'
    +str(slices[0].ManufacturerModelName)+';'
    +str(slices[0].SoftwareVersions)+';'
    +str(slices[0].ContrastBolusAgent)+';'
    +str(slices[0].KVP)+';'
    +str(slices[0].DistanceSourceToDetector)+';'
    +str(slices[0].DistanceSourceToPatient)+';'
    +str(slices[0].RotationDirection)+';'
    +str(slices[0].ExposureTime)+';'
    +str(slices[0].XRayTubeCurrent)+';'
    +str(slices[0].FilterType)+';'
    +str(slices[0].FocalSpots)+';'
    +str(slices[0].StudyInstanceUID)+';'
    +str(slices[0].SeriesInstanceUID)+';'
    +str(slices[0].StudyID)+';'
    +str(slices[0].ImageType)+';'
    +str(slices[0].SOPClassUID)+';'
    +str(slices[0].Modality)+';'
    +str(slices[0].PhotometricInterpretation)+';'
    +str(slices[0].SpecificCharacterSet)+';'
    +str(slices[0].SamplesPerPixel)+';'
    +str(slices[0].PixelAspectRatio)+';'
    +str(slices[0].BitsAllocated)+';'
    +str(slices[0].HighBit)+';'
    +str(slices[0].PixelRepresentation))
    line = line.replace("\r\n","")
    line = line+'\r'
    
    with open(dirMeta+'ROI_metadata.csv', 'a') as roifile:
        roifile.write(line)
        
def set_valid_values(arr):
    # Convert array to float32 with range of 0.0-1.0
    arrmin = np.amin(arr)
    arrmax = np.amax(arr)
    if arrmax-arrmin > 0:
        arr = np.float32((arr-arrmin)/(arrmax-arrmin))
    else:
        arr = np.float32(arr-arrmin)
    arr = np.clip(arr, a_min=0.0, a_max=1.0)
    return arr

def print_duration(msg):
    hours, rem = divmod(time.time()-start, 3600)
    minutes, seconds = divmod(rem, 60)
    print("{:0>2}:{:0>2}:{:0>2} - ".format(int(hours),int(minutes),int(seconds))+msg)
    
# Resample for a uniform slice thickness over all scans   
def resample3d(image, scan, new_spacing=[1,1,1]):
    # Determine current pixel spacing
    spacing = np.array([scan[0].SliceThickness] + list(scan[0].PixelSpacing), dtype=np.float32)

    resize_factor = spacing / new_spacing
    new_real_shape = image.shape * resize_factor
    new_shape = np.round(new_real_shape)
    real_resize_factor = new_shape / image.shape
    new_spacing = spacing / real_resize_factor
    
    image = scipy.ndimage.interpolation.zoom(image, real_resize_factor, mode='nearest')
    
    return image, new_spacing

def find_object(resampled, atlas_a, atlas_c, atlas_s):
    bestacc  = 999999
    best_x   = -1
    best_y   = -1
    best_z   = -1
    depth    = resampled.shape[0]
    height   = resampled.shape[1]
    width    = resampled.shape[2]
    acc3d    = np.zeros((height,width,depth), dtype=np.float32)
    acc3d1   = np.zeros((height,width,depth), dtype=np.uint8)

    # Axial
    atlas_a_h = np.array(atlas_a).shape[0]
    atlas_a_w = np.array(atlas_a).shape[1]
    i_y = (height-atlas_a_h)
    i_x = (width-atlas_a_w)

    for iz in range(depth):
        if iz % 4 == 0:
            temp01 = set_valid_values(resampled[iz,:,:])
            tempimg = Image.fromarray(np.uint8(temp01 * 255) , 'L')
            for iy in range(i_y):
                if iy % 4 == 0:
                    for ix in range(i_x):
                        if ix % 4 == 0:
                            crop = tempimg.crop((0+ix, 0+iy, atlas_a_w+ix, atlas_a_h+iy))
                            acc = ic.image_diff_percent(atlas_a, crop)

                            d = iz-2
                            d = int(np.clip(d, a_min=0.0, a_max=depth))
                            d1 = d+4
                            if d==0:
                                d1 = d+2
                            elif d==depth:
                                d1 = depth

                            for i_d in range(d,d1):
                                h = int(iy+(atlas_a_h//2)-2)
                                for i_h in range(h,h+4):
                                    w = int(ix+(atlas_a_w//2)-2)
                                    for i_w in range(w,w+4):
                                        acc3d[i_h,i_w,i_d]+=acc
                                        acc3d1[i_h,i_w,i_d]+=1

    # Sagittal
    atlas_s_h = np.array(atlas_s).shape[0]
    atlas_s_w = np.array(atlas_s).shape[1]
    i_z = (depth-atlas_s_h)
    i_y = (height-atlas_s_w)

    for ix in range(width):
        if ix % 4 == 0:
            temp01 = set_valid_values(resampled[:,:,ix])
            tempimg = Image.fromarray(np.uint8(temp01 * 255) , 'L')
            for iz in range(i_z):
                if iz % 4 == 0:
                    for iy in range(i_y):
                        if iy % 4 == 0:
                            crop = tempimg.crop((0+iy, 0+iz, atlas_s_w+iy, atlas_s_h+iz))
                            acc = ic.image_diff_percent(atlas_s, crop)

                            w = ix-2
                            w = int(np.clip(w, a_min=0.0, a_max=width))
                            w1 = w+4
                            if w==0:
                                w1 = w+2
                            elif w==width:
                                w1 = width

                            for i_w in range(w,w1):
                                h = int(iy+(atlas_s_w//2)-2)
                                for i_h in range(h,h+4):
                                    d = int(iz+(atlas_s_h//2)-2)
                                    for i_d in range(d,d+4):
                                        acc3d[i_h,i_w,i_d]+=acc
                                        acc3d1[i_h,i_w,i_d]+=1

    # Coronal
    atlas_c_h = np.array(atlas_c).shape[0]
    atlas_c_w = np.array(atlas_c).shape[1]
    i_z = (depth-atlas_c_h)
    i_x = (width-atlas_c_w)

    for iy in range(height):
        if iy % 4 == 0:
            temp01 = set_valid_values(resampled[:,iy,:])
            tempimg = Image.fromarray(np.uint8(temp01 * 255) , 'L')
            for iz in range(i_z):
                if iz % 4 == 0:
                    for ix in range(i_x):
                        if ix % 4 == 0:
                            crop = tempimg.crop((0+ix, 0+iz, atlas_c_w+ix, atlas_c_h+iz))
                            acc = ic.image_diff_percent(atlas_c, crop)

                            h = iy-2
                            h = int(np.clip(h, a_min=0.0, a_max=height))
                            h1 = h+4
                            if h==0:
                                h1 = h+2
                            elif h==height:
                                h1 = height

                            for i_h in range(h,h1):
                                w = int(ix+(atlas_c_w//2)-2)
                                for i_w in range(w,w+4):
                                    d = int(iz+(atlas_c_h//2)-2)
                                    for i_d in range(d,d+4):
                                        acc3d[i_h,i_w,i_d]+=acc
                                        acc3d1[i_h,i_w,i_d]+=1

    # Calculate best accuracy over all axis
    for d in range(depth):
        for h in range(height):
            for w in range(width):
                # Only measurements from all 3 axis are valid
                if acc3d1[h,w,d]==3:
                    if acc3d[h,w,d]<bestacc:
                        bestacc = acc3d[h,w,d]
                        best_x = w
                        best_y = h
                        best_z = d
                    
    return bestacc/3, best_x, best_y, best_z, atlas_a_w, atlas_a_h, atlas_s_w, atlas_s_h, atlas_c_w, atlas_c_h

def create_roifile(inputdir, 
                   outputdir, 
                   dicom,
                   atlas,
                   ratio, 
                   mediansize, 
                   best_x, 
                   best_y, 
                   best_z, 
                   atlas_w, 
                   atlas_h, 
                   atlas_d,
                   correctie_z):
    data_directory = os.path.dirname(inputdir+dicom+'/')
    reader         = sitk.ImageSeriesReader()
    original_image = sitk.ReadImage(reader.GetGDCMSeriesFileNames(data_directory))
    arr3d          = sitk.GetArrayFromImage(original_image)
    depth          = arr3d.shape[0]
    height         = arr3d.shape[1]
    width          = arr3d.shape[2]
    
    # Correction for a small misaligment along the Z-axis to prevent aorta cutoffs
    best_z += correctie_z
    
    # 512-x to comply with orientation in viewers such as ITK-Snap or 3D Slicer
    best_x = 512-(best_x*ratio)
    best_y = 512-(best_y*ratio)
    best_z = 512-(best_z*ratio)
    
    # Add a small margin to prevent (rare) myocardium cutoffs
    atlas_w = int(round(atlas_w*1.05))

    x1 = int(best_x-((atlas_w*ratio)/2))
    x2 = int(best_x+((atlas_w*ratio)/2))
    y1 = int(best_y-((atlas_h*ratio)/2))
    y2 = int(best_y+((atlas_h*ratio)/2))
    z1 = int(best_z-((atlas_d*ratio)/2))
    z2 = int(best_z+((atlas_d*ratio)/2))
    x1 = np.clip(x1, a_min=0, a_max=width)
    x2 = np.clip(x2, a_min=0, a_max=width)
    y1 = np.clip(y1, a_min=0, a_max=height)
    y2 = np.clip(y2, a_min=0, a_max=height)
    z1 = np.clip(z1, a_min=0, a_max=depth)
    z2 = np.clip(z2, a_min=0, a_max=depth)
    
    cropped_image = original_image[x1:x2,y1:y2,z1:z2]
    if atlas:
        cropped_image = img_denoiseMedian(cropped_image, mediansize, posterizeLevels)

    cropped_image.SetSpacing(original_image.GetSpacing())
    cropped_image.SetDirection(original_image.GetDirection())
    
    # Save image to disk
    output_file = os.path.join(outputdir, dicom+' ROI.nii.gz')
    sitk.WriteImage(cropped_image, output_file)

start = time.time() 
dt = datetime.datetime.fromtimestamp(start).strftime('%Y-%m-%d %H:%M:%S') 
print('Start:',dt)
print('')

dia_a = Image.open(dirRef+'diastole 63 A1.jpg').convert('L') 
dia_c = Image.open(dirRef+'diastole 39 C1.jpg').convert('L') 
dia_s = Image.open(dirRef+'diastole 67 S1.jpg').convert('L')
sys_a = Image.open(dirRef+'systole 63 A1.jpg').convert('L') 
sys_c = Image.open(dirRef+'systole 39 C1.jpg').convert('L') 
sys_s = Image.open(dirRef+'systole 67 S1.jpg').convert('L')

# Print columnheaders for 'ROI_metadata.csv'
line = ('dicom;'+
'filelocation;'+
'PatientName;'+
'PatientComments;'+
'ScanOptions;'+
'SeriesNumber;'+
'ImageComments;'+
'StudyComments;'+
'StudyDescription;'+
'PatientID;'+
'PatientBirthDate;'+
'PatientSex;'+
'PatientAge;'+
'PatientSize;'+
'PatientWeight;'+
'StudyDate;'+
'SeriesDate;'+
'shape;'+
'Rows;'+
'Columns;'+
'slices;'+
'PixelSpacing;'+
'SliceThickness;'+
'WindowCenter;'+
'WindowWidth;'+
'RescaleIntercept;'+
'RescaleSlope;'+
'LossyImageCompression;'+
'PatientPosition;'+
'ImagePositionPatient;'+
'ImageOrientationPatient;'+
'FrameOfReferenceUID;'+
'Laterality;'+
'PositionReferenceIndicator;'+
'SliceLocation;'+
'ManufacturerModelName;'+
'SoftwareVersions;'+
'ContrastBolusAgent;'+
'KVP;'+
'DistanceSourceToDetector;'+
'DistanceSourceToPatient;'+
'RotationDirection;'+
'ExposureTime;'+
'XRayTubeCurrent;'+
'FilterType;'+
'FocalSpots;'+
'StudyInstanceUID;'+
'SeriesInstanceUID;'+
'StudyID;'+
'ImageType;'+
'SOPClassUID;'+
'Modality;'+
'PhotometricInterpretation;'+
'SpecificCharacterSet;'+
'SamplesPerPixel;'+
'PixelAspectRatio;'+
'BitsAllocated;'+
'HighBit;'+
'PixelRepresentation')
line = line+'\r'
with open(dirMeta+'ROI_metadata.csv', 'a') as roifile:
    roifile.write(line)

allFiles = 0
for dicom in os.listdir(dirInput):
    if 'copy' not in dicom.lower():
        allFiles += 1

print_duration('{} dicoms to process'.format(allFiles))
count_files = 0
for dicom in os.listdir(dirInput):
    if 'copy' not in dicom.lower():
        if 'diastole' in dicom.lower():
            atlas_a = dia_a
            atlas_c = dia_c
            atlas_s = dia_s
            correctie_z = 15
        elif 'systole' in dicom.lower():
            atlas_a = sys_a
            atlas_c = sys_c
            atlas_s = sys_s
            correctie_z = 7
        else:
            sys.exit('ERROR - Diastole or systole could not be determined for '+dicom)
            
        scan, slices, ratio = readDicom(dirInput+dicom)
        # Resize 512x512x512 to 128x128x128 for faster object detection
        resampled, spacing = resample3d(scan, slices, voxelsize)
        # Object detection
        bestacc, best_x, best_y, best_z, atlas_a_w, atlas_a_h, atlas_s_w, atlas_s_h, atlas_c_w, atlas_c_h = find_object(resampled, 
                                                                                                                        atlas_a, 
                                                                                                                        atlas_c, 
                                                                                                                        atlas_s)

        # Approx. 90 seconds per DICOM
        create_roifile(dirInput, 
                       dirOutput, 
                       dicom, 
                       createAtlas, 
                       ratio, 
                       denoise, 
                       best_x, 
                       best_y, 
                       best_z, 
                       atlas_c_w, 
                       atlas_s_w, 
                       atlas_c_h,
                       correctie_z)
        count_files += 1
        print_duration('dicom {}/{} done'.format(count_files, allFiles))
        
print()
dt = datetime.datetime.fromtimestamp(time.time()).strftime('%Y-%m-%d %H:%M:%S')
print('End:',dt)