In [None]:
import os
import random
from datetime import datetime

now = datetime.now()

# Folder name to save registration transforms
prefix = "MV"
resultsFolder = "results/" + prefix + now.strftime("%Y_%m_%d_%H_%M_%S")
print("Results folder name:", resultsFolder)

path = "./Dataset/data"
labelPath = "Dataset/labels"

data = os.listdir(path)
datasetLen = len(data)

# random shuffle data
random.shuffle(data)


"""Write transforms."""
from pathlib import Path


try:
    Path(f"./{resultsFolder}").mkdir(parents=True, exist_ok=True)
    print("Saving data names.")
    file = open(f"./{resultsFolder}/data.txt", "w") 
    for i, name in enumerate(data):
        # print(f"\tSaving {name}. {i+1}/{len(trainData)}", end='\x1b[1K\r')
        # print(f"\tSaving {name}. {i+1}/{len(trainData)}")
        file.write(name+"\n")
    # print()
    file.close()
except:
    print("Can't save names.")


# Fork processes for each test image
# Doesn't work on jupyter notebook.
# Run it on python.
testImageIdx = 0
for i in range(1, len(data)):
# for i in range(1, 0): # DEBUG 
    n = os.fork() 
    if n == 0: 
        # child
        testImageIdx = i
        break
        
# fill trainData according to the fold
    
fold = testImageIdx // (len(data)//5)
fold = 4 if fold > 4 else fold

trainData = []
for i in range(len(data)):
    index = i // (len(data)//5)
    index = 4 if index > 4 else index
    if index != fold:
        trainData.append(data[i])

resultsFolder += f"/{testImageIdx:02d}"
try:
    Path(f"./{resultsFolder}").mkdir(parents=True, exist_ok=True)
except:
    print("Can't make the results directory.")


In [None]:
UseCurvatureFlow = True
timeStep = 0.04
numberOfIterations = 10
totalLabels = 2

# Tests
atlasesNums = [5]
Ps = [[1,1,1]]
Ns = [[1,1,1]]

inputData = []
for atlas in atlasesNums:
    for P in Ps:
        for N in Ns:
            inputData.append((atlas, P, N))
                

In [None]:
import SimpleITK as sitk
import numpy as np
from utils import *

# Set the number of threads
sitk.ImageRegistrationMethod().SetGlobalDefaultNumberOfThreads(1)

transforms = []
mseList = []

# find biggest P and N
P = np.argmax(Ps, axis=0)[0]
N = np.argmax(Ns, axis=0)[0]

# Read fixed image
pathName = os.path.join(path, data[testImageIdx])
if UseCurvatureFlow:
    fixedImage = sitk.CurvatureFlow(sitk.ReadImage(pathName), 
                                    timeStep = timeStep, 
                                    numberOfIterations = numberOfIterations)
else:
    fixedImage = sitk.ReadImage(pathName)


# for i, f in enumerate(trainData[:1]): # DEBUG
for i, f in enumerate(trainData):
    # print(f'Registration of {f}. {i+1}/{len(trainData)}', end='\x1b[1K\r')
    print(f'Registration of {f}. {i+1}/{len(trainData)}')

    pathName = os.path.join(path, f)

    # Histogram matching
    if UseCurvatureFlow:
        image = sitk.HistogramMatching(sitk.CurvatureFlow(sitk.ReadImage(pathName),
                                                          timeStep = timeStep, 
                                                          numberOfIterations = numberOfIterations),
                                       fixedImage)
    else:
        image = sitk.HistogramMatching(sitk.ReadImage(pathName), fixedImage)

    label = sitk.ReadImage(os.path.join(labelPath, f))

    # t = registration(fixedImage, image, label)
    # t = registration2(fixedImage, image, label)
    # t = registrationElastix(fixedImage, image, resultsFolder)
    # t = registrationElastix2(fixedImage, image, resultsFolder)
    t = registrationElastixMask(fixedImage, image, label, resultsFolder)


    # add transform to the list
    transforms.append(t)

    # add mse to the list
    # mseList.append(mse3D(fixedImage, resampleImage(image, fixedImage, t)))
    mseList.append(mse3DLabels(fixedImage, 
                               resampleImage(image, fixedImage, t), 
                               resampleLabels(label, fixedImage, t)))

# sort mseList
mseListSorted = sorted(enumerate(mseList), key=lambda x: x[1])

