In [None]:
import os
import random
from datetime import datetime

now = datetime.now()

# Folder name to save registration transforms
prefix = ""

# Use 5 fold, else leave one out
Use5Fold = False


######################################################################
resultsFolder = "results/" + prefix + now.strftime("%Y_%m_%d_%H_%M_%S")
print("Results folder name:", resultsFolder)

path = "./Dataset/data"
labelPath = "Dataset/labels"

data = os.listdir(path)
datasetLen = len(data)

# random shuffle data
random.shuffle(data)


"""Write transforms."""
from pathlib import Path


try:
    Path(f"./{resultsFolder}").mkdir(parents=True, exist_ok=True)
    print("Saving data names.")
    file = open(f"./{resultsFolder}/data.txt", "w") 
    for i, name in enumerate(data):
        # print(f"\tSaving {name}. {i+1}/{len(trainData)}", end='\x1b[1K\r')
        # print(f"\tSaving {name}. {i+1}/{len(trainData)}")
        file.write(name+"\n")
    # print()
    file.close()
except:
    print("Can't save names.")


# Fork processes for each test image
# Doesn't work on jupyter notebook.
# Run it on python.
testImageIdx = 0
for i in range(1, len(data)):
# for i in range(1, 0): # DEBUG 
    n = os.fork() 
    if n == 0: 
        # child
        testImageIdx = i
        break
        
# fill trainData
if Use5Fold:    
    fold = testImageIdx // (len(data)//5)
    fold = 4 if fold > 4 else fold

    trainData = []
    for i in range(len(data)):
        index = i // (len(data)//5)
        index = 4 if index > 4 else index
        if index != fold:
            trainData.append(data[i])
else:
    # leave one out
    trainData = []
    for i in range(len(data)):
        if i != testImageIdx:
            trainData.append(data[i])
    

resultsFolder += f"/{testImageIdx:02d}"
try:
    Path(f"./{resultsFolder}").mkdir(parents=True, exist_ok=True)
except:
    print("Can't make the results directory.")


In [None]:
totalLabels = 2

UseCurvatureFlow = True
UseGaussian = False
sigma = 0.5
timeStep = 0.04
numberOfIterations = 10
# timeStep = 0.05
# numberOfIterations = 5

# Tests
lassoTols = [0.1, 0.01, 0.001, 0.0001]
atlasesNums = [4]
Ns = [[5, 5, 5]]
Ps = [[3, 3, 3]]

# lassoTols = [0.001]
# Ns = [[7, 7, 7], [5, 5, 5], [3, 3, 3]]
# Ps = [[3, 3, 3], [5, 5, 5], [7, 7, 7]]

inputData = []
for atlas in atlasesNums:
    for P in Ps:
        for N in Ns:
            for lassoTol in lassoTols:
                inputData.append((atlas, P, N, lassoTol))
                

In [None]:
import SimpleITK as sitk
import numpy as np
from utils import *

# Set the number of threads
sitk.ImageRegistrationMethod().SetGlobalDefaultNumberOfThreads(1)

transforms = []
mseList = []

# find biggest P and N
P = np.argmax(Ps, axis=0)[0]
N = np.argmax(Ns, axis=0)[0]

# Read fixed image
pathName = os.path.join(path, data[testImageIdx])
if UseCurvatureFlow:
    fixedImage = sitk.CurvatureFlow(sitk.ReadImage(pathName), 
                                    timeStep = timeStep, 
                                    numberOfIterations = numberOfIterations)
elif UseGaussian:
    fixedImage = sitk.CurvatureFlow(sitk.SmoothingRecursiveGaussian(
                                        sitk.ReadImage(pathName), 
                                        (sigma, sigma, sigma), 
                                        True), 
                                    timeStep = timeStep, 
                                    numberOfIterations = numberOfIterations)
else:
    fixedImage = sitk.ReadImage(pathName)


# for i, f in enumerate(trainData[:1]): # DEBUG
for i, f in enumerate(trainData):
    # print(f'Registration of {f}. {i+1}/{len(trainData)}', end='\x1b[1K\r')
    print(f'Registration of {f}. {i+1}/{len(trainData)}')

    pathName = os.path.join(path, f)

    # Histogram matching
    if UseCurvatureFlow:
        image = sitk.HistogramMatching(sitk.CurvatureFlow(sitk.ReadImage(pathName),
                                                          timeStep = timeStep, 
                                                          numberOfIterations = numberOfIterations),
                                       fixedImage)
    elif UseGaussian:
        image = sitk.HistogramMatching(sitk.CurvatureFlow(
                                        sitk.SmoothingRecursiveGaussian(
                                            sitk.ReadImage(pathName), 
                                            (sigma, sigma, sigma), 
                                            True), 
                                        timeStep = timeStep, 
                                        numberOfIterations = numberOfIterations),
                                       fixedImage)
    else:
        image = sitk.HistogramMatching(sitk.ReadImage(pathName), fixedImage)

    label = sitk.ReadImage(os.path.join(labelPath, f))

    # t = registration(fixedImage, image, label)
    # t = registration2(fixedImage, image, label)
    # t = registrationElastix(fixedImage, image, resultsFolder)
    # t = registrationElastix2(fixedImage, image, resultsFolder)
    t = registrationElastixMask(fixedImage, image, label, resultsFolder)


    # add transform to the list
    transforms.append(t)

    # add mse to the list
    # mseList.append(mse3D(fixedImage, resampleImage(image, fixedImage, t)))
    mseList.append(mse3DLabels(fixedImage, 
                               resampleImage(image, fixedImage, t), 
                               resampleLabels(label, fixedImage, t)))

