In [1]:
import torch
from torch import nn


In [2]:
def corr2d(X, K):
    h,w = K.shape
    Y = torch.zeros((X.shape[0]-h+1, X.shape[1]-w+1))
    for i in range(Y.shape[0]):
        for j in range(Y.shape[1]):
            Y[i,j] = (X[i:i+h,j:j+w]*K).sum()
    return Y

In [3]:
X = torch.tensor([
    [0, 1, 2],
    [3, 4, 5],
    [6, 7, 8]
])

In [4]:
K = torch.tensor([
    [0, 1],
    [2, 3]
])

In [5]:
Y = corr2d(X, K)

In [6]:
Y

tensor([[19., 25.],
        [37., 43.]])

In [7]:
class Conv2D(nn.Module):
    def __init__(self,kernel_size):
        super().__init__()
        self.weight = nn.Parameter(torch.rand(kernel_size))
        self.bias = nn.Parameter(torch.zeros(1))
    def forward(self,X):
        return corr2d(X, self.weight) + self.bias

In [9]:
layer = Conv2D((3, 3))

In [10]:
layer(X)

tensor([[17.4408]], grad_fn=<AddBackward0>)

In [11]:
def corr2d_multi_in(X,K):
    return sum(corr2d(x,k) for x,k in zip(X,K))

In [16]:
X = torch.tensor([
    [
        [0, 1, 2],
        [3, 4, 5],
        [6, 7, 8]
    ],
    [
        [1, 2, 3],
        [4, 5, 6],
        [7, 8, 9]
    ]
])

In [17]:
X.shape

torch.Size([2, 3, 3])

In [18]:
K = torch.tensor([
    [[0, 1],
    [2, 3]],
    [[1, 2],
    [3, 4]]
])

In [19]:
K.shape

torch.Size([2, 2, 2])

In [20]:
corr2d_multi_in(X,K)

tensor([[ 56.,  72.],
        [104., 120.]])

In [21]:
help(torch.stack)

Help on built-in function stack in module torch:

stack(...)
    stack(tensors, dim=0, *, out=None) -> Tensor
    
    Concatenates a sequence of tensors along a new dimension.
    
    All tensors need to be of the same size.
    
    Arguments:
        tensors (sequence of Tensors): sequence of tensors to concatenate
        dim (int): dimension to insert. Has to be between 0 and the number
            of dimensions of concatenated tensors (inclusive)
    
    Keyword args:
        out (Tensor, optional): the output tensor.



In [22]:
K = torch.stack((K, K+1, K+2),dim=0)

In [24]:
K.shape

torch.Size([3, 2, 2, 2])

In [25]:
def corr2d_multi_in_out(X,K):
    return torch.stack([corr2d_multi_in(X,k) for k in K],dim=0)

In [26]:
corr2d_multi_in_out(X,K)

tensor([[[ 56.,  72.],
         [104., 120.]],

        [[ 76., 100.],
         [148., 172.]],

        [[ 96., 128.],
         [192., 224.]]])