In [1]:
from mxnet import cpu
from mxnet.gluon.block import HybridBlock
from mxnet.gluon import nn

import mxnet as mx

In [2]:
class AlexNet(HybridBlock):
    r"""AlexNet model from the `"One weird trick..." <https://arxiv.org/abs/1404.5997>`_ paper.
    Parameters
    ----------
    classes : int, default 1000
        Number of classes for the output layer.
    """
    def __init__(self, classes=1000, **kwargs):
        super(AlexNet, self).__init__(**kwargs)
        with self.name_scope():
            self.features = nn.HybridSequential(prefix='')
            with self.features.name_scope():
                self.features.add(nn.Conv2D(64, kernel_size=11, strides=4,
                                            padding=2, activation='relu'))
                self.features.add(nn.MaxPool2D(pool_size=3, strides=2))
                self.features.add(nn.Conv2D(192, kernel_size=5, padding=2,
                                            activation='relu'))
                self.features.add(nn.MaxPool2D(pool_size=3, strides=2))
                self.features.add(nn.Conv2D(384, kernel_size=3, padding=1,
                                            activation='relu'))
                self.features.add(nn.Conv2D(256, kernel_size=3, padding=1,
                                            activation='relu'))
                self.features.add(nn.Conv2D(256, kernel_size=3, padding=1,
                                            activation='relu'))
                self.features.add(nn.MaxPool2D(pool_size=3, strides=2))
                self.features.add(nn.Flatten())
                self.features.add(nn.Dense(4096, activation='relu'))
                self.features.add(nn.Dropout(0.5))
                self.features.add(nn.Dense(4096, activation='relu'))
                self.features.add(nn.Dropout(0.5))

            self.output = nn.Dense(classes)

    def hybrid_forward(self, F, x):
        x = self.features(x)
        x = self.output(x)
        return x


In [3]:
net = AlexNet()

In [4]:
len(net._children)

2

In [15]:
class ReversibleConv2D(nn.Conv2D):
    def __init__(self, channels, kernel_size, strides=(1, 1), padding=(0, 0),
                dilation=(1, 1), groups=1, layout='NCHW',
                activation=None, use_bias=True, weight_initializer=None,
                bias_initializer='zeros', in_channels=0, **kwargs):
        super(ReversibleConv2D, self).__init__(channels, kernel_size, strides, padding,
            dilation, groups, layout, activation, use_bias, weight_initializer,
            bias_initializer, in_channels, **kwargs)

        
    def reverse(self, y):
        print(self._channels)

In [16]:
model_ctx = mx.gpu(0)
net = ReversibleConv2D(256, kernel_size=3)
net.collect_params().initialize(mx.init.Normal(sigma=.01), ctx=model_ctx)

x = mx.nd.ones((1, 3, 10,10))

In [17]:
y = net(x.as_in_context(model_ctx))

In [18]:
net.reverse(y)

256
