In [2]:
from mxnet import cpu
from mxnet.gluon.block import HybridBlock
from mxnet.gluon import nn

import mxnet as mx
import mxnet.ndarray as nd
from mxnet import gluon

In [None]:
class AlexNet(HybridBlock):
    r"""AlexNet model from the `"One weird trick..." <https://arxiv.org/abs/1404.5997>`_ paper.
    Parameters
    ----------
    classes : int, default 1000
        Number of classes for the output layer.
    """
    def __init__(self, classes=1000, **kwargs):
        super(AlexNet, self).__init__(**kwargs)
        with self.name_scope():
            self.features = nn.HybridSequential(prefix='')
            with self.features.name_scope():
                self.features.add(nn.Conv2D(64, kernel_size=11, strides=4,
                                            padding=2, activation='relu'))
                self.features.add(nn.MaxPool2D(pool_size=3, strides=2))
                self.features.add(nn.Conv2D(192, kernel_size=5, padding=2,
                                            activation='relu'))
                self.features.add(nn.MaxPool2D(pool_size=3, strides=2))
                self.features.add(nn.Conv2D(384, kernel_size=3, padding=1,
                                            activation='relu'))
                self.features.add(nn.Conv2D(256, kernel_size=3, padding=1,
                                            activation='relu'))
                self.features.add(nn.Conv2D(256, kernel_size=3, padding=1,
                                            activation='relu'))
                self.features.add(nn.MaxPool2D(pool_size=3, strides=2))
                self.features.add(nn.Flatten())
                self.features.add(nn.Dense(4096, activation='relu'))
                self.features.add(nn.Dropout(0.5))
                self.features.add(nn.Dense(4096, activation='relu'))
                self.features.add(nn.Dropout(0.5))

            self.output = nn.Dense(classes)

    def hybrid_forward(self, F, x):
        x = self.features(x)
        x = self.output(x)
        return x


In [None]:
net = AlexNet()

In [None]:
len(net._children)

In [None]:
# class ReversibleConv2D(nn.Conv2D):
#     def __init__(self, channels, kernel_size, strides=(1, 1), padding=(0, 0),
#                 dilation=(1, 1), groups=1, layout='NCHW',
#                 activation=None, use_bias=True, weight_initializer=None,
#                 bias_initializer='zeros', in_channels=0, **kwargs):
#         super(ReversibleConv2D, self).__init__(channels, kernel_size, strides, padding,
#             dilation, groups, layout, activation, use_bias, weight_initializer,
#             bias_initializer, in_channels, **kwargs)
        
#     def hybrid_forward(self, F, x, weight, bias=None):
#         this.in_channels = x.shape[1]
#         return super(ReversibleConv2D, self).hybrid_forward(self, F, x, weight, bias)
    
#     def reverse(self, y):
# #        print (type(self.weight.data()))
#         return nd.Deconvolution(data=y, weight=self.weight.data(), bias=self.bias.data(),
#                                 kernel=self.weight.data().shape,
#                                 stride=self._kwargs['stride'], dilate=self._kwargs['dilate'],
#                                 pad=self._kwargs['pad'], num_filter=self._in_channels, 
#                                 no_bias=self._kwargs['no_bias'])

In [None]:
class ReversibleConv2D(gluon.Block):
    def __init__(self, channels, kernel_size, strides=(1, 1), padding=(0, 0),
                dilation=(1, 1), groups=1, layout='NCHW',
                activation=None, use_bias=True, weight_initializer=None,
                bias_initializer='zeros', in_channels=0, **kwargs):
        super(ReversibleConv2D, self).__init__(**kwargs)
        self.conv2d = nn.Conv2D(channels, kernel_size, strides, padding,
            dilation, groups, layout, activation, use_bias, weight_initializer,
            bias_initializer, in_channels, **kwargs)
        #self.conv2d = nn.Conv2D(10, kernel_size=3)
    
    def forward(self, x):
        self.in_channels = x.shape[1]
        return self.conv2d(x)
    
    def reverse(self, y):
        conv = self.conv2d
        print(conv.weight.data().shape)
        print(conv.bias.data().shape)
        print("kernel:")
        kernel_shape = conv.weight.data().shape[2:]
        print(kernel_shape)
#         return None
        return nd.Deconvolution(data=y, weight=conv.weight.data(), bias=conv.bias.data(),
                                kernel=kernel_shape,
                                stride=conv._kwargs['stride'], dilate=conv._kwargs['dilate'],
                                pad=conv._kwargs['pad'], num_filter=conv._channels, 
                                no_bias=conv._kwargs['no_bias'])        

In [None]:
model_ctx = mx.gpu(0)

# net = nn.Sequential()
# with net.name_scope():
#     net.add(ReversibleConv2D(256, kernel_size=3))

x = mx.nd.ones((1, 3, 20, 20))

net = ReversibleConv2D(10, kernel_size=3)
net.collect_params().initialize(mx.init.Normal(sigma=.01), ctx=model_ctx)

y = net(x.as_in_context(model_ctx))

print(y.shape)

net.reverse(y)

In [None]:
a = (10, 3, 3, 3)

In [None]:
a[1:]

In [10]:
input_4x4 = mx.nd.normal(shape=[1, 3, 4, 4])

print(input_4x4)

kernel_3x3 = mx.nd.normal(shape=[10, 3, 3, 3])
conv = mx.nd.Convolution(data=input_4x4, kernel=(3,3), pad=(1,1), weight=kernel_3x3, num_filter=10, no_bias=True)
print(conv.shape)

#print(conv)

transpose = mx.nd.Deconvolution(data=conv, kernel=(3,3), pad=(1,1), weight=kernel_3x3, num_filter=3, no_bias=True)
print(transpose.shape)

print(transpose)


[[[[ 1.29858637  0.57342744 -0.34835348 -0.58361858]
   [-2.0417614  -0.1796914  -1.2405448  -0.7325387 ]
   [-3.06886029  1.31733298 -1.6384114   0.93694556]
   [ 0.592933   -0.50639272 -0.68630707  0.1130666 ]]

  [[-0.22556685  0.18525851  1.22298074  0.55320805]
   [ 0.25945112 -0.00414529 -0.22740492  0.41514722]
   [ 1.91706145 -1.43812132  0.43687493  0.35308638]
   [ 0.74126935  2.73936391 -0.65508407  2.11390615]]

  [[ 0.54437113 -0.9276405   0.74472493 -0.6918596 ]
   [ 0.2654874  -0.11799012  1.21965587  0.18523361]
   [ 1.62041867 -0.56247813  0.39254659 -0.9529047 ]
   [ 0.05601164 -0.66915834  1.03501856  0.29303032]]]]
<NDArray 1x3x4x4 @cpu(0)>
(1, 10, 4, 4)
(1, 3, 4, 4)

[[[[  55.67409134   -7.4365387     7.69010496  -18.32086182]
   [-103.55095673   -0.908602    -26.96941376  -44.83856201]
   [ -70.33335876  -12.73396397  -19.10709381   20.91474342]
   [  30.7226162   -30.73569298  -48.73796463   12.72243881]]

  [[ -17.39299774  -11.26985931   44.06104279    9.84715