In [1]:
from torch.utils.data import Dataset
import pickle
from PIL import Image

class CustomDataset(Dataset):
    def __init__(self, data_files, transform=None):
        """
        Args:
            data_files (list): List of file paths to the data batches.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        self.data = []
        self.labels = []
        self.transform = transform

        # Load data from all files
        for file in data_files:
            with open(file, 'rb') as f:
                batch = pickle.load(f, encoding='bytes')
                self.data.extend(batch[b'data'])  # Assuming data is stored under the key b'data'
                self.labels.extend(batch[b'labels'])  # Assuming labels are stored under the key b'labels'

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        image = self.data[idx]
        label = self.labels[idx]
        
        # 检查原始数据的形状
        # print(f"Original data shape: {image.shape}, {label}")

        image = image.reshape(32, 32, 3).astype('uint8')
        image = Image.fromarray(image)  # 转换为 PIL 图像

        # 检查转换后的形状
        # print(f"Transformed image shape: {image.shape}")

        if self.transform:
            image = self.transform(image)
        return image, label

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class VGGNet(nn.Module):
    def __init__(self, num_classes=10):
        super(VGGNet, self).__init__()
        self.features = nn.Sequential(
            # Block 1
            nn.Conv2d(3, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),

            # Block 2
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),

            # Block 3
            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),

            # Block 4
            nn.Conv2d(256, 512, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),

            # Block 5
            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )

        self.classifier = nn.Sequential(
            nn.Linear(512, 512),
            nn.ReLU(inplace=True),
            nn.Dropout(),
            nn.Linear(512, 512),
            nn.ReLU(inplace=True),
            nn.Dropout(),
            nn.Linear(512, num_classes),
        )

    def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x

def evaluate(model, test_loader, device):
    """评估模型在测试集上的准确率"""
    model.eval()  # 确保模型处于评估模式
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model.forward(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    accuracy = 100 * correct / total
    return accuracy

def train_model(model, train_loader, criterion, optimizer, num_epochs, device, test_loader=None):
    for epoch in range(num_epochs):
        epoch_loss = 0 
        model.train()  # 确保模型处于训练模式
        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)  # 将数据移动到设备
            optimizer.zero_grad()
            outputs = model.forward(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item() 
        
        # print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')
        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss / len(train_loader):.4f}') 
        
        # 每 10 个 epoch 评估一次
        if (epoch + 1) % 5 == 0 and test_loader is not None:
            print(f"Evaluating at epoch {epoch+1}...")
            accuracy = evaluate(model, test_loader, device)
            print(f'Accuracy at epoch {epoch+1}: {accuracy:.2f}%')

In [None]:
import torch
import torch.optim as optim
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import transforms

test_files = [
    '/Users/mini4chaos/Developer/cifar-10/data/test_batch'
]

data_files = [
    '/Users/mini4chaos/Developer/cifar-10/data/data_batch_1',
    '/Users/mini4chaos/Developer/cifar-10/data/data_batch_2',
    '/Users/mini4chaos/Developer/cifar-10/data/data_batch_3',
    '/Users/mini4chaos/Developer/cifar-10/data/data_batch_4',
    '/Users/mini4chaos/Developer/cifar-10/data/data_batch_5'
]

# transform = transforms.Compose([  
#     transforms.RandomHorizontalFlip(),  
#     transforms.ToTensor(),  
#     transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261)),  
# ])

transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    # transforms.RandomCrop(32, padding=4),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261)),
])

# transform = transforms.Compose([
#     transforms.ToTensor(),
#     transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
# ])

batch_size = 32

device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print(f"Using device: {device}")

# 加载数据集
train_dataset = CustomDataset(data_files, transform=transform)
test_dataset = CustomDataset(test_files, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

# for inputs, labels in train_loader:
#     print(f"Inputs shape: {inputs.shape}, Labels shape: {labels.shape}")
#     break

# 初始化模型、损失函数和优化器
model = VGGNet().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# 训练模型
num_epochs = 20
train_model(model, train_loader, criterion, optimizer, num_epochs, device, test_loader)

Using device: mps
