In [9]:
class GoogLeNet(nn.Module):
    def __init__(self, num_classes=10):  # 修改为 CIFAR-10 的类别数
        super(GoogLeNet, self).__init__()

        # 减小第一个卷积层的stride，保留更多空间信息
        self.conv1 = BasicConv2d(3, 64, kernel_size=3, stride=1, padding=1)  # 修改kernel_size和stride
        self.maxpool1 = nn.MaxPool2d(3, stride=2, padding=1)  # 添加padding=1
        self.conv2 = BasicConv2d(64, 192, kernel_size=1)
        self.conv3 = BasicConv2d(192, 192, kernel_size=3, padding=1)
        self.maxpool2 = nn.MaxPool2d(3, stride=2, padding=1)  # 添加padding=1

        self.inception3a = Inception(192, 64, 96, 128, 16, 32, 32)
        self.inception3b = Inception(256, 128, 128, 192, 32, 96, 64)
        self.maxpool3 = nn.MaxPool2d(3, stride=2, padding=1)  # 添加padding=1

        self.inception4a = Inception(480, 192, 96, 208, 16, 48, 64)
        self.inception4b = Inception(512, 160, 112, 224, 24, 64, 64)
        self.inception4c = Inception(512, 128, 128, 256, 24, 64, 64)
        self.inception4d = Inception(512, 112, 144, 288, 32, 64, 64)
        self.inception4e = Inception(528, 256, 160, 320, 32, 128, 128)
        # 移除第四个最大池化层，因为特征图已经很小了
        # self.maxpool4 = nn.MaxPool2d(2, stride=2)

        self.inception5a = Inception(832, 256, 160, 320, 32, 128, 128)
        self.inception5b = Inception(832, 384, 192, 384, 48, 128, 128)

        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))  # 自适应池化到1x1
        self.dropout = nn.Dropout(0.4)
        self.fc = nn.Linear(1024, num_classes)

    def forward(self, x):
        x = self.conv1(x)  # 32x32 -> 32x32 (修改stride后)
        x = self.maxpool1(x)  # 32x32 -> 16x16
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.maxpool2(x)  # 16x16 -> 8x8

        x = self.inception3a(x)
        x = self.inception3b(x)
        x = self.maxpool3(x)  # 8x8 -> 4x4

        x = self.inception4a(x)
        x = self.inception4b(x)
        x = self.inception4c(x)
        x = self.inception4d(x)
        x = self.inception4e(x)
        # 移除maxpool4
        # x = self.maxpool4(x)  # 4x4 -> 2x2

        x = self.inception5a(x)
        x = self.inception5b(x)

        x = self.avgpool(x)  # 任意尺寸 -> 1x1
        x = torch.flatten(x, 1)
        x = self.dropout(x)
        x = self.fc(x)
        return x


# -------------------------------
# Step 2: 数据预处理和加载
# -------------------------------

transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)

train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False, num_workers=2)


# -------------------------------
# Step 3: 初始化模型、损失函数、优化器
# -------------------------------

device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = GoogLeNet(num_classes=10).to(device)  # CIFAR-10 只有 10 类
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)


# -------------------------------
# Step 4: 训练函数
# -------------------------------

def train(epoch):
    model.train()
    total_loss = 0
    correct = 0
    total = 0
    for inputs, targets in train_loader:
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()

    print(f"Epoch {epoch} | Loss: {total_loss:.3f} | Acc: {100.*correct/total:.2f}%")


# -------------------------------
# Step 5: 测试函数
# -------------------------------

def test():
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, targets in test_loader:
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
    print(f"Test Accuracy: {100.*correct/total:.2f}%")


# -------------------------------
# Step 6: 开始训练和测试
# -------------------------------

for epoch in range(1, 6):  # 简单训练 5 个 epochs
    train(epoch)
test()

Files already downloaded and verified
Files already downloaded and verified
Epoch 1 | Loss: 547.272 | Acc: 49.06%
Epoch 2 | Loss: 350.711 | Acc: 68.95%
Epoch 3 | Loss: 274.516 | Acc: 76.09%
Epoch 4 | Loss: 229.703 | Acc: 80.01%
Epoch 5 | Loss: 204.539 | Acc: 82.44%
Test Accuracy: 79.67%


In [10]:
from torchsummary import summary

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = GoogLeNet().to(device)
summary(model, (3, 224, 224))

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

model = GoogLeNet(num_classes=1000)
total_params = count_parameters(model)
print(f"Total trainable parameters: {total_params / 1e6:.2f} M")

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 224, 224]           1,728
       BatchNorm2d-2         [-1, 64, 224, 224]             128
              ReLU-3         [-1, 64, 224, 224]               0
       BasicConv2d-4         [-1, 64, 224, 224]               0
         MaxPool2d-5         [-1, 64, 112, 112]               0
            Conv2d-6        [-1, 192, 112, 112]          12,288
       BatchNorm2d-7        [-1, 192, 112, 112]             384
              ReLU-8        [-1, 192, 112, 112]               0
       BasicConv2d-9        [-1, 192, 112, 112]               0
           Conv2d-10        [-1, 192, 112, 112]         331,776
      BatchNorm2d-11        [-1, 192, 112, 112]             384
             ReLU-12        [-1, 192, 112, 112]               0
      BasicConv2d-13        [-1, 192, 112, 112]               0
        MaxPool2d-14          [-1, 192,