# Conv2D 的实现

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

## 卷积相关的输入参数

In [2]:
batch_size = 1
input_w, input_h = 10, 10
kernel_size = 3
in_channel, out_channel = 2, 4
padding = 1
stride = 2
dilation = 2
groups = 1

input_tensor = torch.randn(batch_size, in_channel, input_w, input_h)
kernel = torch.randn(out_channel, in_channel // groups, kernel_size, kernel_size)
bias = torch.randn(out_channel)

## Pytorch 算子调用

In [3]:
output_tensor_torch = F.conv2d(
    input_tensor,
    kernel,
    bias,
    stride=stride,
    padding=padding,
    dilation=dilation,
    groups=groups,
)

In [4]:
output_tensor_torch

tensor([[[[ 5.5165,  1.5675,  4.4946,  3.6341],
          [-1.3439,  2.2366,  9.1844,  4.2583],
          [ 1.1333,  1.0637, 11.7791,  8.8648],
          [ 3.2420,  3.1308,  0.2205,  4.3939]],

         [[ 0.7311,  2.7015,  1.6318,  1.2317],
          [-0.4422,  3.3261, -6.8021,  1.3115],
          [-1.0562, -3.0663, -1.3042, -0.7498],
          [ 0.7304, -1.6261,  1.4370, -0.0572]],

         [[ 1.9504,  4.7555,  1.8208,  0.0887],
          [-1.7191, -3.8353,  4.9241,  3.5999],
          [-0.2118,  4.1998,  3.5847,  5.6590],
          [ 4.1347,  3.7212,  3.9359, -2.5025]],

         [[ 6.0964,  9.5853,  5.6209,  0.0882],
          [-3.4786, -2.2103,  4.6434,  2.4786],
          [ 0.5002,  0.0610,  0.4057,  9.4221],
          [ 4.6218,  1.4769, -6.3351, -7.0424]]]])

## 单通道 Conv 实现

In [5]:
def single_channel_conv(
    input_tensor, kernel, bias=None, stride=1, padding=0, dilation=1
):
    assert input_tensor.ndim == 2 and kernel.ndim == 2

    # Apply padding if needed
    if padding > 0:
        input_tensor = F.pad(input_tensor, (padding, padding, padding, padding))
    # Set bias to zero if not provided
    if bias is None:
        bias = torch.zeros(1)

    input_h, input_w = input_tensor.shape
    kernel_h, kernel_w = kernel.shape

    # Calculate dilated kernel dimensions
    # (k - 1) * (d - 1) + k
    dilated_win_h = (kernel_h - 1) * dilation + 1
    dilated_win_w = (kernel_w - 1) * dilation + 1

    # Calculate output dimensions
    output_h = (input_h - dilated_win_h) // stride + 1
    output_w = (input_w - dilated_win_w) // stride + 1

    # Initialize output tensor
    output_tensor = torch.zeros(output_h, output_w)

    # Perform convolution
    for i in range(output_h):
        for j in range(output_w):
            # Extract the sliding window from the input tensor
            input_slice = input_tensor[
                i * stride : i * stride + dilated_win_h : dilation,
                j * stride : j * stride + dilated_win_w : dilation,
            ]
            # Perform element-wise multiplication and sum
            output_tensor[i, j] = torch.sum(input_slice * kernel) + bias

    return output_tensor

## 多通道 Conv 实现

In [6]:
# 多输入通道二维卷积
def multi_input_channel_conv(
    input_tensor, kernels, bias=None, stride=1, padding=0, dilation=1
):
    assert input_tensor.ndim == 3 and kernels.ndim == 3

    # 对每个输入通道执行单通道卷积并叠加
    output_tensor = torch.stack(
        [
            single_channel_conv(input_channel, kernel, None, stride, padding, dilation)
            for input_channel, kernel in zip(input_tensor, kernels)
        ]
    ).sum(dim=0)

    # 如果有偏置则加上
    if bias is not None:
        output_tensor += bias

    return output_tensor


# 多输出通道二维卷积
def multi_output_channel_conv(
    input_tensor, kernels, bias=None, stride=1, padding=0, dilation=1, groups=1
):
    assert input_tensor.ndim == 3 and kernels.ndim == 4

    out_channels = kernels.size(0)
    in_channels = input_tensor.size(0)

    # 确保输出和输入通道数可以被分组均匀整除
    assert out_channels % groups == 0 and in_channels % groups == 0

    # 每个分组的输入和输出通道大小
    input_group_size = in_channels // groups
    output_group_size = out_channels // groups

    # 根据输出通道索引获取对应的输入通道范围
    def get_input_group(index):
        group_start = (index // output_group_size) * input_group_size
        group_end = group_start + input_group_size
        return slice(group_start, group_end)

    # 对每个输出通道执行多输入通道卷积
    output_tensor = torch.stack(
        [
            multi_input_channel_conv(
                input_tensor[get_input_group(output_index)],  # 获取输入通道分组
                kernel,  # 对应的输出通道卷积核
                bias_channel,  # 对应的偏置
                stride,
                padding,
                dilation,
            )
            for output_index, (kernel, bias_channel) in enumerate(zip(kernels, bias))
        ]
    )

    return output_tensor


# 批量多输出通道二维卷积
def batch_multi_channel_conv(
    input_batch, kernels, bias=None, stride=1, padding=0, dilation=1, groups=1
):
    assert input_batch.ndim == 4 and kernels.ndim == 4

    # 对每个批次执行多输出通道卷积
    output_batch = torch.stack(
        [
            multi_output_channel_conv(
                input_tensor, kernels, bias, stride, padding, dilation, groups
            )
            for input_tensor in input_batch
        ]
    )

    return output_batch

In [7]:
output_tensor1 = batch_multi_channel_conv(
    input_tensor,
    kernel,
    bias,
    stride=stride,
    padding=padding,
    dilation=dilation,
    groups=groups,
)
print(output_tensor1)
print("mean diff: ", torch.mean(output_tensor1 - output_tensor_torch))

tensor([[[[ 5.5165,  1.5675,  4.4946,  3.6341],
          [-1.3439,  2.2366,  9.1844,  4.2583],
          [ 1.1333,  1.0637, 11.7791,  8.8648],
          [ 3.2420,  3.1308,  0.2205,  4.3939]],

         [[ 0.7311,  2.7015,  1.6318,  1.2317],
          [-0.4422,  3.3261, -6.8021,  1.3115],
          [-1.0562, -3.0663, -1.3042, -0.7498],
          [ 0.7304, -1.6261,  1.4370, -0.0572]],

         [[ 1.9504,  4.7555,  1.8208,  0.0887],
          [-1.7191, -3.8353,  4.9241,  3.5999],
          [-0.2118,  4.1998,  3.5847,  5.6590],
          [ 4.1347,  3.7212,  3.9359, -2.5025]],

         [[ 6.0964,  9.5853,  5.6209,  0.0882],
          [-3.4786, -2.2103,  4.6434,  2.4786],
          [ 0.5002,  0.0610,  0.4057,  9.4221],
          [ 4.6218,  1.4769, -6.3351, -7.0424]]]])
mean diff:  tensor(-1.0245e-08)


## 多通道的另一种实现

将多通道输入看成是一个每个 `(i,j)` 位置上都是一个 $\mathbb{R}^{c\_in}$ 的向量。 将多通道的权重，看成是一个每个位置都是一个 $\mathbb{R}^{c\_out\times c\_in}$的二维矩阵。那么在滑动窗口内，原来的点积就变成为矩阵乘法，计算得到一个 $\mathbb{R}^{c\_{out}}$ 的向量，然后再将窗口内所的有向量相加。

该实现方案要求输入 Tensor 的布局是 $N\times H\times W\times C_{in}$, weight 的布局是：$K\times K\times C_{in} \times C_{out} $

该实现暂未支持分组卷积。


In [8]:
def multi_channel_conv2d_flat(
    input_tensor, kernel, bias=None, stride=1, padding=0, dilation=1
):
    assert input_tensor.ndim == 4 and kernel.ndim == 4

    # (N,C,H,W)->(N,H,W,C)
    input_tensor = torch.permute(input_tensor, (0, 2, 3, 1))
    # (C_out, C_in, K, K) -> (K, K, C_in, C_out)
    kernel = torch.permute(kernel, (2, 3, 1, 0))

    if padding > 0:
        input_tensor = F.pad(
            input_tensor, (0, 0, padding, padding, padding, padding, 0, 0)
        )

    # Set bias to zero if not provided
    if bias is None:
        bias = torch.zeros(1)

    batch_size, input_h, input_w, in_channel = input_tensor.shape
    kernel_h, kernel_w, _, out_channel = kernel.shape

    # Calculate dilated kernel dimensions
    # (k - 1) * (d - 1) + k
    dilated_win_h = (kernel_h - 1) * dilation + 1
    dilated_win_w = (kernel_w - 1) * dilation + 1

    # Calculate output dimensions
    output_h = (input_h - dilated_win_h) // stride + 1
    output_w = (input_w - dilated_win_w) // stride + 1

    # Initialize output tensor
    output_tensor = torch.zeros(batch_size, output_h, output_w, out_channel)

    for i in range(output_h):
        for j in range(output_w):
            input_slice = input_tensor[
                :,
                i * stride : i * stride + dilated_win_h : dilation,
                j * stride : j * stride + dilated_win_w : dilation,
            ]
            # (bs, k, k, c_in) * (k, k, c_in, c_out) -> (bs, k, k, c_out)
            output_tensor[:, i, j] = torch.einsum(
                "bijk,ijkl->bijl", input_slice, kernel
            ).sum(dim=(1, 2))

    output_tensor += bias
    return torch.permute(output_tensor, (0, 3, 1, 2))

In [9]:
output_tensor2 = multi_channel_conv2d_flat(
    input_tensor, kernel, bias, stride=stride, padding=padding, dilation=dilation
)
print(output_tensor2)
print("mean diff: ", torch.mean(output_tensor2 - output_tensor_torch))

tensor([[[[ 5.5165,  1.5675,  4.4946,  3.6341],
          [-1.3439,  2.2366,  9.1844,  4.2583],
          [ 1.1333,  1.0637, 11.7791,  8.8648],
          [ 3.2420,  3.1308,  0.2205,  4.3939]],

         [[ 0.7311,  2.7015,  1.6318,  1.2317],
          [-0.4422,  3.3261, -6.8021,  1.3115],
          [-1.0562, -3.0663, -1.3042, -0.7498],
          [ 0.7304, -1.6261,  1.4370, -0.0572]],

         [[ 1.9504,  4.7555,  1.8208,  0.0887],
          [-1.7191, -3.8353,  4.9241,  3.5999],
          [-0.2118,  4.1998,  3.5847,  5.6590],
          [ 4.1347,  3.7212,  3.9359, -2.5025]],

         [[ 6.0964,  9.5853,  5.6209,  0.0882],
          [-3.4786, -2.2103,  4.6434,  2.4786],
          [ 0.5002,  0.0610,  0.4057,  9.4221],
          [ 4.6218,  1.4769, -6.3351, -7.0424]]]])
mean diff:  tensor(-2.0489e-08)


## 将卷积转换为1x1卷积核的矩阵乘法

依次计算卷积核心每个平面位置 $(i,j)$上的 $1\times 1$ 的卷积核在每个滑动窗口上和输入 Tensor 的乘法响应，输出是一个和输出 Tensor 尺寸相同的 Tensor。

In [10]:
def conv_matrix_multiplication(
    input_tensor, kernel, bias=None, stride=1, padding=0, dilation=1, groups=1
):
    assert input_tensor.ndim == 4 and kernel.ndim == 4
    if padding > 0:
        input_tensor = F.pad(input_tensor, (padding, padding, padding, padding))
    if bias is None:
        bias = torch.zeros((kernel.size(0),))
    if groups > 1:
        raise NotImplemented
    # NCHW -> NHWC
    input_tensor = torch.permute(input_tensor, (0, 2, 3, 1)).contiguous()
    batch_size, input_h, input_w, in_channel = input_tensor.shape
    # Cout,Cin,K,K -> K, K, C_in, C_out
    kernel = torch.permute(kernel, (2, 3, 1, 0))
    kernel_h, kernel_w, in_channel, out_channel = kernel.shape

    # Calculate dilated kernel dimensions
    # (k - 1) * (d - 1) + k
    dilated_win_h = (kernel_h - 1) * dilation + 1
    dilated_win_w = (kernel_w - 1) * dilation + 1

    # Calculate output dimensions
    output_h = (input_h - dilated_win_h) // stride + 1
    output_w = (input_w - dilated_win_w) // stride + 1

    output_tensor = torch.zeros(batch_size, output_h, output_w, out_channel)
    for i in range(kernel_h):
        for j in range(kernel_w):
            input_slice = input_tensor[
                :,
                i * dilation : i * dilation + output_h * stride : stride,
                j * dilation : j * dilation + output_w * stride : stride,
            ]
            output_tensor += input_slice @ kernel[i, j]
    output_tensor += bias
    return output_tensor.permute((0, 3, 1, 2))

In [11]:
output_tensor3 = conv_matrix_multiplication(
    input_tensor, kernel, bias, stride=stride, padding=padding, dilation=dilation
)
print(output_tensor3)
print("mean diff: ", torch.mean(output_tensor3 - output_tensor_torch))

tensor([[[[ 5.5165,  1.5675,  4.4946,  3.6341],
          [-1.3439,  2.2366,  9.1844,  4.2583],
          [ 1.1333,  1.0637, 11.7791,  8.8648],
          [ 3.2420,  3.1308,  0.2205,  4.3939]],

         [[ 0.7311,  2.7015,  1.6318,  1.2317],
          [-0.4422,  3.3261, -6.8021,  1.3115],
          [-1.0562, -3.0663, -1.3042, -0.7498],
          [ 0.7304, -1.6261,  1.4370, -0.0572]],

         [[ 1.9504,  4.7555,  1.8208,  0.0887],
          [-1.7191, -3.8353,  4.9241,  3.5999],
          [-0.2118,  4.1998,  3.5847,  5.6590],
          [ 4.1347,  3.7212,  3.9359, -2.5025]],

         [[ 6.0964,  9.5853,  5.6209,  0.0882],
          [-3.4786, -2.2103,  4.6434,  2.4786],
          [ 0.5002,  0.0610,  0.4057,  9.4221],
          [ 4.6218,  1.4769, -6.3351, -7.0424]]]])
mean diff:  tensor(-2.0489e-08)


## 卷积的 im2col 的实现

暂不支持分组卷积的功能。

In [12]:
def conv_im2col(
    input_tensor, kernel, bias=None, stride=1, padding=0, dilation=1, groups=1
):
    assert input_tensor.ndim == 4 and kernel.ndim == 4
    if padding > 0:
        input_tensor = F.pad(input_tensor, (padding, padding, padding, padding))
    if bias is None:
        bias = torch.zeros((kernel.size(0),))
    if groups > 1:
        raise NotImplemented
    # NCHW -> NHWC
    input_tensor = torch.permute(input_tensor, (0, 2, 3, 1)).contiguous()
    batch_size, input_h, input_w, in_channel = input_tensor.shape
    out_channel, _, kernel_h, kernel_w = kernel.shape
    output_h = (input_h - (kernel_h - 1) * dilation - 1) // stride + 1
    output_w = (input_w - (kernel_w - 1) * dilation - 1) // stride + 1
    ns, hs, ws, cs = input_tensor.stride()
    inner_dim = kernel_h * kernel_w * in_channel
    input_tensor = input_tensor.as_strided(
        (batch_size, output_h, output_w, kernel_h, kernel_w, in_channel),
        (ns, hs * stride, ws * stride, hs * dilation, ws * dilation, cs),
    )
    input_tensor = input_tensor.reshape(-1, inner_dim)
    out = input_tensor @ kernel.permute(2, 3, 1, 0).reshape(inner_dim, -1) + bias
    return out.reshape(batch_size, output_h, output_w, out_channel).permute(0, 3, 1, 2)

In [13]:
output_tensor4 = conv_im2col(
    input_tensor, kernel, bias, stride=stride, padding=padding, dilation=dilation
)
print(output_tensor4)
print("mean diff: ", torch.mean(output_tensor4 - output_tensor_torch))

tensor([[[[ 5.5165,  1.5675,  4.4946,  3.6341],
          [-1.3439,  2.2366,  9.1844,  4.2583],
          [ 1.1333,  1.0637, 11.7791,  8.8648],
          [ 3.2420,  3.1308,  0.2205,  4.3939]],

         [[ 0.7311,  2.7015,  1.6318,  1.2317],
          [-0.4422,  3.3261, -6.8021,  1.3115],
          [-1.0562, -3.0663, -1.3042, -0.7498],
          [ 0.7304, -1.6261,  1.4370, -0.0572]],

         [[ 1.9504,  4.7555,  1.8208,  0.0887],
          [-1.7191, -3.8353,  4.9241,  3.5999],
          [-0.2118,  4.1998,  3.5847,  5.6590],
          [ 4.1347,  3.7212,  3.9359, -2.5025]],

         [[ 6.0964,  9.5853,  5.6209,  0.0882],
          [-3.4786, -2.2103,  4.6434,  2.4786],
          [ 0.5002,  0.0610,  0.4057,  9.4221],
          [ 4.6218,  1.4769, -6.3351, -7.0424]]]])
mean diff:  tensor(-5.2154e-08)


## 转置卷积

In [14]:
import torch
import torch.nn.functional as F

input_h, input_w = 2, 2
in_channel, out_channel = 4, 8
batch_size = 1
kernel_size = 3
padding = 0
stride = 1
dilation = 1
groups = 2

input_tensor = torch.randn((batch_size, in_channel, input_h, input_w))
# weight 的定义是按 正向的 conv 的参数进行定义的
weight = torch.randn((in_channel, out_channel // groups, kernel_size, kernel_size))
bias = None

output_tensor_torch = F.conv_transpose2d(
    input_tensor,
    weight,
    bias,
    stride=stride,
    padding=padding,
    dilation=dilation,
    groups=groups,
)
print(output_tensor_torch.shape)
print(output_tensor_torch)

torch.Size([1, 8, 4, 4])
tensor([[[[-1.3045,  1.7306,  1.2207, -1.9655],
          [-1.7244,  2.4155,  0.9919,  1.1436],
          [ 0.0267, -1.2458, -2.0820,  0.1620],
          [-0.2911, -0.3754,  3.1075, -1.3857]],

         [[-1.5654, -2.2614, -0.1039,  1.5502],
          [-2.1295,  0.8266,  2.1824,  1.8024],
          [-1.3203,  3.8147,  1.0539, -1.0710],
          [-1.7585,  1.0746,  0.4378, -1.0524]],

         [[ 1.4943, -1.5160,  1.1899, -0.8850],
          [-1.2057, -0.1939,  1.3018, -2.5872],
          [-1.8554,  2.1404,  0.8853, -1.3569],
          [-0.2891, -1.4639,  1.7653,  1.2461]],

         [[ 0.8424,  0.4848,  2.1443,  1.4764],
          [-0.5719, -3.7087,  4.5201, -1.9685],
          [ 2.0317, -5.7544,  4.1698, -2.6215],
          [ 1.9188, -2.2231,  1.4131, -0.4709]],

         [[-0.6279, -1.1798,  0.9922, -0.3383],
          [ 0.4551,  0.4996, -0.9740, -0.5333],
          [ 2.1379,  0.2680,  1.1029,  0.3916],
          [-0.9653, -0.3141,  0.5062,  0.1220]],

     

In [15]:
def single_channel_transposed_conv(
    input_tensor, kernel, bias=None, stride=1, padding=0, dilation=1
):
    assert input_tensor.ndim == 2 and kernel.ndim == 2

    # Set bias to zero if not provided
    if bias is None:
        bias = torch.zeros(1)

    input_h, input_w = input_tensor.shape
    kernel_h, kernel_w = kernel.shape

    # Calculate dilated kernel dimensions
    # (k - 1) * (d - 1) + k
    dilated_win_h = (kernel_h - 1) * dilation + 1
    dilated_win_w = (kernel_w - 1) * dilation + 1

    # Calculate output dimensions
    output_h = (input_h - 1) * stride + dilated_win_h - 2 * padding
    output_w = (input_w - 1) * stride + dilated_win_w - 2 * padding

    # Initialize output tensor
    output_tensor = torch.zeros(output_h, output_w)

    # Perform convolution
    for i in range(input_h):
        for j in range(input_w):
            # Extract the sliding window from the input tensor
            output_slice = output_tensor[
                i * stride : i * stride + dilated_win_h : dilation,
                j * stride : j * stride + dilated_win_w : dilation,
            ]
            # Perform element-wise multiplication and sum
            output_slice += input_tensor[i, j] * kernel
    output_tensor += bias
    return output_tensor

In [16]:
outptu_tensor = single_channel_transposed_conv(
    input_tensor[0, 0], kernel[0, 0], bias, stride, padding, dilation
)
print(outptu_tensor.shape)
print(outptu_tensor)

torch.Size([4, 4])
tensor([[-0.0991,  0.4821,  0.3528,  0.3290],
        [-0.4932, -0.0368,  0.7613,  2.8093],
        [-0.2827, -0.7647, -0.3647,  3.1284],
        [-0.3378, -0.9811, -0.7781,  0.8926]])


In [14]:
def conv_transposed2d(
    input_tensor: torch.Tensor,
    kernel: torch.Tensor,
    bias=None,
    stride=1,
    padding=0,
    dilation=1,
    groups=1,
):
    batch_size, in_channel, input_h, input_w = input_tensor.shape
    out_channel, _, kernel_h, kernel_w = kernel.shape

    upsampled_input_h = (input_h - 1) * stride + 1
    upsampled_input_w = (input_w - 1) * stride + 1

    upsampled_tensor = torch.zeros(
        batch_size, in_channel, upsampled_input_h, upsampled_input_w
    )
    upsampled_tensor[:, :, ::stride, ::stride] = input_tensor

    kernel = torch.flip(kernel, dims=(2, 3))
    # 将 kernel 先按 in_channels 拆成多组：[4,4,3,3] 拆分为两个[2,4,3,3]
    group_kernels = torch.chunk(kernel, groups)
    # 再将 kernel 在 out_channels 上Concat 起来，合并一个[2,8,3,3]
    kernel = torch.concat(group_kernels, dim=1).permute(1, 0, 2, 3)

    # 计算需要的 padding
    output_padding = (kernel_h - 1) * dilation + 1 - padding - 1

    return F.conv2d(
        upsampled_tensor,
        kernel,
        bias,
        stride=1,
        padding=output_padding,
        dilation=dilation,
        groups=groups,
    )

In [18]:
outptu_tensor = conv_transposed2d(
    input_tensor, weight, bias, stride, padding, dilation, groups
)
print(outptu_tensor.shape)
print(outptu_tensor)
print("match: ", torch.allclose(output_tensor_torch, outptu_tensor))

torch.Size([1, 8, 4, 4])
tensor([[[[-1.3045,  1.7306,  1.2207, -1.9655],
          [-1.7244,  2.4155,  0.9919,  1.1436],
          [ 0.0267, -1.2458, -2.0820,  0.1620],
          [-0.2911, -0.3754,  3.1075, -1.3857]],

         [[-1.5654, -2.2614, -0.1039,  1.5502],
          [-2.1295,  0.8266,  2.1824,  1.8024],
          [-1.3203,  3.8147,  1.0539, -1.0710],
          [-1.7585,  1.0746,  0.4378, -1.0524]],

         [[ 1.4943, -1.5160,  1.1899, -0.8850],
          [-1.2057, -0.1939,  1.3018, -2.5872],
          [-1.8554,  2.1404,  0.8853, -1.3569],
          [-0.2891, -1.4639,  1.7653,  1.2461]],

         [[ 0.8424,  0.4848,  2.1443,  1.4764],
          [-0.5719, -3.7087,  4.5201, -1.9685],
          [ 2.0317, -5.7544,  4.1698, -2.6215],
          [ 1.9188, -2.2231,  1.4131, -0.4709]],

         [[-0.6279, -1.1798,  0.9922, -0.3383],
          [ 0.4551,  0.4996, -0.9740, -0.5333],
          [ 2.1379,  0.2680,  1.1029,  0.3916],
          [-0.9653, -0.3141,  0.5062,  0.1220]],

     

## 卷积的梯度计算

### Pytorch的反向传播

In [40]:
import torch
import torch.nn.functional as F

batch_size = 1
input_w, input_h = 11, 11
kernel_size = 3
in_channel, out_channel = 4, 8
padding = 1
stride = 2
dilation = 2
groups = 1

input_tensor = torch.randn(batch_size, in_channel, input_w, input_h, requires_grad=True)
kernel = torch.randn(
    out_channel, in_channel // groups, kernel_size, kernel_size, requires_grad=True
)
bias = torch.randn(out_channel, requires_grad=True)

output_tensor = F.conv2d(
    input_tensor,
    kernel,
    bias,
    stride=stride,
    padding=padding,
    dilation=dilation,
    groups=groups,
)
output_tensor.retain_grad()
loss = output_tensor.sum()
loss.backward()
output_tensor_grad = output_tensor.grad
input_tensor_grad = input_tensor.grad
kernel_grad = kernel.grad
bias_grad = bias.grad

### 使用卷积来实现反向传播


未实现 groups 的功能

In [46]:
def conv2d_backward(
    output_grad,
    input_tensor,
    kernel,
    bias=None,
    stride=1,
    padding=0,
    dilation=1,
    groups=1,
):
    batch_size, out_channel, output_h, output_w = output_grad.shape
    kernel_size = kernel.size(3)

    upsampled_grad_h = (output_h - 1) * stride + 1
    upsampled_grad_w = (output_w - 1) * stride + 1

    upsampled_grad_tensor = torch.zeros(
        batch_size, out_channel, upsampled_grad_h, upsampled_grad_w
    )
    upsampled_grad_tensor[:, :, ::stride, ::stride] = output_grad

    flipped_kernel = torch.flip(kernel, dims=(2, 3)).permute(1, 0, 2, 3)

    # 计算需要的 padding
    output_padding = (kernel_size - 1) * dilation + 1 - padding - 1

    input_grad = F.conv2d(
        upsampled_grad_tensor,
        flipped_kernel,
        bias=None,
        stride=1,
        padding=output_padding,
        dilation=dilation,
        groups=groups,
    )

    # C_in,C_out,K,K -> C_out,C_in,K,K
    weight_grad = F.conv2d(
        input_tensor.permute(1, 0, 2, 3),  # N,C_in,H,W->C_in,N,H,W
        upsampled_grad_tensor.permute(1, 0, 2, 3),  # N,C_out,H,W -> C_out,N,H,W
        bias=None,
        stride=1,
        padding=padding,
        dilation=1,
        groups=groups,
    ).permute(1, 0, 2, 3)
    weight_grad = weight_grad[:, :, ::dilation, ::dilation]

    bias_grad = torch.sum(output_grad, dim=(0, 2, 3))

    return input_grad, (weight_grad, bias_grad)

In [47]:
input_tensor_grad_1, (kernel_grad_1, bias_grad_1) = conv2d_backward(
    output_tensor_grad,
    input_tensor,
    kernel,
    None,
    stride=stride,
    padding=padding,
    dilation=dilation,
    groups=groups,
)
print(torch.allclose(input_tensor_grad, input_tensor_grad_1))
print(torch.allclose(kernel_grad, kernel_grad_1))
print(torch.allclose(bias_grad, bias_grad_1))

True
True
True


### 使用转置卷积与矩阵乘法

未实现 groups 的功能

In [48]:
def conv2d_backward_use_transposed(
    output_grad: torch.Tensor,
    input_tensor,
    kernel,
    bias=None,
    stride=1,
    padding=0,
    dilation=1,
    groups=1,
):
    input_grad = F.conv_transpose2d(
        output_grad,
        kernel,
        bias,
        stride=stride,
        padding=padding,
        dilation=dilation,
        groups=groups,
    )
    # F.unflod 输出的是 (N,C*K*K,L)
    input_unflod = F.unfold(
        input_tensor, kernel.size(3), dilation=dilation, padding=padding, stride=stride
    ).permute(0, 2, 1)
    input_unflod = input_unflod.reshape(-1, input_unflod.size(-1))
    weight_grad = input_unflod.T @ output_grad.permute(0, 2, 3, 1).reshape(
        input_unflod.size(0), -1
    )
    weight_grad = weight_grad.reshape(
        input_tensor.size(1), kernel.size(3), kernel.size(3), -1
    ).permute(3, 0, 1, 2)
    bias_grad = torch.sum(output_grad, dim=(0, 2, 3))
    return input_grad, (weight_grad, bias_grad)

In [49]:
input_tensor_grad_2, (kernel_grad_2, bias_grad_2) = conv2d_backward_use_transposed(
    output_tensor_grad,
    input_tensor,
    kernel,
    None,
    stride=stride,
    padding=padding,
    dilation=dilation,
    groups=groups,
)
print(torch.allclose(input_tensor_grad, input_tensor_grad_2))
print(torch.allclose(kernel_grad, kernel_grad_2))
print(torch.allclose(bias_grad, bias_grad_2))

True
True
True
