In [None]:
import SimpleITK as sitk
import numpy as np

In [None]:
# Read data
reader = sitk.ImageSeriesReader()

dicom_names = reader.GetGDCMSeriesFileNames("../MRI_Data/Baseline/KL0/9003430/10557212")
reader.SetFileNames(dicom_names)

img1 = reader.Execute()

dicom_names = reader.GetGDCMSeriesFileNames("../MRI_Data/Baseline/KL0/9005075/10593811")
reader.SetFileNames(dicom_names)

img2 = reader.Execute()

In [None]:
# Registration
initial_transform = sitk.CenteredTransformInitializer(img1, 
                                                      img2, 
                                                      sitk.AffineTransform(3), 
                                                      sitk.CenteredTransformInitializerFilter.GEOMETRY)
registration_method = sitk.ImageRegistrationMethod()

registration_method.SetMetricAsMattesMutualInformation(numberOfHistogramBins=50)
registration_method.SetMetricSamplingStrategy(registration_method.RANDOM)
registration_method.SetMetricSamplingPercentage(0.01)

registration_method.SetInterpolator(sitk.sitkLinear)
   
registration_method.SetOptimizerAsGradientDescent(learningRate=1.0, numberOfIterations=100) #, estimateLearningRate=registration_method.EachIteration)
registration_method.SetOptimizerScalesFromPhysicalShift() 

final_transform = sitk.AffineTransform(initial_transform)

registration_method.SetInitialTransform(final_transform)
registration_method.SetShrinkFactorsPerLevel(shrinkFactors = [4,2,1])
registration_method.SetSmoothingSigmasPerLevel(smoothingSigmas = [2,1,0])
registration_method.SmoothingSigmasAreSpecifiedInPhysicalUnitsOn()

t = registration_method.Execute(sitk.Cast(img1, sitk.sitkFloat32), sitk.Cast(img2, sitk.sitkFloat32))

resample = sitk.ResampleImageFilter()
resample.SetReferenceImage(img1)

resample.SetInterpolator(sitk.sitkLinear)  
resample.SetTransform(t)
img2_1 = resample.Execute(img2)

