In [13]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [14]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

Using device: cuda


In [15]:
class Perceptron(nn.Module):
  def __init__(self, input_size):
    super(Perceptron, self).__init__()
    self.w = nn.Parameter(torch.randn(input_size))
    self.b = nn.Parameter(torch.randn(1))

  def forward(self, x):
    return F.relu(x @ self.w + self.b).squeeze(-1)

In [16]:
class DenseLayer(nn.Module):
  def __init__(self, input_size, output_size):
    super(DenseLayer, self).__init__()
    self.layer = nn.ModuleList([Perceptron(input_size) for i in range(output_size)])

  def forward(self, x):
    return torch.stack([p(x) for p in self.layer], dim=1)

In [17]:
class ConvLayer(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size

        self.weight = nn.Parameter(
            torch.randn(out_channels, in_channels, kernel_size, kernel_size)
        )
        self.bias = nn.Parameter(torch.zeros(out_channels))

    def forward(self, x):
      batch, C, H, W = x.shape
      k = self.kernel_size

      H_out = H - k + 1
      W_out = W - k + 1

      y = torch.zeros(batch, self.out_channels, H_out, W_out, device=x.device)

      for b in range(batch):
        for out_ch in range(self.out_channels):
          for i in range(H_out):
            for j in range(W_out):
              val = 0
              for in_ch in range(self.in_channels):
                sub = x[b, in_ch, i:i+k, j:j+k]
                val += torch.sum(sub * self.weight[out_ch, in_ch])
                y[b, out_ch, i, j] = val + self.bias[out_ch]
      return y


In [18]:
class MaxPool(nn.Module):
    def __init__(self, kernel_size, stride=None):
        super().__init__()
        self.kernel_size = kernel_size
        if stride == None:
          self.stride = kernel_size
        else:
          self.stride = stride

    def forward(self, x):
        batch, C, H, W = x.shape
        k = self.kernel_size
        s = self.stride

        H_out = (H - k) // s + 1
        W_out = (W - k) // s + 1

        y = torch.zeros(batch, C, H_out, W_out, device = x.device)

        for b in range(batch):
          for c in range(C):
            for i in range(H_out):
              for j in range(W_out):
                  sub = x[b, c, i*s:i*s+k, j*s:j*s+k]
                  y[b, c, i, j] = torch.max(sub)

        return y


In [19]:
class Model(nn.Module):
  def __init__(self):
    super().__init__()

    self.conv1 = ConvLayer(in_channels = 1, out_channels = 12, kernel_size = 3) # (28 - 3)/1 + 1 = 26. Output: (12, 26, 26)
    self.pool = MaxPool(2, 2)  # (12, 13, 13)
    self.conv2 = ConvLayer(in_channels = 12, out_channels = 24, kernel_size = 4)  # (13-4) / 1 + 1 = 10. Output: (24, 10, 10)

    self.fc1 = DenseLayer(24*5*5, 256)
    self.fc2 = DenseLayer(256, 64)
    self.fc3 = DenseLayer(64, 10)

  def forward(self, x):
    x = self.conv1(x)
    x = F.relu(x)
    x = self.pool(x)

    x = self.conv2(x)
    x = F.relu(x)
    x = self.pool(x)

    x = torch.flatten(x, start_dim = 1)

    x = self.fc1(x)
    x = self.fc2(x)
    x = self.fc3(x)

    return x

In [20]:
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
from torchvision import transforms

from torch.utils.data import random_split

transform_modified = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

dataset = MNIST(root = 'data/', download = True, transform = transform_modified, train=True)

train_dataset, test_dataset = random_split(
    dataset,
    [50000, 10000],
    generator=torch.Generator().manual_seed(10)
)

batch_size = 128
train_loader = DataLoader(train_dataset, batch_size, shuffle = True)
test_loader = DataLoader(test_dataset, batch_size, shuffle = False)

In [21]:
mnist_model = Model().to(device)

In [22]:
loss_fn = nn.CrossEntropyLoss()

optimiser = torch.optim.AdamW(mnist_model.parameters(), lr=1e-3)

In [23]:
from tqdm import tqdm

for epoch in range(20):
    mnist_model.train()
    loss_sum = 0

    for images, labels in tqdm(train_loader):
      images = images.to(device)
      labels = labels.to(device)

      output = mnist_model(images)
      loss = loss_fn(output, labels)

      optimiser.zero_grad()
      loss.backward()
      optimiser.step()
      loss_sum += loss.item()

    print(loss_sum / len(train_loader))


  0%|          | 0/391 [02:10<?, ?it/s]


KeyboardInterrupt: 

In [None]:
mnist_model.eval()
correct = 0
total = 0

with torch.no_grad():
  for images, labels in test_loader:
    images = images.to(device)
    labels = labels.to(device)

    outputs = mnist_model(images)
    preds = outputs.argmax(dim=1)

    correct += (preds == labels).sum().item()
    total += labels.size(0)

val_accuracy = correct / total
print("Validation Accuracy:", val_accuracy)