In [None]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import time

#### Setup Device

In [None]:
if torch.backends.mps.is_available():
    device = "mps"
elif torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"

print(device)

#### Load and Normalize CIFAR-10 Data

In [None]:
batch_size = 256

mean = torch.tensor([0.4914, 0.4822, 0.4465])
std = torch.tensor([0.2009, 0.2009, 0.2009])
transform = transforms.Compose([
    transforms.ToTensor(), 
    transforms.Normalize(mean = mean, std = std)])
cifar_train_data = torchvision.datasets.CIFAR10('./data', train=True, download=True, transform=transform)
cifar_test_data = torchvision.datasets.CIFAR10('./data', train=False, download=True, transform=transform)

cifar_train_dl = torch.utils.data.DataLoader(cifar_train_data, batch_size=batch_size, shuffle=True)
cifar_test_dl = torch.utils.data.DataLoader(cifar_test_data, batch_size=batch_size)

## Define Convolutional Neural Networks

### Low-Tier CNN

In [None]:
class CIFAR_Basic_CNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=32, kernel_size=(3, 3), stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=(3, 3), stride=2, padding=1),
            nn.ReLU(),
            nn.Flatten()
        )
        self.classifier = nn.Linear(in_features=64*16*16, out_features=10)
    
    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x)
        return x

### Mid-Tier CNN 

In [None]:
class CIFAR_Medium_CNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=32, kernel_size=(3, 3), stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=(3, 3), stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=(3, 3), stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(in_channels=128, out_channels=256, kernel_size=(3, 3), stride=2, padding=1),
            nn.ReLU(),
            nn.Flatten()
        )
        self.classifier = nn.Sequential(
            nn.Linear(in_features=256*8*8, out_features=256),
            nn.ReLU(),
            nn.Linear(in_features=256, out_features=10)
        )
    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x)
        return x

#### Mid-Tier CNN (with Batch Norm)

In [None]:
class CIFAR_Medium_BatchNorm_CNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=32, kernel_size=(3, 3), stride=1, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=(3, 3), stride=1, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=(3, 3), stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Conv2d(in_channels=128, out_channels=256, kernel_size=(3, 3), stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.Flatten()
        )
        self.classifier = nn.Sequential(
            nn.Linear(in_features=256*8*8, out_features=256),
            nn.ReLU(),
            nn.Linear(in_features=256, out_features=10)
        )
    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x)
        return x

#### Mid-Tier CNN (with batch norm + more linear layers)

In [None]:
class CIFAR_Medium_BatchNorm_LinearPlus_CNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=32, kernel_size=(3, 3), stride=1, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=(3, 3), stride=1, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=(3, 3), stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Conv2d(in_channels=128, out_channels=256, kernel_size=(3, 3), stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.Flatten()
        )
        self.classifier = nn.Sequential(
            nn.Linear(in_features=256*8*8, out_features=512),
            nn.ReLU(),
            nn.Linear(in_features=512, out_features=256),
            nn.ReLU(),
            nn.Linear(in_features=256, out_features=128),
            nn.ReLU(),
            nn.Linear(in_features=128, out_features=10)
        )
    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x)
        return x

#### Dataset Distillation ConvNet

In [None]:
class ConvNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, stride=1, padding=3), #32x32 -> 36x36
            nn.GroupNorm(4, 32, eps=1e-05, affine=True),
            nn.ReLU(),
            nn.AvgPool2d(kernel_size=2, stride=2, padding=0), #36x36 -> 18x18
            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1),
            nn.GroupNorm(4, 64, eps=1e-05, affine=True),
            nn.ReLU(),
            nn.AvgPool2d(kernel_size=2, stride=2, padding=0), #18x18 -> 9x9
            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1),
            nn.GroupNorm(4, 128, eps=1e-05, affine=True),
            nn.ReLU(),
            nn.AvgPool2d(kernel_size=2, stride=2, padding=0), #9x9 -> 4x4
            nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=1),
            nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, stride=1, padding=1),
            nn.GroupNorm(2, 512, eps=1e-05, affine=True),
            nn.ReLU(),
            nn.AvgPool2d(kernel_size=2, stride=2, padding=0) #4x4 -> #2x2
        )
        self.classifier = nn.Linear(in_features=128*4*4, out_features=10, bias=True)

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.shape[0], -1)
        x = self.classifier(x)
        return x

