In [2]:
import copy
import os
import sys

import matplotlib.pyplot as plt
import numpy as np
import scipy.optimize
import skimage.data
import skimage.transform

# TODO: figure out a better way of doing this
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), os.pardir)))
# import register
import utils



In [None]:
im = skimage.data.astronaut()

def apply_affine_tform(params, im):
    scale_x, scale_y, rotation, shear, trans_x, trans_y = params
    
    # re-initialize bad scales
    if scale_x < 0:
        scale_x = [np.abs(np.random.randn())]
        print(scale_x)
    if scale_x < 0:
        scale_x = [np.abs(np.random.randn())]
        print(scale_x)

    moving_tfm = skimage.transform.AffineTransform(
        scale=(scale_x, scale_y),
        rotation=rotation,
        shear=shear,
        translation=(trans_x, trans_y))

    return utils.apply_matrix_tform(im, moving_tfm)
    
def ssd_affine(params, im_fixed, im_moving,
               show_figs=False):
    """SSD is calculated between im_fixed and transformed moving image.

    TODO (nripesh): calculate gradient (numeric or mathematical)
    TODO : callbacks to view progress as images?
    TODO: how to calculate scale - is it simply based on the 
          ratio of the range of params?
    TODO: parallelize the derivative calculation
    """
    # get individual params
    im_moving_regis = apply_affine_tform(params, im_moving)
    
    row, col, channel = im_moving_regis.shape
    ssd = np.sum(np.square(im_fixed - im_moving_regis)) / (
        row * col * channel)
        
    scale_x, scale_y, rotation, shear, trans_x, trans_y = params
    print('scale_x: {:2.2f}, scale_y: {:2.2f}, rotation: {:2.2f}, '.format(
        scale_x, scale_y, rotation) +
          'shear: {:2.2f}, t_x: {:2.2f}, t_y: {:2.2f}, SSD: {:2.3f}'.format(
        shear, trans_x, trans_y, ssd))
    
    if show_figs:
        plt.figure(1)
        plt.title('target (transformed)')
        plt.imshow(im_fixed)
        plt.figure(2)
        plt.title('moving (orig)')
        plt.imshow(im_moving)
        plt.figure(3)
        plt.title('registered')
        plt.imshow(im_moving_regis)
        plt.show()

    return ssd

def gradient_affine(params, im_fixed, im_moving,
             show_figs=False, delta=.1, scale=[1]):
    
    if len(scale) == 1:
        scale = scale * len(params)
        
    row, col, channel = im_fixed.shape
        
    grad = []
    for i in range(len(params)):
        params_i = copy.deepcopy(params)
        params_i[i] += delta * scale[i]
        
        im_moving_regis_i = apply_affine_tform(params_i, im_moving)
        ssd_i = np.sum(np.square(im_fixed - im_moving_regis_i)) / (
            row * col * channel)
        grad.append(ssd_i / delta)
        
    print('grad: {}'.format(grad))
    return grad

