Skip to content
This repository has been archived by the owner on Mar 17, 2021. It is now read-only.

Commit

Permalink
refactored resampling
Browse files Browse the repository at this point in the history
  • Loading branch information
wyli committed Oct 24, 2017
1 parent 97e12d9 commit 7c4f035
Show file tree
Hide file tree
Showing 2 changed files with 204 additions and 161 deletions.
79 changes: 43 additions & 36 deletions niftynet/layer/resampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

class ResamplerLayer(Layer):
"""
resampling images with given coordinates
resample inputs according to sample_coords
"""

def __init__(self,
Expand Down Expand Up @@ -41,30 +41,27 @@ def layer_op(self, inputs, sample_coords):

def _resample_nearest(self, inputs, sample_coords):
in_size = inputs.get_shape().as_list()
in_spatial_size = in_size[1:-1]
batch_size = in_size[0]
in_spatial_size = in_size[1:-1]

out_size = sample_coords.get_shape().as_list()
out_spatial_size = out_size[1:-1]
out_spatial_rank = infer_spatial_rank(sample_coords)

input_size = tf.reshape(
in_spatial_size, [1] * (len(out_size) - 1) + [-1])
spatial_coords = self.boundary_func(
tf.round(sample_coords), input_size)

tf.round(sample_coords), in_spatial_size)
batch_ids = tf.reshape(
tf.range(batch_size), [batch_size] + [1] * (out_spatial_rank + 1))
batch_ids = tf.tile(batch_ids, [1] + out_spatial_size + [1])
output = tf.gather_nd(inputs,
tf.concat([batch_ids, spatial_coords], -1))
output = tf.gather_nd(
inputs, tf.concat([batch_ids, spatial_coords], -1))

if self.boundary == 'ZERO':
# TODO check border
scale = 1. / tf.to_float(input_size - 1)
scale = 1. / (tf.constant(in_spatial_size, dtype=tf.float32) - 1)
mask = tf.logical_and(
tf.reduce_any(sample_coords > 0,
tf.reduce_all(sample_coords > 0,
axis=-1, keep_dims=True),
tf.reduce_any(scale * sample_coords < 1,
tf.reduce_all(scale * sample_coords < 1,
axis=-1, keep_dims=True))
return output * tf.to_float(mask)
return output
Expand Down Expand Up @@ -104,7 +101,7 @@ def _resample_linear(self, inputs, sample_coords):
batch_ids = tf.tile(batch_ids, [1] + out_spatial_size)
sc = (floor_coords, ceil_coords)

def get_a_corner(bc):
def get_knot(bc):
coord = [sc[c][i] for i, c in enumerate(bc)]
coord = tf.stack([batch_ids] + coord, -1)
return tf.gather_nd(inputs, coord)
Expand All @@ -119,38 +116,32 @@ def _pyramid_combination(samples, w_0, w_1):
binary_neighbour_ids = [
[int(c) for c in format(i, '0%ib' % in_spatial_rank)]
for i in range(2 ** in_spatial_rank)]
samples = [get_a_corner(bc) for bc in binary_neighbour_ids]
samples = [get_knot(bc) for bc in binary_neighbour_ids]
return _pyramid_combination(samples, weight_0, weight_1)

def _resample_bspline(self, inputs, sample_coords):

in_size = inputs.get_shape().as_list()
batch_size = in_size[0]
in_spatial_size = in_size[1:-1]
in_spatial_rank = infer_spatial_rank(inputs)
batch_size = in_size[0]

out_spatial_rank = infer_spatial_rank(sample_coords)
input_size = tf.reshape(
in_spatial_size, [1] * (out_spatial_rank + 1) + [-1])
if in_spatial_rank == 2:
raise NotImplementedError(
'bspline interpolation not implemented for 2d yet')
index_voxel_coords = tf.floor(sample_coords)
floor_coords = tf.floor(sample_coords)

# Compute voxels to use for interpolation
grid = tf.meshgrid([-1, 0, 1, 2],
[-1, 0, 1, 2],
[-1, 0, 1, 2],
indexing='ij')
offset_shape = [1, 4 ** in_spatial_rank] + \
[1] * out_spatial_rank + [in_spatial_rank]
offset_shape = [1, -1] + [1] * out_spatial_rank + [in_spatial_rank]
offsets = tf.reshape(tf.stack(grid, 3), offset_shape)
preboundary_spatial_coords = offsets + \
tf.expand_dims(
tf.cast(index_voxel_coords, tf.int32),
1)
spatial_coords = self.boundary_func(
preboundary_spatial_coords, input_size)
sz = spatial_coords.get_shape().as_list()
spatial_coords = \
offsets + tf.expand_dims(tf.cast(floor_coords, tf.int32), 1)
spatial_coords = self.boundary_func(spatial_coords, in_spatial_size)
knot_size = spatial_coords.get_shape().as_list()

# Compute weights for each voxel
def build_coef(u, d):
Expand All @@ -160,39 +151,55 @@ def build_coef(u, d):
tf.pow(u, 3)]
return tf.concat(coeff_list, d) / 6

weight = tf.reshape(sample_coords - index_voxel_coords,
[batch_size, -1, 3])
weight = tf.reshape(sample_coords - floor_coords, [batch_size, -1, 3])
coef_shape = [batch_size, 1, 1, 1, -1]
Bu = build_coef(tf.reshape(weight[:, :, 0], coef_shape), 1)
Bv = build_coef(tf.reshape(weight[:, :, 1], coef_shape), 2)
Bw = build_coef(tf.reshape(weight[:, :, 2], coef_shape), 3)
all_weights = tf.reshape(Bu * Bv * Bw, [batch_size] + sz[1:-1] + [1])
all_weights = tf.reshape(Bu * Bv * Bw,
[batch_size] + knot_size[1:-1] + [1])
# Gather voxel values and compute weighted sum
batch_coords = tf.reshape(
tf.range(batch_size), [batch_size] + [1] * (len(sz) - 1))
batch_coords = tf.tile(batch_coords, [1] + sz[1:-1] + [1])
tf.range(batch_size), [batch_size] + [1] * (len(knot_size) - 1))
batch_coords = tf.tile(batch_coords, [1] + knot_size[1:-1] + [1])
raw_samples = tf.gather_nd(
inputs, tf.concat([batch_coords, spatial_coords], -1))
return tf.reduce_sum(all_weights * raw_samples, reduction_indices=1)


