![](https://i.imgur.com/rCyvEOB.png)

![](https://i.imgur.com/QEpwWGP.png)

In [1]:
import torch

In [7]:
class MyModule(torch.nn.Module):
    def __init__(self):
        super(MyModule, self).__init__()
        self.conv_00 = torch.nn.Conv2d(3, 32, 5, padding=2)  # Input: 3@32×32, Output: 32@32×32
        self.maxpool_00 = torch.nn.MaxPool2d(2)  # Input: 32@32×32, Output: 32@16×16
        self.conv_01 = torch.nn.Conv2d(32, 32, 5, padding=2)  # Input: 32@16×16, Output: 32@16×16
        self.maxpool_01 = torch.nn.MaxPool2d(2)  # Input: 32@16×16, Output: 32@8×8
        self.conv_02 = torch.nn.Conv2d(32, 64, 5, padding=2)  # Input: 32@8×8, Output: 64@8×8
        self.maxpool_02 = torch.nn.MaxPool2d(2)  # Input: 64@8×8, Output: 64@4×4
        self.flatten = torch.nn.Flatten()  # Input: 64@4×4, Output: 1024
        self.linear_00 = torch.nn.Linear(1024, 64)  # Input: 1024, Output: 64
        self.linear_01 = torch.nn.Linear(64, 10)  # Input: 64, Output: 10

    def forward(self, inputs):
        _ = self.conv_00(inputs)
        _ = self.maxpool_00(_)
        _ = self.conv_01(_)
        _ = self.maxpool_01(_)
        _ = self.conv_02(_)
        _ = self.maxpool_02(_)
        _ = self.flatten(_)
        _ = self.linear_00(_)
        outputs = self.linear_01(_)
        return outputs

module = MyModule()
print(module)

MyModule(
  (conv_00): Conv2d(3, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (maxpool_00): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv_01): Conv2d(32, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (maxpool_01): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv_02): Conv2d(32, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (maxpool_02): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (linear_00): Linear(in_features=1024, out_features=64, bias=True)
  (linear_01): Linear(in_features=64, out_features=10, bias=True)
)


In [8]:
# 测试模型类
demo_inputs = torch.ones((64, 3, 32, 32))
demo_inputs.shape

torch.Size([64, 3, 32, 32])

In [10]:
demo_outputs = module(demo_inputs)
demo_outputs.shape
# 输入Inputs的shape为[64, 3, 32, 32], 输出Outputs的shape应该为[64, 10].

torch.Size([64, 10])

In [11]:
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear

# 使用Sequential来优化网络结构
class MyModule(torch.nn.Module):
    def __init__(self):
        super(MyModule, self).__init__()

        self.model = torch.nn.Sequential(
            Conv2d(3, 32, 5, padding=2),  # Input: 3@32×32, Output: 32@32×32
            MaxPool2d(2),  # Input: 32@32×32, Output: 32@16×16
            Conv2d(32, 32, 5, padding=2),  # Input: 32@16×16, Output: 32@16×16
            MaxPool2d(2),  # Input: 32@16×16, Output: 32@8×8
            Conv2d(32, 64, 5, padding=2),  # Input: 32@8×8, Output: 64@8×8
            MaxPool2d(2),  # Input: 64@8×8, Output: 64@4×4
            Flatten(),  # Input: 64@4×4, Output: 1024
            Linear(1024, 64),  # Input: 1024, Output: 64
            Linear(64, 10)  # Input: 64, Output: 10
        )

    def forward(self, inputs):
        outputs = self.model(inputs)
        return outputs

module = MyModule()
print(module)

MyModule(
  (model): Sequential(
    (0): Conv2d(3, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (2): Conv2d(32, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (4): Conv2d(32, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Flatten(start_dim=1, end_dim=-1)
    (7): Linear(in_features=1024, out_features=64, bias=True)
    (8): Linear(in_features=64, out_features=10, bias=True)
  )
)


In [12]:
demo_outputs = module(demo_inputs)
demo_outputs.shape
# 输入Inputs的shape为[64, 3, 32, 32], 输出Outputs的shape应该为[64, 10].

torch.Size([64, 10])

In [13]:
# 使用tensorboard可视化网络结构
from torch.utils.tensorboard import SummaryWriter

writer = SummaryWriter('./logs')

In [14]:
writer.add_graph(module, demo_inputs)

writer.close()
# tensorboard --logdir='logs'