## Multiple Input and Output Channels

Multiple Input Channels

In [5]:
import torch
def corr2d(x,k):
    h,w=k.shape
    Y=torch.zeros((x.shape[0]-h+1,x.shape[1]-w+1))
    for i in range(Y.shape[0]):
        for j in range(Y.shape[1]):
            Y[i,j]=(x[i:i+h,j:j+w]*k).sum()
    return Y

def corr2d_multi_in(X, K):
    # First, traverse along the 0th dimension (channel dimension) of X and K.
    # Then, add them together
    return sum([corr2d(x, k) for x, k in zip(X, K)])

We can construct the input array X and the kernel array K of the above diagram to validate the output of the cross-correlation operation.

In [6]:
X = torch.tensor([[[0, 1, 2], [3, 4, 5], [6, 7, 8]],
[[1, 2, 3], [4, 5, 6], [7, 8, 9]]])
K = torch.tensor([[[0, 1], [2, 3]], [[1, 2], [3, 4]]])
corr2d_multi_in(X, K)

tensor([[ 56.,  72.],
        [104., 120.]])

## Multiple Output Channels

For multiple output channels we simply generate multiple outputs and then stack them together.

In [7]:
def corr2d_multi_in_out(X, K):
    # Traverse along the 0th dimension of K, and each time, perform 
    # cross-correlation operations with input X. All of the results are merged 
    # together using the stack function.
    return torch.stack([corr2d_multi_in(X, k) for k in K], dim=0)

We construct a convolution kernel with 3 output channels by concatenating the kernel array K with K+1(plus one for each element in K) and K+2

In [8]:
K = torch.stack([K, K + 1, K + 2], dim=0)
K.shape

torch.Size([3, 2, 2, 2])

We can have multiple input and output channels.

In [9]:
print(X.shape)
print(K.shape)
print(corr2d_multi_in_out(X, K))

torch.Size([2, 3, 3])
torch.Size([3, 2, 2, 2])
tensor([[[ 56.,  72.],
         [104., 120.]],

        [[ 76., 100.],
         [148., 172.]],

        [[ 96., 128.],
         [192., 224.]]])


## 1X1 Convolutions

In [10]:
def corr2d_multi_in_out_1x1(X, K):
    c_i, h, w = X.shape
    c_o = K.shape[0]
    X = X.reshape((c_i, h * w))
    K = K.reshape((c_o, c_i))
    Y = torch.matmul(K, X) # Matrix multiplication in the fully connected layer.
    return Y.reshape((c_o, h, w))

This is equivalent to cross-correlation with an appropriately narrow 1x1
 kernel.

In [11]:

X = torch.randn(size=(3, 3, 3))
K = torch.randn(size=(2, 3, 1, 1))
Y1 = corr2d_multi_in_out_1x1(X, K)
Y2 = corr2d_multi_in_out(X, K)
(Y1 - Y2).norm().item() < 1e-6

True