In [None]:
import numpy as np
import torch.nn as nn
import torch.nn.functional as func
import torch
from math import floor
from torchvision import datasets
import torchvision.transforms as transforms
from torch import optim
import time
import matplotlib
from matplotlib import pyplot as plt

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cuda


In [None]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

testloader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=True)
trainloader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)

100%|██████████| 9.91M/9.91M [00:00<00:00, 18.3MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 495kB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 4.70MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 12.0MB/s]


In [None]:
class Convolution(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_shape, padding=0, dilation=1, bias=True, stride=1):
        super().__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_shape = kernel_shape
        self.padding = padding
        self.dilation = dilation
        self.stride = stride

        if isinstance(self.padding, (int, float)):
            self.padding = (self.padding, self.padding)
        if isinstance(self.stride, (int, float)):
            self.stride = (self.stride, self.stride)
        if isinstance(self.kernel_shape, (int, float)):
            self.kernel_shape = (self.kernel_shape, self.kernel_shape)

        self.weights = nn.Parameter(torch.randn(out_channels, in_channels, *self.kernel_shape))
        self.bias = None
        if bias:
            self.bias = nn.Parameter(torch.randn(out_channels))

    def forward(self, x):
        # x shape: [batch_size, in_channels, h, w]
        windows = func.unfold(x, self.kernel_shape, dilation=self.dilation, padding=self.padding, stride=self.stride)
        output = torch.matmul(self.weights.view(self.out_channels, -1), windows)
        h, w = x.shape[2:]
        output_height = floor((h + 2*self.padding[0] - self.dilation * (self.kernel_shape[0] - 1) - 1) / self.stride[0] + 1)
        output_width = floor((w + 2*self.padding[1] - self.dilation * (self.kernel_shape[1] - 1) - 1) / self.stride[1] + 1)
        output = func.fold(output, kernel_size=(1, 1), output_size=(output_height, output_width))
        if self.bias is not None:
            output += self.bias.view(1, self.out_channels, 1, 1)
        # print(output.shape)
        return output

