In [None]:
import os
import random
from datetime import datetime

now = datetime.now()

# Folder name to save registration transforms
prefix = "SRCtime"
numberOfProcesses = 14

# Use 5 fold, else leave one out
Use5Fold = False


######################################################################
resultsFolder = "results/" + prefix + now.strftime("%Y_%m_%d_%H_%M_%S")
print("Results folder name:", resultsFolder)

path = "./Dataset/data"
labelPath = "Dataset/labels"

data = os.listdir(path)
datasetLen = len(data)

# random shuffle data
# random.shuffle(data)


"""Write transforms."""
from pathlib import Path


try:
    Path(f"./{resultsFolder}").mkdir(parents=True, exist_ok=True)
    print("Saving data names.")
    file = open(f"./{resultsFolder}/data.txt", "w") 
    for i, name in enumerate(data):
        # print(f"\tSaving {name}. {i+1}/{len(trainData)}", end='\x1b[1K\r')
        # print(f"\tSaving {name}. {i+1}/{len(trainData)}")
        file.write(name+"\n")
    # print()
    file.close()
except:
    print("Can't save names.")


# Fork processes for each test image
# Doesn't work on jupyter notebook.
# Run it on python.
testImageIdx = 0
for i in range(1, numberOfProcesses):
# for i in range(1, len(data)):
# for i in range(1, 0): # DEBUG 
    n = os.fork() 
    if n == 0: 
        # child
        testImageIdx = i
        break
        
# fill trainData
if Use5Fold:    
    fold = testImageIdx // (len(data)//5)
    fold = 4 if fold > 4 else fold

    trainData = []
    for i in range(len(data)):
        index = i // (len(data)//5)
        index = 4 if index > 4 else index
        if index != fold:
            trainData.append(data[i])
else:
    # leave one out
    trainData = []
    for i in range(len(data)):
        if i != testImageIdx:
            trainData.append(data[i])
    

resultsFolder += f"/{testImageIdx:02d}"
try:
    Path(f"./{resultsFolder}").mkdir(parents=True, exist_ok=True)
except:
    print("Can't make the results directory.")


In [None]:
totalLabels = 2

UseCurvatureFlow = True
UseGaussian = False
sigma = 0.5
timeStep = 0.04
numberOfIterations = 10
# timeStep = 0.05
# numberOfIterations = 5

# Tests
lassoTols = [0.1]
atlasesNums = [9]
Ns = [[3, 3, 3]]
Ps = [[9, 9, 9]]

# lassoTols = [0.001]
# Ns = [[7, 7, 7], [5, 5, 5], [3, 3, 3]]
# Ps = [[3, 3, 3], [5, 5, 5], [7, 7, 7]]

inputData = []
for atlas in atlasesNums:
    for P in Ps:
        for N in Ns:
            for lassoTol in lassoTols:
                inputData.append((atlas, P, N, lassoTol))
                

In [None]:
import SimpleITK as sitk
import numpy as np
from utils import *
from time import time

totalStartTime = time()

# Set the number of threads
sitk.ImageRegistrationMethod().SetGlobalDefaultNumberOfThreads(1)

transforms = []
mseList = []

# find biggest P and N
P = np.argmax(Ps, axis=0)[0]
N = np.argmax(Ns, axis=0)[0]

# Read fixed image
pathName = os.path.join(path, data[testImageIdx])
if UseCurvatureFlow:
    fixedImage = sitk.CurvatureFlow(sitk.ReadImage(pathName), 
                                    timeStep = timeStep, 
                                    numberOfIterations = numberOfIterations)
elif UseGaussian:
    fixedImage = sitk.CurvatureFlow(sitk.SmoothingRecursiveGaussian(
                                        sitk.ReadImage(pathName), 
                                        (sigma, sigma, sigma), 
                                        True), 
                                    timeStep = timeStep, 
                                    numberOfIterations = numberOfIterations)
else:
    fixedImage = sitk.ReadImage(pathName)


