In [5]:
# Preprocess CIFAR-10 datasets 
import pickle
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
import os

# load CIFAR-10 datasets
def load_batch(file_path):
    with open(file_path, 'rb') as fo:
        batch = pickle.load(fo, encoding='bytes')
    data = batch[b'data']
    labels = batch[b'labels']
    data = data.reshape(-1, 3, 32, 32).astype(np.uint8)  # trans to 32x32x3 size
    return data, labels

# combine all of the data
def load_cifar10_data(data_dir):
    train_data, train_labels = [], []
    
    for i in range(1, 6):
        batch_file = os.path.join(data_dir, f'data_batch_{i}')
        data, labels = load_batch(batch_file)
        train_data.append(data)
        train_labels.extend(labels)
    
    # combine
    train_data = np.concatenate(train_data)
    train_labels = np.array(train_labels)
    
    # load test data
    test_data, test_labels = load_batch(os.path.join(data_dir, 'test_batch'))
    test_labels = np.array(test_labels)
    
    return train_data, train_labels, test_data, test_labels

class CIFAR10Dataset(Dataset):
    def __init__(self, data, labels, transform=None):
        self.data = data
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        # trans img size to  32x32x3
        image = self.data[idx].reshape(3, 32, 32)  
        image = image.transpose((1, 2, 0)) 
        
        label = self.labels[idx]

        if self.transform:
            image = self.transform(image)

        return image, label


# prepocess and data augmentation
transform = transforms.Compose([
    transforms.ToPILImage(),              
    transforms.RandomHorizontalFlip(),     
    transforms.RandomCrop(32, padding=4), 
    transforms.ToTensor(),                
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))  # normalization
])


data_dir = "D:/Study/7318/a2/Realm-of-Deep-Learning/Realm-of-Deep-Learning/datasets/cifar-10-batches-py"

train_data, train_labels, test_data, test_labels = load_cifar10_data(data_dir)

train_dataset = CIFAR10Dataset(train_data, train_labels, transform=transform)
test_dataset = CIFAR10Dataset(test_data, test_labels, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)


In [9]:
# Train CIFAR-10 datasets with resNet-18 model
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = models.resnet18(pretrained=False)  #
model.fc = nn.Linear(model.fc.in_features, 10)  # CIFAR-10 has 10 feature classes
model = model.to(device)

# loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# train model
def train_model(model, train_loader, criterion, optimizer, num_epochs=10):
    model.train()
    for epoch in range(num_epochs):
        running_loss = 0.0
        correct = 0
        total = 0
        
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device).long()
            
            # forward
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            # backward
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            # get loss and accuracy
            running_loss += loss.item() * images.size(0)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
        
        epoch_loss = running_loss / total
        accuracy = correct / total
        print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss:.4f}, Accuracy: {accuracy:.4f}")

train_model(model, train_loader, criterion, optimizer, num_epochs=30)

model_save_path = "resnet18_cifar10_epoch30_lr001.pth"  
torch.save(model.state_dict(), model_save_path)
print(f"model saved")

Epoch [1/30], Loss: 1.8084, Accuracy: 0.3413
Epoch [2/30], Loss: 1.3975, Accuracy: 0.4873
Epoch [3/30], Loss: 1.1830, Accuracy: 0.5760
Epoch [4/30], Loss: 1.0412, Accuracy: 0.6315
Epoch [5/30], Loss: 0.9576, Accuracy: 0.6632
Epoch [6/30], Loss: 0.8747, Accuracy: 0.6948
Epoch [7/30], Loss: 0.8163, Accuracy: 0.7166
Epoch [8/30], Loss: 0.7782, Accuracy: 0.7306
Epoch [9/30], Loss: 0.7287, Accuracy: 0.7474
Epoch [10/30], Loss: 0.6961, Accuracy: 0.7605
Epoch [11/30], Loss: 0.6796, Accuracy: 0.7657
Epoch [12/30], Loss: 0.6461, Accuracy: 0.7787
Epoch [13/30], Loss: 0.6183, Accuracy: 0.7863
Epoch [14/30], Loss: 0.6045, Accuracy: 0.7937
Epoch [15/30], Loss: 0.5851, Accuracy: 0.7972
Epoch [16/30], Loss: 0.5674, Accuracy: 0.8046
Epoch [17/30], Loss: 0.5455, Accuracy: 0.8114
Epoch [18/30], Loss: 0.5298, Accuracy: 0.8181
Epoch [19/30], Loss: 0.5241, Accuracy: 0.8212
Epoch [20/30], Loss: 0.5023, Accuracy: 0.8274
Epoch [21/30], Loss: 0.4939, Accuracy: 0.8316
Epoch [22/30], Loss: 0.4876, Accuracy: 0.83