# sort mseList
mseListSorted = sorted(enumerate(mseList), key=lambda x: x[1])

###############################################################################
# For every test
###############################################################################
for testIndex, (atlasesNum, P, N, lassoTol) in enumerate(inputData):
# for testIndex, (atlasesNum, P, N, lassoTol) in enumerate(inputData[0:1]): # DEBUG
    # pick the desired number of atlases
    minx = miny = minz = 10000
    maxx = maxy = maxz = 0
    images = []
    labels = []

    # Read fixed image
    pathName = os.path.join(path, data[testImageIdx])
    if UseCurvatureFlow:
        fixedImage = sitk.CurvatureFlow(sitk.ReadImage(pathName), 
                                        timeStep = timeStep, 
                                        numberOfIterations = numberOfIterations)
    elif UseGaussian:
        fixedImage = sitk.CurvatureFlow(sitk.SmoothingRecursiveGaussian(
                                            sitk.ReadImage(pathName), 
                                            (sigma, sigma, sigma), 
                                            True),
                                        timeStep = timeStep, 
                                        numberOfIterations = numberOfIterations)
    else:
        fixedImage = sitk.ReadImage(pathName)

    ###############################################################################
    # Read atlases and labels and tranform them
    ###############################################################################
    for i in range(atlasesNum):
    # for i in range(1): # DEBUG
        idx = mseListSorted[i][0]
        t = transforms[idx]
        f = trainData[idx]
        pathName = os.path.join(path, f)
        # print(f'Reading {f}. {i+1}/{atlasesNum}', end='\x1b[1K\r')
        print(f'Reading {f}. {i+1}/{atlasesNum}')

        # Histogram matching
        if UseCurvatureFlow:
            images.append(sitk.GetArrayFromImage(resampleImage(
                        sitk.HistogramMatching(
                            sitk.CurvatureFlow(sitk.ReadImage(pathName), 
                                               timeStep = timeStep, 
                                               numberOfIterations = numberOfIterations),
                                            fixedImage), 
                          fixedImage, t)))
        elif UseGaussian:
            images.append(sitk.GetArrayFromImage(resampleImage(
                        sitk.HistogramMatching(
                            sitk.CurvatureFlow(sitk.SmoothingRecursiveGaussian(
                                                sitk.ReadImage(pathName), 
                                                (sigma, sigma, sigma), 
                                                True), 
                                               timeStep = timeStep, 
                                               numberOfIterations = numberOfIterations),
                                            fixedImage), 
                          fixedImage, t)))
        else:
            images.append(sitk.GetArrayFromImage(resampleImage(
                        sitk.HistogramMatching(
                            sitk.ReadImage(pathName),
                            fixedImage),  
                        fixedImage, t)))

        images[i][images[i] < 0] = 0
        images[i] = np.asarray(images[i], dtype=np.uint16)

        label = sitk.ReadImage(os.path.join(labelPath, f))
        label = sitk.GetArrayFromImage(resampleLabels(label, fixedImage, t))
        labels.append(label)

        idxs = np.nonzero(label != 0)
        minx = min([minx, min(idxs[0])])
        maxx = max([maxx, max(idxs[0])])
        miny = min([miny, min(idxs[1])])
        maxy = max([maxy, max(idxs[1])])
        minz = min([minz, min(idxs[2])])
        maxz = max([maxz, max(idxs[2])])


    ###############################################################################
    # Calculate the desired shape
    ###############################################################################
    shape, copyShape, offset, length = calculateCropShape(images, P, N, minx, maxx, miny, maxy, minz, maxz)

    ###############################################################################
    # crop images and labels
    ###############################################################################
    for i in range(len(images)):
        images[i] = cropImage(images[i], shape, offset, length, copyShape, "uint16")
        labels[i] = cropImage(labels[i], shape, offset, length, copyShape, "uint8")

    fixedImage = sitk.GetArrayFromImage(fixedImage)
    fixedImage = cropImage(fixedImage, shape, offset, length, copyShape, "uint16")


    ###############################################################################
    # Make to the proper type
    ###############################################################################
    # print("Fixing data types.", end='\x1b[1K\r')
    print("Fixing data types.")
    images = np.array(images, order='C')
    labels = np.array(labels, order='C')
    P = np.array(P, dtype=np.int32)
    N = np.array(N, dtype=np.int32)

    ###############################################################################
    # Perform segmentation
    ###############################################################################
    from cGen import cGen
    print("Running segmentation.")
    segmentationSPBM, segmentationSRC = cGen.applySPBMandSRCSpams(fixedImage, images, 
                                  labels, totalLabels, P, N, lassoTol=lassoTol, 
                                  lassoMaxIter=1e4, verboseX=True,  verboseY=False, 
                                  xmin=-1, xmax=-1, ymin=-1, ymax=-1, zmin=-1, zmax=-1,
                                  # xmin=10, xmax=11, ymin=10, ymax=11, zmin=10, zmax=11,
                                  numThreads=1, lassoL=-1)

    ###############################################################################
    # Save segmentation results
    ###############################################################################

    # print("Saving results.", end='\x1b[1K\r')
    print("Saving results.")
    try:
        saveSegmentation(segmentationSPBM,
                         os.path.join(path, data[testImageIdx]), 
                         os.path.join(resultsFolder, data[testImageIdx]) + f"SPBM_{testIndex}.mhd", 
                         copyShape, 
                         offset,
                         length,
                         verbose=False)
    except:
        print("Can't save SPBM result.")

    try:
        saveSegmentation(segmentationSRC,
                         os.path.join(path, data[testImageIdx]), 
                         os.path.join(resultsFolder, data[testImageIdx]) + f"SRC_{testIndex}.mhd", 
                         copyShape, 
                         offset,
                         length,
                         verbose=False)
    except:
        print("Can't save SRC result.")


    ###############################################################################
    # Save dice index
    ###############################################################################

    originalLabel = sitk.GetArrayFromImage(sitk.ReadImage(os.path.join(labelPath, data[testImageIdx])))

    translatedSegmentationSPBM = translateToOriginal(segmentationSPBM, 
                                                     os.path.join(path, data[testImageIdx]), 
                                                     copyShape, 
                                                     offset, 
                                                     length)

    diceSPBM = Dice(originalLabel, translatedSegmentationSPBM, totalLabels)

    translatedSegmentationSRC = translateToOriginal(segmentationSRC, 
                                                    os.path.join(path, data[testImageIdx]), 
                                                    copyShape, 
                                                    offset, 
                                                    length)

    diceSRC = Dice(originalLabel, translatedSegmentationSRC, totalLabels)

    print(f"{testIndex}: Atlases: {atlasesNum}, P: {P}, N: {N}, lassoTol: {lassoTol}")
    print("SPBM:", diceSPBM)
    print("SRC:", diceSRC)

    try:
        diceFile = open(f"./{resultsFolder}/dice.txt", "a") 
        diceFile.write(f"{testIndex}: Atlases: {atlasesNum}, P: {P}, N: {N}, lassoTol: {lassoTol}"+"\n")
        diceFile.write("SPBM: " + str(diceSPBM) + "\n")
        diceFile.write("SRC: " + str(diceSRC) + "\n")
        diceFile.close()
    except:
        print("Can't append dice results.")