# for i, f in enumerate(trainData[:1]): # DEBUG
for i, f in enumerate(trainData):
    # print(f'Registration of {f}. {i+1}/{len(trainData)}', end='\x1b[1K\r')
    print(f'Registration of {f}. {i+1}/{len(trainData)}')

    pathName = os.path.join(path, f)

    # Histogram matching
    if UseCurvatureFlow:
        image = sitk.HistogramMatching(sitk.CurvatureFlow(sitk.ReadImage(pathName),
                                                          timeStep = timeStep, 
                                                          numberOfIterations = numberOfIterations),
                                       fixedImage)
    elif UseGaussian:
        image = sitk.HistogramMatching(sitk.CurvatureFlow(
                                        sitk.SmoothingRecursiveGaussian(
                                            sitk.ReadImage(pathName), 
                                            (sigma, sigma, sigma), 
                                            True), 
                                        timeStep = timeStep, 
                                        numberOfIterations = numberOfIterations),
                                       fixedImage)
    else:
        image = sitk.HistogramMatching(sitk.ReadImage(pathName), fixedImage)

    label = sitk.ReadImage(os.path.join(labelPath, f))

    # t = registration(fixedImage, image, label)
    # t = registration2(fixedImage, image, label)
    # t = registrationElastix(fixedImage, image, resultsFolder)
    # t = registrationElastix2(fixedImage, image, resultsFolder)
    t = registrationElastixMask(fixedImage, image, label, resultsFolder)


    # add transform to the list
    transforms.append(t)

    # add mse to the list
    # mseList.append(mse3D(fixedImage, resampleImage(image, fixedImage, t)))
    mseList.append(mse3DLabels(fixedImage, 
                               resampleImage(image, fixedImage, t), 
                               resampleLabels(label, fixedImage, t)))
    
registrationEndTime = time()

try:
    timeFile = open(f"./{resultsFolder}/time.txt", "a") 
    timeFile.write("Registration: " + str(registrationEndTime - totalStartTime) + "\n")
    timeFile.close()
except:
    print("Can't append time.")

# sort mseList
mseListSorted = sorted(enumerate(mseList), key=lambda x: x[1])

