Skip to content

Commit

Permalink
neural: ImageAlign/UpConvMerge -> AutoMerge
Browse files Browse the repository at this point in the history
Change the name of ImageAlign/UpConvMerge to the more
fitting "AutoMerge". The UpConvMerge alias still works
for (partial) backwards compatibility.
Parameter names that refer to the possibly inserted UpConv node
now have a "u_" prefix to make them better discernible.

The auto-UpConv feature is now optional: u_hi_res_n_f
only is required if UpConv is actually used.
disable_upconv=True prevents all automatic UpConvs.

Updated the docstring accordingly.
  • Loading branch information
mdraw committed Aug 15, 2017
1 parent 71914ce commit 45c507e
Showing 1 changed file with 54 additions and 38 deletions.
92 changes: 54 additions & 38 deletions elektronn2/neuromancer/neural.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

__all__ = ['Perceptron', 'Conv', 'UpConv', 'Crop', 'LSTM',
'FragmentsToDense', 'Pool', 'Dot', 'FaithlessMerge',
'GRU', 'LRN', 'ImageAlign', 'UpConvMerge']
'GRU', 'LRN', 'AutoMerge', 'UpConvMerge']

################################################################################

Expand Down Expand Up @@ -1176,20 +1176,24 @@ def _calc_comp_cost(self):
self.computational_cost = 0


def ImageAlign(hi_res, lo_res, hig_res_n_f,
activation_func='relu', identity_init=True,
batch_normalisation=False, dropout_rate=0,
name="upconv", print_repr=True, w=None, b=None, gamma=None,
mean=None, std=None, gradnet_mode=None, merge_mode='concat'):
def AutoMerge(hi_res, lo_res, u_hi_res_n_f=None, merge_mode='concat',
disable_upconv=False, name='merge', print_repr=True,
# "u_" Parameters for a possible UpConv constructor:
u_activation_func='relu', u_identity_init=True,
u_batch_normalisation=False, u_dropout_rate=0, u_name='upconv',
u_w=None, u_b=None, u_gamma=None, u_mean=None, u_std=None,
u_gradnet_mode=None):
"""
Try to automatically align and concatenate a high-res and a low-res
convolution output of two branches of a CNN by applying UpConv and Crop to
Merge two network branches by automatic cropping and upconvolutions.
Try to automatically align and merge a high-res and a low-res
(convolution) output of two branches of a CNN by applying UpConv and Crop to
make their shapes and strides compatible.
UpConv is used if the low-res Node's strides are at least twice as large
as the strides of the high-res Node in any dimension.
This function can be used to simplify creation of e.g. architectures similar to
U-Net (see https://arxiv.org/abs/1505.04597).
U-Net (see https://arxiv.org/abs/1505.04597) or skip-connections.
If a ValueError that the shapes cannot be aligned is thrown,
you can try changing the filter shapes and pooling factors of the
Expand All @@ -1205,56 +1209,68 @@ def ImageAlign(hi_res, lo_res, hig_res_n_f,
Parent Node with high resolution output.
lo_res: Node
Parent Node with low resolution output.
hig_res_n_f: int
merge_mode: str
How the merging should be performed. Available options:
'concat' (default): Merge with a ``Concat`` Node.
'add': Merge with an ``Add`` Node.
name: str
Name of the final merge node.
print_repr: bool
Whether to print the node representation upon initialisation.
disable_upconv: bool
If True, no automatic upconvolutions are performed to match strides.
u_hi_res_n_f: int
Number of filters for the aligning UpConv.
activation_func: str
u_activation_func: str
(passed to new UpConv if required).
identity_init: bool
u_identity_init: bool
(passed to new UpConv if required).
batch_normalisation: bool
u_batch_normalisation: bool
(passed to new UpConv if required).
dropout_rate: float
u_dropout_rate: float
(passed to new UpConv if required).
name: str
u_name: str
Name of the intermediate UpConv node if required.
print_repr: bool
Whether to print the node representation upon initialisation.
w
u_w
(passed to new UpConv if required).
b
u_b
(passed to new UpConv if required).
gamma
u_gamma
(passed to new UpConv if required).
mean
u_mean
(passed to new UpConv if required).
std
u_std
(passed to new UpConv if required).
gradnet_mode
u_gradnet_mode
(passed to new UpConv if required).
merge_mode: str
How the merging should be performed. Available options:
'concat' (default): Merge with a ``Concat`` Node.
'add': Merge with an ``Add`` Node.
Returns
-------
Concat
Concat Node that merges the aligned high-res and low-res outputs.
Concat or Add
``Concat`` or ``Add`` node (depending on ``merge_mode``)
that merges the aligned high-res and low-res outputs.
"""
###TODO exchange UpConv and Crop to save computation in some cases
# TODO: Automatically determine which one is hi or lo res.
# TODO: Make concept of resolutions optional (This op can also be just used for auto-cropping)
# TODO: Bundle "u_" parameters for UpConv to a single dict to clean up the signature?

sh_hi = hi_res.shape
sh_lo = lo_res.shape
assert len(sh_hi)==len(sh_lo)
assert sh_hi.spatial_axes == sh_lo.spatial_axes

unpool = sh_lo.strides // sh_hi.strides
if np.any(unpool>1):
lo_res = UpConv(lo_res, hig_res_n_f, unpool,
activation_func=activation_func, identity_init=identity_init,
batch_normalisation=batch_normalisation, dropout_rate=dropout_rate,
name=name, print_repr=print_repr, w=w, b=b, gamma=gamma,
mean=mean, std=std, gradnet_mode=gradnet_mode)
if np.any(unpool > 1) and not disable_upconv:
if u_hi_res_n_f is None:
raise ValueError('AutoMerge is trying to insert an UpConv node, but'
'u_hi_res_n_f is not defined. Please set it to the'
'desired number of features to be used for UpConv.')
lo_res = UpConv(lo_res, u_hi_res_n_f, unpool,
activation_func=u_activation_func, identity_init=u_identity_init,
batch_normalisation=u_batch_normalisation, dropout_rate=u_dropout_rate,
name=u_name, print_repr=print_repr, w=u_w, b=u_b, gamma=u_gamma,
mean=u_mean, std=u_std, gradnet_mode=u_gradnet_mode)

# No both have same stride
# Shapes may have changed
Expand All @@ -1281,15 +1297,15 @@ def ImageAlign(hi_res, lo_res, hig_res_n_f,
hi_res = Crop(hi_res, crop_hi, print_repr=True)

if merge_mode == 'concat':
out = Concat((lo_res, hi_res), axis='f', name='concat_merge', print_repr=True)
out = Concat((lo_res, hi_res), axis='f', name=name, print_repr=True)
elif merge_mode == 'add':
out = Add(lo_res, hi_res, name='add_merge', print_repr=True)
out = Add(lo_res, hi_res, name=name, print_repr=True)
else:
raise ValueError('Invalid "merge_mode". Should be "add" or "concat".')

return out

UpConvMerge = ImageAlign
UpConvMerge = AutoMerge

class Pool(Node):
"""
Expand Down

0 comments on commit 45c507e

Please sign in to comment.