In [None]:
####################################################
# END OF NORMAL CODE 
####################################################
import sys
sys.exit()

In [None]:

pathName = os.path.join(path, "9036287.nii.gz")
image = sitk.ReadImage(pathName)
print(image.GetPixelID())
test = sitk.GetArrayFromImage(image)
test = test.astype(np.uint16)

newImage = sitk.GetImageFromArray(test)
newImage.CopyInformation(image)
print(newImage.GetPixelID())
pathName = os.path.join(path,  "9036287.nii.gz")
# sitk.WriteImage(newImage, pathName, True)

In [None]:
import numpy as np

In [None]:
sitk.GetArrayFromImage(image)

In [None]:

(((sitk.GetArrayFromImage(image) - sitk.GetArrayFromImage(fixedImage))[sitk.GetArrayFromImage(label) > 0])**2).mean()

In [None]:
((sitk.GetArrayFromImage(image) - sitk.GetArrayFromImage(fixedImage))**2).mean()


In [None]:
mask = sitk.GetArrayFromImage(label)
mask[mask > 1] = 1
print(((mask * (sitk.GetArrayFromImage(image) - sitk.GetArrayFromImage(fixedImage)))**2).sum())
((mask * sitk.GetArrayFromImage(image) - mask * sitk.GetArrayFromImage(fixedImage))**2).sum()

In [None]:

for testIndex, (atlasesNum, P, N, lassoTol) in enumerate(inputData):
    print(testIndex, atlasesNum, P, N, lassoTol)

In [None]:
from PIL import Image
import numpy as np

def showImg_(img, z=60):
    a = sitk.GetArrayFromImage(img[z,:,:])
    disImg = Image.fromarray(np.interp(a, (a.min(), a.max()), (0, 255)).astype('uint8'))
    disImg.show()
    
