In [6]:
import torch
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset
import pandas as pd
from PIL import Image
import os

# 数据预处理
import os

class BananaDataset(Dataset):
    def __init__(self, csv_file, base_dir, transform=None):
        self.data = pd.read_csv(csv_file)
        self.base_dir = "/kaggle/input/banana-classification/valid"  # 图片的基目录
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        image_filename = self.data.iloc[idx, 0]
        image_path = os.path.join(self.base_dir, image_filename)  # 构造绝对路径
#         print(f"Trying to open image: {image_path}")  # 打印图片路径
        image = Image.open(image_path)
        label = self.data.iloc[idx, 1:].values.astype(float)  # 将标签转换为浮点数
        if self.transform:
            image = self.transform(image)
        return image, torch.tensor(label, dtype=torch.float32)

# 转换
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# 数据集和数据加载器
dataset = BananaDataset(csv_file='/kaggle/input/banana-classification/valid/_classes.csv',  base_dir="/kaggle/input/banana-classification/valid",transform=transform)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

# 模型
model = torchvision.models.resnet50(pretrained=True)
model.fc = torch.nn.Linear(model.fc.in_features, 6)  # 修改最后一层
# 冻结除第3层和第4层以及全连接层之外的所有层
for param in model.parameters():
    param.requires_grad = False

# 仅第3层和第4层以及全连接层可训练
for param in model.layer3.parameters():
    param.requires_grad = True
for param in model.layer4.parameters():
    param.requires_grad = True
for param in model.fc.parameters():
    param.requires_grad = True
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)


# 损失函数和优化器
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

num_epochs = 15 # 示例周期数
for epoch in range(num_epochs):
    running_loss = 0.0
    correct = 0
    total = 0
    for images, labels in dataloader:
        images = images.to(device)
        labels = labels.to(device)

        # 前向传播
        outputs = model(images)
        loss = criterion(outputs, labels.argmax(dim=1))

        # 反向传播和优化
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # 计算损失和准确率
        running_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels.argmax(dim=1)).sum().item()

    # 打印每个周期的损失和准确率
    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(dataloader):.4f}, Accuracy: {100 * correct/total:.2f}%')

    # 保存模型
    torch.save(model.state_dict(), f'model_epoch_{epoch+1}.pth')






Epoch [1/15], Loss: 0.8170, Accuracy: 68.77%
Epoch [2/15], Loss: 0.4622, Accuracy: 82.55%
Epoch [3/15], Loss: 0.3780, Accuracy: 86.35%
Epoch [4/15], Loss: 0.2738, Accuracy: 91.14%
Epoch [5/15], Loss: 0.1428, Accuracy: 95.47%
Epoch [6/15], Loss: 0.1786, Accuracy: 93.77%
Epoch [7/15], Loss: 0.2272, Accuracy: 92.65%
Epoch [8/15], Loss: 0.0749, Accuracy: 97.64%
Epoch [9/15], Loss: 0.0552, Accuracy: 98.16%
Epoch [10/15], Loss: 0.0489, Accuracy: 98.62%
Epoch [11/15], Loss: 0.1043, Accuracy: 97.11%
Epoch [12/15], Loss: 0.0556, Accuracy: 98.49%
Epoch [13/15], Loss: 0.0259, Accuracy: 99.15%
Epoch [14/15], Loss: 0.0810, Accuracy: 97.44%
Epoch [15/15], Loss: 0.1162, Accuracy: 95.93%


In [10]:
# 测试模型的函数
def test_model(model, test_loader, device):
    model.eval()  # 设置模型为评估模式
    correct = 0
    total = 0
    with torch.no_grad():  # 在测试阶段不计算梯度
        for images, labels in test_loader:
            images = images.to(device)
            labels = labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels.argmax(dim=1)).sum().item()

    print(f'Accuracy of the model on the test images: {100 * correct / total}%')

# 创建测试数据加载器（确保使用正确的数据集和数据变换）
test_dataloader = DataLoader(BananaDataset(csv_file='/kaggle/input/banana-classification/valid/_classes.csv', base_dir="/kaggle/input/banana-classification/valid", transform=transform), batch_size=32, shuffle=False)

# 加载最后一个保存的模型并测试
model.load_state_dict(torch.load(f'model_epoch_{10}.pth'))
model.to(device)
test_model(model, test_dataloader, device)

Accuracy of the model on the test images: 98.68766404199475%