###############################################################################
# For every test
###############################################################################
for testIndex, (atlasesNum, P, N) in enumerate(inputData):
# for testIndex, (atlasesNum, P, N, lassoTol) in enumerate(inputData[0:1]): # DEBUG
    # pick the desired number of atlases
    minx = miny = minz = 10000
    maxx = maxy = maxz = 0
    images = []
    labels = []

    # Read fixed image
    pathName = os.path.join(path, data[testImageIdx])
    if UseCurvatureFlow:
        fixedImage = sitk.CurvatureFlow(sitk.ReadImage(pathName), 
                                        timeStep = timeStep, 
                                        numberOfIterations = numberOfIterations)
    else:
        fixedImage = sitk.ReadImage(pathName)

    ###############################################################################
    # Read atlases and labels and tranform them
    ###############################################################################
    for i in range(atlasesNum):
    # for i in range(1): # DEBUG
        idx = mseListSorted[i][0]
        t = transforms[idx]
        f = trainData[idx]
        pathName = os.path.join(path, f)
        # print(f'Reading {f}. {i+1}/{atlasesNum}', end='\x1b[1K\r')
        print(f'Reading {f}. {i+1}/{atlasesNum}')

        # Histogram matching
        if UseCurvatureFlow:
            images.append(sitk.GetArrayFromImage(resampleImage(
                        sitk.HistogramMatching(
                            sitk.CurvatureFlow(sitk.ReadImage(pathName), 
                                               timeStep = timeStep, 
                                               numberOfIterations = numberOfIterations),
                                            fixedImage), 
                          fixedImage, t)))
        else:
            images.append(sitk.GetArrayFromImage(resampleImage(
                        sitk.HistogramMatching(
                            sitk.ReadImage(pathName),
                            fixedImage),  
                        fixedImage, t)))

        images[i][images[i] < 0] = 0
        images[i] = np.asarray(images[i], dtype=np.uint16)

        label = sitk.ReadImage(os.path.join(labelPath, f))
        label = sitk.GetArrayFromImage(resampleLabels(label, fixedImage, t))
        labels.append(label)

        idxs = np.nonzero(label != 0)
        minx = min([minx, min(idxs[0])])
        maxx = max([maxx, max(idxs[0])])
        miny = min([miny, min(idxs[1])])
        maxy = max([maxy, max(idxs[1])])
        minz = min([minz, min(idxs[2])])
        maxz = max([maxz, max(idxs[2])])


    ###############################################################################
    # Calculate the desired shape
    ###############################################################################
    shape, copyShape, offset, length = calculateCropShape(images, P, N, minx, maxx, miny, maxy, minz, maxz)

    ###############################################################################
    # crop images and labels
    ###############################################################################
    for i in range(len(images)):
        images[i] = cropImage(images[i], shape, offset, length, copyShape, "uint16")
        labels[i] = cropImage(labels[i], shape, offset, length, copyShape, "uint8")

    fixedImage = sitk.GetArrayFromImage(fixedImage)
    fixedImage = cropImage(fixedImage, shape, offset, length, copyShape, "uint16")


    ###############################################################################
    # Make to the proper type
    ###############################################################################
    # print("Fixing data types.", end='\x1b[1K\r')
    print("Fixing data types.")
    images = np.array(images, order='C')
    labels = np.array(labels, order='C')
    P = np.array(P, dtype=np.int32)
    N = np.array(N, dtype=np.int32)

    ###############################################################################
    # Perform segmentation
    ###############################################################################
    from cGen import cGen
    print("Running segmentation.")
    segmentation = cGen.applyMV(fixedImage, images, 
                                labels, totalLabels, verboseX=True,  
                                verboseY=False, 
                                xmin=-1, xmax=-1, ymin=-1, ymax=-1, zmin=-1, zmax=-1,
                               )

    ###############################################################################
    # Save segmentation results
    ###############################################################################

    # print("Saving results.", end='\x1b[1K\r')
    print("Saving results.")
    try:
        saveSegmentation(segmentation,
                         os.path.join(path, data[testImageIdx]), 
                         os.path.join(resultsFolder, data[testImageIdx]) + f"MV_{testIndex}.mhd", 
                         copyShape, 
                         offset,
                         length,
                         verbose=False)
    except:
        print("Can't save result.")


    ###############################################################################
    # Save dice index
    ###############################################################################

    originalLabel = sitk.GetArrayFromImage(sitk.ReadImage(os.path.join(labelPath, data[testImageIdx])))

    translatedSegmentation = translateToOriginal(segmentation, 
                                                 os.path.join(path, data[testImageIdx]), 
                                                 copyShape, 
                                                 offset, 
                                                 length)

    dice= Dice(originalLabel, translatedSegmentation, totalLabels)

    print(f"{testIndex}: Atlases: {atlasesNum}, P: {P}, N: {N}")
    print("MV:", dice)

    try:
        diceFile = open(f"./{resultsFolder}/dice.txt", "a") 
        diceFile.write(f"{testIndex}: Atlases: {atlasesNum}, P: {P}, N: {N}"+"\n")
        diceFile.write("MV: " + str(dice) + "\n")
        diceFile.close()
    except:
        print("Can't append dice results.")