def showImg2_(img, z=60):
    disImg = Image.fromarray(sitk.GetArrayFromImage(img[z,:,:]).astype('uint8'))
    disImg.show()
    
def showImg3_(img, z=60):
    a = sitk.GetArrayFromImage(img[z,:,:])
    disImg = image.fromarray(np.interp(a, (0, 5000), (0, 255)).astype('uint8'))
    disImg.show()
    
def showImg4_(img, z=60):
    img2 = Image.fromarray(sitk.GetArrayFromImage(img[z,:,:])*60, 'L')
    img2.show()

In [None]:
labels[0].shape

In [None]:
showImg2_(sitk.GetImageFromArray(fixedImage), 20)
showImg2_(sitk.GetImageFromArray(images[0]), 20)
showImg_(sitk.GetImageFromArray(labels[0]), 20)
# showImg_(sitk.GetImageFromArray(segmentationSPBM), 20)

In [None]:
showImg2_(sitk.GetImageFromArray(fixedImage), 60)
showImg2_(sitk.GetImageFromArray(images[0]), 60)
showImg_(sitk.GetImageFromArray(segmentationSPBM), 60)
# showImg2_(images[1], 60)

In [None]:
segmentationSPBM[segmentationSPBM > 0]

In [None]:
l = sitk.GetArrayFromImage(label)
l[l == 4] = 0
mask = sitk.GetImageFromArray(l)

showImg4_(mask, 100)


In [None]:
sitk.ImageRegistrationMethod().SetGlobalDefaultNumberOfThreads()

In [None]:
help(sitk.ElastixImageFilter())
# sitk.ElastixImageFilter().GetOutputDirectory()

In [None]:
help(elastixImageFilter)

In [None]:
import SimpleITK as sitk
import os
from utils import *

path = "./Dataset/data"
labelPath = "Dataset/labels"

data = os.listdir(path)
datasetLen = len(data)

# spilt dataset
trainData = data[:int(datasetLen * 0.7)]
testData = data[int(datasetLen * 0.7):]

"""Write transforms."""
from pathlib import Path

timeStep = 0.04
numberOfIterations = 10

pathName = os.path.join(path, testData[0])
# fixedImage = sitk.ReadImage(pathName)

# fixedImage = sitk.CurvatureFlow(sitk.ReadImage(pathName), timeStep = 0.04, numberOfIterations = 10)
fixedImage = sitk.CurvatureFlow(sitk.ReadImage(pathName), timeStep = timeStep,
                                numberOfIterations = numberOfIterations)


for i, f in enumerate(trainData[8:9]): # for test
# for i, f in enumerate(trainData):
    # print(f'Registration of {f}. {i+1}/{len(trainData)}', end='\x1b[1K\r')
    print(f'Registration of {f}. {i+1}/{len(trainData)}')
    
    pathName = os.path.join(path, f)
                  
    # image = sitk.ReadImage(pathName)
    
    # image = sitk.HistogramMatching(sitk.ReadImage(pathName), fixedImage)
    image = sitk.HistogramMatching(sitk.CurvatureFlow(sitk.ReadImage(pathName), 
                                                      # timeStep = 0.04, 
                                                      # numberOfIterations = 10),
                                                      timeStep = timeStep, 
                                                      numberOfIterations = numberOfIterations),
                                   fixedImage)
    
    label = sitk.ReadImage(os.path.join(labelPath, f))
    
    # t = registrationElastixMask(fixedImage, image, label)
    # t2 = registration(fixedImage, image, label)

    resample = sitk.ResampleImageFilter()
    resample.SetReferenceImage(fixedImage)

    resample.SetInterpolator(sitk.sitkLinear)
    # resample.SetTransform(t)
    resample.SetSize()

    newImage = resample.Execute(image)


In [None]:

resample = sitk.ResampleImageFilter()
resample.SetReferenceImage(fixedImage)

