In [2]:
# 加载库
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image

# 数据预处理
transform = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# 自定义数据集类
class GarbageDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.image_paths = []
        self.labels = []

        for label in range(40):
            label_dir = os.path.join(self.root_dir, str(label))
            if not os.path.isdir(label_dir):
                print(f"Directory not found: {label_dir}")
                continue
            for img_name in os.listdir(label_dir):
                img_path = os.path.join(label_dir, img_name)
                if os.path.isfile(img_path):
                    self.image_paths.append(img_path)
                    self.labels.append(label)
        
        print(f"Total images found: {len(self.image_paths)}")
        
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert('RGB')
        if self.transform:
            image = self.transform(image)
        label = self.labels[idx]
        return image, label

# 加载训练数据
train_dataset = GarbageDataset(root_dir='train/train', transform=transform)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)

# 定义模型
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.fc1 = nn.Linear(128*16*16, 512)
        self.fc2 = nn.Linear(512, 40)

    def forward(self, x):
        x = nn.ReLU()(self.conv1(x))
        x = nn.MaxPool2d(2, 2)(x)
        x = nn.ReLU()(self.conv2(x))
        x = nn.MaxPool2d(2, 2)(x)
        x = nn.ReLU()(self.conv3(x))
        x = nn.MaxPool2d(2, 2)(x)
        x = x.view(-1, 128*16*16)
        x = nn.ReLU()(self.fc1(x))
        x = self.fc2(x)
        return x

# 实例化模型、损失函数和优化器
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = SimpleCNN().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# 模型训练
num_epochs = 1  # 调试阶段将epoch数减少
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for inputs, labels in train_loader:
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    
    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader):.4f}')

# 保存模型
torch.save(model.state_dict(), 'best_model.pth')

# 预测测试集结果
class TestDataset(Dataset):
    def __init__(self, root_dir, file_path, transform=None):
        self.root_dir = root_dir
        self.file_path = file_path
        self.transform = transform
        with open(file_path, 'r') as f:
            self.image_paths = [line.strip() for line in f]

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = os.path.join(self.root_dir, self.image_paths[idx])
        image = Image.open(img_path).convert('RGB')
        if self.transform:
            image = self.transform(image)
        return image, self.image_paths[idx]

# 加载测试数据
test_dataset = TestDataset(root_dir='test/test', file_path='test/testpath.txt', transform=transform)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

# 加载训练好的模型
model.load_state_dict(torch.load('best_model.pth'))
model.eval()

# 预测测试集结果
predictions = []
with torch.no_grad():
    for inputs, image_names in test_loader:
        inputs = inputs.to(device)
        outputs = model(inputs)
        _, predicted = torch.max(outputs, 1)
        predictions.extend(zip(image_names, predicted.cpu().tolist()))
        
# 保存预测结果
with open('result.csv', 'w') as f:
    for image_name, label in predictions:
        f.write(f'{label}\n')  
        

Total images found: 14402
Epoch [1/1], Loss: 3.1512
