In [6]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split

# 确保使用独显
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# 数据预处理
train_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

val_test_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# 加载数据集
data_dir = r"L:\常惠林\萎凋\自然萎凋\分类"
full_dataset = datasets.ImageFolder(data_dir, transform=train_transform)

# 6:2:2 划分
total_size = len(full_dataset)
train_size = int(total_size * 0.6)
val_size = int(total_size * 0.2)
test_size = total_size - train_size - val_size

train_dataset, val_dataset, test_dataset = random_split(full_dataset, [train_size, val_size, test_size])

# 修改验证集和测试集的 transform
val_dataset.dataset.transform = val_test_transform
test_dataset.dataset.transform = val_test_transform

# DataLoader
batch_size = 32
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)

# MobileNet 模型
class DepthwiseSeparableConv(nn.Module):
    def __init__(self, in_channels, out_channels, stride):
        super(DepthwiseSeparableConv, self).__init__()
        self.depthwise = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=stride, padding=1, groups=in_channels)
        self.pointwise = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)

    def forward(self, x):
        x = self.depthwise(x)
        x = self.pointwise(x)
        return x

class MobileNet(nn.Module):
    def __init__(self, num_classes=3):
        super(MobileNet, self).__init__()

        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=2, padding=1)
        self.bn1 = nn.BatchNorm2d(32)
        self.relu = nn.ReLU(inplace=True)

        self.conv2 = DepthwiseSeparableConv(32, 64, stride=1)
        self.conv3 = DepthwiseSeparableConv(64, 128, stride=2)
        self.conv4 = DepthwiseSeparableConv(128, 128, stride=1)
        self.conv5 = DepthwiseSeparableConv(128, 256, stride=2)
        self.conv6 = DepthwiseSeparableConv(256, 256, stride=1)
        self.conv7 = DepthwiseSeparableConv(256, 512, stride=2)
        self.conv8 = DepthwiseSeparableConv(512, 512, stride=1)
        self.conv9 = DepthwiseSeparableConv(512, 512, stride=1)
        self.conv10 = DepthwiseSeparableConv(512, 512, stride=1)
        self.conv11 = DepthwiseSeparableConv(512, 512, stride=1)
        self.conv12 = DepthwiseSeparableConv(512, 1024, stride=2)
        self.conv13 = DepthwiseSeparableConv(1024, 1024, stride=1)

        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(1024, num_classes)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)

        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        x = self.conv5(x)
        x = self.conv6(x)
        x = self.conv7(x)
        x = self.conv8(x)
        x = self.conv9(x)
        x = self.conv10(x)
        x = self.conv11(x)
        x = self.conv12(x)
        x = self.conv13(x)

        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

# 训练过程
def train_model(model, criterion, optimizer, num_epochs=50, save_path="best_model.pth"):
    best_acc = 0.0

    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0

        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item() * inputs.size(0)

        epoch_loss = running_loss / train_size
        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}")

        # 计算验证集准确率
        model.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for inputs, labels in val_loader:
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs)
                _, preds = torch.max(outputs, 1)
                total += labels.size(0)
                correct += (preds == labels).sum().item()

        val_acc = correct / total
        print(f"Validation Accuracy: {val_acc:.4f}")

        # 保存最优模型
        if val_acc > best_acc:
            best_acc = val_acc
            torch.save(model.state_dict(), save_path)
            print(f"Best model saved with accuracy: {best_acc:.4f}")

    print(f"Best Validation Accuracy: {best_acc:.4f}")

# 测试过程
def test_model(model, model_path="best_model.pth"):
    # 加载最优模型
    model.load_state_dict(torch.load(model_path))
    model.eval()

    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (preds == labels).sum().item()

    test_acc = correct / total
    print(f"Final Test Accuracy: {test_acc:.4f}")

# 运行
model = MobileNet(num_classes=3).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

train_model(model, criterion, optimizer, num_epochs=25)
test_model(model)


Using device: cuda:0
Epoch 1/25, Loss: 1.1333
Validation Accuracy: 0.3936
Best model saved with accuracy: 0.3936
Epoch 2/25, Loss: 1.0964
Validation Accuracy: 0.3936
Epoch 3/25, Loss: 1.0966
Validation Accuracy: 0.3936
Epoch 4/25, Loss: 1.0966
Validation Accuracy: 0.3936
Epoch 5/25, Loss: 1.0962
Validation Accuracy: 0.3936
Epoch 6/25, Loss: 1.0964
Validation Accuracy: 0.3936
Epoch 7/25, Loss: 1.0953
Validation Accuracy: 0.3936
Epoch 8/25, Loss: 1.0952
Validation Accuracy: 0.3936
Epoch 9/25, Loss: 1.0948
Validation Accuracy: 0.3936
Epoch 10/25, Loss: 1.0959
Validation Accuracy: 0.3936
Epoch 11/25, Loss: 1.1052
Validation Accuracy: 0.4078
Best model saved with accuracy: 0.4078
Epoch 12/25, Loss: 1.0956
Validation Accuracy: 0.3286
Epoch 13/25, Loss: 1.0862
Validation Accuracy: 0.4541
Best model saved with accuracy: 0.4541
Epoch 14/25, Loss: 1.0503
Validation Accuracy: 0.3936
Epoch 15/25, Loss: 0.9866
Validation Accuracy: 0.5494
Best model saved with accuracy: 0.5494
Epoch 16/25, Loss: 0.9

  model.load_state_dict(torch.load(model_path))


Final Test Accuracy: 0.5907
