In [26]:
import pandas as pd
import os
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from torchvision import transforms

class AnimalDataset(Dataset):
    def __init__(self, img_dir, transform=None):
        self.img_dir = img_dir
        self.transform = transform
        self.data = []

        # 遍历文件夹下的所有图片
        for fname in os.listdir(img_dir):
            if fname.endswith('.jpg') or fname.endswith('.png'):
                # 根据文件名前缀提取标签
                if fname.startswith('cat'):
                    label = 0
                elif fname.startswith('dog'):
                    label = 1
                else:
                    continue
                self.data.append((fname, label))

    def __getitem__(self, idx):
        fname, label = self.data[idx] # 获取图片名和标签
        img_path = os.path.join(self.img_dir, fname) # 构建图片路径
        image = Image.open(img_path).convert('RGB')  # 转为RGB # 读取图片
        if self.transform:
            image = self.transform(image)
        else:
            image = transforms.Compose([
                transforms.Resize((224, 224)),
                transforms.ToTensor()
            ])(image)
        return image, label

    def __len__(self):
        return len(self.data)


class TestDataset(Dataset):
    def __init__(self, img_dir, csv_path, transform=None):
        self.img_dir = img_dir
        self.transform = transform

        # 读取CSV文件
        df = pd.read_csv(csv_path)

        # 保存每张图片的id和label
        self.data = [(str(row['id']) + '.jpg', int(row['label'])) for _, row in df.iterrows()]

    def __getitem__(self, idx):
        fname, label = self.data[idx]
        img_path = os.path.join(self.img_dir, fname)
        image = Image.open(img_path).convert('RGB')  # RGB彩色图像
        if self.transform:
            image = self.transform(image)
        else:
            image = transforms.Compose([
                transforms.Resize((224, 224)),
                transforms.ToTensor()
            ])(image)
        return image, label

    def __len__(self):
        return len(self.data)


In [27]:
# 定义图像预处理
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor()
])

# 加载数据集
train_dataset = AnimalDataset('data/dog_and_cat/train', transform=transform)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_dataset = TestDataset('data/dog_and_cat/test', csv_path='data/dog_and_cat/sampleSubmission.csv', transform=transform)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

# # 取一批数据看看
# images, labels = next(iter(train_loader))
# print(images)   # torch.Size([32, 3, 28, 28])
# print(labels[:10])

# 取一批数据看看
images, labels = next(iter(test_loader))
print(images)
print(labels[:10])

tensor([[[[0.1529, 0.1451, 0.1569,  ..., 0.3333, 0.2784, 0.1882],
          [0.1647, 0.1490, 0.1373,  ..., 0.3333, 0.2745, 0.1882],
          [0.1529, 0.1255, 0.1373,  ..., 0.3451, 0.2902, 0.2000],
          ...,
          [0.6824, 0.7098, 0.7176,  ..., 0.5373, 0.5373, 0.5098],
          [0.6431, 0.6941, 0.7216,  ..., 0.4667, 0.3804, 0.3882],
          [0.6667, 0.6784, 0.7098,  ..., 0.6392, 0.4667, 0.4157]],

         [[0.2588, 0.2471, 0.2588,  ..., 0.4627, 0.4157, 0.3333],
          [0.2706, 0.2510, 0.2392,  ..., 0.4549, 0.4078, 0.3255],
          [0.2588, 0.2275, 0.2392,  ..., 0.4510, 0.4000, 0.3216],
          ...,
          [0.6980, 0.7255, 0.7333,  ..., 0.5216, 0.5137, 0.4824],
          [0.6588, 0.7059, 0.7373,  ..., 0.4588, 0.3686, 0.3725],
          [0.6784, 0.6902, 0.7216,  ..., 0.6078, 0.4353, 0.3804]],

         [[0.3765, 0.3804, 0.4000,  ..., 0.8392, 0.7804, 0.6784],
          [0.3882, 0.3843, 0.3804,  ..., 0.8353, 0.7725, 0.6745],
          [0.3765, 0.3608, 0.3804,  ..., 0

In [28]:
from matplotlib import pyplot as plt
from utils.accuracy import evaluate_accuracy
import torch
from torch import nn
from utils.init import init_weights
from models.alexnet import AlexNet

# 定义网络
net = AlexNet()
net.apply(init_weights)

# 定义损失函数和优化器
loss = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(net.parameters(), lr=0.001)

# 训练
epochs = 10
train_losses, train_accs, test_accs = [], [], []
for epoch in range(1, epochs+1):
    net.train()
    total_loss, total_acc, total_count = 0, 0, 0
    for X, y in train_loader:
        y_hat = net(X)
        l = loss(y_hat, y)
        optimizer.zero_grad()
        l.backward()
        optimizer.step()


        total_loss += l.item() * y.numel()
        total_acc += (y_hat.argmax(dim=1) == y).sum().item()
        total_count += y.numel()

    train_acc = total_acc / total_count
    train_loss = total_loss / total_count
    test_acc = evaluate_accuracy(net, test_loader)

    train_losses.append(train_loss)
    train_accs.append(train_acc)
    test_accs.append(test_acc)
    print(f'Epoch {epoch}, Loss {train_loss:.4f}, Acc {train_acc:.4f}, Test Acc {test_acc:.4f}')

plt.figure(figsize=[12, 5])
plt.subplot(1, 2, 1)
plt.plot(range(1, epochs+1), train_losses, 'o-', label='Training Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.title('Loss Curve')
plt.legend()
plt.subplot(1, 2, 2)
plt.plot(range(1, epochs+1), train_accs, 'o-', label='Training Accuracy')
plt.plot(range(1, epochs+1), test_accs, 'o-', label='Test Accuracy')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.title('Accuracy Curve')
plt.legend()
plt.show()



KeyboardInterrupt: 