# 实现模型 —— CIFAR10
模型结构如图所示：

<img src="https://yongruizhang-image.oss-cn-chengdu.aliyuncs.com/img/CIFAR10_model.png" width="400px"/>

卷积前后尺寸计算公式：

<img alt="截屏2024-03-31 11.28.50" src="https://yongruizhang-image.oss-cn-chengdu.aliyuncs.com/img/%E6%88%AA%E5%B1%8F2024-03-31%2011.28.50.png" width="400px"/>

# 常规方式定义网络

In [18]:
import torch
from torch import nn
from torch.nn import Conv2d, Sequential, MaxPool2d, Flatten, Linear

class MY_CIFAR10(nn.Module):
    def __init__(self):
        super(MY_CIFAR10, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, 5, padding=2) # 由于卷积核长度和宽度相同，因此把(5,5)写成5即可。输入图片长宽为32，输出长宽也为32，根据公式计算得出padding=2，stride和dilation都是默认为1
        self.maxpool1 = nn.MaxPool2d(2) # kernel = 2
        self.conv2 = nn.Conv2d(32, 32, 5, padding=2) # 和conv1上面同理
        self.maxpool2 = nn.MaxPool2d(2)
        self.conv3 = nn.Conv2d(32, 64, 5, padding=2)
        self.maxpool3 = nn.MaxPool2d(2)  # 这一步得到的结果为64*4*4 = 1024
        self.flatten = nn.Flatten()  # 将数据展平为 一维，即1*1024
        self.linear1 = nn.Linear(1024, 64)  # 这一步途中没有，但是自己可以分析出来
        self.linear2 = nn.Linear(64, 10)  # 最终为10个类别
        
    def forward(self, x):
        x = self.conv1(x)
        x = self.maxpool1(x)
        x = self.conv2(x)
        x = self.maxpool2(x)
        x = self.conv3(x)
        x = self.maxpool3(x)
        x = self.flatten(x)
        x = self.linear1(x)
        x = self.linear2(x)
        return x

## 实例化网络

In [13]:
my_cifar10 = MY_CIFAR10()

In [14]:
print(my_cifar10)

MY_CIFAR10(
  (conv1): Conv2d(3, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (maxpool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(32, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (maxpool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv3): Conv2d(32, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (maxpool3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (linear1): Linear(in_features=1024, out_features=64, bias=True)
  (linear2): Linear(in_features=64, out_features=10, bias=True)
)


In [15]:
input = torch.ones((64, 3, 32, 32))  # 生成一个batch_size=64，channel=3，height=32，weight=32
output = my_cifar10(input)
print(output.shape)

torch.Size([64, 10])


# 使用Sequential定义网络
使用Sequential会使得代码简洁很多，方便写页方便读

In [25]:
import torch
from torch import nn
from torch.nn import Conv2d, Sequential, MaxPool2d, Flatten, Linear

class MY_CIFAR10_Sequential(nn.Module):
    def __init__(self):
        super(MY_CIFAR10_Sequential, self).__init__()
        self.model1 = Sequential(
            Conv2d(3, 32, 5, padding=2),
            MaxPool2d(2),
            Conv2d(32, 32, 5, padding=2),
            MaxPool2d(2),
            Conv2d(32, 64, 5, padding=2),
            MaxPool2d(2),
            Flatten(),
            Linear(1024, 64),
            Linear(64, 10)
        )
    
    def forward(self, x):
        x = self.model1(x)
        return x

In [26]:
my_cifar10_sequential = MY_CIFAR10_Sequential()
print(my_cifar10_sequential)

MY_CIFAR10_Sequential(
  (model1): Sequential(
    (0): Conv2d(3, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (2): Conv2d(32, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (4): Conv2d(32, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Flatten(start_dim=1, end_dim=-1)
    (7): Linear(in_features=1024, out_features=64, bias=True)
    (8): Linear(in_features=64, out_features=10, bias=True)
  )
)


In [27]:
input = torch.ones((64, 3, 32, 32))
ouput = my_cifar10_sequential(input)
print(output.shape)

torch.Size([64, 10])


# 使用TensorBoard可视化
如下代码执行后，在本文件夹下创建一个名为log_my_cifar10的文件夹
然后在本文件夹的路径下打开终端，输入命令：`tensorboard --logdir=log_my_cifar10` 即可打开tensorboard

In [28]:
from torch.utils.tensorboard import SummaryWriter

writer = SummaryWriter("log_my_cifar10")
writer.add_graph(my_cifar10_sequential, input)
writer.close()