In [None]:
import numpy as np
import matplotlib
from matplotlib import pyplot as plt
from skimage import io
from tqdm import tqdm
import os, sys
import SimpleITK as sitk
from PIL import Image
import cv2
from IPython.display import clear_output

from pathlib import Path

PIPELINE_ROOT = Path('../src').resolve().parent.parent
sys.path.append(PIPELINE_ROOT.as_posix())
print(PIPELINE_ROOT)

from library.utilities.utilities_process import get_image_size, read_image

%load_ext autoreload
%autoreload 2

In [None]:
def compare_two_images(img1, img2):
    plt.figure(figsize=[20,20])
    plt.imshow(img1,cmap='Blues')
    plt.imshow(img2,alpha=0.5,cmap='Reds')

In [None]:
animal = 'CTB004'
DIR = f'/net/birdstore/Active_Atlas_Data/data_root/pipeline_data/{animal}/preps'
INPUT = os.path.join(DIR, 'CH1', 'thumbnail_cleaned')

In [None]:
fixed_index = str(105).zfill(3)
moving_index = str(106).zfill(3) # big image

In [None]:
%%time
pixelType = sitk.sitkFloat32
fixed_file = os.path.join(INPUT, f'{fixed_index}.tif')
moving_file = os.path.join(INPUT, f'{moving_index}.tif')
fixed = sitk.ReadImage(fixed_file, pixelType)
moving = sitk.ReadImage(moving_file, pixelType)

elastixImageFilter = sitk.ElastixImageFilter()
elastixImageFilter.SetFixedImage(fixed)
elastixImageFilter.SetMovingImage(moving)
elastixImageFilter.SetParameterMap(sitk.GetDefaultParameterMap("translation"))
elastixImageFilter.AddParameterMap(sitk.GetDefaultParameterMap("rigid"))
elastixImageFilter.AddParameterMap(sitk.GetDefaultParameterMap("affine"))
elastixImageFilter.SetParameter("NumberOfResolutions", "8")
elastixImageFilter.SetParameter("MaximumNumberOfIterations", "1500")
elastixImageFilter.LogToConsoleOff()
elastixImageFilter.Execute()
trans_params = elastixImageFilter.GetTransformParameterMap()[0]["TransformParameters"]
rigid_params = elastixImageFilter.GetTransformParameterMap()[1]["TransformParameters"]
affine_params = elastixImageFilter.GetTransformParameterMap()[2]["TransformParameters"]

x1,y1 = trans_params
rotation,x2,y2 = rigid_params
rotation = float(rotation)
xshift = float(x1) + float(x2)
yshift = float(y1) + float(y2)

In [None]:
width, height = get_image_size(moving_file)
center = np.array([width, height]) / 2

#rotation, xshift, yshift = np.array([rotation, xshift, yshift]).astype(np.float16)
center = np.array(center).astype(np.float16)
R = np.array(
    [
        [np.cos(rotation), -np.sin(rotation)],
        [np.sin(rotation), np.cos(rotation)],
    ]
)
shift = center + (xshift, yshift) - np.dot(R, center)
Trigid = np.vstack([np.column_stack([R, shift]), [0, 0, 1]])
print(Trigid)

In [None]:
print(affine_params)
r00,r01,xshift,r10,r11,yshift = [float(a) for a in affine_params]
width, height = get_image_size(moving_file)
center = np.array([width, height]) / 2

#rotation, xshift, yshift = np.array([rotation, xshift, yshift]).astype(np.float16)
center = np.array(center).astype(np.float16)
R = np.array(
    [
        [r00, r01],
        [r10, r11],
    ]
)
shift = center + (xshift, yshift) - np.dot(R, center)
Taffine = np.vstack([np.column_stack([R, shift]), [0, 0, 1]])
print(Taffine)

In [None]:
transform_parameters = np.array(affine_params, dtype=float).reshape(2,3)
affine_rotation = np.vstack([transform_parameters, [0, 0, 1]])
#Taffine = np.linalg.inv(affine_rotation)
Taffine = affine_rotation

In [None]:
Taffine

In [None]:
Trigid

In [None]:
im1 = Image.open(moving_file)
im2 = im1.transform((im1.size), Image.Transform.AFFINE, Trigid.flatten()[:6], resample=Image.Resampling.NEAREST)
rigidimg = np.array(im2)
del im1, im2
im1 = Image.open(moving_file)
im2 = im1.transform((im1.size), Image.Transform.AFFINE, Taffine.flatten()[:6], resample=Image.Resampling.NEAREST)
affineimg = np.array(im2)
del im1, im2

In [None]:
from scipy import ndimage as ndi
img = io.imread(moving_file)
affineimg = ndi.affine_transform(img, Trigid)

In [None]:
#fixed_index = str(105).zfill(3)
#moving_index = str(106).zfill(3) big image
fig, ax = plt.subplots(1,4)
fig.set_figheight(15)
fig.set_figwidth(15)
ax[0].grid()
ax[1].grid()
ax[2].grid()
ax[3].grid()
ax[0].imshow(sitk.GetArrayFromImage(fixed), cmap="gray")
ax[1].imshow(rigidimg, cmap="gray")
ax[2].imshow(affineimg, cmap="gray")
ax[3].imshow(sitk.GetArrayFromImage(moving), cmap="gray")

In [None]:
fig = plt.figure(figsize=(8, 6))
plt.title('fixed image')
plt.grid()
plt.imshow(sitk.GetArrayFromImage(fixed), cmap='gray')

In [None]:
fig = plt.figure(figsize=(8, 6))
plt.title('result image')
plt.grid()
plt.imshow(sitk.GetArrayFromImage(resultImage), cmap='gray')

In [None]:
dimension = 2        
offset = [2]*dimension # use a Python trick to create the offset list based on the dimension
translation = sitk.TranslationTransform(dimension, offset)
print(translation)

In [None]:
R,x,y

In [None]:
point = [10, 11]
rotation2D = sitk.Euler2DTransform()
rotation2D.SetTranslation((x,y))
rotation2D.SetAngle(R)
print(f'original point: {point}')
transformed_point = rotation2D.TransformPoint(point)
translation_inverse = rotation2D.GetInverse()
print(f'transformed point: {transformed_point}')
print(f'back to original: {translation_inverse.TransformPoint(transformed_point)}')

In [None]:
moving_resampled = sitk.Resample(moving, fixed, final_transform, sitk.sitkLinear, 0.0, moving.GetPixelID())
fig = plt.figure(figsize=(15,8))
plt.imshow(sitk.GetArrayViewFromImage(fixed), cmap='gray')
plt.grid()
plt.title('fixed image', fontsize=10)

fig = plt.figure(figsize=(15, 8))
plt.imshow(sitk.GetArrayViewFromImage(moving_resampled), cmap='gray')
plt.grid()
plt.title('resampled image')

fig = plt.figure(figsize=(15, 8))
plt.title('moving image')
plt.grid()
plt.imshow(sitk.GetArrayViewFromImage(moving), cmap='gray')
    