Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unpooling layer in tensorflow #632

Closed
ziky90 opened this issue Apr 29, 2016 · 128 comments · Fixed by #2272
Closed

Unpooling layer in tensorflow #632

ziky90 opened this issue Apr 29, 2016 · 128 comments · Fixed by #2272
Labels
Feature Request help wanted Needs help as a contribution layers

Comments

@ziky90
Copy link

ziky90 commented Apr 29, 2016

It would be nice to have in TensorFlow also the unpooling layer as it is described in the paper on deconvolution networks: http://cvlab.postech.ac.kr/research/deconvnet/

I was googling a bit and I found that the added unpooling layer would be handful also for others:
http://stackoverflow.com/questions/36548736/tensorflow-unpooling

@zheng-xq
Copy link

zheng-xq commented May 3, 2016

For deconv, you can use "conv2d_backprop_input" with stride to achieve similar effect. It is the gradient of the conv with stride.

@daeyun
Copy link

daeyun commented May 23, 2016

my implementation using tf.reshape and tf.concat

def unpool(value, name='unpool'):
    """N-dimensional version of the unpooling operation from
    https://www.robots.ox.ac.uk/~vgg/rg/papers/Dosovitskiy_Learning_to_Generate_2015_CVPR_paper.pdf

    :param value: A Tensor of shape [b, d0, d1, ..., dn, ch]
    :return: A Tensor of shape [b, 2*d0, 2*d1, ..., 2*dn, ch]
    """
    with tf.name_scope(name) as scope:
        sh = value.get_shape().as_list()
        dim = len(sh[1:-1])
        out = (tf.reshape(value, [-1] + sh[-dim:]))
        for i in range(dim, 0, -1):
            out = tf.concat([out, tf.zeros_like(out)], i)
        out_size = [-1] + [s * 2 for s in sh[1:-1]] + [sh[-1]]
        out = tf.reshape(out, out_size, name=scope)
    return out


def pool(value, name='pool'):
    """Downsampling operation.
    :param value: A Tensor of shape [b, d0, d1, ..., dn, ch]
    :return: A Tensor of shape [b, d0/2, d1/2, ..., dn/2, ch]
    """
    with tf.name_scope(name) as scope:
        sh = value.get_shape().as_list()
        out = value
        for sh_i in sh[1:-1]:
            assert sh_i % 2 == 0
        for i in range(len(sh[1:-1])):
            out = tf.reshape(out, (-1, 2, np.prod(sh[i + 2:])))
            out = out[:, 0, :]
        out_size = [-1] + [math.ceil(s / 2) for s in sh[1:-1]] + [sh[-1]]
        out = tf.reshape(out, out_size, name=scope)
    return out

@mwalton
Copy link

mwalton commented May 24, 2016

I've been interested in this as well; currently working on 'what-where' / convolutional autoencoders (ala. Zhao et al.)

Thanks @daeyun for the code, I've been trying to figure this out myself. Dosovitskiy uses a kronecker product w/ a block mask (same shape as pooling, all zeros w/ a 1 in the upper left) to unpool. However, as observed in the paper (fig 9) this fails to reconstruct meaningful structure in deeper feature maps. An alternative proposed by Zeiler uses 'switches' (essentially the argmax of the maxpooling operation) to reconstruct using the exact location of the maxima

I've been playing around with tf.maxpool_with_argmax in an attempt to reproduce the 'switched' unpooling experiments first explored by Zeiler and extended by Zhao.

Any thoughts on how this could be implemented?

@girving
Copy link

girving commented Jun 28, 2016

What's the mathematical definition of unpooling?

@ziky90
Copy link
Author

ziky90 commented Jun 29, 2016

The unpooling that I had on my ming is described in here http://www.matthewzeiler.com/pubs/iccv2011/iccv2011.pdf
and corresponding implementation in caffe can be found here: https://github.com/HyeonwooNoh/caffe/blob/master/src/caffe/layers/unpooling_layer.cpp
Also some more formal description is available in the torch documentation:
https://github.com/torch/nn/blob/master/doc/convolution.md#spatialmaxunpooling