In [None]:
####################################################
# END OF NORMAL CODE 
####################################################
import sys
sys.exit()

In [None]:

pathName = os.path.join(path, "9036287.nii.gz")
image = sitk.ReadImage(pathName)
print(image.GetPixelID())
test = sitk.GetArrayFromImage(image)
test = test.astype(np.uint16)

newImage = sitk.GetImageFromArray(test)
newImage.CopyInformation(image)
print(newImage.GetPixelID())
pathName = os.path.join(path,  "9036287.nii.gz")
# sitk.WriteImage(newImage, pathName, True)

In [None]:
import numpy as np

In [None]:
sitk.GetArrayFromImage(image)

In [None]:

(((sitk.GetArrayFromImage(image) - sitk.GetArrayFromImage(fixedImage))[sitk.GetArrayFromImage(label) > 0])**2).mean()

In [None]:
((sitk.GetArrayFromImage(image) - sitk.GetArrayFromImage(fixedImage))**2).mean()


In [None]:
mask = sitk.GetArrayFromImage(label)
mask[mask > 1] = 1
print(((mask * (sitk.GetArrayFromImage(image) - sitk.GetArrayFromImage(fixedImage)))**2).sum())
((mask * sitk.GetArrayFromImage(image) - mask * sitk.GetArrayFromImage(fixedImage))**2).sum()

In [None]:

for testIndex, (atlasesNum, P, N, lassoTol) in enumerate(inputData):
    print(testIndex, atlasesNum, P, N, lassoTol)

In [None]:
from PIL import Image
import numpy as np

def showImg_(img, z=60):
    a = sitk.GetArrayFromImage(img[z,:,:])
    disImg = Image.fromarray(np.interp(a, (a.min(), a.max()), (0, 255)).astype('uint8'))
    disImg.show()
    
def showImg2_(img, z=60):
    disImg = Image.fromarray(sitk.GetArrayFromImage(img[z,:,:]).astype('uint8'))
    disImg.show()
    
def showImg3_(img, z=60):
    a = sitk.GetArrayFromImage(img[z,:,:])
    disImg = image.fromarray(np.interp(a, (0, 5000), (0, 255)).astype('uint8'))
    disImg.show()
    
def showImg4_(img, z=60):
    img2 = Image.fromarray(sitk.GetArrayFromImage(img[z,:,:])*60, 'L')
    img2.show()

In [None]:
labels[0].shape

In [None]:
showImg2_(sitk.GetImageFromArray(fixedImage), 20)
showImg2_(sitk.GetImageFromArray(images[0]), 20)
showImg_(sitk.GetImageFromArray(labels[0]), 20)
# showImg_(sitk.GetImageFromArray(segmentationSPBM), 20)

In [None]:
showImg2_(sitk.GetImageFromArray(fixedImage), 60)
showImg2_(sitk.GetImageFromArray(images[0]), 60)
showImg_(sitk.GetImageFromArray(segmentationSPBM), 60)
# showImg2_(images[1], 60)

In [None]:

labels = np.array(labels, order='C', dtype=np.uint8)

segmentationSPBM, segmentationSRC = cGen.applySPBMandSRCSpams(fixedImage, images, 
                              labels, 3, P, N, lassoTol=0.00001, 
                              lassoMaxIter=1e4, verboseX=True,  verboseY=False, 
                              xmin=-1, xmax=-1, ymin=-1, ymax=-1, zmin=-1, zmax=-1)

In [None]:
segmentationSPBM, segmentationSRC = cGen.applySPBMandSRCSpams(fixedImage,
                                                         images, 
                              np.ones(labels.shape, dtype=np.uint8, order='C')+50, 3, P, N, lassoTol=0.00001, 
                              lassoMaxIter=1e4, verboseX=True,  verboseY=False, 
                              xmin=-1, xmax=-1, ymin=-1, ymax=-1, zmin=-1, zmax=-1)

In [None]:
from cGen import cGen
import numpy as np

P = [7, 7, 7]
N = [7, 7, 7]
P = [5, 5, 5]
N = [5, 5, 5]
P = np.array(P, dtype=np.int32)
N = np.array(N, dtype=np.int32)