def im_register_affine(im_fixed, im_moving, pad_size=None,
                       options={}, show_figs=False, 
                       delta=.1, scale=[1]):
    """For registering using affine transformation.
    """
    if pad_size is None:
        pad_size = (int(im_moving.shape[0] * 2),
                    int(im_moving.shape[1] * 2))

    orig_size = (im_moving.shape[0], im_moving.shape[1])

    im_fixed = utils.resize_image(im_fixed, pad_size)
    im_moving = utils.resize_image(im_moving, pad_size) / 255

    im_f_shape = im_fixed.shape
    im_m_shape = im_moving.shape

    assert im_f_shape == im_m_shape

    # random initialization and assignment
    init_scale = (1., 1.)
    init_rotation = .1
    init_shear = 0.0
    init_translation = (0, 0)
    params_init = (init_scale[0], init_scale[1], init_rotation,
                   init_shear, init_translation[0], init_translation[1])
    
    error_function = lambda x: ssd_affine(
        x, im_fixed, im_moving, show_figs=show_figs)
    
    error_gradient = lambda x: gradient_affine(
        x, im_fixed, im_moving, show_figs=show_figs,
        delta=delta, scale=scale)
    
    method = 'BFGS'
    
    if method == 'Nelder-Mead':
        affine_optimized = scipy.optimize.minimize(
            error_function, params_init, method='Nelder-Mead',
            options=options)
    elif method == 'BFGS':
        affine_optimized = scipy.optimize.minimize(
            error_function, params_init, method='BFGS',
            options=options, jac='3-point')
    
    # apply solved registration
    regis_tfm = skimage.transform.AffineTransform(
        scale=(affine_optimized.x[0], affine_optimized.x[1]),
        rotation=affine_optimized.x[2],
        shear=affine_optimized.x[3],
        translation=(affine_optimized.x[4], affine_optimized.x[5]))
    
    im_registered = utils.apply_matrix_tform(
            im_moving, regis_tfm)

    return utils.resize_image(im_registered, orig_size)

# initialization
true_scale = (1.3, .8)
true_rotation = 0
true_shear = 0
true_translation = (100, -100)

aff_tfm = skimage.transform.AffineTransform(
    scale=true_scale,
    rotation=true_rotation,
    shear=true_shear,
    translation=true_translation)

PADDED_IM_SIZE = (800, 800) 
im_fixed = utils.apply_matrix_tform(
    utils.resize_image(im, PADDED_IM_SIZE), aff_tfm)
im_moving = utils.resize_image(im, PADDED_IM_SIZE)

# optimization

# scaling based on maximum parameter values
scale = [3, 3, 2*3.14, 1, 100, 100]

delta = .1

im_registered = im_register_affine(
    im_fixed, im_moving, pad_size=PADDED_IM_SIZE,
    options={'maxiter': 200}, show_figs=False,
    delta=delta, scale=scale
)

plt.figure(1)
plt.title('target (transformed)')
plt.imshow(im_fixed)
plt.figure(2)
plt.title('moving (orig)')
plt.imshow(im_moving)
plt.figure(3)
plt.title('registered')
plt.imshow(im_registered)
plt.show()

  "options for 'jac'. Using '2-point' instead." % jac)


scale_x: 1.00, scale_y: 1.00, rotation: 0.10, shear: 0.00, t_x: 0.00, t_y: 0.00, SSD: 0.108
scale_x: 1.00, scale_y: 1.00, rotation: 0.10, shear: 0.00, t_x: 0.00, t_y: 0.00, SSD: 0.108
scale_x: 1.00, scale_y: 1.00, rotation: 0.10, shear: 0.00, t_x: 0.00, t_y: 0.00, SSD: 0.108
scale_x: 1.00, scale_y: 1.00, rotation: 0.10, shear: 0.00, t_x: 0.00, t_y: 0.00, SSD: 0.108
scale_x: 1.00, scale_y: 1.00, rotation: 0.10, shear: 0.00, t_x: 0.00, t_y: 0.00, SSD: 0.108
scale_x: 1.00, scale_y: 1.00, rotation: 0.10, shear: 0.00, t_x: 0.00, t_y: 0.00, SSD: 0.108
scale_x: 1.00, scale_y: 1.00, rotation: 0.10, shear: 0.00, t_x: 0.00, t_y: 0.00, SSD: 0.108
scale_x: 1.00, scale_y: 1.00, rotation: 0.10, shear: 0.00, t_x: 0.00, t_y: 0.00, SSD: 0.108
scale_x: 0.96, scale_y: 0.97, rotation: 0.12, shear: 0.04, t_x: 0.00, t_y: -0.00, SSD: 0.104
scale_x: 0.96, scale_y: 0.97, rotation: 0.12, shear: 0.04, t_x: 0.00, t_y: -0.00, SSD: 0.104
scale_x: 0.96, scale_y: 0.97, rotation: 0.12, shear: 0.04, t_x: 0.00, t_y: -0.

