# 7.4. Multiple Input and Multiple Output Channels

In [ ]:
import torch
from d2l import torch as d2l

## 7.4.1. Multiple Input Channels

https://d2l.ai/_images/conv-multi-in.svg

In [ ]:
multi_input = torch.tensor([[[1, 2, 3],
                             [4, 5, 6],
                             [7, 8, 9]],
                            [[0, 1, 2],
                             [3, 4, 5],
                             [6, 7, 8]]])

In [ ]:
multi_kernel = torch.tensor([[[1, 2],
                              [3, 4]],
                             [[0, 1],
                              [2, 3]]])

In [ ]:
def corr2d_multi_in(X, K):
    # Iterate through the 0th dimension (channel) of K first, then add them up
    return sum(d2l.corr2d(x, k) for x, k in zip(X, K))

In [ ]:
# zip
a = torch.tensor([1, 2, 3, 4, 5, 6])
b = torch.tensor([2, 3, 4, 5, 6, 7])
c = torch.tensor([3, 4, 5, 6, 7, 8])
abc = zip(a, b, c)
for i in abc:
    print(i)

# zip：把可迭代的对象打包成一个总体的可迭代对象

In [ ]:
corr2d_multi_in(multi_input, multi_kernel)

## 7.4.2. Multiple Output Channels

In [ ]:
def corr2d_multi_in_out(X, K):
    # Iterate through the 0th dimension of K, and each time, perform
    # cross-correlation operations with input X. All the results are
    # stacked together
    return torch.stack([corr2d_multi_in(X, k) for k in K], 0)

In [ ]:
d = torch.tensor([[1, 2],
                  [3, 4]])
d

In [ ]:
d + 1

In [ ]:
# stack
stack = torch.stack((d, d+1, d+2))
stack

In [ ]:
stack.shape

In [ ]:
torch.cat((d, d+1, d+2))

In [ ]:
torch.cat((d, d+1, d+2), 1)

In [ ]:
torch.stack((d, d+1, d+2), 0)

In [ ]:
torch.stack((d, d+1, d+2), 1)

In [ ]:
torch.stack((d, d+1, d+2), 2)

## 7.4.3. 1X1 Convolutional Layer

https://d2l.ai/_images/conv-1x1.svg

The only computation of the 1X1 convolution occurs on the channel dimension.

1X1 的卷积核的唯一作用就是降维。

In [ ]:
def corr2d_multi_in_out_1x1(X, K):
    c_i, h, w = X.shape
    c_o = K.shape[0]
    X = X.reshape(c_i, h * w)
    K = K.reshape(c_o, c_i)
    # Matrix multiplication in the fully connected layer
    Y = torch.matmul(K, X)
    return Y.reshape((c_o, h, w))

In [ ]:
mock_x = torch.normal(0, 1, (3, 3, 3))
print(mock_x)
c_i, h, w = mock_x.shape
mock_x = mock_x.reshape(c_i, h * w)
print(mock_x)

mock_kernel = torch.normal(0, 1, (2, 3, 1, 1))
c_o = mock_kernel.shape[0]
print(mock_kernel)
mock_kernel = mock_kernel.reshape(c_o, c_i)
print(mock_kernel)

mock_y = torch.matmul(mock_kernel, mock_x)
print(mock_y.reshape(c_o, h, w))

In [ ]:
# torch.normal(means, std, out=None)
X = torch.normal(0, 1, (3, 3, 3))
X

In [ ]:
K = torch.normal(0, 1, (2, 3, 1, 1))
K

In [ ]:
Y1 = corr2d_multi_in_out_1x1(X, K)
Y1

In [ ]:
Y2 = corr2d_multi_in_out(X, K)
Y2

In [ ]:
assert float(torch.abs(Y1 - Y2).sum()) < 1e-6