In [12]:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

In [13]:
class DonutConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.stride = stride
        self.dilation = dilation
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding=0, dilation=dilation)

        self.pad_size = ((kernel_size - 1) * dilation) // 2

    def forward(self, x):
        x = torch.cat([x, x[:, :, :, :self.pad_size]], dim=3)
        x = torch.cat([x[:, :, :, -self.pad_size:], x], dim=3)
        x = torch.cat([x, x[:, :, :self.pad_size, :]], dim=2)
        x = torch.cat([x[:, :, -self.pad_size:, :], x], dim=2)
        return self.conv(x)

class DonutCNN(nn.Module):
    def __init__(self, img_size=28, num_classes=10):
        super().__init__()

        self.img_size = img_size
        self.num_classes = num_classes

        self.conv1 = DonutConv2d(1, 6, kernel_size=5) 
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = DonutConv2d(6, 16, kernel_size=5)
        self.fc1 = nn.Linear(16 * ((img_size // 4) ** 2), 128)
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, num_classes)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * ((self.img_size // 4) ** 2))
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x
    
class CNN(nn.Module):
    def __init__(self, img_size=28, num_classes=10):
        super().__init__()

        self.img_size = img_size
        self.num_classes = num_classes

        self.conv1 = nn.Conv2d(1, 6, kernel_size=5, padding='same') 
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, kernel_size=5, padding='same')
        self.fc1 = nn.Linear(16 * ((img_size // 4) ** 2), 128)
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, num_classes)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * ((self.img_size // 4) ** 2))
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [14]:
def train(net, trainloader, optimizer, criterion, epochs=5):
    net.train()
    for epoch in range(epochs):
        running_loss = 0.0
        for i, data in enumerate(trainloader, 0):
            inputs, labels = data
            optimizer.zero_grad()
            outputs = net(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
            if i % 200 == 199:
                print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 200:.3f}')
                running_loss = 0.0
    print('Finished Training')

def evaluate(net, testloader):
    net.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data in testloader:
            images, labels = data
            outputs = net(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    print(f'Accuracy of the network on the 10000 test images: {100 * correct // total} %')

In [15]:
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
trainset = torchvision.datasets.MNIST(root='../data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=32, shuffle=True)
testset = torchvision.datasets.MNIST(root='../data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=32, shuffle=False)

net = DonutCNN()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

train(net, trainloader, optimizer, criterion)
evaluate(net, testloader)

[1,   200] loss: 2.297
[1,   400] loss: 2.277
[1,   600] loss: 2.187
[1,   800] loss: 1.310
[1,  1000] loss: 0.607
[1,  1200] loss: 0.444
[1,  1400] loss: 0.326
[1,  1600] loss: 0.298
[1,  1800] loss: 0.259
[2,   200] loss: 0.226
[2,   400] loss: 0.216
[2,   600] loss: 0.197
[2,   800] loss: 0.183
[2,  1000] loss: 0.170
[2,  1200] loss: 0.150
[2,  1400] loss: 0.141
[2,  1600] loss: 0.137
[2,  1800] loss: 0.136
[3,   200] loss: 0.122
[3,   400] loss: 0.119
[3,   600] loss: 0.101
[3,   800] loss: 0.103
[3,  1000] loss: 0.102
[3,  1200] loss: 0.099
[3,  1400] loss: 0.108
[3,  1600] loss: 0.089
[3,  1800] loss: 0.085
[4,   200] loss: 0.079
[4,   400] loss: 0.075
[4,   600] loss: 0.089
[4,   800] loss: 0.073
[4,  1000] loss: 0.079
[4,  1200] loss: 0.075
[4,  1400] loss: 0.065
[4,  1600] loss: 0.073
[4,  1800] loss: 0.081
[5,   200] loss: 0.059
[5,   400] loss: 0.067
[5,   600] loss: 0.057
[5,   800] loss: 0.064
[5,  1000] loss: 0.071
[5,  1200] loss: 0.064
[5,  1400] loss: 0.061
[5,  1600] 