In [1]:
import os
import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision.datasets import ImageFolder
from torchvision.transforms import Resize, ToTensor
from torch.utils.data import DataLoader

In [25]:
# 设置训练和测试数据集的路径
train_data_path = 'Homework5_dataset/train_data'
test_data_path = 'Homework5_dataset/test_data'

In [26]:
# 定义数据预处理的转换操作
transform = torchvision.transforms.Compose([
    Resize((64, 64)),  # 调整图片大小为120x120
    ToTensor()  # 将图片转换为张量
])

# 加载训练集和测试集
train_dataset = ImageFolder(train_data_path, transform=transform)
test_dataset = ImageFolder(test_data_path, transform=transform)

In [27]:
# 创建数据加载器
batch_size = 64
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [28]:
# 定义卷积神经网络模型
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
        self.conv3 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.fc1 = nn.Linear(64 * 8 * 8, 128)  # 输入通道大小根据调整后的图片大小而变化
        self.fc2 = nn.Linear(128, 2)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2)
        x = F.relu(self.conv3(x))
        x = F.max_pool2d(x, 2)
        x = x.view(-1, 64 * 8 * 8)  # 输入通道大小根据调整后的图片大小而变化
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# 实例化模型
model = CNN()

In [29]:
# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [31]:
# 训练模型
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
num_epochs = 10
best_accuracy = 0.0
patience = 3
early_stop_counter = 0

kernel_sizes = [3, 5]
strides = [1, 2]
paddings = [1, 2]
hidden_numbers = [128, 256]



    for epoch in range(num_epochs):
        # 训练模型
        model.train()
        train_loss = 0.0
        train_correct = 0
        total_train = 0

        for images, labels in train_dataloader:
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()

            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            _, predicted = torch.max(outputs.data, 1)
            total_train += labels.size(0)
            train_correct += (predicted == labels).sum().item()
            train_loss += loss.item() * images.size(0)

        train_accuracy = train_correct / total_train
        train_loss = train_loss / total_train

        # 测试模型
        model.eval()
        test_loss = 0.0
        test_correct = 0
        total_test = 0

        with torch.no_grad():
            for images, labels in test_dataloader:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                loss = criterion(outputs, labels)

                _, predicted = torch.max(outputs.data, 1)
                total_test += labels.size(0)
                test_correct += (predicted == labels).sum().item()
                test_loss += loss.item() * images.size(0)

        test_accuracy = test_correct / total_test
        test_loss = test_loss / total_test

        # 输出每个Epoch的训练集和测试集结果
        print(f"Epoch {epoch+1}/{num_epochs} - Train Loss: {train_loss:.4f} - Train Acc: {train_accuracy:.4f} - Test Loss: {test_loss:.4f} - Test Acc: {test_accuracy:.4f}")

        # 检查是否提前停止训练
        if test_accuracy > best_accuracy:
            best_accuracy = test_accuracy
            early_stop_counter = 0
        else:
            early_stop_counter += 1

            if early_stop_counter >= patience:
                print("Early stopping...")
                break

# 输出模型的最终准确性
print(f"Best Test Accuracy: {best_accuracy:.4f}")

NameError: name 'itertools' is not defined