In [None]:
class MaxPool(nn.Module):
    def __init__(self, channels, kernel_shape, padding=0, stride=1):
        super().__init__()
        self.channels = channels
        self.kernel_shape = kernel_shape
        self.padding = padding
        self.dilation = 1
        self.stride = stride
        if isinstance(self.padding, (int, float)):
            self.padding = (self.padding, self.padding)
        if isinstance(self.stride, (int, float)):
            self.stride = (self.stride, self.stride)
        if isinstance(self.kernel_shape, (int, float)):
            self.kernel_shape = (self.kernel_shape, self.kernel_shape)

    def forward(self, x):
        h, w = x.shape[2:]
        windows = func.unfold(x, self.kernel_shape, padding=self.padding, stride=self.stride)
        output, _ = torch.max(windows.view(windows.shape[0], self.channels, windows.shape[1]//self.channels, windows.shape[-1]), dim=2, keepdim=False)
        output_height = floor(
            (h + 2 * self.padding[0] - self.dilation * (self.kernel_shape[0] - 1) - 1) / self.stride[0] + 1)
        output_width = floor(
            (w + 2 * self.padding[1] - self.dilation * (self.kernel_shape[1] - 1) - 1) / self.stride[1] + 1)
        output = func.fold(output, kernel_size=(1, 1), output_size=(output_height, output_width))
        # print(output.shape)
        return output

In [None]:
class ConvNet2(nn.Module):
    def __init__(self):
        super().__init__()
        self.layers = nn.Sequential(
            Convolution(1, 16, 5, stride=1, padding=2, bias=True),
            nn.ReLU(),
            MaxPool(16, kernel_shape=2, stride=2),
            Convolution(16, 32, 3, stride=1, padding=1, bias=True),
            nn.ReLU(),
            MaxPool(32, kernel_shape=2, stride=2),
            nn.Flatten(),
            nn.Linear(32 * 7 * 7, 32, bias=True),
            nn.BatchNorm1d(32),
            nn.ReLU(),
            nn.Linear(32, 10, bias=True)
        )
        self.softmax = nn.Softmax(dim=1)

    def forward(self, X):
        logits = self.layers(X)
        return logits

    def predict(self, X):
        return torch.argmax(self.softmax(self.forward(X)), dim=1)

In [None]:
def train(model, name, optimizer, loss_fn, epochs=10, device="cpu"):
    EPOCHS = epochs
    val_best = 0
    train_loss, val_acc = [], []
    for epoch in range(EPOCHS):
        running_loss = 0.0
        for i, batch_data in enumerate(trainloader, 0):
            inputs, labels = batch_data
            inputs = inputs.to(device)
            labels = labels.to(device)

            optimizer.zero_grad()
            outputs = model(inputs)
            loss = loss_fn(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            if ((i+1)%100 == 0):
                print("       ", i+1, "Train batch")
        train_loss.append(running_loss/len(trainloader))
        with torch.inference_mode():
            corr = overall = 0
            for i, batch in enumerate(testloader):
                images, truth = batch
                images = images.to(device)
                truth = truth.to(device)
                output = model.predict(images)
                corr += (output == truth).sum().item()
                overall += len(truth)
            print(f"Epoch {epoch+1} Test Accuracy: {corr/overall}")
            val_acc.append(corr/overall)
            if corr/overall > val_best:
                val_best = corr/overall
                torch.save(model.state_dict(), f'/content/{name}.pth')
                print('Saved new best')
        print(f"Epoch: {epoch+1}, loss: {running_loss/len(trainloader)}\n")
    return train_loss, val_acc

In [None]:
model = ConvNet2()
model.to(device)
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters())

print(len(trainloader))
train_loss, val_acc = train(model, 'conv2', epochs=20, optimizer=optimizer, loss_fn=loss_fn, device=device)

938
        100 Train batch
        200 Train batch
        300 Train batch
        400 Train batch
        500 Train batch
        600 Train batch
        700 Train batch
        800 Train batch
        900 Train batch
Epoch 1 Test Accuracy: 0.9824
Saved new best
Epoch: 1, loss: 0.23272685066008492

        100 Train batch
        200 Train batch
        300 Train batch
        400 Train batch
        500 Train batch
        600 Train batch
        700 Train batch
        800 Train batch
        900 Train batch
Epoch 2 Test Accuracy: 0.9841
Saved new best
Epoch: 2, loss: 0.05144226897472162

        100 Train batch
        200 Train batch
        300 Train batch
        400 Train batch
        500 Train batch
        600 Train batch
        700 Train batch
        800 Train batch
        900 Train batch
Epoch 3 Test Accuracy: 0.9861
Saved new best
Epoch: 3, loss: 0.03429233625658782

        100 Train batch
        200 Train batch
        300 Train batch
        400 Train batch
      

In [None]:
plt.figure(1)
plt.plot(range(len(train_loss)), train_loss)
plt.title("Trian loss over epochs")
plt.figure(2)
plt.plot(range(len(val_acc)), val_acc)
plt.title("Val Accuracy over epochs")
plt.show()

In [None]:
"""
How the model would be implemented using nn.Conv2D and nn.MaxPool2D
"""
class ConvNet2_speedup(nn.Module):
    def __init__(self):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Conv2d(1, 16, 5, stride=1, padding=2, bias=True),
            nn.ReLU(),
            nn.MaxPool2d(2, stride=2),
            nn.Conv2d(16, 32, 3, stride=1, padding=1, bias=True),
            nn.ReLU(),
            nn.MaxPool2d(2, stride=2),
            nn.Flatten(),
            nn.Linear(32 * 7 * 7, 32, bias=True),
            nn.BatchNorm1d(32),
            nn.ReLU(),
            nn.Linear(32, 10, bias=True)
        )
        self.softmax = nn.Softmax(dim=1)

    def forward(self, X):
        logits = self.layers(X)
        return logits

    def predict(self, X):
        return torch.argmax(self.softmax(self.forward(X)), dim=1)


Benchmarking it against our implementation shows that this version takes 45 seconds per epoch, against 108 seconds by our implementation