# ConvMixer模型原理及其PyTorch逐行实现

来自b站up主deep_thoughts 合集【PyTorch源码教程与前沿人工智能算法复现讲解】

P_17_ConvMixer模型原理及其PyTorch逐行实现：
    
https://www.bilibili.com/video/BV1K34y1o74P/?spm_id_from=pageDriver&vd_source=18e91d849da09d846f771c89a366ed40

torch.nn.Conv2d 官方文档：https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html#torch.nn.Conv2d

***论文***

Patches Are All You Need?

https://openreview.net/pdf?id=TVHS5Y4dNvM

## 论文中模型定义代码

In [4]:
import torch
import torch.nn as nn

def ConvMixer(h, depth, kernel_size=9, patch_size=7, n_classes=1000):
    Seq, ActBn = nn.Sequential, lambda x: Seq(x, nn.GELU(), nn.BatchNorm2d(h))
    Residual = type('Residual', (Seq,), {'forward': lambda self, x: self[0](x) + x})
    return Seq(ActBn(nn.Conv2d(3, h, patch_size, stride=patch_size)),
               *[Seq(Residual(ActBn(nn.Conv2d(h, h, kernel_size, groups=h, padding="same"))),
                     ActBn(nn.Conv2d(h, h, 1))) for i in range(depth)],
               nn.AdaptiveAvgPool2d((1,1)), nn.Flatten(), nn.Linear(h, n_classes))

In [8]:
h = 16
depth = 1

image = torch.randn((1,3,14,14))
conv_mixer = ConvMixer(h, depth)
out = conv_mixer(image)
print(out.size())

torch.Size([1, 1000])


## depthwise 和 pointwise 参数量对比

In [2]:
import torch
import torch.nn as nn

conv_general = nn.Conv2d(3,3,3,padding="same")
subconv_space_mixing = nn.Conv2d(3,3,3,groups=3,padding="same")
subconv_channel_mixing = nn.Conv2d(3,3,1)

for p in conv_general.parameters():
    print(torch.numel(p))

for p in subconv_space_mixing.parameters():
    print(torch.numel(p))
    
for p in subconv_channel_mixing.parameters():
    print(torch.numel(p))

81
3
27
3
9
3