@girving
Copy link

girving commented Jun 29, 2016

@ziky90 That's the gradient of max pooling, which we already have an as op.

@ziky90
Copy link
Author

ziky90 commented Jun 29, 2016

@girving Thank you for pointing me at gradient of max pooling. Though it's really difficult to find it as a gradient of max pooling, plus it's also not much documented.
Is there a plan to create separate "layer", for example tf.nn.max_unpool, etc.? From my point of view it'd be much more intuitive, together with adding the documentation it would make it super easy to use.

Btw. It seems, that it confuses and makes other people to build custom solutions instead of simply using something like tf.nn.max_unpool. @ppwwyyxx
https://github.com/ppwwyyxx/tensorpack/blob/master/tensorpack/models/pool.py#L66

@girving
Copy link

girving commented Jun 29, 2016

Yes, giving it a name like tf.nn.max_unpool with good documentation might be good, and we'd be happy to accept PRs.

As a tip for the future, though: this is one advantage of trying to understand the mathematical relationship between different operations. Once you know that unpooling is just the gradient of pooling, it's clear that TensorFlow already implements it, even if the name is different from what one might expect.

@dbbert
Copy link

dbbert commented Jun 30, 2016

Could you share a code example of how to implement unpooling using the gradient of max pooling?

@girving
Copy link

girving commented Jun 30, 2016

It's currently hidden as gen_nn_ops._max_pool_grad, and is used only from the gradient of max_pool:

https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/nn_grad.py#L353

There's also gen_nn_ops._max_pool_with_argmax_grad. Unfortunately, both of them take the original input, which means they'd have to be tweaked to serve as unpooling.

@NickShahML
Copy link

Any plans to get a unpool layer to tensorflow? @girving as you point out, if the gradient operation already exists, then it doesn't seem like much work to get it working?

@girving
Copy link

girving commented Jul 5, 2016

@LeavesBreathe I was wrong initially about how easy it would be, since the gradient operators as written taken the original input. Thus, we probably do need a new exposed op, though it may be able to use the same underlying compute kernels (I'm not sure).

@syed-ahmed
Copy link

Are there any performance gain/loss if one uses the second output of tf.nn.max_pool_with_argmax (which are the indices of the max pool) and uses it along with a tf.map_fn to achive a max unpooling?

@girving
Copy link

girving commented Jul 12, 2016

@syed-ahmed That doesn't work: if you are doing unpooling, you don't start out with an input that you could pass to tf.nn.max_pool_with_argmax.

@syed-ahmed
Copy link

syed-ahmed commented Jul 12, 2016

@girving Can we not just save the indices from tf.nn.max_pool_with_argmax during downsampling for reuse during upsampling? We would use the saved argmax indices to inform us where we want the input to the corresponding upsample layer to go.

@girving
Copy link

girving commented Jul 13, 2016

@syed-ahmed To clarify, it will work but it's a bit awkward. You can certainly store the indices, but the current MaxPoolGradWithArgmax op also wants the values that you originally passed to max pooling. It should use only the shape from these values, but you still need to pass them in. That's not too horrible when it's used as a gradient (though it's still a memory usage bug), but it is not clean enough to give it a nice name.

The same bug occurred in the initial version of conv_3d, so if someone wants to fix this they can look at https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/kernels/conv_grad_ops_3d.cc. The code defines a new op that takes an original shape input rather the whole original input, and uses the same C++ kernel to implement both of them (with a conditional based on name).

If anyone does this, the new op can be given a nicer name like max_unpool.

@syed-ahmed
Copy link

@girving Thanks for clarifying! I totally forgot the case about the gradient. I'll try to fix this issue.

@syed-ahmed
Copy link

Hi @girving, could you please tell what error would result with the memory usage bug? Just wanted to clarify, is it a bug because it's not best practice or did you encounter an error during that initial version of conv_3d? I get the following error for the implementation described above with MaxPoolWithArgmax and was wondering if anybody encountered it before:

