In [4]:
import torch
import torch.nn as nn
import torchvision.models.resnet as resnet

class CifarResNet(nn.Module):
    def __init__(
        self,
        layers=[1, 1, 1],
        num_classes=10,
        channels=[64, 128, 256]
    ):
        super(CifarResNet, self).__init__()
        self.inplanes = channels[0]

        self.conv1 = nn.Conv2d(3, channels[0], kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(channels[0])
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.Identity()

        self.layer1 = self._make_layer(resnet.BasicBlock, channels[0], layers[0])
        self.layer2 = self._make_layer(resnet.BasicBlock, channels[1], layers[1], stride=2)
        self.layer3 = self._make_layer(resnet.BasicBlock, channels[2], layers[2], stride=2)

        # 动态计算全连接层输入大小
        dummy_input = torch.randn(1, 3, 32, 32)
        dummy_output = self.layer3(self.layer2(self.layer1(self.maxpool(self.relu(self.bn1(self.conv1(dummy_input)))))))
        num_features = dummy_output.view(1, -1).size(1)

        self.fc = nn.Linear(num_features, num_classes)

    def _make_layer(self, block, planes, blocks, stride=1):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.inplanes, planes * block.expansion,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(planes * block.expansion),
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample))
        self.inplanes = planes * block.expansion
        for _ in range(1, blocks):
            layers.append(block(self.inplanes, planes))
        
        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)

        x = torch.flatten(x, 1)
        x = self.fc(x)

        return x

# 实例化网络
net = CifarResNet()


In [5]:
# 打印模型
print(net)

# 假设有一个输入张量
input_tensor = torch.randn(8, 3, 32, 32)  # batch_size=8, 输入大小 3x32x32
output = net(input_tensor)
print(output.shape)  # 输出的形状应该是 (8, 10)


CifarResNet(
  (conv1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): Identity()
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (layer2): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(128, 128,