In [13]:
import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import ImageFolder
import torch.nn as nn
import torch.optim as optim
from torchvision import models

In [14]:
# 1. 数据预处理
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # 调整图像大小以符合ResNet输入尺寸
    transforms.ToTensor(),  # 将图像转换为Tensor
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # 归一化
])

# 2. 加载数据集
train_data_path = './train_data'
train_dataset = ImageFolder(root=train_data_path, transform=transform)

# 你可以按照比例分割训练集以创建一个验证集
train_size = int(0.8 * len(train_dataset))
val_size = len(train_dataset) - train_size
train_dataset, val_dataset = random_split(train_dataset, [train_size, val_size])

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

# 3. 定义ResNet模型
model = models.resnet18(pretrained=True)  # 使用预训练的ResNet18
#model = models.resnet50(pretrained=True)  # 使用预训练的ResNet50
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 5)  # 修改最后一层以匹配类别数

# 4. 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)



In [15]:
# 5. 训练模型
num_epochs = 3
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for images, labels in train_loader:
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    print(f'Epoch {epoch+1}, Loss: {running_loss/len(train_loader)}')

    # 验证
    model.eval()
    total = 0
    correct = 0
    with torch.no_grad():
        for images, labels in val_loader:
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    print(f'Accuracy on validation set: {100 * correct / total}%')

Epoch 1, Loss: 0.9421532029906908
Accuracy on validation set: 88.29787234042553%
Epoch 2, Loss: 0.20529291282097498
Accuracy on validation set: 93.61702127659575%
Epoch 3, Loss: 0.07134090705464284
Accuracy on validation set: 100.0%


In [16]:
#torch.save(model.state_dict(), 'model_weights.pth')
torch.save(model.state_dict(), 'model_weights_resnet18.pth')

In [None]:
# 6. 测试模型
test_data_path = './test_data'
test_dataset = ImageFolder(root=test_data_path, transform=transform)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

model.eval()
total = 0
correct = 0
with torch.no_grad():
    for images, labels in test_loader:
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
print(f'Accuracy on test set: {100 * correct / total}%')