# 实现卷积核的多通道输入和输出

In [2]:
import mxnet as mx
from mxnet import autograd,nd
from mxnet import gluon
from mxnet.gluon import nn

In [3]:
import sys
sys.path.append('..')

## 多输入通道

In [4]:
import gluonbook as gb

In [5]:
def corr_2d_multi_in(X,K):
    #对卷积核输出通道的每一维遍历计算，最后将结果相加
    return nd.add_n(*[gb.corr2d(x,k) for x,k in zip(X,K)])

In [6]:
X = nd.array([[[0, 1, 2], [3, 4, 5], [6, 7, 8]],
[[1, 2, 3], [4, 5, 6], [7, 8, 9]]],ctx=mx.gpu())
K = nd.array([[[0, 1], [2, 3]], [[1, 2], [3, 4]]],ctx=mx.gpu())

In [7]:
corr_2d_multi_in(X,K)


[[ 56.  72.]
 [104. 120.]]
<NDArray 2x2 @cpu(0)>

## 多输出通道
## 通过计算每一个输出通道的值，最后将输出通道在输出维度进行合并得到

In [8]:
def corr2d_multi_out(X,K):
    return nd.stack(*[corr_2d_multi_in(X,k) for k in K])

In [10]:
K = nd.stack(K,K+1,K+2)

In [11]:
corr2d_multi_out(X,K)


[[[ 56.  72.]
  [104. 120.]]

 [[ 76. 100.]
  [148. 172.]]

 [[ 96. 128.]
  [192. 224.]]]
<NDArray 3x2x2 @cpu(0)>

## 下面实现1 * 1 的卷积层

In [29]:
def corr2d_mulit_in_out_1x1(X,K):
    c_i,h,w = X.shape
    c_o = K.shape[0]
    X = X.reshape((c_i,h*w))
    K = K.reshape((c_o,c_i))
    Y = nd.dot(K,X)
    return Y.reshape((c_o,h,w))

In [30]:
X = nd.random.uniform(shape=(3, 3, 3),ctx=mx.gpu())
K = nd.random.uniform(shape=(2, 3, 1, 1),ctx = mx.gpu())

In [31]:
with mx.Context(mx.gpu()):
    Y1 = corr2d_mulit_in_out_1x1(X, K)
    Y2 = corr2d_multi_out(X, K)
    (Y1 - Y2).norm().asscalar() < 1e-6

In [32]:
(Y1 - Y2).norm().asscalar() < 1e-6

True

## 定义一个多输入多输出通道的卷积层

In [50]:
class Conv2D_mulit(nn.Block):
    def __init__(self,kernel_size,ch_in,ch_out,**kargs):
        super(Conv2D_mulit,self).__init__(**kargs)
        self.weight = self.params.get('weights',shape=(ch_out,ch_in)+kernel_size)
        self.bias  = self.params.get('bias',shape=(ch_out,ch_in,))
        
    def forward(self,X):
        #下面进行多通道计算
        Y = corr2d_multi_out(X,self.weight.data())
        return nd.relu(Y+self.bias.data(),ctx=mx.gpu())

In [51]:
conv2d_multi = Conv2D_mulit(kernel_size=(3,3),ch_in=3,ch_out=2)

In [52]:
conv2d_multi.params

conv2d_mulit3_ (
  Parameter conv2d_mulit3_weights (shape=(2, 3, 3, 3), dtype=<class 'numpy.float32'>)
  Parameter conv2d_mulit3_bias (shape=(2, 3), dtype=<class 'numpy.float32'>)
)

In [53]:
conv2d_multi.initialize(ctx=mx.gpu())

In [55]:
with mx.Context(mx.gpu()):
    print(conv2d_multi(X))


[[[0.         0.07260251 0.06445698]
  [0.         0.01962225 0.0287397 ]]

 [[0.         0.00455641 0.        ]
  [0.         0.         0.        ]]]
<NDArray 2x2x3 @gpu(0)>
