Skip to content

Commit

Permalink
Merge pull request dipy#1475 from Borda/refactor_reg_iter
Browse files Browse the repository at this point in the history
Refactor demon registration - _iterate
  • Loading branch information
skoudoro committed Mar 27, 2018
2 parents c2fc04a + 293878f commit 9adea4a
Showing 1 changed file with 67 additions and 39 deletions.
106 changes: 67 additions & 39 deletions dipy/align/imwarp.py
Expand Up @@ -1208,13 +1208,7 @@ def _iterate(self):
fw_step = np.array(self.metric.compute_forward())

# set zero displacements at the boundary
fw_step[0, ...] = 0
fw_step[:, 0, ...] = 0
fw_step[-1, ...] = 0
fw_step[:, -1, ...] = 0
if(self.dim == 3):
fw_step[:, :, 0, ...] = 0
fw_step[:, :, -1, ...] = 0
fw_step = self.__set_no_boundary_displacement(fw_step)

# Normalize the forward step
nrm = np.sqrt(np.sum((fw_step/current_disp_spacing)**2, -1)).max()
Expand All @@ -1234,10 +1228,7 @@ def _iterate(self):
bw_step = np.array(self.metric.compute_backward())

# set zero displacements at the boundary
bw_step[0, ...] = 0
bw_step[:, 0, ...] = 0
if(self.dim == 3):
bw_step[:, :, 0, ...] = 0
bw_step = self.__set_no_boundary_displacement(bw_step)

# Normalize the backward step
nrm = np.sqrt(np.sum((bw_step/current_disp_spacing) ** 2, -1)).max()
Expand All @@ -1264,45 +1255,82 @@ def _iterate(self):

self.energy_list.append(fw_energy + bw_energy)

self.__invert_models(current_disp_world2grid, current_disp_spacing)

# Free resources no longer needed to compute the forward and backward
# steps
if self.callback is not None:
self.callback(self, RegistrationStages.ITER_END)
self.metric.free_iteration()

return der

def __set_no_boundary_displacement(self, step):
""" set zero displacements at the boundary
Parameters
----------
step : array, ndim 2 or 3
displacements field
Returns
-------
step : array, ndim 2 or 3
displacements field
"""
step[0, ...] = 0
step[:, 0, ...] = 0
step[-1, ...] = 0
step[:, -1, ...] = 0
if self.dim == 3:
step[:, :, 0, ...] = 0
step[:, :, -1, ...] = 0
return step

def __invert_models(self, current_disp_world2grid, current_disp_spacing):
"""Converting static - moving models in both direction.
Parameters
----------
current_disp_world2grid : array, shape (3, 3) or (4, 4)
the space-to-grid transformation associated to the displacement field
d (transforming physical space coordinates to voxel coordinates of the
displacement field grid)
current_disp_spacing :array, shape (2,) or (3,)
the spacing between voxels (voxel size along each axis)
"""

# Invert the forward model's forward field
self.static_to_ref.backward = np.array(
self.invert_vector_field(
self.static_to_ref.forward,
current_disp_world2grid,
current_disp_spacing,
self.inv_iter, self.inv_tol, self.static_to_ref.backward))
self.invert_vector_field(self.static_to_ref.forward,
current_disp_world2grid,
current_disp_spacing,
self.inv_iter, self.inv_tol,
self.static_to_ref.backward))

# Invert the backward model's forward field
self.moving_to_ref.backward = np.array(
self.invert_vector_field(
self.moving_to_ref.forward,
current_disp_world2grid,
current_disp_spacing,
self.inv_iter, self.inv_tol, self.moving_to_ref.backward))
self.invert_vector_field(self.moving_to_ref.forward,
current_disp_world2grid,
current_disp_spacing,
self.inv_iter, self.inv_tol,
self.moving_to_ref.backward))

# Invert the forward model's backward field
self.static_to_ref.forward = np.array(
self.invert_vector_field(
self.static_to_ref.backward,
current_disp_world2grid,
current_disp_spacing,
self.inv_iter, self.inv_tol, self.static_to_ref.forward))
self.invert_vector_field(self.static_to_ref.backward,
current_disp_world2grid,
current_disp_spacing,
self.inv_iter, self.inv_tol,
self.static_to_ref.forward))

# Invert the backward model's backward field
self.moving_to_ref.forward = np.array(
self.invert_vector_field(
self.moving_to_ref.backward,
current_disp_world2grid,
current_disp_spacing,
self.inv_iter, self.inv_tol, self.moving_to_ref.forward))

# Free resources no longer needed to compute the forward and backward
# steps
if self.callback is not None:
self.callback(self, RegistrationStages.ITER_END)
self.metric.free_iteration()

return der
self.invert_vector_field(self.moving_to_ref.backward,
current_disp_world2grid,
current_disp_spacing,
self.inv_iter, self.inv_tol,
self.moving_to_ref.forward))

def _approximate_derivative_direct(self, x, y):
r"""Derivative of the degree-2 polynomial fit of the given x, y pairs
Expand Down

0 comments on commit 9adea4a

Please sign in to comment.