scale_x: 0.81, scale_y: 0.98, rotation: 0.17, shear: 0.20, t_x: 0.00, t_y: -0.01, SSD: 0.101
scale_x: 0.81, scale_y: 0.98, rotation: 0.17, shear: 0.20, t_x: 0.00, t_y: -0.01, SSD: 0.101
scale_x: 0.81, scale_y: 0.98, rotation: 0.17, shear: 0.20, t_x: 0.00, t_y: -0.01, SSD: 0.101
scale_x: 0.81, scale_y: 0.98, rotation: 0.17, shear: 0.20, t_x: 0.00, t_y: -0.01, SSD: 0.101
scale_x: 0.82, scale_y: 0.97, rotation: 0.17, shear: 0.20, t_x: 0.00, t_y: -0.01, SSD: 0.101
scale_x: 0.82, scale_y: 0.97, rotation: 0.17, shear: 0.20, t_x: 0.00, t_y: -0.01, SSD: 0.101
scale_x: 0.82, scale_y: 0.97, rotation: 0.17, shear: 0.20, t_x: 0.00, t_y: -0.01, SSD: 0.101
scale_x: 0.82, scale_y: 0.97, rotation: 0.17, shear: 0.20, t_x: 0.00, t_y: -0.01, SSD: 0.101
scale_x: 0.82, scale_y: 0.97, rotation: 0.17, shear: 0.20, t_x: 0.00, t_y: -0.01, SSD: 0.101
scale_x: 0.82, scale_y: 0.97, rotation: 0.17, shear: 0.20, t_x: 0.00, t_y: -0.01, SSD: 0.101
scale_x: 0.82, scale_y: 0.97, rotation: 0.17, shear: 0.20, t_x: 0.00, 

scale_x: 0.83, scale_y: 0.97, rotation: 0.21, shear: 0.15, t_x: 0.01, t_y: -0.03, SSD: 0.101
scale_x: 0.83, scale_y: 0.97, rotation: 0.21, shear: 0.15, t_x: 0.01, t_y: -0.03, SSD: 0.101
scale_x: 0.83, scale_y: 0.97, rotation: 0.21, shear: 0.15, t_x: 0.01, t_y: -0.03, SSD: 0.101
scale_x: 0.83, scale_y: 0.97, rotation: 0.21, shear: 0.15, t_x: 0.01, t_y: -0.03, SSD: 0.101
scale_x: 0.83, scale_y: 0.97, rotation: 0.21, shear: 0.15, t_x: 0.01, t_y: -0.03, SSD: 0.101
scale_x: 0.83, scale_y: 0.97, rotation: 0.21, shear: 0.15, t_x: 0.01, t_y: -0.03, SSD: 0.101
scale_x: 0.83, scale_y: 0.97, rotation: 0.21, shear: 0.15, t_x: 0.01, t_y: -0.03, SSD: 0.101
scale_x: 0.83, scale_y: 0.97, rotation: 0.21, shear: 0.15, t_x: 0.01, t_y: -0.03, SSD: 0.101
scale_x: 0.83, scale_y: 0.97, rotation: 0.20, shear: 0.16, t_x: 0.01, t_y: -0.02, SSD: 0.101
scale_x: 0.83, scale_y: 0.97, rotation: 0.20, shear: 0.16, t_x: 0.01, t_y: -0.02, SSD: 0.101
scale_x: 0.83, scale_y: 0.97, rotation: 0.20, shear: 0.16, t_x: 0.01, 

