In [None]:
import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import random_split, DataLoader

In [None]:
data_statistics = ((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))

train_transforms_cifar = transforms.Compose([
    transforms.RandomCrop(32, padding=4, padding_mode='reflect'), 
    transforms.RandomHorizontalFlip(), 
    transforms.ToTensor(),
    transforms.Normalize(*data_statistics, inplace=True) 
])

test_transforms_cifar = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(*data_statistics, inplace=True) 
])

dataset = torchvision.datasets.CIFAR10(root="data/", download=True, train=False, transform=train_transforms_cifar)
test_dataset = torchvision.datasets.CIFAR10(root="data/", download=True, train=False, transform=test_transforms_cifar)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to data/cifar-10-python.tar.gz


  0%|          | 0/170498071 [00:00<?, ?it/s]

Extracting data/cifar-10-python.tar.gz to data/
Files already downloaded and verified


In [None]:
val_ratio = 0.2
train_dataset, val_dataset = random_split(dataset, [int((1-val_ratio)*len(dataset)), int(val_ratio*len(dataset))])

batch_size = 32 
train_dl = DataLoader(train_dataset, batch_size, shuffle=True, pin_memory=True) 
val_dl = DataLoader(val_dataset, batch_size, pin_memory=True) 
test_dl = DataLoader(test_dataset, batch_size, pin_memory=True)

In [None]:
def get_default_device():
    return torch.device("cuda") if torch.cuda.is_available else torch.device("cpu")

def to_device(entity, device):
    if isinstance(entity, (list, tuple)):
        return [to_device(elem, device) for elem in entity]
    return entity.to(device, non_blocking=True)

class DeviceDataLoader():
    def __init__(self, dataloader, device):
        self.dl = dataloader
        self.device = device
        
    def __iter__(self):
        for b in self.dl:
            yield to_device(b, self.device)
            
    def __len__(self):
        return len(self.dl)
    
device = get_default_device()
train_dl = DeviceDataLoader(train_dl, device)
val_dl = DeviceDataLoader(val_dl, device)
test_dl = DeviceDataLoader(test_dl, device)


In [None]:
import torch.nn as nn
from collections import OrderedDict

def conv_block(in_channels, out_channels, pool=False):
    layers=[nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels), 
            nn.ReLU(inplace=True)]
    if pool:
        layers.append(nn.MaxPool2d(2))
    return nn.Sequential(*layers)


class ResnetX(nn.Module):
    def __init__(self, in_channels, num_classes):
        super().__init__()

        self.conv1 = conv_block(in_channels, 64)
        self.conv2 = conv_block(64, 128, pool=True)
        self.res1 = nn.Sequential(OrderedDict([("conv1res1", conv_block(128, 128)), ("conv2res1", conv_block(128, 128))]))

        self.conv3 = conv_block(128, 256, pool=True)
        self.conv4 = conv_block(256, 512, pool=True)
        self.res2 = nn.Sequential(conv_block(512, 512), conv_block(512, 512))
    
        self.classifier = nn.Sequential(nn.MaxPool2d(4),
                                        nn.Flatten(), 
                                        nn.Dropout(0.2),
                                        nn.Linear(512, num_classes))
    
    def forward(self, x):
        out = self.conv1(x)
        out = self.conv2(out)
        out = self.res1(out) + out
        out = self.conv3(out)
        out = self.conv4(out)
        out = self.res2(out) + out
        return self.classifier(out)
        

In [None]:
model = ResnetX(3, 10)
model
[x.shape for x in model.res1.parameters()]

[torch.Size([128, 128, 3, 3]),
 torch.Size([128]),
 torch.Size([128]),
 torch.Size([128]),
 torch.Size([128, 128, 3, 3]),
 torch.Size([128]),
 torch.Size([128]),
 torch.Size([128])]

In [None]:
def accuracy(logits, labels):
    pred, predClassId = torch.max(logits, dim=1) 
    return torch.tensor(torch.sum(predClassId==labels).item() / len(logits))

def evaluate(model, dl, loss_func):
    model.eval()
    batch_losses, batch_accs = [], []
    for images, labels in dl:
        with torch.no_grad():
            logits = model(images)
        batch_losses.append(loss_func(logits, labels))
        batch_accs.append(accuracy(logits, labels))
    epoch_avg_loss = torch.stack(batch_losses).mean().item()
    epoch_avg_acc = torch.stack(batch_accs).mean()
    return epoch_avg_loss, epoch_avg_acc

def train(model, train_dl, val_dl, epochs, max_lr, loss_func, optim):
    optimizer = optim(model.parameters(), max_lr)
    scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr, epochs*len(train_dl))

    results = []
    lrs = []
    for epoch in range(epochs):
        model.train()
        train_losses = []
        lrs = []
        for images, labels in train_dl:
            logits = model(images)
            loss = loss_func(logits, labels)
            train_losses.append(loss)
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            lrs.append(optimizer.param_groups[0]["lr"])
            scheduler.step()
        epoch_train_loss = torch.stack(train_losses).mean().item()
            
        epoch_avg_loss, epoch_avg_acc = evaluate(model, val_dl, loss_func) 
        results.append({"avg_valid_loss": epoch_avg_loss, "avg_valid_acc": epoch_avg_acc, "avg_train_loss": epoch_train_loss, "lr": lrs})
    
    return results

In [None]:
model = to_device(model, device)
epochs = 50
max_lr = 1e-2
loss_func = nn.functional.cross_entropy
optim = torch.optim.Adam

In [None]:
results = train(model, train_dl, val_dl, epochs, max_lr, loss_func, optim)

In [None]:
for result in results:
  print(result["avg_valid_acc"])

In [None]:
def plot(results, pairs):
    import matplotlib.pyplot as plt
    fig, axes = plt.subplots(len(pairs), figsize=(10,10))
    for i, pair in enumerate(pairs):
      for title, graphs in pair.items():
        axes[i].set_title = title
        axes[i].legend = graphs
        for graph in graphs:
            axes[i].plot([result[graph] for result in results], '-x')

plot(results, [{"Accuracies vs epochs": ["avg_valid_acc"]}, {"Losses vs epochs":["avg_valid_loss", "avg_train_loss"]}, {"Learning rates vs batches": ["lr"]}])

In [None]:
_, test_acc = evaluate(model, test_dl, loss_func) 
print(test_acc)

torch.save(model.state_dict(), "cifar10.pth") 
model2 = to_device(ResnetX(3, 10), device)
_, test_acc = evaluate(model2, test_dl, loss_func) 
print(test_acc)

model2.load_state_dict(torch.load("cifar10.pth"))
_, test_acc = evaluate(model2, test_dl, loss_func) 
print(test_acc)