E tensorflow/stream_executor/cuda/cuda_driver.cc:1110] failed to synchronize the stop event: CUDA_ERROR_ILLEGAL_ADDRESS
E tensorflow/stream_executor/cuda/cuda_timer.cc:54] Internal: error destroying CUDA event in context 0x69951c0: CUDA_ERROR_ILLEGAL_ADDRESS
E tensorflow/stream_executor/cuda/cuda_timer.cc:59] Internal: error destroying CUDA event in context 0x69951c0: CUDA_ERROR_ILLEGAL_ADDRESS
F tensorflow/stream_executor/cuda/cuda_timer.cc:64] Check failed: start_event_ != nullptr && stop_event_ != nullptr ```

@girving
Copy link

girving commented Jul 26, 2016

@syed-ahmed It's not an actual error unless you run out of memory. The issue is that if the gradient takes the original input tensor rather than the shape, the original input must be stored for the remainder of the forward pass and the backward pass up to that point. If only the shape is needed, that's a long time to hold onto otherwise unneeded memory.

@syed-ahmed
Copy link

syed-ahmed commented Jul 27, 2016

@girving Thanks for your reply. I am defining a MaxUnpoolGrad for the corresponding MaxUnpool operation that I have implemented. Following is what I declare as top_offset and bottom_offset for MaxUnpoolGrad:

const int top_offset = params.tensor_in_rows * params.tensor_in_cols * params.depth; 
const int bottom_offset = params.out_height * params.out_width * params.depth;

The correspoding cuda kernel declared in maxpooling_op_gpu.cu.cc is:

template <typename dtype>
__global__ void MaxUnpoolBackward(const int nthreads, const dtype* top_diff,
                                          const int64* mask, const int top_offset,
                                  const int bottom_offset, dtype* bottom_diff) {
  CUDA_1D_KERNEL_LOOP(index, nthreads) {
    int image_id = (index / bottom_offset);
    CudaAtomicAdd(bottom_diff + index, top_diff[mask[index] + image_id * top_offset]);
  }
}

My graph builds but it is when the session runs that I get the following error:

E tensorflow/stream_executor/cuda/cuda_driver.cc:1110] failed to synchronize the stop event: CUDA_ERROR_ILLEGAL_ADDRESS
E tensorflow/stream_executor/cuda/cuda_timer.cc:54] Internal: error destroying CUDA event in context 0x69951c0: CUDA_ERROR_ILLEGAL_ADDRESS
E tensorflow/stream_executor/cuda/cuda_timer.cc:59] Internal: error destroying CUDA event in context 0x69951c0: CUDA_ERROR_ILLEGAL_ADDRESS
F tensorflow/stream_executor/cuda/cuda_timer.cc:64] Check failed: start_event_ != nullptr && stop_event_ != nullptr 

I am also returning in nn_grad.py like this:

[None, gen_nn_ops._max_unpool_grad(array_ops.shape(op.inputs[1]),
                                     grad,
                                     op.inputs[2],
                                     op.get_attr("ksize"),
                                     op.get_attr("strides"),
                                     padding=op.get_attr("padding")), None)]

where:

MaxUnpool
-input0: input_shape
-input1: grad_in
-input3: argmax

I have made sure the maxunpooling and its grad operation is taking a input shape rather than a input 4D tensor. Do you know how to debug this cuda errors/any tool that can help in finding the origin of these errors? What does these errors indicate? I read a comment on the maxpooling_op_gpu.cu.cc about racing conditions. Is it anyhow related to this?

@girving
Copy link

girving commented Jul 27, 2016

@syed-ahmed Is it possible to use cuDNN for these operations? Writing them yourself will result in very slow code. The same goes for CPU: it would be better to use existing Eigen code if possible.

@syed-ahmed
Copy link

syed-ahmed commented Jul 27, 2016

@girving Thank you for your reply. I will try implementing the cudnn version once i get this cuda one running. I was able to use cuda-gdb to get some sort of trace where my error is originating from. Here's the output from cuda-gdb:

CUDA Exception: Warp Out-of-range Address
The exception was triggered at PC 0x7ffe9976c1d0

Program received signal CUDA_EXCEPTION_5, Warp Out-of-range Address.
[Switching focus to CUDA kernel 0, grid 4660, block (172,0,0), thread (256,0,0), device 0, sm 0, warp 40, lane 0]
0x00007ffe9976c218 in void tensorflow::(anonymous namespace)::MaxUnpoolForward<float>(int, float const*, long long const*, int, int, float*)<<<(662,1,1),(1024,1,1)>>> ()

Here's how it is defined in the cu.cc file:

...
template <typename dtype>
__global__ void MaxUnpoolForward(const int nthreads, const dtype* top_diff,
                                const int64* mask, const int top_offset,
                                const int bottom_offset, dtype* bottom_diff) {
  CUDA_1D_KERNEL_LOOP(index, nthreads) {
    int image_id = (index / top_offset);
    CudaAtomicAdd(bottom_diff + image_id * bottom_offset + mask[index],
                  top_diff[index]);
  }
}

template <typename dtype>
__global__ void MaxUnpoolBackward(const int nthreads, const dtype* top_diff,
                                          const int64* mask, const int top_offset,
                                  const int bottom_offset, dtype* bottom_diff) {
  CUDA_1D_KERNEL_LOOP(index, nthreads) {
    int image_id = (index / bottom_offset);
    CudaAtomicAdd(bottom_diff, top_diff[mask[index] + image_id * top_offset]);
  }
}

#undef CUDA_1D_KERNEL_LOOP
...

I am kinda lost since I'm a beginner with cuda. Anybody has any idea what might be going wrong?

@girving
Copy link

girving commented Jul 27, 2016

It's impossible to debug this without seeing your code. As a wild guess: maybe you are running GPU kernels on Tensor objects stored on the CPU?

@syed-ahmed
Copy link

syed-ahmed commented Jul 27, 2016

Hi @girving. Sorry for not posting the full code. I didn't want to lengthen this issue by posting all the code. You can review the changes in this link.

I am calling the max unpool like this:

 return gen_nn_ops._max_unpool(array_ops.shape(origin_input_tensor), grad,
                                     argmax_tensor,
                                     ksize=[1, 2, 2, 1], strides=[1,1,1,1],
                                     padding="VALID", name=name)

I am not sure if the origin_input_tensor and argmax_tensor objects are in CPU or GPU. The cuda-gdb output of MaxUnpoolForward suggests that "This occurs when any thread within a warp accesses an address that is outside the valid range of local or shared memory regions." gpu error reporting

@syed-ahmed
Copy link

Also there is a lot of code duplication in my changes. I can make the unpool op use the same compute kernel. I was just trying out if using the same compute kernel was causing the CUDA error in the version I posted here.

@wenouyang
Copy link

wenouyang commented Jul 28, 2016

In the Tensorflow implementation (https://github.com/MarvinTeichmann/tensorflow-fcn/blob/master/fcn32_vgg.py) of fully convolutional model (https://people.eecs.berkeley.edu/~jonlong/long_shelhamer_fcn.pdf), author define a function of

``def _upscore_layer(self, bottom, shape,
                   num_classes, name, debug,
                   ksize=4, stride=2):
       strides = [1, stride, stride, 1]
        with tf.variable_scope(name):
        in_features = bottom.get_shape()[3].value

        if shape is None:
            # Compute shape out of Bottom
            in_shape = tf.shape(bottom)

            h = ((in_shape[1] - 1) * stride) + 1
            w = ((in_shape[2] - 1) * stride) + 1
            new_shape = [in_shape[0], h, w, num_classes]
        else:
            new_shape = [shape[0], shape[1], shape[2], num_classes]
        output_shape = tf.pack(new_shape)

        logging.debug("Layer: %s, Fan-in: %d" % (name, in_features))
        f_shape = [ksize, ksize, num_classes, in_features]

        # create
        num_input = ksize * ksize * in_features / stride
        stddev = (2 / num_input)**0.5

        weights = self.get_deconv_filter(f_shape)
        deconv = tf.nn.conv2d_transpose(bottom, weights, output_shape,
                                        strides=strides, padding='SAME')

        if debug:
            deconv = tf.Print(deconv, [tf.shape(deconv)],
                              message='Shape of %s' % name,
                              summarize=4, first_n=1)

    _activation_summary(deconv)
    return deconv

Looks like author just uses tf.nn.conv2d_transpose to do the upsampling. Is my understanding correct?

@ziky90
Copy link
Author

ziky90 commented Jul 29, 2016

@wenouyang Yes in the FCN in https://people.eecs.berkeley.edu/~jonlong/long_shelhamer_fcn.pdf they use only tf.nn.conv2d_transpose() to perform the upsampling, but there exists also other models, mainly for semantic segmentation that use also max_unpooling, for example http://arxiv.org/abs/1505.04366.

@girving
Copy link

girving commented Jul 29, 2016

Sorry for the delay, taking a look at your code now.

@girving
Copy link

girving commented Jul 29, 2016

I must not understand your code. How are you doing an effectively 3D unpooling operation (batch, height, width) with a 1D loop that does only one integer division? One integer division is only powerful enough to express a 2D loop.

@syed-ahmed
Copy link

syed-ahmed commented Jul 29, 2016

@girving I followed the MaxPoolBackward code in the maxpooling_op_gpu.cu.cc. I thought n-dimensions of the tensor is taken care of by the following in maxpooling_op.cc in the LaunchMaxUnpooling function I defined (like LaunchMaxPoolingGradWithArgmax):

const int input_size = params.tensor_in_batch * params.tensor_in_rows *
                           params.tensor_in_cols * params.depth;
const int output_size = params.tensor_in_batch * params.out_height *
                            params.out_width * params.depth;
const int top_offset = params.out_height * params.out_width * params.depth;
const int bottom_offset = params.tensor_in_rows * params.tensor_in_cols * params.depth;

nio1814 referenced this issue in nio1814/tensorflow Feb 16, 2018
nio1814 referenced this issue in nio1814/tensorflow Feb 16, 2018
nio1814 referenced this issue in nio1814/tensorflow Feb 16, 2018
nio1814 referenced this issue in nio1814/tensorflow Feb 16, 2018
@XXY0118
Copy link

XXY0118 commented Feb 20, 2018

@chrisranderson Hi, I'am newer in Deep Learning and Tensorflow. Would you please instruct me more detail about how to implement unpooling by tensorflow/tensorflow#16885? Thanks a lot!

@chrisranderson
Copy link

@XXY0118 Copy and paste these lines https://github.com/rayanelleuch/tensorflow/blob/b46d50583d8f4893f1b1d629d0ac9cb2cff580af/tensorflow/contrib/layers/python/layers/layers.py#L2291-L2327, and you should be good to go. I wish GitHub allowed some kind of DM for occasions like this.

@apatsekin
Copy link

apatsekin commented Mar 23, 2018

@daeyun , please swap parameters in tf.concat call from:
out = tf.concat(i, [out, tf.zeros_like(out)])
to:
out = tf.concat([out, tf.zeros_like(out)], i)

Other than that works fine for unpooling without positions indices. Thanks!

@JFChi
Copy link

JFChi commented Jun 29, 2018

Is there a strided version of unpool function?

@Harshini-Gadige
Copy link

We are going to close this issue. Feel free to reopen it if you want to contribute and link the PR to it.

@jkyl
Copy link

jkyl commented Feb 1, 2019

A differentiable and GPU-safe avg_unpool2d implementation is as follows:

def avg_unpool2d(x, factor):
  '''
  Performs "average un-pooling", i.e. nearest neighbor upsampling,
  without the faulty `tf.image.resize_nearest_neighbor` op.
  ''' 
  x = tf.transpose(x, [1, 2, 3, 0])
  x = tf.expand_dims(x, 0)
  x = tf.tile(x, [factor**2, 1, 1, 1, 1])
  x = tf.batch_to_space_nd(x, [factor, factor], [[0, 0], [0, 0]])
  x = tf.transpose(x[0], [3, 0, 1, 2])
  return x

@greydanus
Copy link

greydanus commented Mar 26, 2019

I believe that a max_unpool/avg_unpool function would be quite useful. The argument that we should "just use the gradient op" ignores the fact that this makes our code hacky and opaque. Also, there's no official documentation for this approach.

TensorFlow doesn't ask people to implement deconvolution, even though technically it can be expressed as a convolution. Why? It's convenient and it lets researchers focus on more important things. The same goes for unpooling.

@alextp
Copy link

alextp commented Mar 27, 2019

@greydanus @jkyl I'd love to approve a PR adding this max_unpool implementation to tf and a unit test.

@greydanus
Copy link

I'm working on a PR + unit test. More to come.

@alextp alextp reopened this Mar 29, 2019
@alextp
Copy link

alextp commented Mar 29, 2019

Reopening so @graydanus's PR can close it

@yselivonchyk
Copy link

I revisited the implementations in current thread and found that @rayanelleuch solution from Oct 24, 2017 works the best for me. It works with batches (i.e. first dimension of the input tensor is None), produces known output shape and produces no type errors.

I also added tf.keras layers for MaxPoolingWithArgmax and Unpooling (previously mentioned versions did not work for tf.keras but worked with just keras, somehow) here https://github.com/yselivonchyk/Tensorflow_WhatWhereAutoencoder/blob/master/pooling.py

@Twice22
Copy link

Twice22 commented Apr 12, 2019

Hello everybody!

As @Panaetius highlighted it, the Unpooling layers presented here have a drawback. They don't account for the padding due to the fact that tf.nn.max_pool_with_argmax does not return a tensor whose size contains the padding (if you use padding='SAME' for example). If have change the function so that it unpools the tensor to the size of the prev_tensor that we use during the tf.nn.max_pool_with_argmax:

def max_unpool(pool, ind, prev_tensor, scope='unpool_2d'):
	"""
	Implement the unpooling operation, as explained here:
	https://stackoverflow.com/questions/36548736/tensorflow-unpooling

	Args:
		pool (tensor): Input tensor of shape (N, H, W, C)
		ind (tensor): Input tensor of shape (N, H, W, C) containing the maximum
			flatten indices (see https://www.tensorflow.org/api_docs/python/tf.nn.max_pool_with_argmax)
		prev_tensor (tensor): previous tensor shape
		scope (str): scope in which to register the operations
	Return:
		ret (tensor): tensor same shape as prev_tensor that corresponds to the "invert" of the
			max pooling operation
	"""
	with tf.variable_scope(scope):
		# input_shape = [N, H, W, C]
		input_shape = tf.shape(pool)
		o_shape = tf.shape(prev_tensor)

		output_shape = [input_shape[0], o_shape[1], o_shape[2], input_shape[3]]

		# N * H * W * C
		flat_input_size = tf.reduce_prod(input_shape)

		# flat output_shape = [N, 4 * H * W * C]
		flat_output_shape = [output_shape[0], output_shape[1] * output_shape[2] * output_shape[3]]

		updates = tf.reshape(pool, [flat_input_size])

		# create the tensor [ [[[1]]], [[[0]]], ..., [[[N-1]]] ]
		batch_range = tf.reshape(
			tf.range(tf.cast(output_shape[0], tf.int64), dtype=ind.dtype),
			shape=[input_shape[0], 1, 1, 1])

		# b is a tensor of size (N, H, W, C) whose first element of the batch are 3D-array full of 0
		# second element of the batch are 3D-array full of 1, ...   
		b = tf.ones_like(ind) * batch_range
		b = tf.reshape(b, [flat_input_size, 1])

		# indices = [ [0, ind_1], [0, ind_2], ... [0, ind_k], ..., [N-1, ind_{N*H*W*C}], [N-1, ind_{N*H*W*C-1}] ]
		indices = tf.reshape(ind, [flat_input_size, 1])
		indices = tf.concat([b, indices], axis=-1)

		ret = tf.scatter_nd(indices, updates, shape=tf.cast(flat_output_shape, tf.int64))
		ret = tf.reshape(ret, output_shape)

		set_input_shape = pool.get_shape()
		prev_tensor_shape = prev_tensor.get_shape()

		set_output_shape = [set_input_shape[0], prev_tensor_shape[1], prev_tensor_shape[2], set_input_shape[3]]
		ret.set_shape(set_output_shape)

		return ret

You can use it as follow:

maxpool_layer, maxpool_idx = tf.nn.max_pool_with_argmax(
			your_input,
			[1, 2, 2, 1], [1, 2, 2, 1],
			padding='SAME',
			name="max_pooling_5")

		conv_layer = tf.layers.conv2d(
			maxpool_layer,
			filters=4096,
			kernel_size=7,
			name='conv')

		deconv_layer = tf.layers.conv2d(
			conv_layer,
			filters=512,
			kernel_size=1,
			kernel_initializer=tf.contrib.layers.xavier_initializer(),
			name="deconv")
		unpooling_layer5 = max_unpool(deconv_layer, maxpool_idx5, your_input, scope="Unpooling_5")

This implemention works well with padding and doesn't need to use set_shape() while reading the tf_records, which means that during the prediction time you can pass one single image at a time (batch_size=1) and have image of totally different sizes and it won't break.

sdmonov referenced this issue in sdmonov/onnx-tensorflow Jul 5, 2019
MaxUnpool is not supported by tensorflow by default. Refer to
https://github.com/tensorflow/tensorflow/issues/2169 for more information.
The current solution uses proposed code from the above issue with modifications
to support for padding and strides.
@dynamicwebpaige
Copy link
Contributor

This sounds like a great feature! Adding support for unpooling is outside of the scope of TensorFlow Core, but would be a fantastic addition to TensorFlow Addons.

Transferring this issue now; @seanpmorgan for visibility.

@dynamicwebpaige dynamicwebpaige transferred this issue from tensorflow/tensorflow Oct 27, 2019
@seanpmorgan
Copy link
Member

Thanks for transferring. This seems like a nice fit in addons, though it will need to be converted to fit the Keras Layer API and have appropriate test cases.

@seanpmorgan seanpmorgan added Feature Request help wanted Needs help as a contribution layers labels Oct 27, 2019
@bhack bhack mentioned this issue Jul 3, 2020
@bhack bhack mentioned this issue Dec 7, 2020
21 tasks
@bhack bhack linked a pull request Dec 10, 2020 that will close this issue
21 tasks
@NEGU93
Copy link

NEGU93 commented Mar 18, 2021

Here it is my implementation also posted stackoverflow. You should apply the max-pooling using tf.nn.max_pool_with_argmax and then pass the argmax result of tf.nn.max_pool_with_argmax

def unpooling(inputs, output_shape, argmax):
        """
        Performs unpooling, as explained in:
        https://www.oreilly.com/library/view/hands-on-convolutional-neural/9781789130331/6476c4d5-19f2-455f-8590-c6f99504b7a5.xhtml
        :param inputs: Input Tensor.
        :param output_shape: Desired output shape. For example, on 2D unpooling, this should be 4D (because of number of samples and channels).
        :param argmax: Result argmax from tf.nn.max_pool_with_argmax
            https://www.tensorflow.org/api_docs/python/tf/nn/max_pool_with_argmax
        """
        flat_output_shape = tf.cast(tf.reduce_prod(output_shape), tf.int64)

        updates = tf.reshape(inputs, [-1])
        indices = tf.expand_dims(tf.reshape(argmax, [-1]), axis=-1)

        ret = tf.scatter_nd(indices, updates, shape=[flat_output_shape])
        ret = tf.reshape(ret, output_shape)
        return ret

This has a small bug/feature that is that if argmax has a repeated value it will perform an addition instead of just putting the value once. Beware of this if stride is 1. I don't know, however, if this is desired or not. This feature was also present in @Twice22 solution as I based my implementation on his code.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Feature Request help wanted Needs help as a contribution layers
Projects
None yet
Development

Successfully merging a pull request may close this issue.