scale_x: 0.82, scale_y: 0.97, rotation: 0.22, shear: 0.15, t_x: 0.02, t_y: -0.05, SSD: 0.101
scale_x: 0.82, scale_y: 0.97, rotation: 0.22, shear: 0.15, t_x: 0.02, t_y: -0.05, SSD: 0.101
scale_x: 0.82, scale_y: 0.97, rotation: 0.22, shear: 0.15, t_x: 0.02, t_y: -0.05, SSD: 0.101
scale_x: 0.82, scale_y: 0.97, rotation: 0.22, shear: 0.15, t_x: 0.02, t_y: -0.05, SSD: 0.101
scale_x: 0.82, scale_y: 0.98, rotation: 0.22, shear: 0.14, t_x: 0.02, t_y: -0.05, SSD: 0.101
scale_x: 0.82, scale_y: 0.98, rotation: 0.22, shear: 0.14, t_x: 0.02, t_y: -0.05, SSD: 0.101
scale_x: 0.82, scale_y: 0.98, rotation: 0.22, shear: 0.14, t_x: 0.02, t_y: -0.05, SSD: 0.101
scale_x: 0.82, scale_y: 0.98, rotation: 0.22, shear: 0.14, t_x: 0.02, t_y: -0.05, SSD: 0.101
scale_x: 0.82, scale_y: 0.98, rotation: 0.22, shear: 0.14, t_x: 0.02, t_y: -0.05, SSD: 0.101
scale_x: 0.82, scale_y: 0.98, rotation: 0.22, shear: 0.14, t_x: 0.02, t_y: -0.05, SSD: 0.101
scale_x: 0.82, scale_y: 0.98, rotation: 0.22, shear: 0.14, t_x: 0.02, 

scale_x: 0.85, scale_y: 0.99, rotation: 0.32, shear: 0.05, t_x: 0.05, t_y: -0.18, SSD: 0.100
scale_x: 0.85, scale_y: 0.99, rotation: 0.32, shear: 0.05, t_x: 0.05, t_y: -0.18, SSD: 0.100
scale_x: 0.85, scale_y: 0.99, rotation: 0.32, shear: 0.05, t_x: 0.05, t_y: -0.18, SSD: 0.100
scale_x: 0.85, scale_y: 0.99, rotation: 0.32, shear: 0.05, t_x: 0.05, t_y: -0.18, SSD: 0.100
scale_x: 0.85, scale_y: 0.99, rotation: 0.32, shear: 0.05, t_x: 0.05, t_y: -0.18, SSD: 0.100
scale_x: 0.85, scale_y: 0.99, rotation: 0.32, shear: 0.05, t_x: 0.05, t_y: -0.18, SSD: 0.100
scale_x: 0.85, scale_y: 0.99, rotation: 0.32, shear: 0.05, t_x: 0.05, t_y: -0.18, SSD: 0.100
scale_x: 0.85, scale_y: 0.99, rotation: 0.32, shear: 0.05, t_x: 0.05, t_y: -0.18, SSD: 0.100
scale_x: 0.85, scale_y: 0.99, rotation: 0.33, shear: 0.03, t_x: 0.06, t_y: -0.20, SSD: 0.100
scale_x: 0.85, scale_y: 0.99, rotation: 0.33, shear: 0.03, t_x: 0.06, t_y: -0.20, SSD: 0.100
scale_x: 0.85, scale_y: 0.99, rotation: 0.33, shear: 0.03, t_x: 0.06, 

