# 手写实现nn.TransposedConv转置卷积
来自b站up主deep_thoughts 合集【PyTorch源码教程与前沿人工智能算法复现讲解】

P_25_手写实现nn.TransposedConv转置卷积：

https://www.bilibili.com/video/BV1fM4y1F7j5/?spm_id_from=333.788&vd_source=18e91d849da09d846f771c89a366ed40

torch.nn.Unfold 官方文档：https://pytorch.org/docs/stable/generated/torch.nn.Unfold.html (用于对input展开，同上一节step2手写的取region区域做展开，但是上一节没考虑通道维度）

转置卷积官方文档：https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose2d.html

## step4：通过对kernel进行展开来实现二维卷积，并得出转置卷积，不考虑batch、channel大小，假设stride=1

In [18]:
import torch
import torch.nn as nn
import torch.nn.functional as F

def get_kernel_matrix(kernel, input_size, stride=1):
    """基于kernel和输入特征图的大小来得到填充拉直后的kernel堆叠后的矩阵"""
    kernel_h, kernel_w = kernel.shape
    input_h, input_w = input_size
    num_out_feat_map = (input_h - kernel_h + 1) * (input_w - kernel_w +1)
    result = torch.zeros((num_out_feat_map, input_h * input_w))  # 初始化结果矩阵，输出特征图元素个数*输入特征图元素个数
    count = 0
    for i in range(0, input_h-kernel_h+1, 1):
        for j in range(0, input_w-kernel_w+1, 1):
            padded_kernel = F.pad(kernel, (j, input_w-kernel_w-j, i, input_h-kernel_h-i))  # 填充成跟输入特征图一样大小
            result[count] = padded_kernel.flatten()
            count += 1
    return result


## 测试1：验证二维卷积

In [19]:
kernel = torch.randn(3, 3)
input = torch.randn(4, 4)
kernel_matrix = get_kernel_matrix(kernel, input.shape)  # 4*16
mm_conv2d_output = kernel_matrix @ input.reshape((-1, 1))  # 通过矩阵乘积来计算卷积
pytorch_conv2d_output = F.conv2d(input.unsqueeze(0).unsqueeze(0), kernel.unsqueeze(0).unsqueeze(0))  # PyTorch conv2d API
# print(kernel)
# print(kernel_matrix)
print(mm_conv2d_output.reshape((2,2)))
print(pytorch_conv2d_output)

tensor([[ 1.7583, -3.0868],
        [-2.0347,  0.5580]])
tensor([[[[ 1.7583, -3.0868],
          [-2.0347,  0.5580]]]])


## 测试2：验证二维转置卷积

In [20]:
print(kernel_matrix.shape)
print(kernel_matrix.transpose(-1, -2).shape)
mm_transposed_conv2d_output = kernel_matrix.transpose(-1, -2) @ mm_conv2d_output  # 通过矩阵乘积来计算转置卷积
pytorch_transposed_conv2d_output = F.conv_transpose2d(pytorch_conv2d_output, kernel.unsqueeze(0).unsqueeze(0))
print(mm_transposed_conv2d_output.reshape(4, 4))
print(pytorch_transposed_conv2d_output)

torch.Size([4, 16])
torch.Size([16, 4])
tensor([[ 1.0477, -3.6902,  5.7490, -4.3885],
        [-3.2956,  6.9641, -5.3846,  1.5702],
        [ 0.6556,  0.9386,  1.1903,  0.7276],
        [ 2.0307,  0.0418,  0.4080, -0.1569]])
tensor([[[[ 1.0477, -3.6902,  5.7490, -4.3885],
          [-3.2956,  6.9641, -5.3846,  1.5702],
          [ 0.6556,  0.9386,  1.1903,  0.7276],
          [ 2.0307,  0.0418,  0.4080, -0.1569]]]])
