# 手写实现PyTorch的DilatedConv和GroupConv

来自b站up主deep_thoughts 合集【PyTorch源码教程与前沿人工智能算法复现讲解】

P_27_手写实现PyTorch的DilatedConv和GroupConv：

https://www.bilibili.com/video/BV1UY411W7W8/?spm_id_from=333.788&vd_source=18e91d849da09d846f771c89a366ed40

torch.nn.Conv2d 官方文档：https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html#torch.nn.Conv2d

## 讲解diliation

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

a = torch.randn(7, 7)

print(a)

print(a[0:3, 0:3])  # dilation=1

print(a[0:5:2, 0:5:2])  # dilation=2

print(a[0:7:3, 0:7:3])  # dilation=3

tensor([[ 1.3251,  0.5899, -1.2114, -0.1594, -0.4984, -0.9406, -2.7473],
        [ 0.5222,  0.0237,  0.0589, -1.1198, -0.0831, -1.6083, -0.8463],
        [ 0.3990, -0.3168, -0.2034, -0.1573,  1.0243, -0.4935, -0.0531],
        [ 1.7709, -0.6033,  1.4683,  0.0527, -0.5170,  0.6849,  1.2501],
        [-0.2548, -0.8616,  0.2300,  0.1580,  3.1022,  1.0031, -0.3412],
        [-0.2472,  0.2144,  0.9049, -0.3527,  1.1092,  2.0951,  1.9061],
        [-0.0876,  0.2803,  1.4772, -0.3633, -0.9858, -1.3790,  0.3779]])
tensor([[ 1.3251,  0.5899, -1.2114],
        [ 0.5222,  0.0237,  0.0589],
        [ 0.3990, -0.3168, -0.2034]])
tensor([[ 1.3251, -1.2114, -0.4984],
        [ 0.3990, -0.2034,  1.0243],
        [-0.2548,  0.2300,  3.1022]])
tensor([[ 1.3251, -0.1594, -2.7473],
        [ 1.7709,  0.0527,  1.2501],
        [-0.0876, -0.3633,  0.3779]])


## 讲解group convolution

In [None]:
in_channel, out_channel = 2, 4
groups = 2
sub_in_channel, sub_out_channel = 1, 2

# group>1，通道融合不需要完全充分，只需要在一个个group内进行融合，最后拼接
# 1*1 pointwise convolution用于融合各通道

## 实现空洞卷积和群卷积(dilation and group)

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

def matrix_multiplication_for_conv2d_final(input, kernel, bias=0, stride=1, padding=0, dilation=1, groups=1):
    if padding > 0:
        input = F.pad(input, (padding, padding, padding, padding, 0, 0, 0, 0))
        
    bs, in_channel, input_h, input_w = input.shape
    out_channel, _, kernel_h, kernel_w = kernel.shape
    
    assert out_channel % groups == 0 and in_channel % groups == 0, "groups必须要同时被输入通道数和输出通道数整除！"
    input = input.reshape((bs, groups, in_channel//groups, input_h, input_w))
    kernel = kernel.reshape((groups, out_channel//groups, in_channel//groups, kernel_h, kernel_w))
    
    kernel_h = (kernel_h-1)*(dilation-1)+kernel_h
    kernel_w = (kernel_w-1)*(dilation-1)+kernel_w
    
    output_h = math.floor((input_h-kernel_h)/stride)+1
    output_w = math.floor((input_w-kernel_w)/stride)+1
    output_shape = (bs, groups, out_channel//groups, output_h, output_w)
    output = torch.zeros(output_shape)
    
    if bias is None:
        bias = torch.zeros(out_channel)
        
    for ind in range(bs):  # 对batchsize进行遍历
        for g in range(groups):  # 对群组进行遍历
            for oc in range(out_channel//groups):  # 对分组后的输出通道进行遍历
                for ic in range(in_channel//groups):  # 对分组后的输入通道进行遍历
                    for i in range(0, input_h-kernel_h+1, stride):  # 对高度遍历
                        for j in range(0, input_w-kernel_w+1, stride):  # 对宽度遍历
                            region = input[ind, g, ic, i:i+kernel_h:dilation, j:j+kernel_w:dilation]  # 特征区域
                            output[ind, g, oc, int(i/stride), int(j/stride)] += torch.sum(region * kernel[g, oc, ic])
                output[ind, g, oc] += bias[g*(out_channel//groups)+oc]  # 考虑偏置项
    
    output = output.reshape((bs, out_channel, output_h, output_w))  # 还原成四维张量
    
    return output

## 以下为验证测试的代码，验证函数与PyTorch API结果是否一致

In [9]:
kernel_size = 3
bs, in_channel, input_h, input_w = 2, 2, 5, 5
out_channel = 4
groups, dilation, stride, padding = 2, 2, 2, 1

input = torch.randn(bs, in_channel, input_h, input_w)
kernel = torch.randn(out_channel, in_channel//groups, kernel_size, kernel_size)
bias = torch.randn(out_channel)

pytorch_conv2d_api_output = F.conv2d(input, kernel, bias=bias, padding=padding, stride=stride, dilation=dilation, groups=groups)  # PyTorch API结果
mm_conv2d_final_output = matrix_multiplication_for_conv2d_final(input, kernel, bias=bias, padding=padding, stride=stride, dilation=dilation, groups=groups)  # 自定义代码结果

flag = torch.allclose(pytorch_conv2d_api_output, mm_conv2d_final_output)
print(flag)

True
