In [1]:
from matplotlib import pyplot as plt
import numpy as np
import collections

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from torchvision import datasets, transforms

import datetime
import random

torch.set_printoptions(edgeitems=2)
torch.manual_seed(123)

<torch._C.Generator at 0x7f50701c4e90>

In [127]:
class_names = ['airplane','automobile','bird','cat','deer',
               'dog','frog','horse','ship','truck']

data_path = '../data-unversioned/p1ch6/'
cifar10 = datasets.CIFAR10(
    data_path, train=True, download=True,
    transform=transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.RandomCrop(32, padding=4),
        transforms.ToTensor(),
        transforms.Normalize((0.4915, 0.4823, 0.4468),
                             (0.2470, 0.2435, 0.2616))
    ]))

cifar10_val = datasets.CIFAR10(
    data_path, train=False, download=True,
    transform=transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.RandomCrop(32, padding=4),
        transforms.ToTensor(),
        transforms.Normalize((0.4915, 0.4823, 0.4468),
                             (0.2470, 0.2435, 0.2616))
    ]))

Files already downloaded and verified
Files already downloaded and verified


In [128]:
label_map = {0: 0, 2: 1}
class_names = ['airplane', 'bird']
cifar2 = [(img, label_map[label])
          for img, label in cifar10
          if label in [0, 2]]
cifar2_val = [(img, label_map[label])
              for img, label in cifar10_val
              if label in [0, 2]]


In [134]:
class ResBlock(nn.Module):
    def __init__(self, n_chans, widen_factor):
        super(ResBlock, self).__init__()
        self.conv1 = nn.Conv2d(n_chans * widen_factor, n_chans * widen_factor, kernel_size=3,
                               padding=1, bias=False)
        self.batch_norm1 = nn.BatchNorm2d(num_features=n_chans * widen_factor)
        self.conv2 = nn.Conv2d(n_chans * widen_factor, n_chans * widen_factor, kernel_size=3,
                               padding=1, bias=False)
        self.batch_norm2 = nn.BatchNorm2d(num_features=n_chans * widen_factor)
        self.relu = nn.ReLU(inplace=True)

        torch.nn.init.kaiming_normal_(self.conv1.weight,
                                      nonlinearity='relu')
        torch.nn.init.kaiming_normal_(self.conv2.weight,
                                      nonlinearity='relu')
        torch.nn.init.constant_(self.batch_norm1.weight, 0.5)
        torch.nn.init.constant_(self.batch_norm2.weight, 0.5)
        torch.nn.init.zeros_(self.batch_norm1.bias)
        torch.nn.init.zeros_(self.batch_norm2.bias)

    def forward(self, x):
        out = self.batch_norm1(x)
        out = self.relu(out)
        out = self.conv1(out)
        out = self.batch_norm2(out)
        out = self.relu(out)
        out = self.conv2(out)
        out += x
        return out


class NetResDeep(nn.Module):
    def __init__(self, width, n_blocks):
        super(NetResDeep, self).__init__()
        self.width = width
        self.conv1 = nn.Conv2d(3, 16 * width, kernel_size=3, padding=1)
        self.batch_norm1 = nn.BatchNorm2d(num_features=16 * width)
        resblocks = []
        for _ in range(n_blocks):
            resblocks.append(ResBlock(16, widen_factor=width))
        self.resblocks = nn.Sequential(*resblocks)
        self.fc1 = nn.Linear(8 * 8 * 16 * width, 500)
        self.fc2 = nn.Linear(500, 2)

        torch.nn.init.kaiming_normal_(self.conv1.weight,
                                      nonlinearity='relu')
        torch.nn.init.kaiming_normal_(self.fc1.weight,
                                      nonlinearity='relu')
        torch.nn.init.constant_(self.batch_norm1.weight, 0.5)
        torch.nn.init.zeros_(self.batch_norm1.bias)

    def forward(self, x):
        out = self.conv1(x)  # Eingabe durch Conv1 leiten
        out = self.batch_norm1(out)  # Batch-Normalization auf Conv1-Ausgabe anwenden
        out = F.relu(out)  # ReLU-Aktivierung auf Conv1-Ausgabe anwenden
        out = F.avg_pool2d(out, 2)  # Average-Pooling auf Conv1-Ausgabe mit Fenstergröße 2x2 anwenden
        out = self.resblocks(out)  # Ausgabe von Max-Pooling durch die Sequenz der Residual-Blöcke leiten
        # out = self.batch_norm1(out)  # Batch-Normalization auf Conv1-Ausgabe anwenden
        # out = F.relu(out)  # ReLU-Aktivierung auf Conv1-Ausgabe anwenden
        out = F.avg_pool2d(out, 2)
        out = out.view(-1, 8 * 8 * 16 * self.width)  # Ausgabe der letzten Max-Pooling-Schicht in einen Vektor umformen
        out = F.relu(self.fc1(out))  # Ausgabe der linearen Schicht fc1 durch ReLU-Aktivierung leiten
        out = self.fc2(out)  # Ausgabe der linearen Schicht fc2 erhalten
        return out


In [130]:
# define the training loop
def training_loop(n_epochs, optimizer, model, loss_fn, train_loader):
    for epoch in range(1, n_epochs + 1):
        loss_train = 0.0

        for imgs, labels in train_loader:
            imgs = imgs.to(device=device)  # <1>
            labels = labels.to(device=device)
            outputs = model(imgs)
            outputs = outputs[:len(labels)]
            loss = loss_fn(outputs, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            loss_train += loss.item()

        if epoch == 1 or epoch % 10 == 0:
            print('{} Epoch {}, Training loss {}'.format(
                datetime.datetime.now(), epoch,
                loss_train / len(train_loader)))

In [131]:
# validates the accuracy of a given model on training and validation data
# and stores the results in a dictionary
def validate(model, train_loader, val_loader):
    accdict = {}
    for name, loader in [("train", train_loader), ("val", val_loader)]:
        correct = 0
        total = 0

        with torch.no_grad():
            for imgs, labels in loader:
                imgs = imgs.to(device=device)  # <1>
                labels = labels.to(device=device)
                outputs = model(imgs)
                outputs = outputs[:len(labels)]
                _, predicted = torch.max(outputs, dim=1) # <1>
                total += labels.shape[0]
                correct += int((predicted == labels).sum())

        print("Accuracy {}: {:.2f}".format(name , correct / total))
        accdict[name] = correct / total
    return accdict

In [132]:
device = torch.device('cuda')

In [135]:
train_loader = torch.utils.data.DataLoader(cifar2, batch_size=64, shuffle=True)
val_loader = torch.utils.data.DataLoader(cifar2_val, batch_size=64, shuffle=False)
all_acc_dict = collections.OrderedDict()

model = NetResDeep(width=2, n_blocks=14).to(device=device)
optimizer = optim.SGD(model.parameters(), lr=3e-3)
loss_fn = nn.CrossEntropyLoss()

training_loop(
    n_epochs = 20,
    optimizer = optimizer,
    model = model,
    loss_fn = loss_fn,
    train_loader = train_loader,
)
all_acc_dict["res deep"] = validate(model, train_loader, val_loader)
print("Number of free parameters in the model:",sum(p.numel() for p in model.parameters()))


2023-06-14 14:58:30.118135 Epoch 1, Training loss 0.5484911881055042
2023-06-14 14:58:54.086111 Epoch 10, Training loss 0.28675311167908324
2023-06-14 14:59:20.708086 Epoch 20, Training loss 0.18180611952664746
Accuracy train: 0.93
Accuracy val: 0.85
Number of free parameters in the model: 1286302
