In [None]:
import os
import pickle
import tarfile
import numpy as np

def extract_cifar10(file_path, extract_path="./cifar-10-batches-py"):
    if not os.path.exists(extract_path):
        with tarfile.open(file_path, "r:gz") as tar:
            tar.extractall()
    return extract_path

def load_batch(batch_file):
    with open(batch_file, 'rb') as f:
        entry = pickle.load(f, encoding='latin1')
        data = entry['data']
        labels = entry['labels']
        data = data.reshape(-1, 3, 32, 32).astype(np.uint8)
        data = np.transpose(data, (0, 2, 3, 1))  # (N, H, W, C)
        return data, labels


In [14]:
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

class CIFAR10Dataset(Dataset):
    def __init__(self, data, labels, transform=None):
        self.data = data  # keep as numpy array!
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        img = self.data[idx]
        label = self.labels[idx]
        if self.transform:
            img = self.transform(img)  # Transform will convert it to tensor
        return img, label


def prepare_dataloaders(data_path, batch_size=128):
    # Load training data
    train_data, train_labels = [], []
    for i in range(1, 6):
        data, labels = load_batch(os.path.join(data_path, f"data_batch_{i}"))
        train_data.append(data)
        train_labels += labels
    train_data = np.concatenate(train_data)

    # Load test data
    test_data, test_labels = load_batch(os.path.join(data_path, "test_batch"))

    # Transform (PyTorch tensors, normalization optional)
    transform = transforms.Compose([
        transforms.ToTensor(),  # Already 0-1, so ToTensor suffices
    ])

    train_set = CIFAR10Dataset(train_data, train_labels, transform)
    test_set = CIFAR10Dataset(test_data, test_labels, transform)

    train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False)

    return train_loader, test_loader


In [15]:
import torch.nn as nn
import torch.optim as optim

class SimpleVGG(nn.Module):
    def __init__(self):
        super(SimpleVGG, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 64, 3, padding=1), nn.ReLU(),
            nn.Conv2d(64, 64, 3, padding=1), nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(64, 128, 3, padding=1), nn.ReLU(),
            nn.Conv2d(128, 128, 3, padding=1), nn.ReLU(),
            nn.MaxPool2d(2),
        )
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(8 * 8 * 128, 256), nn.ReLU(),
            nn.Linear(256, 10)
        )

    def forward(self, x):
        x = self.features(x)
        return self.classifier(x)

def train_model(model, train_loader, test_loader, device, epochs=10):
    model = model.to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=1e-3)

    for epoch in range(epochs):
        model.train()
        total_loss = 0
        for data, targets in train_loader:
            data, targets = data.to(device), targets.to(device)
            optimizer.zero_grad()
            outputs = model(data)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        print(f"Epoch {epoch+1}, Loss: {total_loss:.4f}")
        test_model(model, test_loader, device)

def test_model(model, loader, device):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data, targets in loader:
            data, targets = data.to(device), targets.to(device)
            outputs = model(data)
            _, preds = torch.max(outputs, 1)
            correct += (preds == targets).sum().item()
            total += targets.size(0)
    print(f"Test Accuracy: {100 * correct / total:.2f}%")


In [16]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Paths
data_dir = extract_cifar10("cifar-10-python.tar.gz")
train_loader, test_loader = prepare_dataloaders(data_dir)

# Model Training
model = SimpleVGG()
train_model(model, train_loader, test_loader, device)


Epoch 1, Loss: 595.9320
Test Accuracy: 59.03%
Epoch 2, Loss: 403.8946
Test Accuracy: 64.70%
Epoch 3, Loss: 323.1286
Test Accuracy: 71.90%
Epoch 4, Loss: 260.7529
Test Accuracy: 74.30%
Epoch 5, Loss: 211.1789
Test Accuracy: 74.33%
Epoch 6, Loss: 168.5321
Test Accuracy: 75.02%
Epoch 7, Loss: 126.6453
Test Accuracy: 74.94%
Epoch 8, Loss: 90.5166
Test Accuracy: 75.34%
Epoch 9, Loss: 62.8233
Test Accuracy: 74.08%
Epoch 10, Loss: 50.9830
Test Accuracy: 74.73%