resample.SetInterpolator(sitk.sitkLinear)
# resample.SetTransform(t)
resample.SetSize([i // 2 for i in list(image.GetSize())])
resample.SetOutputSpacing([i * 2 for i in list(image.GetSpacing())])

newImage = resample.Execute(image)

In [None]:
resample = sitk.ResampleImageFilter()
resample.SetReferenceImage(newImage)

resample.SetInterpolator(sitk.sitkLinear)
# resample.SetTransform(t)
resample.SetSize([i * 2 for i in list(newImage.GetSize())])
resample.SetOutputSpacing([i / 2 for i in list(newImage.GetSpacing())])

image2 = resample.Execute(image)

In [None]:
shrinkFilter = sitk.ShrinkImageFilter()
shrinkFilter.SetShrinkFactor(1)
newImage2 = shrinkFilter.Execute(image)


expandFilter = sitk.ExpandImageFilter()
expandFilter.SetExpandFactor(1)
expandFilter.SetInterpolator(sitk.sitkLinear)
image2 = expandFilter.Execute(newImage2)


In [None]:
print(newImage2.GetSize())

showImg_(newImage2, 30)

In [None]:
size = [i // 2 for i in list(image.GetSize())]
'''
for i, val in enumerate(image.GetSize()):
    print(i, val)
    size[i] = val // 2
'''   
print(size)

In [None]:
showImg_(image)
showImg_(image2)

In [None]:
fixedImage.GetPixelID()

In [None]:
import SimpleITK as sitk
help(image)

In [None]:
help(sitk.GetDefaultParameterMap("affine"))
elastixImageFilter.PrintParameterMap()

In [None]:
import SimpleITK as sitk
import os
from utils import *

path = "./Dataset/data"
labelPath = "Dataset/labels"

data = os.listdir(path)
datasetLen = len(data)

# spilt dataset
trainData = data[:int(datasetLen * 0.7)]
testData = data[int(datasetLen * 0.7):]

"""Write transforms."""
from pathlib import Path

timeStep = 0.04
numberOfIterations = 10

pathName = os.path.join(path, testData[0])
# fixedImage = sitk.ReadImage(pathName)

# fixedImage = sitk.CurvatureFlow(sitk.ReadImage(pathName), timeStep = 0.04, numberOfIterations = 10)
fixedImage = sitk.CurvatureFlow(sitk.ReadImage(pathName), timeStep = timeStep,
                                numberOfIterations = numberOfIterations)


for i, f in enumerate(trainData[8:9]): # for test
# for i, f in enumerate(trainData):
    # print(f'Registration of {f}. {i+1}/{len(trainData)}', end='\x1b[1K\r')
    print(f'Registration of {f}. {i+1}/{len(trainData)}')
    
    pathName = os.path.join(path, f)
                  
    # image = sitk.ReadImage(pathName)
    
    # image = sitk.HistogramMatching(sitk.ReadImage(pathName), fixedImage)
    image = sitk.HistogramMatching(sitk.CurvatureFlow(sitk.ReadImage(pathName), 
                                                      # timeStep = 0.04, 
                                                      # numberOfIterations = 10),
                                                      timeStep = timeStep, 
                                                      numberOfIterations = numberOfIterations),
                                   fixedImage)
    
    label = sitk.ReadImage(os.path.join(labelPath, f))
    
    t = registrationElastixMask(fixedImage, image, label)
    t2 = registration(fixedImage, image, label)

    showImg_(resampleImage(image, fixedImage, t))
    showImg_(resampleImage(image, fixedImage, t2))
    showImg_(fixedImage)

'''
# showImg_(fixedImage)
# showImg_(image)

# label = sitk.ReadImage(os.path.join(labelPath, f))
# Create the mask
l = sitk.GetArrayFromImage(label)
l[l>1] = 1
mask = sitk.GetImageFromArray(l)
mask.CopyInformation(label)
    
elastixImageFilter = sitk.ElastixImageFilter()
elastixImageFilter.SetFixedImage(fixedImage)
elastixImageFilter.SetMovingImage(image)
elastixImageFilter.AddMovingMask(mask)
elastixImageFilter.SetParameterMap(sitk.GetDefaultParameterMap("affine"))
elastixImageFilter.SetParameter("Metric", "AdvancedMeanSquares")

# elastixImageFilter.SetParameter("MaximumNumberOfIterations", "8192")
elastixImageFilter.SetParameter("MaximumNumberOfIterations", "2048")

elastixImageFilter.SetParameter("MaximumNumberOfSamplingAttempts", "32")
elastixImageFilter.SetParameter("NumberOfSamplesForExactGradient", "8192")
elastixImageFilter.SetParameter("NumberOfSpatialSamples", "8192")
elastixImageFilter.SetParameter("SamplingPercentage", "0.8")

elastixImageFilter.SetParameter("NumberOfJacobianMeasurements", "10000")
# elastixImageFilter.SetParameter("SelfHessianSmoothingSigma", "0")
elastixImageFilter.SetParameter("UseNormalization", "true")

elastixImageFilter.SetParameter("AutomaticTransformInitialization", "true")
elastixImageFilter.SetParameter("ErodeMovingMask", "false")
elastixImageFilter.SetParameter("CheckNumberOfSamples", "false")
elastixImageFilter.SetParameter("WriteResultImage", "false")
elastixImageFilter.LogToFileOff()

# elastixImageFilter.SetOutputDirectory("AHE")
# elastixImageFilter.SetParameter("RandomSparseMask", "true")

"""
elastixImageFilter.SetParameter("FixedImagePyramid", "FixedRecursiveImagePyramid")
elastixImageFilter.SetParameter("MovingImagePyramid", "MovingRecursiveImagePyramid")
elastixImageFilter.SetParameter("ImagePyramidSchedule", "8 8 8  4 4 4  2 2 2  1 1 1")
(FixedImagePyramid "FixedRecursiveImagePyramid")
(MovingImagePyramid "MovingRecursiveImagePyramid")
(ImagePyramidSchedule 8 8 8  4 4 4  2 2 2  1 1 1)
"""

# elastixImageFilter.SetParameter("NumberOfBandStructureSamples", "100")
# elastixImageFilter.SetParameter("ImagePyramidSchedule", "[[1,1,1],[1,1,1],[1,1,1],[1,1,1]]")
# elastixImageFilter.SetParameter("ImagePyramidSchedule", "1")
# elastixImageFilter.SetParameter("AutomaticTransformInitializationMethod", "CenterOfGravity")
# elastixImageFilter.SetParameter("ImageSampler", "RandomSparseMask")
# elastixImageFilter.SetParameter("AutomaticScalesEstimation", "false")

elastixImageFilter.Execute()


showImg_(elastixImageFilter.GetResultImage())
showImg_(fixedImage)

'''


In [None]:
# resultImage2 = elastixImageFilter.GetResultImage()
transformParameterMap = elastixImageFilter.GetTransformParameterMap()

In [None]:
showImg_(elastixImageFilter.GetResultImage())

In [None]:
print(transformParameterMap[0].asdict())

In [None]:
af = sitk.AffineTransform(3)
# af.SetMatrix([float(i) for i in transformParameterMap[0].asdict()['TransformParameters']])
# af.SetTranslation([float(i) for i in transformParameterMap[0].asdict()['TransformParameters']])
test = [float(i) for i in transformParameterMap[0].asdict()['TransformParameters']]
# test2 = [[i for i in test[0:3]], [i for i in test[4:7]], [i for i in test[8:11]]]
af.SetTranslation(test[9:])
af.SetMatrix(test[0:9])
test = [float(i) for i in transformParameterMap[0].asdict()['CenterOfRotationPoint']]
af.SetCenter(test)

resample = sitk.ResampleImageFilter()
resample.SetReferenceImage(fixedImage)
resample.SetInterpolator(sitk.sitkNearestNeighbor)
resample.SetTransform(af)

testL = resample.Execute(label)

# from utils import *
# testL = resampleLabels(label, fixedImage, af)

resample = sitk.ResampleImageFilter()
resample.SetReferenceImage(fixedImage)
resample.SetInterpolator(sitk.sitkLinear)
resample.SetTransform(af)
testI = resample.Execute(image)

print(resample.GetTransform())

In [None]:
# help(sitk.ResampleImageFilter())
print(af.GetMatrix())
print(af.GetCenter())
print(af.GetTranslation())
help(af)

In [None]:
showImg_(elastixImageFilter.GetResultImage())
# showImg_(image)
showImg_(fixedImage)
# showImg_(testL)
# showImg_(testI)
# showImg_(label)
# TODO: maybe deformation fiead????

In [None]:
sitk.GetArrayFromImage(testL)

In [None]:
transformixImageFilter = sitk.TransformixImageFilter()
transformixImageFilter.SetTransformParameter("UseBinaryFormatForTransformationParameters", "true")
transformixImageFilter.SetTransformParameterMap(elastixImageFilter.GetTransformParameterMap())
transformixImageFilter.SetMovingImage(label)
transformixImageFilter.AddTransformParameter("Interpolator", "NearestNeighborInterpolator")
transformixImageFilter.SetTransformParameter("ResultImagePixelType", "double")
transformixImageFilter.SetTransformParameter("MovingInternalImagePixelType", "double")
transformixImageFilter.SetTransformParameter("FixedInternalImagePixelType", "double")
# transformixImageFilter.SetTransformParameter("FixedInternalImagePixelType", "sitk.sitkInt8")
# transformixImageFilter.AddTransformParameter("ResampleInterpolator", "NearestNeighborInterpolator")
lel = transformixImageFilter.Execute()
showImg_(lel)

In [None]:
# Transform label map using the deformation field from above
resultLabel = sitk.Cast(sitk.Transformix(label, elastixImageFilter.GetTransformParameterMap()), sitk.sitkInt8)
showImg_(resultLabel)
np.unique(sitk.GetArrayFromImage(resultLabel))

In [None]:
sitk.sitkInt8

In [None]:
np.unique(sitk.GetArrayFromImage(lel))

In [None]:
transformixImageFilter.PrintParameterMap()

In [None]:
print(transformixImageFilter.GetComputeSpatialJacobian())
print(transformixImageFilter.GetComputeDeterminantOfSpatialJacobian())
print(transformixImageFilter.GetComputeDeformationField())

In [None]:
help(sitk.Transformix)

In [None]:
help(sitk.TransformixImageFilter())

In [None]:
help(sitk.ElastixImageFilter())

In [None]:
test = sitk.ElastixImageFilter()
test.AddParameterMap(transformParameterMap)

In [None]:
sitk.PrintParameterMap(elastixImageFilter.GetParameterMap())

In [None]:
help(sitk.ElastixImageFilter)

In [None]:
print(transformParameterMap)

In [None]:
showImg2_(fixedImage, 60)
showImg2_(images[0], 60)
showImg2_(resultImage, 60)

In [None]:
showImg_(fixedImage, 60)
showImg_(images[1], 60)
showImg_(resultImage, 60)

In [None]:
help(sitk.GetImageFromArray)

In [None]:

labels = np.array(labels, order='C', dtype=np.uint8)

segmentationSPBM, segmentationSRC = cGen.applySPBMandSRCSpams(fixedImage, images, 
                              labels, 3, P, N, lassoTol=0.00001, 
                              lassoMaxIter=1e4, verboseX=True,  verboseY=False, 
                              xmin=-1, xmax=-1, ymin=-1, ymax=-1, zmin=-1, zmax=-1)

In [None]:
segmentationSPBM, segmentationSRC = cGen.applySPBMandSRCSpams(fixedImage,
                                                         images, 
                              np.ones(labels.shape, dtype=np.uint8, order='C')+50, 3, P, N, lassoTol=0.00001, 
                              lassoMaxIter=1e4, verboseX=True,  verboseY=False, 
                              xmin=-1, xmax=-1, ymin=-1, ymax=-1, zmin=-1, zmax=-1)

In [None]:
print(fixedImage.shape)
print(fixedImage.dtype)
print(type(fixedImage))
print(images.shape)
print(images.dtype)
print(type(images))
print(labels.shape)
print(labels.dtype)
print(type(labels))

In [None]:
from cGen import cGen
import numpy as np
import segmentation


P = [5, 5, 5]
N = [5, 5, 5]
P = np.array(P, dtype=np.int32)
N = np.array(N, dtype=np.int32)

ii = np.array(np.random.randint(0, high=500, size=(390, 276, 129)), dtype=np.uint16, order='C')
i = np.array(np.random.randint(0, high=500, size=(5, 390, 276, 129)), dtype=np.uint16, order='C')
iii = np.array(np.random.randint(0, high=3, size=(5, 390, 276, 129)), dtype=np.uint8, order='C')
v = [50,50,50]
v = np.array(v, dtype=np.int32)

a = cGen.createA(i, v, P, N)
aa = segmentation.createA(i, v, P, N)

b = cGen.createB(ii, v, P)
bb = segmentation.createB(ii, v, P)
# print(a)
# print(aa)
l = cGen.createL(iii, v, N)
ll = segmentation.createL(iii, v, N)

for x in range(len(l)):
    if l[x] != ll[x]:
        print("l", x)

for x in range(len(b)):
    if b[x] != bb[x]:
        print("b", x)

for x in range(len(a)):
    for y in range(len(a[x])):
        if a[x,y] != aa[x,y]:
            print(x,y)

In [None]:
import os
# Set environment variables
os.environ['OPENBLAS_NUM_THREADS'] = '1'
#export OPENBLAS_NUM_THREADS=1

from cGen import cGen
import numpy as np

P = [7, 7, 7]
N = [7, 7, 7]
P = [3, 3, 3]
N = [5, 5, 5]
P = np.array(P, dtype=np.int32)
N = np.array(N, dtype=np.int32)

ii = np.array(np.random.randint(0, high=500, size=(390, 276, 129)), dtype=np.uint16, order='C')
i = np.array(np.random.randint(0, high=500, size=(5, 390, 276, 129)), dtype=np.uint16, order='C')
l = np.array(np.random.randint(0, high=3, size=(5, 390, 276, 129)), dtype=np.uint8, order='C')
# segmentationSPBM, segmentationSRC = cGen.applySPBMandSRCSpams(ii,
# '''
segmentationSPBM, segmentationSRC = cGen.applySPBMandSRCSpams(ii,
                                                              i, 
                                                              l, 3, P, N, lassoTol=0.1, 
                              lassoMaxIter=1e0, verboseX=True,  verboseY=False, 
                              xmin=-1, xmax=-1, ymin=-1, ymax=-1, zmin=60, zmax=61,
                              # xmin=-1, xmax=-1, ymin=60, ymax=61, zmin=60, zmax=61,
                              numThreads=1, lassoL=1)
# '''
'''
segmentation = cGen.applySPEP(ii, i,  l, 3, P, N, verboseX=True,  verboseY=True, 
                              xmin=-1, xmax=-1, ymin=-1, ymax=-1, zmin=-1, zmax=-1,)
'''
'''
segmentation = cGen.applySPBMandSRC(ii,
                              i, 
                              l, 3, P, N, lassoTol=0.1, 
                              lassoMaxIter=1e0, verboseX=True,  verboseY=False, 
                              xmin=-1, xmax=-1, ymin=-1, ymax=-1, zmin=60, zmax=61,
                              # xmin=-1, xmax=-1, ymin=60, ymax=61, zmin=60, zmax=61,
                              )
'''
'''
segmentation = cGen.applyMV(ii,
                              i, 
                              l, 3, verboseX=True,  verboseY=False, 
                              xmin=-1, xmax=-1, ymin=-1, ymax=-1, zmin=-1, zmax=-1,
                              # xmin=-1, xmax=-1, ymin=60, ymax=61, zmin=60, zmax=61,
                              )
'''

In [None]:
from utils import Dice
# l = np.array(np.random.randint(0, high=3, size=(5, 390, 276, 129)), dtype=np.uint8, order='C')
i = np.array(np.random.randint(0, high=3, size=(390, 276, 129)), dtype=np.uint8, order='C')
j = np.array(np.random.randint(0, high=3, size=(390, 276, 129)), dtype=np.uint8, order='C')
Dice(np.array(i), np.array(j), 2)

In [None]:
np.sum(l[l>0])
l

In [None]:
import spams
import numpy as np

A = np.zeros(shape=(20, 10), dtype=np.single, order='F')
B = np.zeros(shape=(20, 1), dtype=np.single, order='F')
# B = np.zeros(shape=(20, 1), dtype=np.single, order='F')


alpha = spams.lasso(B, A, 
               return_reg_path = False, 
               lambda1 = 0.1, 
               lambda2 = 0.,
               pos = True,
               mode = 2,
               numThreads = -1,
              )


In [None]:
type(alpha.A[:,0])

In [None]:
alpha.toarray()[:,0]

In [None]:
import SimpleITK as sitk
import numpy as np

image = sitk.ReadImage("./Dataset/data/9001104.nii.gz")
imageSmooth = sitk.SmoothingRecursiveGaussian(image, (0.5, 0.5, 0.5), True)
imageCurv = sitk.CurvatureFlow(imageSmooth, 0.05, 5)
# imageCurv2 = sitk.CurvatureFlow(image, 0.04, 10)

# showImg2_(image, 50)
showImg2_(imageSmooth, 50)
showImg2_(imageCurv, 50)
# showImg2_(imageCurv2, 50)


In [None]:

sitk.WriteImage(sitk.GetImageFromArray(fixedImage), "fixedImage.mhd")
sitk.WriteImage(sitk.GetImageFromArray(images[0]), "movingImage.mhd")
sitk.WriteImage(sitk.GetImageFromArray(labels[0]), "movinglabel.mhd")


In [None]:
import SimpleITK as sitk
import numpy as np

fixedImage = sitk.GetArrayFromImage(sitk.ReadImage("fixedImage.mhd"))
image = sitk.GetArrayFromImage(sitk.ReadImage("movingImage.mhd"))
label = sitk.GetArrayFromImage(sitk.ReadImage("movinglabel.mhd"))

fixedImage = np.array(fixedImage, order='C')
images = np.array([image], order='C')
labels = np.array([label], order='C')

In [None]:
print(fixedImage.shape)
print(images.shape)
print(labels.shape)

In [None]:
from cGen import cGen
import numpy as np

P = [3, 3, 3]
N = [5, 5, 5]
P = np.array(P, dtype=np.int32)
N = np.array(N, dtype=np.int32)

segmentationSPBM, segmentationSRC = cGen.applySPBMandSRCSpams(fixedImage, images, 
# segmentationSPBM, segmentationSRC = cGen.applySPBMandSRC(fixedImage, images, 
                              labels, 3, P, N, lassoTol=0.001, 
                              lassoMaxIter=1e4, verboseX=True,  verboseY=False, 
                              xmin=-1, xmax=-1, ymin=-1, ymax=-1, zmin=60, zmax=61)
                              # xmin=60, xmax=61, ymin=-1, ymax=-1, zmin=-1, zmax=-1)
                              # xmin=10, xmax=11, ymin=10, ymax=11, zmin=10, zmax=11)

In [None]:
showImg_(sitk.GetImageFromArray(segmentationSPBM), 60)
showImg_(sitk.GetImageFromArray(segmentationSRC), 60)
# showImg_(sitk.GetImageFromArray(originalLabel), 60)
showImg_(sitk.GetImageFromArray(labels[0]), 60)

In [None]:
segmentationSRC[segmentationSRC == 3]

In [None]:

originalLabel = sitk.GetArrayFromImage(sitk.ReadImage(os.path.join(labelPath, testData[0])))
translatedSegmentationSPBM = translateToOriginal(segmentationSPBM, 
                                                 os.path.join(path, testData[0]), 
                                                 copyShape, 
                                                 offset, 
                                                 length)

print(Dice(originalLabel, translatedSegmentationSPBM, 3))

translatedSegmentationSRC = translateToOriginal(segmentationSRC, 
                                                os.path.join(path, testData[0]), 
                                                copyShape, 
                                                offset, 
                                                length)

print(Dice(originalLabel, translatedSegmentationSRC, 3))