scale_x: 0.85, scale_y: 0.98, rotation: 0.34, shear: 0.01, t_x: 0.06, t_y: -0.23, SSD: 0.100
scale_x: 0.85, scale_y: 0.98, rotation: 0.34, shear: 0.01, t_x: 0.06, t_y: -0.23, SSD: 0.100
scale_x: 0.85, scale_y: 0.98, rotation: 0.34, shear: 0.01, t_x: 0.06, t_y: -0.23, SSD: 0.100
scale_x: 0.85, scale_y: 0.98, rotation: 0.34, shear: 0.01, t_x: 0.06, t_y: -0.23, SSD: 0.100
scale_x: 0.85, scale_y: 0.98, rotation: 0.34, shear: 0.00, t_x: 0.06, t_y: -0.24, SSD: 0.100
scale_x: 0.85, scale_y: 0.98, rotation: 0.34, shear: 0.00, t_x: 0.06, t_y: -0.24, SSD: 0.100
scale_x: 0.85, scale_y: 0.98, rotation: 0.34, shear: 0.00, t_x: 0.06, t_y: -0.24, SSD: 0.100
scale_x: 0.85, scale_y: 0.98, rotation: 0.34, shear: 0.00, t_x: 0.06, t_y: -0.24, SSD: 0.100
scale_x: 0.85, scale_y: 0.98, rotation: 0.34, shear: 0.00, t_x: 0.06, t_y: -0.24, SSD: 0.100
scale_x: 0.85, scale_y: 0.98, rotation: 0.34, shear: 0.00, t_x: 0.06, t_y: -0.24, SSD: 0.100
scale_x: 0.85, scale_y: 0.98, rotation: 0.34, shear: 0.00, t_x: 0.06, 

scale_x: 0.85, scale_y: 0.97, rotation: 0.35, shear: -0.03, t_x: 0.06, t_y: -0.29, SSD: 0.100
scale_x: 0.85, scale_y: 0.97, rotation: 0.35, shear: -0.03, t_x: 0.06, t_y: -0.29, SSD: 0.100
scale_x: 0.85, scale_y: 0.97, rotation: 0.35, shear: -0.03, t_x: 0.06, t_y: -0.29, SSD: 0.100
scale_x: 0.85, scale_y: 0.97, rotation: 0.35, shear: -0.03, t_x: 0.06, t_y: -0.29, SSD: 0.100
scale_x: 0.85, scale_y: 0.97, rotation: 0.35, shear: -0.03, t_x: 0.06, t_y: -0.29, SSD: 0.100
scale_x: 0.85, scale_y: 0.97, rotation: 0.35, shear: -0.03, t_x: 0.06, t_y: -0.29, SSD: 0.100
scale_x: 0.85, scale_y: 0.97, rotation: 0.35, shear: -0.03, t_x: 0.06, t_y: -0.29, SSD: 0.100
scale_x: 0.85, scale_y: 0.97, rotation: 0.35, shear: -0.03, t_x: 0.06, t_y: -0.29, SSD: 0.100
scale_x: 0.85, scale_y: 0.97, rotation: 0.35, shear: -0.03, t_x: 0.06, t_y: -0.29, SSD: 0.100
scale_x: 0.85, scale_y: 0.97, rotation: 0.35, shear: -0.03, t_x: 0.06, t_y: -0.29, SSD: 0.100
scale_x: 0.85, scale_y: 0.97, rotation: 0.35, shear: -0.03, 

scale_x: 0.85, scale_y: 0.98, rotation: 0.36, shear: -0.06, t_x: 0.06, t_y: -0.34, SSD: 0.100
scale_x: 0.85, scale_y: 0.98, rotation: 0.36, shear: -0.06, t_x: 0.06, t_y: -0.34, SSD: 0.100
scale_x: 0.85, scale_y: 0.98, rotation: 0.36, shear: -0.06, t_x: 0.06, t_y: -0.34, SSD: 0.100
scale_x: 0.85, scale_y: 0.98, rotation: 0.36, shear: -0.06, t_x: 0.06, t_y: -0.34, SSD: 0.100
scale_x: 0.85, scale_y: 0.98, rotation: 0.36, shear: -0.06, t_x: 0.06, t_y: -0.34, SSD: 0.100
scale_x: 0.85, scale_y: 0.98, rotation: 0.36, shear: -0.06, t_x: 0.06, t_y: -0.34, SSD: 0.100
scale_x: 0.85, scale_y: 0.98, rotation: 0.36, shear: -0.06, t_x: 0.06, t_y: -0.34, SSD: 0.100
scale_x: 0.85, scale_y: 0.98, rotation: 0.36, shear: -0.06, t_x: 0.06, t_y: -0.34, SSD: 0.100
scale_x: 0.85, scale_y: 0.98, rotation: 0.36, shear: -0.06, t_x: 0.06, t_y: -0.34, SSD: 0.100
scale_x: 0.85, scale_y: 0.98, rotation: 0.36, shear: -0.06, t_x: 0.06, t_y: -0.34, SSD: 0.100
scale_x: 0.85, scale_y: 0.98, rotation: 0.36, shear: -0.06, 