In [None]:
def createA(images, voxel, P, N):
    # Allocate memory
    A = np.zeros(shape=(P[0]*P[1]*P[2], len(images)*N[0]*N[1]*N[2]), dtype='uint16')
    
    index = 0
    for image in images:
        for n0 in range(-(N[0]-1)//2, (N[0]-1)//2+1):
            for n1 in range(-(N[1]-1)//2, (N[1]-1)//2+1):
                for n2 in range(-(N[2]-1)//2, (N[2]-1)//2+1):
                    v = [voxel[0] + n0, voxel[1] + n1, voxel[2] + n2]
                    A[:, index] = sitk.GetArrayFromImage(image[v[0]-(P[0]-1)//2:v[0]+(P[0]-1)//2+1, 
                                                               v[1]-(P[1]-1)//2:v[1]+(P[1]-1)//2+1, 
                                                               v[2]-(P[2]-1)//2:v[2]+(P[2]-1)//2+1,
                                                              ]).reshape((-1,))
                    index += 1
                    
    return A

In [None]:
def createB(image, voxel, P):
    return sitk.GetArrayFromImage(image[voxel[0]-(P[0]-1)//2:voxel[0]+(P[0]-1)//2+1, 
                                        voxel[1]-(P[1]-1)//2:voxel[1]+(P[1]-1)//2+1, 
                                        voxel[2]-(P[2]-1)//2:voxel[2]+(P[2]-1)//2+1,
                                       ]).reshape((-1,))
    

In [None]:
from sklearn.linear_model import Lasso

v = [100,100,100]
P = [3, 3, 3]
N = [5, 5, 5]
A = createA([img2_1], v, P, N)
B = createB(img1, v, P)

# lasso = Lasso(alpha=0.0001, positive=True, max_iter=1000000000, tol=0.00000000000001)
lasso = Lasso(alpha=100, positive=True, max_iter=1e6)
lasso.fit(A,B)

In [None]:
A @ lasso.coef_ - B
lasso.coef_.shape


In [None]:
# P = [3, 3, 3]
# N = [5, 5, 5]
# size = img1.GetSize()
# w = np.zeros(shape=(size[0], size[1], size[2] ,N[0]*N[1]*N[2]))
# for x in range(N[0]//2, size[0]-N[0]//2):
#     for y in range(N[1]//2, size[1]-N[1]//2):
#         for x in range(N[2]//2, size[2]-N[2]//2):
#             v = [x, y, z]
#             A = createA([img2_1], v, P, N)
#             B = createB(img1, v, P)

#             lasso = Lasso(alpha=100, positive=True, max_iter=1e6)
#             lasso.fit(A,B)
#             w[x,y,z,:] = lasso.coef_
            

In [None]:
def labelFusion(w, L):
    return np.sum(w*L) / np.sum(w)

def segmentation(L):
    return 1 if L >= 0.5 else 0

def transformSegmentation(original_image, original_segmentation):
    return sitk.Resample(original_segmentation, original_image.GetSize(),
                         sitk.Transform(), 
                         sitk.sitkNearestNeighbor,
                         original_image.GetOrigin(),
                         original_image.GetSpacing(),
                         original_image.GetDirection(),
                         0,
                         original_segmentation.GetPixelID())

In [None]:
img1_seg = sitk.ReadImage("../MRI_Data/Baseline/KL0/9003430/9003430.segmentation_masks.mhd")
img2_seg = sitk.ReadImage("../MRI_Data/Baseline/KL0/9005075/9005075.segmentation_masks.mhd")

# img1S = transformSegmentation(img1, img1_seg)


In [None]:
def printData(img):
    print(img.GetDirection())
    print(img.GetHeight())
    print(img.GetOrigin())
    print(img.GetSpacing())
    print(img.GetDepth())
    print(img.GetSize())
    print(img.GetWidth())
    print(img.GetNumberOfComponentsPerPixel())
    

In [None]:
img2S = transformSegmentation(img2_1, img2_seg)

printData(img2_1)
printData(img2S)

In [None]:
print(img1S[v[0], v[1], v[2]])
print(img2S[v[0], v[1], v[2]])

In [None]:
def createL(images, voxel, N):
    # Allocate memory
    L = np.zeros(shape=(N[0]*N[1]*N[2],), dtype='uint8')
    
    index = 0
    for image in images:
        for n0 in range(-(N[0]-1)//2, (N[0]-1)//2+1):
            for n1 in range(-(N[1]-1)//2, (N[1]-1)//2+1):
                for n2 in range(-(N[2]-1)//2, (N[2]-1)//2+1):
                    v = [voxel[0] + n0, voxel[1] + n1, voxel[2] + n2]
                    L[index] = image[v[0], v[1], v[2]]
                    index += 1
                    
    return L

In [None]:
L = createL([img2S], v, N)
print(L)

In [None]:
label = labelFusion(lasso.coef_, L)
    
segmentation(label)

In [None]:
print(label)

In [None]:
from IPython.display import display
from PIL import Image as image

def showImg(img):
    img2 = image.fromarray(sitk.GetArrayFromImage(img[:,:,60])*60, 'L')
    display(img2)
    
def showImg2(img):
    img2 = image.fromarray(sitk.GetArrayFromImage(img[60,:,:])*60, 'L')
    display(img2)
    
def showImg3(img):
    disImg = image.fromarray(sitk.GetArrayFromImage(img)[60,:,:].astype('uint8'))
    display(disImg)
    
showImg3(img2_1)
showImg(img2S)

In [None]:
print(img2_seg.GetSize())
print(img2S.GetSize())

In [None]:
seg1 = transformSegmentation(img2, img2_seg)
seg2 = resample.Execute(seg1)

In [None]:
def showImg_(img, z=60):
    img2 = image.fromarray(sitk.GetArrayFromImage(img[:,:,z])*60, 'L')
    img2.show()
 
def showImg2_(img, z=60):
    disImg = image.fromarray(sitk.GetArrayFromImage(img[z,:,:]).astype('uint8'))
    disImg.show()
    
def showImg3_(img, z=60, t=''):
    disImg = image.fromarray(sitk.GetArrayFromImage(img[:,:,z]).astype('uint8'))
    disImg.show(title=t)
    
def showImg4_(img, z=60):
    img2 = image.fromarray(sitk.GetArrayFromImage(img[z,:,:])*60, 'L')
    img2.show()

In [None]:
showImg3_(img2)
showImg_(seg1)

In [None]:
printData(img2)
printData(img2_seg)

In [None]:
showImg3_(img1)
showImg3_(img2)

In [None]:
initial_transform = sitk.CenteredTransformInitializer(img1, 
                                                      img2, 
                                                      sitk.AffineTransform(3), 
                                                      sitk.CenteredTransformInitializerFilter.GEOMETRY)
registration_method = sitk.ImageRegistrationMethod()

registration_method.SetMetricAsMattesMutualInformation(numberOfHistogramBins=50)
registration_method.SetMetricSamplingStrategy(registration_method.RANDOM)
registration_method.SetMetricSamplingPercentage(0.01)

registration_method.SetInterpolator(sitk.sitkLinear)
   
registration_method.SetOptimizerAsGradientDescent(learningRate=1.0, numberOfIterations=1000) #, estimateLearningRate=registration_method.EachIteration)
registration_method.SetOptimizerScalesFromPhysicalShift() 

final_transform = sitk.AffineTransform(initial_transform)

registration_method.SetInitialTransform(final_transform)
registration_method.SetShrinkFactorsPerLevel(shrinkFactors = [4,2,1])
registration_method.SetSmoothingSigmasPerLevel(smoothingSigmas = [2,1,0])
registration_method.SmoothingSigmasAreSpecifiedInPhysicalUnitsOn()

t = registration_method.Execute(sitk.Cast(img1, sitk.sitkFloat32), sitk.Cast(img2, sitk.sitkFloat32))

resample = sitk.ResampleImageFilter()
resample.SetReferenceImage(img1)

resample.SetInterpolator(sitk.sitkLinear)  
resample.SetTransform(t)
img2_2 = resample.Execute(img2)

In [None]:
showImg3_(img1)
showImg3_(img2_2)

In [None]:
initial_transform = sitk.CenteredTransformInitializer(sitk.Cast(fixed_image,moving_image.GetPixelID()), 
                                                      moving_image, 
                                                      sitk.Euler3DTransform(), 
                                                      sitk.CenteredTransformInitializerFilter.GEOMETRY)

registration_method = sitk.ImageRegistrationMethod()

registration_method.SetMetricAsMattesMutualInformation(numberOfHistogramBins=50)
registration_method.SetMetricSamplingStrategy(registration_method.RANDOM)
registration_method.SetMetricSamplingPercentage(0.01)

registration_method.SetInterpolator(sitk.sitkLinear)
   
registration_method.SetOptimizerAsGradientDescent(learningRate=1.0, numberOfIterations=100) #, estimateLearningRate=registration_method.EachIteration)
registration_method.SetOptimizerScalesFromPhysicalShift() 

final_transform = sitk.Euler3DTransform(initial_transform)
registration_method.SetInitialTransform(final_transform)
registration_method.SetShrinkFactorsPerLevel(shrinkFactors = [4,2,1])
registration_method.SetSmoothingSigmasPerLevel(smoothingSigmas = [2,1,0])
registration_method.SmoothingSigmasAreSpecifiedInPhysicalUnitsOn()

registration_method.Execute(sitk.Cast(fixed_image, sitk.sitkFloat32), 
                            sitk.Cast(moving_image, sitk.sitkFloat32))

resample = sitk.ResampleImageFilter()
resample.SetReferenceImage(fixed_image)

# SimpleITK supports several interpolation options, we go with the simplest that gives reasonable results.     
resample.SetInterpolator(sitk.sitkLinear)  
resample.SetTransform(final_transform)
img2_3 = resample.Execute(moving_image)

In [None]:
z = 50
showImg3_(img1, z)
showImg_(img1S, z)

In [None]:
def transformSegmentation2(original_image, original_segmentation):
    return sitk.Resample(original_segmentation, 
                         original_image,
                         sitk.sitk.Transform()(), 
                         sitk.sitkLinear,
                         0,
                         original_segmentation.GetPixelID())
img1S2 = transformSegmentation2(img1, img1_seg)

In [None]:
img1_1 = transformSegmentation2(img1_seg, img1)

In [None]:
def showImg_(img, z=60):
    img2 = image.fromarray(sitk.GetArrayFromImage(img[:,:,z])*60, 'L')
    img2.show()
    
def showImg3_(img, z=60):
#     disImg = image.fromarray(sitk.GetArrayFromImage(img[:,:,z]).astype('uint8'))
    a = sitk.GetArrayFromImage(img[:,:,z])
    disImg = image.fromarray(np.interp(a, (a.min(), a.max()), (0, 255)).astype('uint8'))
#     np.interp(a, (0, 65535), (0, 255))
    disImg.show()
    
def showImg4_(img, z=60):
    disImg = image.fromarray(sitk.GetArrayFromImage(img[:,:,z]), 'RGB')
    disImg.show()
    
def showImg2_(img, z=60):
#     disImg = image.fromarray(sitk.GetArrayFromImage(img[:,:,z]).astype('uint8'))
    a = sitk.GetArrayFromImage(img[:,:,z])/2
    disImg = image.fromarray(np.interp(a, (a.min(), a.max()), (0, 255)).astype('uint8'))
#     np.interp(a, (0, 65535), (0, 255))
    disImg.show()

In [None]:
def transformSegmentation(originalImage, originalSegmentation):
    initialTransform = sitk.CenteredTransformInitializer(sitk.Cast(originalImage,originalSegmentation.GetPixelID()), 
                                                      originalSegmentation, 
                                                      sitk.Euler3DTransform(), 
                                                      sitk.CenteredTransformInitializerFilter.GEOMETRY)
    
    return sitk.Resample(originalSegmentation, originalImage.GetSize(),
                         initialTransform, 
                         sitk.sitkNearestNeighbor,
                         originalImage.GetOrigin(),
                         originalImage.GetSpacing(),
                         originalImage.GetDirection(),
                         0,
                         originalSegmentation.GetPixelID())

In [None]:
def transformSegmentation2(originalImage, originalSegmentation):
#     initialTransform = sitk.CenteredTransformInitializer(sitk.Cast(originalImage,originalSegmentation.GetPixelID()), 
#                                                       originalSegmentation, 
#                                                       sitk.Euler3DTransform(), 
#                                                       sitk.CenteredTransformInitializerFilter.GEOMETRY)
    
    return sitk.Resample(originalSegmentation, 
                         originalImage.GetSize(),
#                          initialTransform,
                         sitk.Transform(),
                         sitk.sitkNearestNeighbor,
                         originalImage.GetOrigin(),
                         originalImage.GetSpacing(),
                         originalImage.GetDirection(),
                         0,
                         originalSegmentation.GetPixelID())

In [None]:
img1S = transformSegmentation(img1, img1_seg)

In [None]:
img1S1 = transformSegmentation2(img1, img1_seg)

In [None]:
img11 = transformSegmentation2(img1_seg, img1)

In [None]:
z = 60
showImg3_(img1, z)
showImg_(img1S1, z)
showImg_(img1S, z)

In [None]:
img1_seg.GetSize()

In [None]:
z = 90
showImg3_(img1, z)
showImg_(img1S2, z)

In [None]:
def print_image_info(img):

    print ("Size:      %d   %d   %d"   % (img.GetSize   ()[0], img.GetSize   ()[1], img.GetSize   ()[2]))
    print ("Spacing:   %.2f %.2f %.2f" % (img.GetSpacing()[0], img.GetSpacing()[1], img.GetSpacing()[2]))
    print ("Origin:    %.2f %.2f %.2f" % (img.GetOrigin ()[0], img.GetOrigin ()[1], img.GetOrigin ()[2]))
    print ("Direction: \n%.2f %.2f %.2f \n%.2f %.2f %.2f \n%.2f %.2f %.2f" %     \
           (img.GetDirection()[0], img.GetDirection()[1], img.GetDirection()[2], \
            img.GetDirection()[3], img.GetDirection()[4], img.GetDirection()[5], \
            img.GetDirection()[6], img.GetDirection()[7], img.GetDirection()[8]))



In [None]:
print_image_info(img1)
print_image_info(img1S)

In [None]:
z = 80
showImg3_(img1, z)
showImg3_(img1S, z)

In [None]:
import matplotlib.pyplot as plt

mr_image = img1
npa = sitk.GetArrayViewFromImage(mr_image)

# Display the image slice from the middle of the stack, z axis
z = int(mr_image.GetDepth()/2)
npa_zslice = sitk.GetArrayViewFromImage(mr_image)[z,:,:]

# Three plots displaying the same data, how do we deal with the high dynamic range?
fig = plt.figure(figsize=(20,6))

fig.add_subplot(1,3,1)
plt.imshow(npa_zslice)
plt.title('default colormap', fontsize=10)
plt.axis('off')

fig.add_subplot(1,3,2)
plt.imshow(npa_zslice,cmap=plt.cm.Greys_r);
plt.title('grey colormap', fontsize=10)
plt.axis('off')

fig.add_subplot(1,3,3)
# plt.title('grey colormap,\n scaling based on volumetric min and max values', fontsize=10)
# plt.imshow(npa_zslice,cmap=plt.cm.Greys_r, vmin=npa.min(), vmax=npa.max())
# plt.axis('off');
plt.imshow(sitk.GetArrayViewFromImage(img1S[:,:,z])*60,cmap=plt.cm.Greys_r)
plt.axis('off');

In [None]:
z = 80
showImg3_(img1, z)
showImg3_(img1S, z)

In [None]:
pink= [255,105,180]
green = [0,255,0]
gold = [255,215,0]
red = [255,0,0]
blue = [0,0,255]

comp = sitk.LabelOverlay(image=img1, 
                         labelImage=img1S, 
                         opacity=0.0005, 
                         backgroundValue=0,
                         colormap=blue+red+green+pink)
# showImg3_(comp, z)

In [None]:
z = 60
plt.figure(figsize=(20,10))
plt.imshow(sitk.GetArrayViewFromImage(comp[:,:,z]))


In [None]:
showImg3_(comp, z)
showImg3_(img1, z)

In [None]:
contour_image = sitk.LabelToRGB(sitk.LabelContour(img1S, fullyConnected=True, backgroundValue=255), 
                                colormap=blue+red+green+pink , backgroundValue=255)

In [None]:
showImg3_(contour_image, z)

In [None]:
contour_overlaid_image = sitk.LabelMapContourOverlay(sitk.Cast(img1S, sitk.sitkLabelUInt8), 
#                                                      img1,
                                                     sitk.Cast(img1, sitk.sitkUInt8),
                                                     opacity = 1, 
                                                     contourThickness=[1,1,1],
                                                     dilationRadius= [1,1,1],
                                                     colormap=blue+red+green+pink)
#                                                      colormap=blue+red)
# showImg3_(contour_overlaid_image, z)

In [None]:
z = 50
plt.figure(figsize=(20,10))
plt.imshow(sitk.GetArrayViewFromImage(contour_overlaid_image[:,:,z]))
# plt.savefig('seg1.jpg')

In [None]:
img1S1 = np.rollaxis(np.rollaxis(sitk.GetArrayViewFromImage(img1_seg), 0, 2), 0, 2)

In [None]:
np.rollaxis(np.rollaxis(sitk.GetArrayViewFromImage(img1_seg), 0, 2), 1, -1).shape

In [None]:
z = 60
plt.figure(figsize=(20,10))
plt.imshow(img1S1[::-1,:,z])

# z = 60
# plt.figure(figsize=(20,10))
# plt.imshow(sitk.GetArrayViewFromImage(img1[:,:,z]),cmap=plt.cm.Greys_r)

In [None]:
showImg3_(img1, z)

In [None]:
img1F = sitk.AdaptiveHistogramEqualization(img1, alpha=0.3, beta=0.3)
img1F1 = sitk.AdaptiveHistogramEqualization(img1, alpha=0.1, beta=0.3)
img1F2 = sitk.AdaptiveHistogramEqualization(img1, alpha=0.3, beta=0.1)

In [None]:
img1F3 = sitk.AdaptiveHistogramEqualization(img1, alpha=0.8, beta=0.3)
img1F4 = sitk.AdaptiveHistogramEqualization(img1, alpha=0.3, beta=0.8)

In [None]:
img1F5 = sitk.AdaptiveHistogramEqualization(img1, alpha=0.8, beta=0.1)
img1F6 = sitk.AdaptiveHistogramEqualization(img1, alpha=0.6, beta=0.05)

In [None]:
img1F7 = sitk.AdaptiveHistogramEqualization(img1, alpha=0.3, beta=0.2)

In [None]:
# img1F8 = sitk.AdaptiveHistogramEqualization(img1, alpha=0.3, beta=0.2, radius = np.ones(3, dtype='uint8')*3)
img1F8 = sitk.AdaptiveHistogramEqualization(img1, alpha=0.3, beta=0.2, radius = [3,3,3])

In [None]:

img1F9 = sitk.AdaptiveHistogramEqualization(img1, alpha=0.3, beta=0.2, radius = [7,7,7])
img1F10 = sitk.AdaptiveHistogramEqualization(img1, alpha=0.3, beta=0.2, radius = [9,9,9])

In [None]:
img2F = sitk.AdaptiveHistogramEqualization(img2, alpha=0.3, beta=0.3)

In [None]:
z = 60
plt.figure(figsize=(30,20))

plt.subplot(4, 3, 1)
plt.imshow(sitk.GetArrayViewFromImage(img1[:,:,z]), cmap="gray", vmin=0, vmax=4096)

plt.subplot(4, 3, 2)
plt.imshow(sitk.GetArrayViewFromImage(img1F[:,:,z]), cmap="gray", vmin=0, vmax=4096)

plt.subplot(4, 3, 3)
plt.imshow(sitk.GetArrayViewFromImage(img1F1[:,:,z]), cmap="gray", vmin=0, vmax=4096)

plt.subplot(4, 3, 4)
plt.imshow(sitk.GetArrayViewFromImage(img1F2[:,:,z]), cmap="gray", vmin=0, vmax=4096)

plt.subplot(4, 3, 5)
plt.imshow(sitk.GetArrayViewFromImage(img1F3[:,:,z]), cmap="gray", vmin=0, vmax=4096)

plt.subplot(4, 3, 6)
plt.imshow(sitk.GetArrayViewFromImage(img1F4[:,:,z]), cmap="gray", vmin=0, vmax=4096)

plt.subplot(4, 3, 7)
plt.imshow(sitk.GetArrayViewFromImage(img1F5[:,:,z]), cmap="gray", vmin=0, vmax=4096)

plt.subplot(4, 3, 8)
plt.imshow(sitk.GetArrayViewFromImage(img1F6[:,:,z]), cmap="gray", vmin=0, vmax=4096)

plt.subplot(4, 3, 9)
plt.imshow(sitk.GetArrayViewFromImage(img1F7[:,:,z]), cmap="gray", vmin=0, vmax=4096)

plt.subplot(4, 3, 10)
plt.imshow(sitk.GetArrayViewFromImage(img1F8[:,:,z]), cmap="gray", vmin=0, vmax=4096)

plt.subplot(4, 3, 11)
plt.imshow(sitk.GetArrayViewFromImage(img1F9[:,:,z]), cmap="gray", vmin=0, vmax=4096)

plt.subplot(4, 3, 12)
plt.imshow(sitk.GetArrayViewFromImage(img1F10[:,:,z]), cmap="gray", vmin=0, vmax=4096)

In [None]:
z = 60
plt.figure(figsize=(10,10))

plt.subplot(2, 2, 1)
plt.imshow(sitk.GetArrayViewFromImage(img1F7[:,:,z]), cmap="gray", vmin=0, vmax=4096)

plt.subplot(2, 2, 2)
plt.imshow(sitk.GetArrayViewFromImage(img1F8[:,:,z]), cmap="gray", vmin=0, vmax=4096)

plt.subplot(2, 2, 3)
plt.imshow(sitk.GetArrayViewFromImage(img1F9[:,:,z]), cmap="gray", vmin=0, vmax=4096)

plt.subplot(2, 2, 4)
plt.imshow(sitk.GetArrayViewFromImage(img1F10[:,:,z]), cmap="gray", vmin=0, vmax=4096)

In [None]:
z = 60
plt.figure(figsize=(10,10))

plt.subplot(2, 2, 1)
plt.imshow(sitk.GetArrayViewFromImage(img1[:,:,z]), cmap="gray", vmin=0, vmax=2092)

plt.subplot(2, 2, 2)
plt.imshow(sitk.GetArrayViewFromImage(img2[:,:,z]), cmap="gray", vmin=0, vmax=2092)

plt.subplot(2, 2, 3)
plt.imshow(sitk.GetArrayViewFromImage(img1F[:,:,z]), cmap="gray", vmin=0, vmax=4096)

plt.subplot(2, 2, 4)
plt.imshow(sitk.GetArrayViewFromImage(img2F[:,:,z]), cmap="gray", vmin=0, vmax=4096)

In [None]:
def transformSegmentation3(originalImage, originalSegmentation):
    initialTransform = sitk.CenteredTransformInitializer(originalImage, 
                                                      sitk.Cast(originalSegmentation,originalImage.GetPixelID()),
                                                      sitk.Euler3DTransform(), 
#                                                          sitk.AffineTransform(3), 
                                                      sitk.CenteredTransformInitializerFilter.GEOMETRY)
    
    return sitk.Resample(originalSegmentation, originalImage.GetSize(),
                         initialTransform, 
                         sitk.sitkNearestNeighbor,
                         originalImage.GetOrigin(),
                         originalImage.GetSpacing(),
                         originalImage.GetDirection(),
                         0,
                         originalSegmentation.GetPixelID())

#     return sitk.Resample(originalSegmentation, 
#                          originalImage.GetSize(),
#                          sitk.Transform(),
#                          sitk.sitkNearestNeighbor,
#                          originalImage.GetOrigin(),
#                          originalImage.GetSpacing(),
#                          originalImage.GetDirection(),
#                          0,
#                          originalSegmentation.GetPixelID())

In [None]:
img1S1 = transformSegmentation(img1, img1_seg)
img1S2 = transformSegmentation2(img1, img1_seg)
img1S3 = transformSegmentation3(img1, img1_seg)
img1S32 = transformSegmentation3(img1, img1_seg2)

In [None]:
print(img1S1.GetSize())
print(img1S2.GetSize())
print(img1S3.GetSize())

In [None]:
z = 90
showImg3_(img1, z)
showImg_(img1S3, z)
showImg_(img1S32, z)

In [None]:
img12 = sitk.ReadImage("9003430i.nii")

In [None]:
img12.GetSize()

In [None]:
z = 60
showImg2_(img12, z)
showImg4_(img1_seg2, z)

In [None]:
a = sitk.GetArrayViewFromImage(img1_seg2)
b = sitk.GetArrayViewFromImage(img1_seg)

In [None]:
np.unique(a)

In [None]:
for i in range(a.shape[0]):
    for j in range(a.shape[1]):
        for k in range(a.shape[2]):
            if a[i,j,k] != b[i,j,k]:
                if a[i,j,k] == 5 and b[i,j,k] == 4:
                    continue
                print(i,j,k)
                    

In [None]:
showImg4_(img1_seg2, z)
showImg4_(img1_seg, z)

In [None]:
def transformSegmentation4(originalImage, originalSegmentation):
    initialTransform = sitk.CenteredTransformInitializer(sitk.Cast(originalSegmentation, originalImage.GetPixelID()),
                                                      originalImage,
#                                                       sitk.Euler3DTransform(), 
                                                         sitk.AffineTransform(3), 
                                                      sitk.CenteredTransformInitializerFilter.GEOMETRY)
    
    return sitk.Resample(originalImage, originalSegmentation.GetSize(),
                         initialTransform, 
                         sitk.sitkLinear,
                         originalSegmentation.GetOrigin(),
                         originalSegmentation.GetSpacing(),
                         originalSegmentation.GetDirection(),
                         0,
                         originalImage.GetPixelID())
    

In [None]:
img13 = transformSegmentation4(img1, img1_seg)

In [None]:
z = 60
showImg2_(img12, z)
showImg2_(img13, z)
showImg4_(img1_seg2, z)

In [None]:
def registration(fixedImage, movingImage):
    initial_transform = sitk.CenteredTransformInitializer(fixedImage, 
                                                      movingImage, 
                                                      sitk.AffineTransform(3),
#                                                       sitk.Euler3DTransform(),
                                                      sitk.CenteredTransformInitializerFilter.GEOMETRY)
    registration_method = sitk.ImageRegistrationMethod()

    registration_method.SetMetricAsMattesMutualInformation(numberOfHistogramBins=50)
    registration_method.SetMetricSamplingStrategy(registration_method.RANDOM)
    registration_method.SetMetricSamplingPercentage(0.01)

    registration_method.SetInterpolator(sitk.sitkLinear)

    registration_method.SetOptimizerAsGradientDescent(learningRate=1.0, numberOfIterations=100)
    registration_method.SetOptimizerScalesFromPhysicalShift() 

    final_transform = sitk.AffineTransform(initial_transform)

    registration_method.SetInitialTransform(final_transform)
    registration_method.SetShrinkFactorsPerLevel(shrinkFactors = [4,2,1])
    registration_method.SetSmoothingSigmasPerLevel(smoothingSigmas = [2,1,0])
    registration_method.SmoothingSigmasAreSpecifiedInPhysicalUnitsOn()

    t = registration_method.Execute(sitk.Cast(fixedImage, sitk.sitkFloat32), sitk.Cast(movingImage, sitk.sitkFloat32))

    resample = sitk.ResampleImageFilter()
    resample.SetReferenceImage(fixedImage)

    resample.SetInterpolator(sitk.sitkLinear)  
    resample.SetTransform(t)
    
    return resample.Execute(movingImage)

In [None]:
def registration2(fixedImage, movingImage):
    initial_transform = sitk.CenteredTransformInitializer(fixedImage, 
                                                      movingImage, 
                                                      sitk.AffineTransform(3),
#                                                       sitk.Euler3DTransform(),
                                                      sitk.CenteredTransformInitializerFilter.GEOMETRY)
    registration_method = sitk.ImageRegistrationMethod()

    registration_method.SetMetricAsMattesMutualInformation(numberOfHistogramBins=400)
    registration_method.SetMetricSamplingStrategy(registration_method.NONE) # REGULAR RANDOM NONE
    registration_method.SetMetricSamplingPercentage(0.1)

    registration_method.SetInterpolator(sitk.sitkLinear)

    registration_method.SetOptimizerAsGradientDescent(learningRate=1.0, numberOfIterations=100)
    registration_method.SetOptimizerScalesFromPhysicalShift() 

#     final_transform = sitk.Euler3DTransform(initial_transform)
    final_transform = sitk.AffineTransform(initial_transform)

    registration_method.SetInitialTransform(final_transform)
    registration_method.SetShrinkFactorsPerLevel(shrinkFactors = [4,2,1])
    registration_method.SetSmoothingSigmasPerLevel(smoothingSigmas = [2,1,0])
    registration_method.SmoothingSigmasAreSpecifiedInPhysicalUnitsOn()

    t = registration_method.Execute(sitk.Cast(fixedImage, sitk.sitkFloat32), sitk.Cast(movingImage, sitk.sitkFloat32))

    resample = sitk.ResampleImageFilter()
    resample.SetReferenceImage(fixedImage)

    resample.SetInterpolator(sitk.sitkLinear)  
    resample.SetTransform(t)
    
    return resample.Execute(movingImage)

In [None]:
img22_3 = registration2(img1, img2_2)

In [None]:
img23_3 = registration2(img1, img2_2)

In [None]:
img24_3 = registration2(img1, img2_2)

In [None]:
img25_3 = registration2(img1, img2_2)

In [None]:
img35_2 = registration2(img1, img3_2)

In [None]:
z = 80
showImg3_(img1, z, "1")
# showImg3_(img25_3, z, "2")
showImg3_(img35_2, z, "2")

In [None]:
img21 = registration(img1, img2)

In [None]:
img22 = registration2(img1, img2)

In [None]:
img23 = registration2(img1, img2)

In [None]:
img24 = registration2(img1, img2)

In [None]:
img31 = registration2(img1, img3)

In [None]:
z = 60
showImg3_(img1, z, "1")
showImg3_(img25, z, "2")
showImg3_(img31, z, "3")

In [None]:
# img1_1 = sitk.AdaptiveHistogramEqualization(img1)
# img2_1 = sitk.AdaptiveHistogramEqualization(img2)
img3_1 = sitk.AdaptiveHistogramEqualization(img3)

In [None]:
def registration3(fixedImage, movingImage):
    initial_transform = sitk.CenteredTransformInitializer(fixedImage, 
                                                      movingImage, 
#                                                       sitk.AffineTransform(3),
                                                      sitk.Euler3DTransform(),
                                                      sitk.CenteredTransformInitializerFilter.GEOMETRY)
    registration_method = sitk.ImageRegistrationMethod()

    registration_method.SetMetricAsMattesMutualInformation(numberOfHistogramBins=100)
    registration_method.SetMetricSamplingStrategy(registration_method.NONE) # REGULAR RANDOM NONE
    registration_method.SetMetricSamplingPercentage(0.1) # 0.5

    registration_method.SetInterpolator(sitk.sitkLinear)

    registration_method.SetOptimizerAsGradientDescent(learningRate=1.0, numberOfIterations=500)
    registration_method.SetOptimizerScalesFromPhysicalShift() 

    final_transform = sitk.Euler3DTransform(initial_transform)

    registration_method.SetInitialTransform(final_transform)
    registration_method.SetShrinkFactorsPerLevel(shrinkFactors = [4,2,1])
    registration_method.SetSmoothingSigmasPerLevel(smoothingSigmas = [2,1,0])
    registration_method.SmoothingSigmasAreSpecifiedInPhysicalUnitsOn()

    t = registration_method.Execute(sitk.Cast(fixedImage, sitk.sitkFloat32), sitk.Cast(movingImage, sitk.sitkFloat32))

    resample = sitk.ResampleImageFilter()
    resample.SetReferenceImage(fixedImage)

    resample.SetInterpolator(sitk.sitkLinear)  
    resample.SetTransform(t)
    
    return resample.Execute(movingImage)

In [None]:
# img21_1 = registration(img1_1, img2_1)
_img32 = registration3(img1_1, img3_1)

In [None]:
z = 90
showImg3_(img1_1, z)
showImg3_(_img32, z)

In [None]:
z = 80
plt.figure(figsize=(20,10))

plt.subplot(2, 4, 1)
plt.imshow(sitk.GetArrayViewFromImage(img1_1[:,:,z]), cmap="gray", vmin=0, vmax=4096)

plt.subplot(2, 4, 2)
plt.imshow(sitk.GetArrayViewFromImage(_img21[:,:,z]), cmap="gray", vmin=0, vmax=4096)

plt.subplot(2, 4, 3)
plt.imshow(sitk.GetArrayViewFromImage(_img22[:,:,z]), cmap="gray", vmin=0, vmax=4096)

plt.subplot(2, 4, 4)
plt.imshow(sitk.GetArrayViewFromImage(_img23[:,:,z]), cmap="gray", vmin=0, vmax=4096)

plt.subplot(2, 4, 5)
plt.imshow(sitk.GetArrayViewFromImage(_img25[:,:,z]), cmap="gray", vmin=0, vmax=4096)

plt.subplot(2, 4, 6)
plt.imshow(sitk.GetArrayViewFromImage(_img26[:,:,z]), cmap="gray", vmin=0, vmax=4096)

In [None]:
dicom_names = reader.GetGDCMSeriesFileNames("../MRI_Data/Baseline/KL0/9011115/20050801/10450912")
reader.SetFileNames(dicom_names)

img3 = reader.Execute()

In [None]:
z = 60
plt.figure(figsize=(20,10))

plt.subplot(2, 4, 1)
plt.imshow(sitk.GetArrayViewFromImage(img1_1[:,:,z]), cmap="gray", vmin=0, vmax=4096)

plt.subplot(2, 4, 2)
plt.imshow(sitk.GetArrayViewFromImage(img31_1[:,:,z]), cmap="gray", vmin=0, vmax=4096)

plt.subplot(2, 4, 3)
plt.imshow(sitk.GetArrayViewFromImage(img32_1[:,:,z]), cmap="gray", vmin=0, vmax=4096)

plt.subplot(2, 4, 4)
plt.imshow(sitk.GetArrayViewFromImage(img33_1[:,:,z]), cmap="gray", vmin=0, vmax=4096)

plt.subplot(2, 4, 5)
plt.imshow(sitk.GetArrayViewFromImage(img34_1[:,:,z]), cmap="gray", vmin=0, vmax=4096)

In [None]:
img3_2 = sitk.HistogramMatching(img3, img1)

In [None]:
# img31_2 = registration(img1, img3_2)
img32_2 = registration2(img1, img3_2)
# img33_2 = registration3(img1, img3_2)

In [None]:
z = 80
plt.figure(figsize=(20,10))
# img35_1 = registration2(img1_1
plt.subplot(2, 4, 1)
plt.imshow(sitk.GetArrayViewFromImage(img1_1[:,:,z]), cmap="gray", vmin=0, vmax=4096)

plt.subplot(2, 4, 2)
plt.imshow(sitk.GetArrayViewFromImage(img35_1[:,:,z]), cmap="gray", vmin=0, vmax=4096)


In [None]:
z = 60
showImg3_(img1, z, "1")
showImg3_(img32_2, z, "2")
# showImg3_(img22_2, z, "3")