In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset, random_split
from torchvision import transforms
from PIL import Image
import os
import torch

# 强制使用 GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# 自定义数据集类
class EnhancedDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.classes = sorted(os.listdir(root_dir))
        self.class_to_idx = {cls: i for i, cls in enumerate(self.classes)}
        self.images = self._load_images()

    def _load_images(self):
        images = []
        for cls in self.classes:
            cls_dir = os.path.join(self.root_dir, cls)
            for img_name in os.listdir(cls_dir):
                img_path = os.path.join(cls_dir, img_name)
                images.append((img_path, self.class_to_idx[cls]))
        return images

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path, label = self.images[idx]
        image = Image.open(img_path).convert('RGB')
        if self.transform:
            image = self.transform(image)
        return image, label

# CBAM模块定义
class ChannelAttention(nn.Module):
    def __init__(self, channel, reduction=16):
        super().__init__()
        self.maxpool = nn.AdaptiveMaxPool2d(1)
        self.avgpool = nn.AdaptiveAvgPool2d(1)
        self.se = nn.Sequential(
            nn.Conv2d(channel, channel // reduction, 1, bias=False),
            nn.ReLU(),
            nn.Conv2d(channel // reduction, channel, 1, bias=False)
        )
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        max_result = self.maxpool(x)
        avg_result = self.avgpool(x)
        max_out = self.se(max_result)
        avg_out = self.se(avg_result)
        output = self.sigmoid(max_out + avg_out)
        return output

class SpatialAttention(nn.Module):
    def __init__(self, kernel_size=7):
        super().__init__()
        self.conv = nn.Conv2d(2, 1, kernel_size=kernel_size, padding=kernel_size // 2)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        max_result, _ = torch.max(x, dim=1, keepdim=True)
        avg_result = torch.mean(x, dim=1, keepdim=True)
        result = torch.cat([max_result, avg_result], 1)
        output = self.conv(result)
        output = self.sigmoid(output)
        return output

class CBAMBlock(nn.Module):
    def __init__(self, channel=512, reduction=16, kernel_size=7):
        super().__init__()
        self.ca = ChannelAttention(channel=channel, reduction=reduction)
        self.sa = SpatialAttention(kernel_size=kernel_size)

    def forward(self, x):
        out = x * self.ca(x)
        out = out * self.sa(out)
        return out

# 带CBAM的分类模型
class CBAMClassifier(nn.Module):
    def __init__(self, num_classes, in_channels=3):  # 
        super().__init__()
        
        # 基础卷积模块
        self.base = nn.Sequential(
            nn.Conv2d(in_channels, 64, kernel_size=3, padding=1),  #
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2),
            
            CBAMBlock(channel=64, reduction=16),
            
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d(2),
            
            CBAMBlock(channel=128, reduction=16),
            
            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.AdaptiveAvgPool2d(1)
        )
        
        # 分类器
        self.classifier = nn.Sequential(
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(128, num_classes)
        )

    def forward(self, x):
        x = self.base(x)
        x = x.view(x.size(0), -1)
        return self.classifier(x)


# 训练配置
def train_model():
    # 检查GPU是否可用
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # 数据预处理
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
    ])

    # 加载数据集
    dataset = EnhancedDataset(
        root_dir=r"L:\常惠林\萎凋\自然萎凋\分类",
        transform=transform
    )
    
    # 按6:2:2划分训练集、验证集和测试集
    train_size = int(0.6 * len(dataset))
    val_size = int(0.2 * len(dataset))
    test_size = len(dataset) - train_size - val_size
    train_dataset, val_dataset, test_dataset = random_split(dataset, [train_size, val_size, test_size])

    # 创建数据加载器
    train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=64)
    test_loader = DataLoader(test_dataset, batch_size=64)

    # 初始化模型并移动到GPU
    model = CBAMClassifier(num_classes=len(dataset.classes)).to(device)

    # 损失函数和优化器
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.0001, weight_decay=1e-4)

    # 训练循环
    best_acc = 0.0
    for epoch in range(100):
        model.train()
        running_loss = 0.0
        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)  # 数据移动到GPU
            
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()

        # 验证阶段
        model.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for inputs, labels in val_loader:
                inputs, labels = inputs.to(device), labels.to(device)  # 数据移动到GPU
                outputs = model(inputs)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

        val_acc = 100 * correct / total
        print(f"Epoch {epoch+1}/50 | Loss: {running_loss/len(train_loader):.4f} | Val Acc: {val_acc:.2f}%")

        # 保存最佳模型
        if val_acc > best_acc:
            best_acc = val_acc
            torch.save(model.state_dict(), "best_model.pth")

    # 测试阶段
    model.load_state_dict(torch.load("best_model.pth"))
    model.eval()
    test_correct = 0
    test_total = 0
    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)  # 数据移动到GPU
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            test_total += labels.size(0)
            test_correct += (predicted == labels).sum().item()

    test_acc = 100 * test_correct / test_total
    print(f"Test Accuracy: {test_acc:.2f}%")

if __name__ == "__main__":
    train_model()

Using device: cuda:0
Using device: cuda
Epoch 1/50 | Loss: 0.8863 | Val Acc: 31.06%
Epoch 2/50 | Loss: 0.7187 | Val Acc: 49.50%
Epoch 3/50 | Loss: 0.6853 | Val Acc: 67.13%
Epoch 4/50 | Loss: 0.6368 | Val Acc: 39.68%
Epoch 5/50 | Loss: 0.5389 | Val Acc: 47.70%
Epoch 6/50 | Loss: 0.5415 | Val Acc: 49.50%
Epoch 7/50 | Loss: 0.5354 | Val Acc: 48.90%
Epoch 8/50 | Loss: 0.4495 | Val Acc: 65.33%
Epoch 9/50 | Loss: 0.4067 | Val Acc: 38.28%
Epoch 10/50 | Loss: 0.3529 | Val Acc: 44.29%
Epoch 11/50 | Loss: 0.3105 | Val Acc: 33.67%
Epoch 12/50 | Loss: 0.3194 | Val Acc: 40.48%
Epoch 13/50 | Loss: 0.2412 | Val Acc: 75.55%
Epoch 14/50 | Loss: 0.2649 | Val Acc: 31.06%
Epoch 15/50 | Loss: 0.2072 | Val Acc: 74.15%
Epoch 16/50 | Loss: 0.1925 | Val Acc: 74.35%
Epoch 17/50 | Loss: 0.1445 | Val Acc: 67.54%
Epoch 18/50 | Loss: 0.1986 | Val Acc: 56.31%
Epoch 19/50 | Loss: 0.1258 | Val Acc: 56.11%
Epoch 20/50 | Loss: 0.1754 | Val Acc: 33.27%
Epoch 21/50 | Loss: 0.1332 | Val Acc: 54.51%
Epoch 22/50 | Loss: 0.09

  model.load_state_dict(torch.load("best_model.pth"))


Test Accuracy: 77.80%
