In [1]:
from tqdm import tqdm
import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms
import numpy as np
import pandas as pd
from helper_functions import accuracy_fn
from torchsummary import summary
from safetensors.torch import save_model
import matplotlib.pyplot as plt
import cv2

In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cuda'

In [3]:
data_transforms = transforms.Compose([
    transforms.ToTensor()
])

In [4]:
train_data = datasets.MNIST(
    root="data",
    train=True,
    download=True,
    transform=data_transforms
)

test_data = datasets.MNIST(
    root="data",
    train=False,
    download=True,
    transform=data_transforms
)

In [5]:
batch_size = 16

# put custom dataset to dataloader
train_dl = DataLoader(train_data, batch_size=batch_size, shuffle=True)
test_dl = DataLoader(test_data, batch_size=batch_size, shuffle=True)

In [6]:
# test if data can be inserted into dataloader
x, y = next(iter(train_dl))

### Model Creation

A Capsule Network (CapsNet) is a model that tries to improve froms the cons of a normal CNN pipeline by avoiding max pooling which make it retain as much information as it can from a given image because max pooling creates a lossy pipeline

It has its own custom activation function, optimizer and loss:
- Activation Function: Squashing Function
- Optimizer: Routing
- Loss: Margin loss

The squashing function is defined as the following equation $$v_j=\frac{||s_j||^2}{1+||s_j||^2}\frac{s_j}{||s_j||}$$
where $v_j$ is the vector output from capsule $j$ with $s_j$ as its summed input

In [7]:
# squashing funcction
class Squash(nn.Module):
    def __init__(self, epsilon=1e-8):
        super().__init__()
        self.epsilon = epsilon # an epsilon is necessary to prevent the calculation dividing with zero
    
    def forward(self, x):
        sj_squared = (x**2).sum(dim=-1, keepdims=True)
        return (sj_squared / (1+ sj_squared)) * (x / torch.sqrt(sj_squared + self.epsilon))

In [8]:
s = torch.ones((1, 256, 256), dtype=torch.uint8)
sq = Squash()
sq(s)

tensor([[[0.0623, 0.0623, 0.0623,  ..., 0.0623, 0.0623, 0.0623],
         [0.0623, 0.0623, 0.0623,  ..., 0.0623, 0.0623, 0.0623],
         [0.0623, 0.0623, 0.0623,  ..., 0.0623, 0.0623, 0.0623],
         ...,
         [0.0623, 0.0623, 0.0623,  ..., 0.0623, 0.0623, 0.0623],
         [0.0623, 0.0623, 0.0623,  ..., 0.0623, 0.0623, 0.0623],
         [0.0623, 0.0623, 0.0623,  ..., 0.0623, 0.0623, 0.0623]]])

In [None]:
class CapsNet(nn.Module):
    def __init__(self):
        super().__init__()

        self.conv1 = nn.Conv2d(1, 256, 9)
        self.conv2 = nn.Conv2d(256, 32*8, 9, 2)
        self.squash = Squash()
    
    def forward(self, x):
        x = nn.ReLU(self.conv1(x))

In [None]:
model = CapsNet().to(device)

### Model Training

In [253]:
torch.manual_seed(20)

epochs = 5

for epoch in tqdm(range(epochs)):
    print(f"Epoch: {epoch}\n------")
    
    # TRAINING
    train_loss, train_acc = 0, 0
    model.train()
    for batch, (X, y) in enumerate(train_dl):

        X, y = X.to(device), y.to(device)

        # forward pass
        train_pred = model(X)

        # metrics
        loss = loss_fn(train_pred, y)
        train_loss += loss
        train_acc += accuracy_fn(y_true=y, y_pred=train_pred.argmax(dim=1))

        # backprop
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
    # print metrics
    train_loss /= len(train_dl)
    train_acc /= len(train_dl)
    print(f"Train Loss: {train_loss:.4f} | Train Accuracy: {train_acc:.2f}%")

    
    # TESTING
    test_loss, test_acc = 0, 0
    model.eval()
    with torch.inference_mode():
        for X, y in test_dl:

            X, y = X.to(device), y.to(device)


            # forward pass
            test_pred = model(X)

            # metrics
            test_loss += loss_fn(test_pred, y)
            test_acc += accuracy_fn(y_true=y, y_pred=test_pred.argmax(dim=1))
        
        # print metrics
        test_loss /= len(test_dl)
        test_acc /= len(test_dl)
        print(f"Test Loss: {test_loss:.4f} | Test Accuracy: {test_acc:.2f}%")

  0%|          | 0/5 [00:00<?, ?it/s]

Epoch: 0
------
Train Loss: 0.0825 | Train Accuracy: 98.14%


 20%|██        | 1/5 [01:18<05:14, 78.68s/it]

Test Loss: 0.0545 | Test Accuracy: 98.71%
Epoch: 1
------
Train Loss: 0.0693 | Train Accuracy: 98.35%


 40%|████      | 2/5 [02:36<03:54, 78.30s/it]

Test Loss: 0.0501 | Test Accuracy: 98.64%
Epoch: 2
------
Train Loss: 0.0622 | Train Accuracy: 98.42%


 60%|██████    | 3/5 [03:56<02:37, 78.77s/it]

Test Loss: 0.0450 | Test Accuracy: 98.75%
Epoch: 3
------
Train Loss: 0.0586 | Train Accuracy: 98.53%


 80%|████████  | 4/5 [05:17<01:19, 79.69s/it]

Test Loss: 0.0430 | Test Accuracy: 98.77%
Epoch: 4
------
Train Loss: 0.0497 | Train Accuracy: 98.69%


100%|██████████| 5/5 [06:33<00:00, 78.71s/it]

Test Loss: 0.0407 | Test Accuracy: 98.79%





In [254]:
save_model(model, "resnet_mnist.safetensors")