def _boundary_replicate(sample_coords, input_size):
sample_coords = tf.cast(sample_coords, COORDINATES_TYPE)
sample_coords, input_size = _param_type_and_shape(sample_coords, input_size)
return tf.maximum(tf.minimum(sample_coords, input_size - 1), 0)


def _boundary_circular(sample_coords, input_size):
sample_coords = tf.cast(sample_coords, COORDINATES_TYPE)
sample_coords, input_size = _param_type_and_shape(sample_coords, input_size)
return tf.mod(tf.mod(sample_coords, input_size) + input_size, input_size)


def _boundary_symmetric(sample_coords, input_size):
sample_coords = tf.cast(sample_coords, COORDINATES_TYPE)
sample_coords, input_size = _param_type_and_shape(sample_coords, input_size)
circular_size = input_size + input_size - 2
return (input_size - 1) - tf.abs(
(input_size - 1) - _boundary_circular(sample_coords, circular_size))


def _param_type_and_shape(sample_coords, input_size):
sample_coords = tf.cast(sample_coords, COORDINATES_TYPE)
try:
input_size = tf.constant(input_size, dtype=COORDINATES_TYPE)
except TypeError:
pass
# try: # broadcasting input_size to match the shape of coordinates
# if len(input_size) > 1:
# broadcasting_shape = [1] * (infer_spatial_rank(sample_coords) + 1)
# input_size = tf.reshape(input_size, broadcasting_shape + [-1])
# except (TypeError, AssertionError):
# # do nothing
# pass
return sample_coords, input_size


SUPPORTED_INTERPOLATION = {'BSPLINE', 'LINEAR', 'NEAREST'}

SUPPORTED_BOUNDARY = {
Expand Down

0 comments on commit 7c4f035

Please sign in to comment.