# 卷积层实现

In [2]:
from torch import nn
import torch


class Conv2D(nn.Module):
    """
    二维卷积层实现
    """
    def __init__(self, in_channels, out_channels, kernel_size, stride, padding):
        """
        根据参数初始化卷积层参数
        """
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size if isinstance(kernel_size, tuple) else (kernel_size, kernel_size)
        self.stride = stride if isinstance(stride, tuple) else (stride, stride)
        self.padding = padding if isinstance(padding, tuple) else (padding, padding)

        self.weight = nn.Parameter(
            torch.randn(out_channels, in_channels, *self.kernel_size)
        )


    def forward(self, X):
        """
        前向传播, X.shape = (batch_size, in_channels, height, width)
        """
        batch, in_C, in_H, in_W = X.shape
        kernel_H, kernel_W = self.kernel_size
        stride_H, stride_W = self.stride
        padding_H, padding_W = self.padding

        # 进行填充
        temp = torch.zeros(batch, in_C, in_H + padding_H * 2, in_W + padding_W * 2, device=X.device)
        temp[:, :, padding_H:in_H+padding_H, padding_W:in_W+padding_W] = X
        X = temp

        # 输出张量初始化
        out = torch.zeros(
            batch,
            self.out_channels,
            (in_H - kernel_H + padding_H * 2) // stride_H + 1,
            (in_W - kernel_W + padding_W * 2) // stride_W + 1,
            device=X.device
        )

        # 进行遍历卷积
        for i in range(out.shape[2]):
            for j in range(out.shape[3]):
                h_start = i * kernel_H
                h_end = (i + 1) * kernel_H
                w_start = j * kernel_W
                w_end = (j + 1) * kernel_W

                for b in range(out.shape[0]):
                    for c in range(out.shape[1]):
                        out[b, c, i, j] = torch.sum(X[b, :, h_start:h_end, w_start, w_end] * self.weight[c])

        return out