ii = np.array(np.random.randint(0, high=500, size=(390, 276, 129)), dtype=np.uint16, order='C')
i = np.array(np.random.randint(0, high=500, size=(5, 390, 276, 129)), dtype=np.uint16, order='C')
l = np.array(np.random.randint(0, high=3, size=(5, 390, 276, 129)), dtype=np.uint8, order='C')
# segmentationSPBM, segmentationSRC = cGen.applySPBMandSRCSpams(ii,
'''
segmentationSPBM, segmentationSRC = cGen.applySPBMandSRCSpams(ii,
                                                              i, 
                                                              l, 3, P, N, lassoTol=0.1, 
                              lassoMaxIter=1e0, verboseX=True,  verboseY=False, 
                              xmin=-1, xmax=-1, ymin=-1, ymax=-1, zmin=60, zmax=61,
                              # xmin=-1, xmax=-1, ymin=60, ymax=61, zmin=60, zmax=61,
                                  numThreads=1, lassoL=1)
'''
segmentation = cGen.applySPEP(ii, i,  l, 3, P, N, verboseX=True,  verboseY=True, 
                              xmin=-1, xmax=-1, ymin=-1, ymax=-1, zmin=-1, zmax=-1,)

In [None]:
np.sum(l[l>0])
l

In [None]:
import spams
import numpy as np

A = np.zeros(shape=(20, 10), dtype=np.single, order='F')
B = np.zeros(shape=(20, 1), dtype=np.single, order='F')
# B = np.zeros(shape=(20, 1), dtype=np.single, order='F')


alpha = spams.lasso(B, A, 
               return_reg_path = False, 
               lambda1 = 0.1, 
               lambda2 = 0.,
               pos = True,
               mode = 2,
               numThreads = -1,
              )


In [None]:
type(alpha.A[:,0])

In [None]:
alpha.toarray()[:,0]

In [None]:

sitk.WriteImage(sitk.GetImageFromArray(fixedImage), "fixedImage.mhd")
sitk.WriteImage(sitk.GetImageFromArray(images[0]), "movingImage.mhd")
sitk.WriteImage(sitk.GetImageFromArray(labels[0]), "movinglabel.mhd")


In [None]:
import SimpleITK as sitk
import numpy as np

fixedImage = sitk.GetArrayFromImage(sitk.ReadImage("fixedImage.mhd"))
image = sitk.GetArrayFromImage(sitk.ReadImage("movingImage.mhd"))
label = sitk.GetArrayFromImage(sitk.ReadImage("movinglabel.mhd"))

fixedImage = np.array(fixedImage, order='C')
images = np.array([image], order='C')
labels = np.array([label], order='C')

In [None]:
print(fixedImage.shape)
print(images.shape)
print(labels.shape)

In [None]:
from cGen import cGen

P = [3, 3, 3]
N = [5, 5, 5]
P = np.array(P, dtype=np.int32)
N = np.array(N, dtype=np.int32)

segmentationSPBM, segmentationSRC = cGen.applySPBMandSRCSpams(fixedImage, images, 
# segmentationSPBM, segmentationSRC = cGen.applySPBMandSRC(fixedImage, images, 
                              labels, 3, P, N, lassoTol=0.001, 
                              lassoMaxIter=1e4, verboseX=True,  verboseY=False, 
                              xmin=-1, xmax=-1, ymin=-1, ymax=-1, zmin=60, zmax=61)
                              # xmin=60, xmax=61, ymin=-1, ymax=-1, zmin=-1, zmax=-1)
                              # xmin=10, xmax=11, ymin=10, ymax=11, zmin=10, zmax=11)

In [None]:
showImg_(sitk.GetImageFromArray(segmentationSPBM), 60)
showImg_(sitk.GetImageFromArray(segmentationSRC), 60)
# showImg_(sitk.GetImageFromArray(originalLabel), 60)
showImg_(sitk.GetImageFromArray(labels[0]), 60)

In [None]:
segmentationSRC[segmentationSRC == 3]

In [None]:

originalLabel = sitk.GetArrayFromImage(sitk.ReadImage(os.path.join(labelPath, testData[0])))
translatedSegmentationSPBM = translateToOriginal(segmentationSPBM, 
                                                 os.path.join(path, testData[0]), 
                                                 copyShape, 
                                                 offset, 
                                                 length)

print(Dice(originalLabel, translatedSegmentationSPBM, 3))

translatedSegmentationSRC = translateToOriginal(segmentationSRC, 
                                                os.path.join(path, testData[0]), 
                                                copyShape, 
                                                offset, 
                                                length)

print(Dice(originalLabel, translatedSegmentationSRC, 3))