###############################################################################
# For every test
###############################################################################
for testIndex, (atlasesNum, P, N, lassoTol) in enumerate(inputData):
# for testIndex, (atlasesNum, P, N, lassoTol) in enumerate(inputData[0:1]): # DEBUG
    startTestTime = time()
    
    # pick the desired number of atlases
    minx = miny = minz = 10000
    maxx = maxy = maxz = 0
    images = []
    labels = []

    # Read fixed image
    pathName = os.path.join(path, data[testImageIdx])
    if UseCurvatureFlow:
        fixedImage = sitk.CurvatureFlow(sitk.ReadImage(pathName), 
                                        timeStep = timeStep, 
                                        numberOfIterations = numberOfIterations)
    elif UseGaussian:
        fixedImage = sitk.CurvatureFlow(sitk.SmoothingRecursiveGaussian(
                                            sitk.ReadImage(pathName), 
                                            (sigma, sigma, sigma), 
                                            True),
                                        timeStep = timeStep, 
                                        numberOfIterations = numberOfIterations)
    else:
        fixedImage = sitk.ReadImage(pathName)

    ###############################################################################
    # Read atlases and labels and tranform them
    ###############################################################################
    for i in range(atlasesNum):
    # for i in range(1): # DEBUG
        idx = mseListSorted[i][0]
        t = transforms[idx]
        f = trainData[idx]
        pathName = os.path.join(path, f)
        # print(f'Reading {f}. {i+1}/{atlasesNum}', end='\x1b[1K\r')
        print(f'Reading {f}. {i+1}/{atlasesNum}')

        # Histogram matching
        if UseCurvatureFlow:
            images.append(sitk.GetArrayFromImage(resampleImage(
                        sitk.HistogramMatching(
                            sitk.CurvatureFlow(sitk.ReadImage(pathName), 
                                               timeStep = timeStep, 
                                               numberOfIterations = numberOfIterations),
                                            fixedImage), 
                          fixedImage, t)))
        elif UseGaussian:
            images.append(sitk.GetArrayFromImage(resampleImage(
                        sitk.HistogramMatching(
                            sitk.CurvatureFlow(sitk.SmoothingRecursiveGaussian(
                                                sitk.ReadImage(pathName), 
                                                (sigma, sigma, sigma), 
                                                True), 
                                               timeStep = timeStep, 
                                               numberOfIterations = numberOfIterations),
                                            fixedImage), 
                          fixedImage, t)))
        else:
            images.append(sitk.GetArrayFromImage(resampleImage(
                        sitk.HistogramMatching(
                            sitk.ReadImage(pathName),
                            fixedImage),  
                        fixedImage, t)))

        images[i][images[i] < 0] = 0
        images[i] = np.asarray(images[i], dtype=np.uint16)

        label = sitk.ReadImage(os.path.join(labelPath, f))
        label = sitk.GetArrayFromImage(resampleLabels(label, fixedImage, t))
        labels.append(label)

        idxs = np.nonzero(label != 0)
        minx = min([minx, min(idxs[0])])
        maxx = max([maxx, max(idxs[0])])
        miny = min([miny, min(idxs[1])])
        maxy = max([maxy, max(idxs[1])])
        minz = min([minz, min(idxs[2])])
        maxz = max([maxz, max(idxs[2])])


    ###############################################################################
    # Calculate the desired shape
    ###############################################################################
    shape, copyShape, offset, length = calculateCropShape(images, P, N, minx, maxx, miny, maxy, minz, maxz)

    ###############################################################################
    # crop images and labels
    ###############################################################################
    for i in range(len(images)):
        images[i] = cropImage(images[i], shape, offset, length, copyShape, "uint16")
        labels[i] = cropImage(labels[i], shape, offset, length, copyShape, "uint8")

    fixedImage = sitk.GetArrayFromImage(fixedImage)
    fixedImage = cropImage(fixedImage, shape, offset, length, copyShape, "uint16")


    ###############################################################################
    # Make to the proper type
    ###############################################################################
    # print("Fixing data types.", end='\x1b[1K\r')
    print("Fixing data types.")
    images = np.array(images, order='C')
    labels = np.array(labels, order='C')
    P = np.array(P, dtype=np.int32)
    N = np.array(N, dtype=np.int32)
    
    segmentationStart = time()
    
    try:
        timeFile = open(f"./{resultsFolder}/time.txt", "a") 
        timeFile.write(f"{testIndex}: Atlases: {atlasesNum}, P: {P}, N: {N}, lassoTol: {lassoTol}"+"\n")
        timeFile.write("Data preparation: " + str(segmentationStart - startTestTime) + "\n")
        timeFile.close()
    except:
        print("Can't append time.")

    ###############################################################################
    # Perform segmentation
    ###############################################################################
    from cGen import cGen
    print("Running segmentation.")
    segmentationSRC = cGen.applySRCSpams(fixedImage, images, 
                                  labels, totalLabels, P, N, lassoTol=lassoTol, 
                                  lassoMaxIter=1e4, verboseX=True,  verboseY=False, 
                                  xmin=-1, xmax=-1, ymin=-1, ymax=-1, zmin=-1, zmax=-1,
                                  numThreads=1, lassoL=-1)
    
    try:
        timeFile = open(f"./{resultsFolder}/time.txt", "a") 
        timeFile.write("Segmentation: " + str(time() - segmentationStart) + "\n")
        timeFile.close()
    except:
        print("Can't append time.")

    ###############################################################################
    # Save segmentation results
    ###############################################################################

    # print("Saving results.", end='\x1b[1K\r')
    print("Saving results.")
    try:
        saveSegmentation(segmentationSRC,
                         os.path.join(path, data[testImageIdx]), 
                         os.path.join(resultsFolder, data[testImageIdx]) + f"SRC_{testIndex}.mhd", 
                         copyShape, 
                         offset,
                         length,
                         verbose=False)
    except:
        print("Can't save SRC result.")


    ###############################################################################
    # Save dice index
    ###############################################################################

    originalLabel = sitk.GetArrayFromImage(sitk.ReadImage(os.path.join(labelPath, data[testImageIdx])))

    translatedSegmentationSRC = translateToOriginal(segmentationSRC, 
                                                    os.path.join(path, data[testImageIdx]), 
                                                    copyShape, 
                                                    offset, 
                                                    length)

    diceSRC = Dice(originalLabel, translatedSegmentationSRC, totalLabels)

    print(f"{testIndex}: Atlases: {atlasesNum}, P: {P}, N: {N}, lassoTol: {lassoTol}")
    print("SRC:", diceSRC)

    try:
        diceFile = open(f"./{resultsFolder}/dice.txt", "a") 
        diceFile.write(f"{testIndex}: Atlases: {atlasesNum}, P: {P}, N: {N}, lassoTol: {lassoTol}"+"\n")
        diceFile.write("SRC: " + str(diceSRC) + "\n")
        diceFile.close()
    except:
        print("Can't append dice results.")

try:
    timeFile = open(f"./{resultsFolder}/time.txt", "a") 
    timeFile.write("Total time: " + str(time() - totalStartTime) + "\n")
    timeFile.close()
except:
    print("Can't append time.")

In [None]:
####################################################
# END OF NORMAL CODE 
####################################################
import sys
sys.exit()