scale_x: 0.85, scale_y: 0.98, rotation: 0.36, shear: -0.06, t_x: 0.09, t_y: -0.79, SSD: 0.100
scale_x: 0.85, scale_y: 0.98, rotation: 0.36, shear: -0.06, t_x: 0.09, t_y: -0.79, SSD: 0.100
scale_x: 0.85, scale_y: 0.98, rotation: 0.36, shear: -0.06, t_x: 0.09, t_y: -0.79, SSD: 0.100
scale_x: 0.85, scale_y: 0.98, rotation: 0.36, shear: -0.06, t_x: 0.09, t_y: -0.79, SSD: 0.100
scale_x: 0.86, scale_y: 0.99, rotation: 0.37, shear: -0.08, t_x: 0.12, t_y: -1.05, SSD: 0.100
scale_x: 0.86, scale_y: 0.99, rotation: 0.37, shear: -0.08, t_x: 0.12, t_y: -1.05, SSD: 0.100
scale_x: 0.86, scale_y: 0.99, rotation: 0.37, shear: -0.08, t_x: 0.12, t_y: -1.05, SSD: 0.100
scale_x: 0.86, scale_y: 0.99, rotation: 0.37, shear: -0.08, t_x: 0.12, t_y: -1.05, SSD: 0.100
scale_x: 0.86, scale_y: 0.99, rotation: 0.37, shear: -0.08, t_x: 0.12, t_y: -1.05, SSD: 0.100
scale_x: 0.86, scale_y: 0.99, rotation: 0.37, shear: -0.08, t_x: 0.12, t_y: -1.05, SSD: 0.100
scale_x: 0.86, scale_y: 0.99, rotation: 0.37, shear: -0.08, 

scale_x: 0.89, scale_y: 1.00, rotation: 0.37, shear: -0.11, t_x: 0.30, t_y: -3.37, SSD: 0.099
scale_x: 0.89, scale_y: 1.00, rotation: 0.37, shear: -0.11, t_x: 0.30, t_y: -3.37, SSD: 0.099
scale_x: 0.89, scale_y: 1.00, rotation: 0.37, shear: -0.11, t_x: 0.30, t_y: -3.37, SSD: 0.099
scale_x: 0.89, scale_y: 1.00, rotation: 0.37, shear: -0.11, t_x: 0.30, t_y: -3.37, SSD: 0.099
scale_x: 0.90, scale_y: 0.99, rotation: 0.37, shear: -0.13, t_x: 0.39, t_y: -4.42, SSD: 0.099
scale_x: 0.90, scale_y: 0.99, rotation: 0.37, shear: -0.13, t_x: 0.39, t_y: -4.42, SSD: 0.099
scale_x: 0.90, scale_y: 0.99, rotation: 0.37, shear: -0.13, t_x: 0.39, t_y: -4.42, SSD: 0.099
scale_x: 0.90, scale_y: 0.99, rotation: 0.37, shear: -0.13, t_x: 0.39, t_y: -4.42, SSD: 0.099
scale_x: 0.90, scale_y: 0.99, rotation: 0.37, shear: -0.13, t_x: 0.39, t_y: -4.42, SSD: 0.099
scale_x: 0.90, scale_y: 0.99, rotation: 0.37, shear: -0.13, t_x: 0.39, t_y: -4.42, SSD: 0.099
scale_x: 0.90, scale_y: 0.99, rotation: 0.37, shear: -0.13, 