In [None]:
class ConvNet_Simple(nn.Module):
    def __init__(self):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=128, kernel_size=3, stride=1, padding=3),
            nn.GroupNorm(128, 128, eps=1e-05, affine=True),
            nn.ReLU(),
            nn.AvgPool2d(kernel_size=2, stride=2, padding=0),
            nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1),
            nn.GroupNorm(128, 128, eps=1e-05, affine=True),
            nn.ReLU(),
            nn.AvgPool2d(kernel_size=2, stride=2, padding=0),
            nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1),
            nn.GroupNorm(128, 128, eps=1e-05, affine=True),
            nn.ReLU(),
            nn.AvgPool2d(kernel_size=2, stride=2, padding=0)
        )
        self.classifier = nn.Linear(in_features=2048, out_features=10, bias=True)

    def forward(self, x):
            x = self.features(x)
            x = x.view(x.shape[0], -1)
            x = self.classifier(x)
            return x

### High-Tier CNN

In [None]:
class CIFAR_High_CNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=32, kernel_size=(3, 3), stride=1, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=(3, 3), stride=1, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=(3, 3), stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Conv2d(in_channels=128, out_channels=256, kernel_size=(3, 3), stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.Flatten()
        )
        self.classifier = nn.Sequential(
            nn.Linear(in_features=256*8*8, out_features=512),
            nn.ReLU(),
            nn.Linear(in_features=512, out_features=256),
            nn.ReLU(),
            nn.Linear(in_features=256, out_features=128),
            nn.ReLU(),
            nn.Linear(in_features=128, out_features=64),
            nn.ReLU(),
            nn.Linear(in_features=64, out_features=10)
        )
    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x)
        return x

## Train and test functions

In [None]:
def train(model, num_epochs, lr, train_dl, test_dl, optimizer="SGD", scheduler=False):
    model.to(device)

    if optimizer == "SGD":
        opt = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9)
    else:
        opt = torch.optim.Adam(model.parameters(), lr=lr)

    if scheduler:
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(opt)

    # Classification tasks should use cross entropy
    loss_function = nn.CrossEntropyLoss()

    train_start_time = time.time()
    for epoch in range(1, num_epochs + 1):
        model.train()
        epoch_start_time = time.time()
        for X, y in train_dl:
            X, y = X.to(device), y.to(device)

            predictions = model(X)
            loss = loss_function(predictions, y)

            opt.zero_grad()
            loss.backward()
            opt.step()
        
        test_accuracy = test(model, test_dl)
        epoch_total_time = time.time() - epoch_start_time
        print(f"Epoch {epoch:>03} took {epoch_total_time:.2f}s Test Acc: {test_accuracy:.4f}")
    
    total_train_time = time.time() - train_start_time
    final_test_accuracy = test(model, test_dl)
    print(f'Training for {num_epochs} complete.')
    print(f'Model achieved final test accuracy of {final_test_accuracy:.4f}')
    print(f'Training took {total_train_time:.2f} seconds.')

def test(model, test_dl):
    model.to(device)
    model.eval()

    correct_predictions = 0
    total_predictions = 0
    
    # Testing, so no gradient computation is needed
    with torch.no_grad():
        for image, label in test_dl:
            image, label = image.to(device), label.to(device)

            predictions = model(image)
            _, predictions = torch.max(predictions, 1)

            correct_predictions += (predictions==label).sum().item()
            total_predictions += predictions.shape[0]
    
    return correct_predictions / total_predictions

In [None]:
train(CIFAR_Basic_CNN(), 20, 0.05, cifar_train_dl, cifar_test_dl)

In [None]:
train(CIFAR_Medium_CNN(), 20, 0.05, cifar_train_dl, cifar_test_dl)

In [None]:
train(CIFAR_Medium_BatchNorm_CNN(), 30, 0.05, cifar_train_dl, cifar_test_dl)

In [None]:
train(CIFAR_Medium_BatchNorm_LinearPlus_CNN(), 30, 0.05, cifar_train_dl, cifar_test_dl)

In [None]:
train(CIFAR_High_CNN(), 30, 0.05, cifar_train_dl, cifar_test_dl)

In [None]:
train(ConvNet(), 30, 0.05, cifar_train_dl, cifar_test_dl)

In [None]:
train(ConvNet_Simple(), 30, 0.05, cifar_train_dl, cifar_test_dl)