In [None]:
import os
import random

# Folder name to save registration transforms
resultsFolder = "results"

path = "./Dataset/data"
labelPath = "Dataset/labels"

data = os.listdir(path)
datasetLen = len(data)

# random shuffle data
random.shuffle(data)

# spilt dataset
trainData = data[:int(datasetLen * 0.7)]
atlases = data[int(datasetLen * 0.7):]

"""Write transforms."""
from pathlib import Path

Path(f"./{resultsFolder}").mkdir(parents=True, exist_ok=True)

print("Saving training data names.")
file = open(f"./{folder}/trainData.txt", "w") 
for i, name in enumerate(trainData):
    print(f"\tSaving {name}. {i+1}/{len(data)}")
    file.write(name+"\n")
file.close()

print("Saving atlases names.")
file = open(f"./{folder}/atlases.txt", "w") 
for i, name in enumerate(atlases):
    print(f"\tSaving {name}. {i+1}/{len(data)}")
    file.write(name+"\n")
file.close()

In [None]:
import SimpleITK as sitk
import numpy as np
from utils import *

transforms = []
mseList = []

# Read fixed image
pathName = os.path.join(path, atlases[0])
fixedImage = sitk.CurvatureFlow(sitk.ReadImage(pathName), timeStep = 0.04, numberOfIterations = 10)


for i, f in enumerate(trainData[:2]):
    print(f'Reading {f}. {i+1}/{len(trainData)}')
    
    pathName = os.path.join(path, f)
                  
    # Histogram matching
    image = sitk.HistogramMatching(sitk.CurvatureFlow(sitk.ReadImage(pathName), 
                                                      timeStep = 0.04, 
                                                      numberOfIterations = 10),
                                   fixedImage)

    label = sitk.ReadImage(os.path.join(labelPath, f))

    t = registration(fixedImage, image, label)

    # add transform to the list
    transforms.append(t)

    # add mse to the list
    mseList.append(mse3D(fixedImage, resampleImage(image, fixedImage, t)))
    
# sort mseList
mseListSorted = sorted(enumerate(mseList), key=lambda x: x[1])

# pick the desired number of atlases
atlasesNum = 5
minx = miny = minz = 10000
maxx = maxy = maxz = 0
images = []
labels = []

# Read atlases and labels and tranform them
for i in range(atlasesNum):
    idx = mseListSorted[i][0]
    t = transforms[idx]
    f = trainData[idx]
    pathName = os.path.join(path, f)
    
    # Histogram matching
    images.append(sitk.GetArrayFromImage(resampleImage(
                sitk.HistogramMatching(
                    sitk.CurvatureFlow(sitk.ReadImage(pathName), 
                                       timeStep = 0.04, 
                                       numberOfIterations = 10),
                                    fixedImage), 
                  fixedImage, t)))
    
    label = sitk.ReadImage(os.path.join(labelPath, f))
    label = sitk.GetArrayFromImage(resampleSegmentation(label, fixedImage, t))
    labels.append(label)
    
    idxs = np.nonzero(label != 0)
    minx = min([minx, min(idxs[0])])
    maxx = max([maxx, max(idxs[0])])
    miny = min([miny, min(idxs[1])])
    maxy = max([maxy, max(idxs[1])])
    minz = min([minz, min(idxs[2])])
    maxz = max([maxz, max(idxs[2])])
    
    
