In [7]:
import itk
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn.functional as F

import icon_registration.pretrained_models
import icon_registration.network_wrappers

import icon_registration.visualize

def register_pair(image_A, image_B):

    tr = itk.DisplacementFieldTransform[(itk.D, 3)].New()       

    
    zeros = np.zeros((80, 192, 192, 3))
    
    print(zeros.shape)

    itk_disp_field = array_to_vector_image(zeros)

    tr.SetDisplacementField(itk_disp_field)

    tr.SetDebug(True)

    to_aligned = resampling_transform(image_A, [192, 192, 80])

    from_aligned = resampling_transform(image_B, [192, 192, 80]).GetInverseTransform()
        
    phi_AB_itk = itk.CompositeTransform[itk.D, 3].New()
        
    phi_AB_itk.PrependTransform(from_aligned)
    phi_AB_itk.PrependTransform(tr)
    phi_AB_itk.PrependTransform(to_aligned)
        
    return phi_AB_itk, None

def array_to_vector_image(array): 
    # array is a numpy array of doubles of shape 
    # 3, H, W, D
        
    # returns an itk.Image of vectors
    # returns image with [1, 1, 1] spacing :(
    assert isinstance(array, np.ndarray)
    
    print(array.shape)
        

    PixelType = itk.Vector[itk.D, 3]
    ImageType = itk.Image[PixelType, 3]
    
    
            
    vector_image = itk.PyBuffer[ImageType].GetImageViewFromArray(array, array.shape[:-1])
    print("huh")
    print(vector_image.GetLargestPossibleRegion().GetSize())


    return vector_image

In [8]:
def resampling_transform(image, shape):

    imageType = itk.template(image)[0][itk.template(image)[1]]

    dummy_image = itk.image_from_array(np.zeros(tuple(reversed(shape)), dtype=itk.array_from_image(image).dtype))
    if len(shape) == 2:
        transformType = itk.MatrixOffsetTransformBase[itk.D, 2, 2]
    else:
        transformType = itk.VersorRigid3DTransform[itk.D]
    initType = itk.CenteredTransformInitializer[transformType, imageType, imageType]
    initializer = initType.New()
    initializer.SetFixedImage(dummy_image)
    initializer.SetMovingImage(image)
    transform = transformType.New()

    initializer.SetTransform(transform)
    initializer.InitializeTransform()

    if len(shape) == 3:
        transformType = itk.MatrixOffsetTransformBase[itk.D, 3, 3]
        t2 = transformType.New()
        t2.SetCenter(transform.GetCenter())
        t2.SetOffset(transform.GetOffset())
        transform = t2
    m = transform.GetMatrix()
    m_a = np.array(m)

    input_shape = image.GetLargestPossibleRegion().GetSize()

    for i in range(len(shape)):

        m_a[i, i] = image.GetSpacing()[i] * (input_shape[i] / shape[i])

    m_a = image.GetDirection() @ m_a

    transform.SetMatrix(itk.matrix_from_array(m_a))

    return transform



In [None]:
import itk
import icon_registration.test_utils
import icon_registration.itk_wrapper

outdir = "/home/hastings/blog/_assets/ICON_test/"

icon_registration.test_utils.download_test_data()

image_B = itk.imread(str(
    icon_registration.test_utils.TEST_DATA_DIR /
    "knees_diverse_sizes" / 
    #"9126260_20060921_SAG_3D_DESS_LEFT_11309302_image.nii.gz")
     "9487462_20081003_SAG_3D_DESS_RIGHT_11495603_image.nii.gz")
)

image_A = itk.imread(str(
    icon_registration.test_utils.TEST_DATA_DIR /
    "knees_diverse_sizes" / 
    "9225063_20090413_SAG_3D_DESS_RIGHT_12784112_image.nii.gz")
) 
print(image_A.GetLargestPossibleRegion().GetSize())
print(image_B.GetLargestPossibleRegion().GetSize())
print(image_A.GetSpacing())
print(image_B.GetSpacing())

phi_AB, phi_BA = register_pair(image_A, image_B)


assert(isinstance(phi_AB, itk.CompositeTransform))   
interpolator = itk.LinearInterpolateImageFunction.New(image_A)

print("pre segfault")
#warped_image_A = itk.resample_image_filter(image_A, 
#    transform=phi_AB, 
#    interpolator=interpolator,
#    size=[192, 192, 80],
#)
warped_image_A = itk.resample_image_filter(image_A, 
    transform=phi_AB,
    interpolator=interpolator,
    size=itk.size(image_B),
    output_spacing=itk.spacing(image_B),
    output_direction=image_B.GetDirection(),
    output_origin=image_B.GetOrigin()
)
print("post_segfault")

plt.imshow(np.array(itk.checker_board_image_filter(warped_image_A, image_B))[40])
plt.colorbar()
plt.show()