In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchinfo import summary
import matplotlib.pyplot as plt

In [None]:
batch_size = 16

train_data = datasets.CIFAR10(root='./data', train=True, download=True, transform=transforms.ToTensor())
test_data = datasets.CIFAR10(root='./data', train=False, download=True, transform=transforms.ToTensor())

train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True, drop_last=True)
test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=True, drop_last=True)

In [None]:
fig = plt.figure()
for i in range(9):
    ax = fig.add_subplot(3, 3, i+1)
    # permute changes the order of the dimensions
    ax.imshow(train_data[i][0].permute(1, 2, 0))
    ax.set_title(train_data[i][1])

fig.show()

train_data[0][0].shape,train_data[0][0].max(), train_data[0][0].min(), next(iter(train_loader))[1]

In [None]:
class ResBlock(nn.Module):

    def __init__(self, in_channel, out_channel, stride=1) -> None:

        super(ResBlock, self).__init__()
        self.layer = nn.Sequential(
            nn.Conv2d(in_channel, out_channel, kernel_size=3, stride=stride, padding=1),
            nn.BatchNorm2d(out_channel),
            nn.ReLU(),
            nn.Conv2d(out_channel, out_channel, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(out_channel),
        )

        self.shortcut = nn.Sequential()
        if in_channel != out_channel or stride > 1:
            self.shortcut = nn.Sequential(nn.Conv2d(in_channel, out_channel, kernel_size=3, stride=stride, padding=1),
                                          nn.BatchNorm2d(out_channel))

    def forward(self, x):
        out1 = self.layer(x)
        out2 = self.shortcut(x)
        out = out1 + out2

        return F.relu(out)


class ResNet(nn.Module):

    def __init__(self, ResBlock) -> None:
        super(ResNet, self).__init__()
        self.in_channel = 32
        self.conv1 = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
        )

        self.layer1 = self.make_layer(ResBlock, 64, 2, 2)
        self.layer2 = self.make_layer(ResBlock, 128, 2, 2)
        self.layer3 = self.make_layer(ResBlock, 256, 2, 2)
        self.layer4 = self.make_layer(ResBlock, 512, 2, 2)

        self.fc = nn.Linear(512, 10)

    def make_layer(self, block, out_channel, stride, num_block):
        layers = []
        for i in range(num_block):
            if i == 0:
                in_stride = stride
            else:
                in_stride = 1
            layers.append(block(self.in_channel, out_channel, in_stride))
            self.in_channel = out_channel
        return nn.Sequential(*layers)

    def forward(self, x):
        out = self.conv1(x)
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = F.avg_pool2d(out, 2)
        out = out.view(out.size(0), -1)
        out = self.fc(out)
        return out


def resnet():
    return ResNet(ResBlock)


In [None]:
# model
model = resnet()

# device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

# summary
summary(model, input_size=(batch_size, 3, 32, 32), device=device)

In [None]:
# loss
criterion = nn.CrossEntropyLoss()

# optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)



# tensorboard
writer = SummaryWriter('runs/cifar10')

# training
epochs = 10
step = 0
for epoch in range(epochs):
    for i, data in enumerate(train_loader, 0):
        running_correct = 0
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        # statistics
        _, preds = torch.max(outputs, 1)
        running_correct += torch.sum(preds == labels.data)
        if i+1 % 10 == 0:
            with torch.no_grad():
                step += 1
                writer.add_scalar('training loss', loss, step)
                writer.add_scalar('accuracy', running_correct / batch_size, step)
                print(f'Epoch {epoch + 1} / {epochs}, Training loss: {loss:.4f}, Accuracy: {running_correct / batch_size:.4f}')

    # test accuracy
    with torch.no_grad():
        correct = 0
        total = 0
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).cpu().sum().item()
        print(f'epoch: {epoch+1}, Accuracy of the network on the 10000 test images: {100 * correct / total:.4f}')