###############################################################################
# Calculate the desired shape
shape = [[minx - (P[0]//2 + N[0]//2), maxx + (P[0]//2 + N[0]//2)],
         [miny - (P[1]//2 + N[1]//2), maxy + (P[1]//2 + N[1]//2)],
         [minz - (P[2]//2 + N[2]//2), maxz + (P[2]//2 + N[2]//2)],
        ]

copyShape = [[],[],[]]

# x range
if shape[0][0] < 0: 
    copyShape[0].append(0)
else:
    copyShape[0].append(shape[0][0])
    
if shape[0][1] > images[0].shape[0] - 1: 
    copyShape[0].append(images[0].shape[0] - 1)
else:
    copyShape[0].append(shape[0][1])

# y range
if shape[1][0] < 0: 
    copyShape[1].append(0)
else:
    copyShape[1].append(shape[1][0])
    
if shape[1][1] > images[0].shape[1] - 1: 
    copyShape[1].append(images[0].shape[1] - 1)
else:
    copyShape[1].append(shape[1][1])
    
#z range
if shape[2][0] < 0: 
    copyShape[2].append(0)
else:
    copyShape[2].append(shape[2][0])
    
if shape[2][1] > images[0].shape[2] - 1: 
    copyShape[2].append(images[0].shape[2] - 1)
else:
    copyShape[2].append(shape[2][1])
    
offset = []
offset.append(copyShape[0][0] - shape[0][0])
offset.append(copyShape[1][0] - shape[1][0])
offset.append(copyShape[2][0] - shape[2][0])

length = []
length.append(copyShape[0][1] - copyShape[0][0] + 1)
length.append(copyShape[1][1] - copyShape[1][0] + 1)
length.append(copyShape[2][1] - copyShape[2][0] + 1)


# crop images and labels
for i in range(len(images)):
    newImage = np.zeros((shape[0][1] - shape[0][0] + 1, shape[1][1] - shape[1][0] + 1, shape[2][1] - shape[2][0] + 1), dtype="uint16", order="F")
    newImage[offset[0]:offset[0]+length[0], 
             offset[1]:offset[1]+length[1], 
             offset[2]:offset[2]+length[2]] = images[i][copyShape[0][0]:copyShape[0][1]+1,
                                                        copyShape[1][0]:copyShape[1][1]+1,
                                                        copyShape[2][0]:copyShape[2][1]+1,]
    images[i] = newImage

    newLabels = np.zeros((shape[0][1] - shape[0][0] + 1, shape[1][1] - shape[1][0] + 1, shape[2][1] - shape[2][0] + 1,), dtype="uint8", order="F")
    newLabels[offset[0]:offset[0]+length[0], 
              offset[1]:offset[1]+length[1], 
              offset[2]:offset[2]+length[2]] = labels[i][copyShape[0][0]:copyShape[0][1]+1,
                                                         copyShape[1][0]:copyShape[1][1]+1,
                                                         copyShape[2][0]:copyShape[2][1]+1,]
    labels.append(newLabels)

fixedImage = sitk.GetArrayFromImage(fixedImage)
newImage = np.zeros((shape[0][1] - shape[0][0] + 1, shape[1][1] - shape[1][0] + 1, shape[2][1] - shape[2][0] + 1), dtype="uint16", order="F")
newImage[offset[0]:offset[0]+length[0], 
         offset[1]:offset[1]+length[1], 
         offset[2]:offset[2]+length[2]] = fixedImage[copyShape[0][0]:copyShape[0][1]+1,
                                                     copyShape[1][0]:copyShape[1][1]+1,
                                                     copyShape[2][0]:copyShape[2][1]+1,]
fixedImage = newImage


# Make to the proper type
images = np.array(images, order='C')
labels = np.array(labels, order='C')
P = np.array(P, dtype=np.int32)
N = np.array(N, dtype=np.int32)

from cGen import cGen

segmentation = cGen.applySPBM(fixedImage, images, labels, 3, P, N, lassoTol=0.00001, lassoMaxIter=1e4, verboseX=True, 
                 verboseY=False, xmin=-1, xmax=-1, ymin=-1, ymax=-1, zmin=-1, zmax=-1)

In [None]:
# sort mseList
mseListSorted = sorted(enumerate(mseList), key=lambda x: x[1])

# pick the desired number of atlases
atlasesNum = 5
minx = miny = minz = 10000
maxx = maxy = maxz = 0
images = []
labels = []

# Read atlases and labels and tranform them
for i in range(2):
    idx = mseListSorted[i][0]
    print(idx)
    t = transforms[idx]
    f = trainData[idx]
    pathName = os.path.join(path, f)
    
    # Histogram matching
    images.append(resampleImage(sitk.HistogramMatching(sitk.CurvatureFlow(sitk.ReadImage(pathName), 
                                                      timeStep = 0.04, 
                                                      numberOfIterations = 10),
                                         fixedImage), 
                  fixedImage, t))
    
    label = sitk.ReadImage(os.path.join(labelPath, f))
    label = sitk.GetArrayFromImage(resampleSegmentation(label, fixedImage, t))
    
    idxs = np.nonzero(label != 0)
    minx = min([minx, min(idxs[0])])
    maxx = max([maxx, max(idxs[0])])
    miny = min([miny, min(idxs[1])])
    maxy = max([maxy, max(idxs[1])])
    minz = min([minz, min(idxs[2])])
    maxz = max([maxz, max(idxs[2])])

In [None]:
from PIL import Image

def showImg_(img, z=60):
    a = sitk.GetArrayFromImage(img[z,:,:])
    disImg = Image.fromarray(np.interp(a, (a.min(), a.max()), (0, 255)).astype('uint8'))
    disImg.show()
    
def showImg2_(img, z=60):
    disImg = Image.fromarray(sitk.GetArrayFromImage(img[z,:,:]).astype('uint8'))
    disImg.show()
    
def showImg3_(img, z=60):
    a = sitk.GetArrayFromImage(img[z,:,:])
    disImg = image.fromarray(np.interp(a, (0, 5000), (0, 255)).astype('uint8'))
    disImg.show()
    
def showImg4_(img, z=60):
    img2 = Image.fromarray(sitk.GetArrayFromImage(img[z,:,:])*60, 'L')
    img2.show()

In [None]:
showImg2_(fixedImage, 60)
showImg2_(images[0], 60)
showImg2_(images[1], 60)

In [None]:
l = sitk.GetArrayFromImage(label)
l[l == 4] = 0
mask = sitk.GetImageFromArray(l)

showImg4_(mask, 100)


In [None]:
'''
labels
1: katw gonato
2: katw xondos
3: panw gonato
4: panw xondros
5: panw xondros
'''

In [None]:
idx = 0
f = trainData[idx]
pathName = os.path.join(path, f)
image = sitk.HistogramMatching(sitk.CurvatureFlow(sitk.ReadImage(pathName), 
                                                      timeStep = 0.04, 
                                                      numberOfIterations = 10),
                               fixedImage)
'''
label = sitk.ReadImage(os.path.join(labelPath, f))
# Create the mask
l = sitk.GetArrayFromImage(label)
l[l>1] = 1
mask = sitk.GetImageFromArray(l)
mask.CopyInformation(label)
'''
    
elastixImageFilter = sitk.ElastixImageFilter()
elastixImageFilter.SetFixedImage(fixedImage)
elastixImageFilter.SetMovingImage(image)
# elastixImageFilter.AddMovingMask(mask)
elastixImageFilter.SetParameterMap(sitk.GetDefaultParameterMap("affine"))
elastixImageFilter.Execute()

In [None]:
resultImage2 = elastixImageFilter.GetResultImage()
transformParameterMap = elastixImageFilter.GetTransformParameterMap()

In [None]:
sitk.PrintParameterMap(elastixImageFilter.GetParameterMap())

In [None]:
help(sitk.ElastixImageFilter)

In [None]:
print(transformParameterMap)

In [None]:
showImg2_(fixedImage, 60)
showImg2_(images[0], 60)
showImg2_(resultImage, 60)

In [None]:
showImg_(fixedImage, 60)
showImg_(images[1], 60)
showImg_(resultImage, 60)

In [None]:
help(sitk.GetImageFromArray)