In [None]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import torch.optim as optim
from torch.optim.lr_scheduler import MultiStepLR
from tqdm import tqdm

In [None]:
class BasicBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1, downsample=None):
        super().__init__()

        self.conv1 = nn.Conv2d(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=3,
            stride=stride,
            padding=1,
            bias=False
        )

        self.bn1 = nn.BatchNorm2d(out_channels)

        self.conv2 = nn.Conv2d(
            in_channels=out_channels,
            out_channels=out_channels,
            kernel_size=3,
            stride=1,
            padding=1,
            bias=False
        )

        self.bn2 = nn.BatchNorm2d(out_channels)

        self.downsample = downsample
        self.relu = nn.ReLU(inplace=True)
    
    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            identity = self.downsample(x)
        
        out += identity
        out = self.relu(out)
        
        return out

In [None]:
class ResNet20(nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()

        self.conv1 = nn.Conv2d(
            in_channels=3,
            out_channels=16,
            kernel_size=3,
            stride=1,
            padding=1,
            bias=False
        )
        self.bn1 = nn.BatchNorm2d(16)
        self.relu = nn.ReLU(inplace=True)

        self.block1_1 = BasicBlock(16, 16)
        self.block1_2 = BasicBlock(16, 16)
        self.block1_3 = BasicBlock(16, 16)

        self.block2_1 = BasicBlock(
            in_channels=16,
            out_channels=32,
            stride=2,
            downsample=nn.Sequential(
                nn.Conv2d(16, 32, kernel_size=1, stride=2, bias=False),
                nn.BatchNorm2d(32)
            )
        )
        self.block2_2 = BasicBlock(32, 32)
        self.block2_3 = BasicBlock(32, 32)

        self.block3_1 = BasicBlock(
            in_channels=32,
            out_channels=64,
            stride=2,
            downsample=nn.Sequential(
                nn.Conv2d(32, 64, kernel_size=1, stride=2, bias=False),
                nn.BatchNorm2d(64)
            )
        )
        self.block3_2 = BasicBlock(64, 64)
        self.block3_3 = BasicBlock(64, 64)

        self.avgpooling = nn.AdaptiveAvgPool2d((1,1))
        self.fc = nn.Linear(64, num_classes)
    
    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)

        x = self.block1_1(x)
        x = self.block1_2(x)
        x = self.block1_3(x)

        x = self.block2_1(x)
        x = self.block2_2(x)
        x = self.block2_3(x)

        x = self.block3_1(x)
        x = self.block3_2(x)
        x = self.block3_3(x)

        x = self.avgpooling(x)
        x = torch.flatten(x ,1)
        x = self.fc(x)

        return x

In [None]:
def train(model, train_loader, criterion, optimizer, scheduler, device, epochs=10):
    model.train()
    for epoch in range(epochs):
        running_loss = 0.0
        correct = 0
        total = 0
        pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}")
        for images, labels in pbar:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)

            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            current_acc = 100 * correct / total
            pbar.set_postfix({'loss': running_loss / (pbar.n + 1), 'acc': f'{current_acc:.2f}%'})

        scheduler.step()
        avg_loss = running_loss / len(train_loader)
        epoch_acc = 100 * correct / total
        print(f"Epoch [{epoch+1}/{epochs}], Average Loss: {avg_loss:.4f}, Accuracy: {epoch_acc:.2f}%, Current LR: {scheduler.get_last_lr()[0]}")

def evaluate(model, test_loader, device):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in tqdm(test_loader, desc="Evaluating"):
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = 100 * correct / total
    print(f"Test Accuracy: {accuracy:.2f}%")
    return accuracy

In [None]:
transform = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])

class Args:
    epochs = 30
    batch_size = 32
    lr = 0.1
    momentum = 0.9
    weight_decay = 1e-4
    num_workers = 2
    device = "cuda"
    use_compile = False
    save_path = "resnet20_cifar10.pth"

args = Args()

train_dataset = torchvision.datasets.CIFAR10(root="./data", train=True, download=True, transform=transform)
test_dataset = torchvision.datasets.CIFAR10(root="./data", train=False, download=True, transform=transform)

train_loader = DataLoader(
    train_dataset, 
    batch_size=args.batch_size, 
    shuffle=True, 
    num_workers=args.num_workers, 
    persistent_workers=True
)
test_loader = DataLoader(
    test_dataset, 
    batch_size=args.batch_size, 
    shuffle=False, 
    num_workers=args.num_workers, 
    persistent_workers=True
)

device_name = args.device.lower()

if device_name == "cuda" and torch.cuda.is_available():
    device = torch.device("cuda")
elif device_name == "mps" and torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")

print(f"Using device: {device}")

model = ResNet20(num_classes=10).to(device)

if args.use_compile:
    if device.type == "mps":
        print("torch.compile is not supported on MPS — skipping.")
    else:
        model = torch.compile(model)

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
scheduler = MultiStepLR(optimizer, milestones=[10, 20], gamma=0.1)

print("Starting training...")
train(model, train_loader, criterion, optimizer, scheduler, device, epochs=args.epochs)

print("Starting evaluation...")
evaluate(model, test_loader, device)

print(f"Saving model to {args.save_path}")
torch.save(model, args.save_path)
print("Model saved.")