In [5]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import time

# --- 1. Setup and Data Preparation ---

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Define a Cutout augmentation class
class Cutout:
    def __init__(self, length):
        self.length = length

    def __call__(self, img):
        h, w = img.size(1), img.size(2)
        mask = torch.ones((h, w), dtype=torch.float32)
        y = torch.randint(h, (1,)).item()
        x = torch.randint(w, (1,)).item()

        y1 = torch.clamp(torch.tensor(y - self.length // 2), 0, h)
        y2 = torch.clamp(torch.tensor(y + self.length // 2), 0, h)
        x1 = torch.clamp(torch.tensor(x - self.length // 2), 0, w)
        x2 = torch.clamp(torch.tensor(x + self.length // 2), 0, w)

        mask[y1:y2, x1:x2] = 0.
        img = img * mask.unsqueeze(0)
        return img

# Define data transformations
# These are standard stats for CIFAR10
cifar10_mean = (0.4914, 0.4822, 0.4465)
cifar10_std = (0.2471, 0.2435, 0.2616)

transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(cifar10_mean, cifar10_std),
    Cutout(length=16) # Powerful augmentation
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(cifar10_mean, cifar10_std)
])

# Hyperparameters
BATCH_SIZE = 1024 # Use a large batch size for speed
NUM_WORKERS = 8 # Adjust based on your system's capabilities

# Load CIFAR-10 dataset
train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=True)

test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)

# --- 2. The Model Architecture (ResNet9) ---

def conv_block(in_channels, out_channels, pool=False):
    layers = [
        nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(inplace=True)
    ]
    if pool:
        layers.append(nn.MaxPool2d(2))
    return nn.Sequential(*layers)

class ResNet9(nn.Module):
    def __init__(self, in_channels=3, num_classes=10):
        super().__init__()

        self.conv1 = conv_block(in_channels, 64)
        self.conv2 = conv_block(64, 128, pool=True)
        self.res1 = nn.Sequential(conv_block(128, 128), conv_block(128, 128))

        self.conv3 = conv_block(128, 256, pool=True)
        self.conv4 = conv_block(256, 512, pool=True)
        self.res2 = nn.Sequential(conv_block(512, 512), conv_block(512, 512))

        self.classifier = nn.Sequential(
            nn.MaxPool2d(4),
            nn.Flatten(),
            nn.Linear(512, num_classes)
        )

    def forward(self, xb):
        out = self.conv1(xb)
        out = self.conv2(out)
        out = self.res1(out) + out
        out = self.conv3(out)
        out = self.conv4(out)
        out = self.res2(out) + out
        out = self.classifier(out)
        return out

model = ResNet9().to(device)

# --- 3. Training and Evaluation ---

# Hyperparameters for training
EPOCHS = 30
MAX_LR = 0.01
WEIGHT_DECAY = 1e-4

# Loss, optimizer, and scheduler
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=MAX_LR, weight_decay=WEIGHT_DECAY)
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, MAX_LR, epochs=EPOCHS, steps_per_epoch=len(train_loader))

# Automatic Mixed Precision (AMP) for speed
scaler = torch.amp.GradScaler('cuda')

def evaluate(model, loader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    return 100 * correct / total

# Training loop
start_time = time.time()

for epoch in range(EPOCHS):
    model.train()
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()

        # Forward pass with AMP
        with torch.amp.autocast('cuda'):
            outputs = model(images)
            loss = criterion(outputs, labels)

        # Backward pass with AMP
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        # Update learning rate
        scheduler.step()

    # Evaluate at the end of the epoch
    val_accuracy = evaluate(model, test_loader)
    print(f"Epoch [{epoch+1}/{EPOCHS}], Validation Accuracy: {val_accuracy:.2f}%")

end_time = time.time()
print(f"Total Training Time: {(end_time - start_time)/60:.2f} minutes")

# Final check on test set
final_accuracy = evaluate(model, test_loader)
print(f"\nFinal Test Accuracy: {final_accuracy:.2f}%")

Using device: cuda




Epoch [1/30], Validation Accuracy: 41.24%




Epoch [2/30], Validation Accuracy: 54.59%




Epoch [3/30], Validation Accuracy: 61.13%




Epoch [4/30], Validation Accuracy: 67.29%




Epoch [5/30], Validation Accuracy: 60.79%


Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7ef7f4b29e40>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^^^^^^
  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: can only test a child process


Epoch [6/30], Validation Accuracy: 53.69%


Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7ef7f4b29e40>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
    if w.is_alive():
       ^Exception ignored in: ^<function _MultiProcessingDataLoaderIter.__del__ at 0x7ef7f4b29e40>^
^Traceback (most recent call last):
^  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
^    ^^self._shutdown_workers()^
^^^  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers

      File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
    if w.is_alive():assert self._parent_pid == os.getpid(), 'can only test a child process'

                ^  ^^Exception ignored in: 

Epoch [7/30], Validation Accuracy: 51.69%




Epoch [8/30], Validation Accuracy: 70.54%


Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7ef7f4b29e40>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
    if w.is_alive():
  Exception ignored in:  <function _MultiProcessingDataLoaderIter.__del__ at 0x7ef7f4b29e40> 
 Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
  Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7ef7f4b29e40>^
    ^Traceback (most recent call last):
^self._shutdown_workers()^^
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers

Epoch [9/30], Validation Accuracy: 77.05%




Epoch [10/30], Validation Accuracy: 79.64%




Epoch [11/30], Validation Accuracy: 82.72%




Epoch [12/30], Validation Accuracy: 81.76%
Epoch [13/30], Validation Accuracy: 72.73%
Epoch [14/30], Validation Accuracy: 86.25%
Epoch [15/30], Validation Accuracy: 82.14%
Epoch [16/30], Validation Accuracy: 87.00%
Epoch [17/30], Validation Accuracy: 85.25%
Epoch [18/30], Validation Accuracy: 86.15%
Epoch [19/30], Validation Accuracy: 89.26%
Epoch [20/30], Validation Accuracy: 88.66%
Epoch [21/30], Validation Accuracy: 89.34%
Epoch [22/30], Validation Accuracy: 89.40%
Epoch [23/30], Validation Accuracy: 91.27%
Epoch [24/30], Validation Accuracy: 91.54%
Epoch [25/30], Validation Accuracy: 92.02%
Epoch [26/30], Validation Accuracy: 91.38%
Epoch [27/30], Validation Accuracy: 92.24%
Epoch [28/30], Validation Accuracy: 91.99%
Epoch [29/30], Validation Accuracy: 92.21%
Epoch [30/30], Validation Accuracy: 92.25%
Total Training Time: 16.15 minutes

Final Test Accuracy: 92.25%
