# CNN from scratch

In [10]:
# show version and all available gpu devices
import torch
print(torch.__version__)
if torch.cuda.is_available():
    for idx in range(torch.cuda.device_count()):
        print(torch.cuda.get_device_name(idx))
else:
    print('No GPU available')

2.5.1
NVIDIA GeForce GTX 1050


In [11]:
from torch import nn


class MyConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding='same', dilation=1, groups=1, bias=True):
        super(MyConv2d, self).__init__()        
        self.stride = (stride, stride) if isinstance(stride, int) else tuple(stride)
        self.dilation = (dilation, dilation) if isinstance(dilation, int) else tuple(dilation)
        self.kernel_size = (kernel_size, kernel_size) if isinstance(kernel_size, int) else tuple(kernel_size)
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.groups = self.check_params(groups)
        self.bias = bias

        if isinstance(padding, str):
            if padding.lower() == 'same':
                self.padding = tuple((s * (k - 1) // 2 for s, k in zip(self.stride, self.kernel_size)))
            elif padding.lower() == 'valid':
                self.padding = (0, 0)
            else:
                raise ValueError('Padding must be "same", "valid", or an integer.')
        else:
            self.padding = (padding, padding) if isinstance(padding, int) else tuple(padding)

        self.weight = nn.Parameter(
            torch.randn(out_channels, in_channels // groups, kernel_size, kernel_size)
        )
        self.bias = nn.Parameter(torch.randn(out_channels)) if bias else None
        
    def forward(self, x):
        # return self.my_forward(x) # a lot slower since it's not optimized, doesn't exploit parallelism
        return self.fold_forward(x)  # Optimized version using matrix multiplication and unfolding with parallelism
        # return self.default_forward(x) # Default PyTorch implementation
    
    def my_forward(self, x):
        batch_size, in_channels, height, width = x.size()

        # Check dimensions
        assert in_channels == self.in_channels, \
            f'Expected input with {self.in_channels} channels, but got {in_channels} channels'

        # Calculate output dimensions
        out_height = (height + 2 * self.padding[0] - self.dilation[0] * (self.kernel_size[0] - 1) - 1) // self.stride[0] + 1
        out_width = (width + 2 * self.padding[1] - self.dilation[1] * (self.kernel_size[1] - 1) - 1) // self.stride[1] + 1

        # Apply padding if needed
        if self.padding != (0, 0):
            x = nn.functional.pad(x, (self.padding[1], self.padding[1], self.padding[0], self.padding[0]))

        # Initialize the output tensor
        out = torch.zeros((batch_size, self.out_channels, out_height, out_width), device=x.device)

        # Perform the convolution operation
        out_channels_per_group = self.out_channels // self.groups
        in_channels_per_group = self.in_channels // self.groups

        for g in range(self.groups):
            for i in range(out_height):
                for j in range(out_width):
                    h_start = i * self.stride[0]
                    h_end = h_start + self.kernel_size[0] * self.dilation[0]
                    w_start = j * self.stride[1]
                    w_end = w_start + self.kernel_size[1] * self.dilation[1]

                    x_slice = x[:, g * in_channels_per_group:(g + 1) * in_channels_per_group, h_start:h_end:self.dilation[0], w_start:w_end:self.dilation[1]]

                    out[:, g * out_channels_per_group:(g + 1) * out_channels_per_group, i, j] = \
                        torch.einsum('bijk,oijk->bo', x_slice, self.weight[g * out_channels_per_group:(g + 1) * out_channels_per_group]) + \
                        self.bias[g * out_channels_per_group:(g + 1) * out_channels_per_group]

        return out

    def fold_forward(self, x):
        """Reproduces F.conv2d with the specified parameters without the explicit call to the function."""
        n_batch, in_channels, in_height, in_width = x.size()
        out_height = (in_height + 2 * self.padding[0] - self.dilation[0] * (self.kernel_size[0] - 1) - 1) // self.stride[0] + 1
        out_width = (in_width + 2 * self.padding[1] - self.dilation[1] * (self.kernel_size[1] - 1) - 1) // self.stride[1] + 1
    
        # Pad the input tensor
        if self.padding != (0, 0):
            x = nn.functional.pad(x, (self.padding[1], self.padding[1], self.padding[0], self.padding[0]))
    
        # Unfold the input tensor to apply the convolution operation
        x = torch.nn.functional.unfold(x, self.kernel_size, dilation=self.dilation, stride=self.stride)
    
        # Reshape the weights
        weight = self.weight.view(self.weight.size(0), -1)
    
        # Perform matrix multiplication
        x = torch.matmul(x.permute(0, 2, 1), weight.t())
        
        # Reshape the output tensor
        x = x.view(n_batch, self.out_channels, out_height, out_width)
    
        return x + self.bias.view(1, -1, 1, 1) if self.bias is not None else x


    def default_forward(self, x):
        return F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)

    def check_params(self, groups):
        """Checks the parameters of the Conv2d layer to ensure they are valid."""
        if groups <= 0:
            raise ValueError('groups must be a positive integer')
        if self.in_channels % groups != 0:
            raise ValueError('in_channels must be divisible by groups')
        if self.out_channels % groups != 0:
            raise ValueError('out_channels must be divisible by groups')
        return groups

In [12]:
import torch
from torch import nn
import torch.nn.functional as F

from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms

import pytorch_lightning as pl
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import EarlyStopping
import torchmetrics

class MNISTModel(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.mnist_val = None
        self.mnist_train = None
        self.mnist_test = None
        self.convBlock1 = nn.Sequential(
            MyConv2d(1, 32, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
        self.convBlock2 = nn.Sequential(
            MyConv2d(32, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
        self.fc_out = nn.Linear(64 * 7 * 7, 10)
        
        # Initialize metrics
        self.train_accuracy = torchmetrics.Accuracy(task='multiclass', num_classes=10)
        self.val_accuracy = torchmetrics.Accuracy(task='multiclass', num_classes=10)
        self.test_accuracy = torchmetrics.Accuracy(task='multiclass', num_classes=10)

    def forward(self, x):
        x = self.convBlock1(x)
        x = self.convBlock2(x)
        x = x.view(x.size(0), -1)  # Flatten the tensor
        return self.fc_out(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self.forward(x)
        loss = F.cross_entropy(logits, y)
        
        # Calculate and log accuracy
        preds = torch.argmax(logits, dim=1)
        acc = self.train_accuracy(preds, y)
        self.log('train_loss', loss)
        self.log('train_acc', acc, prog_bar=True)
        
        return loss
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self.forward(x)
        loss = F.cross_entropy(logits, y)
        
        # Calculate and log accuracy
        preds = torch.argmax(logits, dim=1)
        acc = self.val_accuracy(preds, y)
        self.log('val_loss', loss, prog_bar=True)
        self.log('val_acc', acc, prog_bar=True)
        
        return loss

    def test_step(self, batch, batch_idx):
        x, y = batch
        logits = self.forward(x)
        loss = F.cross_entropy(logits, y)
        
        # Calculate and log accuracy
        preds = torch.argmax(logits, dim=1)
        acc = self.test_accuracy(preds, y)
        self.log('test_loss', loss, prog_bar=True)
        self.log('test_acc', acc, prog_bar=True)
        
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.001)

    def prepare_data(self):
        # Download and prepare the MNIST dataset
        datasets.MNIST('../data/raw', train=True, download=True)
        datasets.MNIST('../data/raw', train=False, download=True)

    def setup(self, stage=None):
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,))
        ])
        if stage == 'fit' or stage is None:
            mnist_full = datasets.MNIST('../data/raw', train=True, transform=transform)
            self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])
        if stage == 'test' or stage is None:
            self.mnist_test = datasets.MNIST('../data/raw', train=False, transform=transform)

    def train_dataloader(self):
        return DataLoader(self.mnist_train, batch_size=256, num_workers=8, shuffle=True, persistent_workers=True)

    def val_dataloader(self):
        return DataLoader(self.mnist_val, batch_size=128, num_workers=8, shuffle=False, persistent_workers=True)
    
    def test_dataloader(self):
        return DataLoader(self.mnist_test, batch_size=128, num_workers=8, shuffle=False, persistent_workers=True)

In [13]:
# EarlyStopping callback
early_stopping = EarlyStopping(
    monitor='val_loss',  # Metric to monitor
    patience=3,  # Number of epochs with no improvement after which training will be stopped
    verbose=True,  # Print messages when early stopping is triggered
    mode='min'  # Minimize the monitored metric
)

# Trainer with EarlyStopping callback
trainer = Trainer(max_epochs=50, callbacks=[early_stopping], accelerator='gpu', devices=1, default_root_dir='../models/cnn-scratch')
model = MNISTModel()
trainer.fit(model)

# Get best weights from checkpoint
best_model_path = trainer.checkpoint_callback.best_model_path
model = MNISTModel.load_from_checkpoint(best_model_path)

# Test the model
trainer.test(model)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name           | Type               | Params | Mode 
--------------------------------------------------------------
0 | convBlock1     | Sequential         | 320    | train
1 | convBlock2     | Sequential         | 18.5 K | train
2 | fc_out         | Linear             | 31.4 K | train
3 | train_accuracy | MulticlassAccuracy | 0      | train
4 | val_accuracy   | MulticlassAccuracy | 0      | train
5 | test_accuracy  | MulticlassAccuracy | 0      | train
--------------------------------------------------------------
50.2 K    Trainable params
0         Non-trainable params
50.2 K    Total params
0.201     Total estimated model params size (MB)
12        Modules in train mode
0         Modules in eval mode


Epoch 0: 100%|██████████| 215/215 [00:08<00:00, 24.15it/s, v_num=1, train_acc=0.875]
Validation: |          | 0/? [00:00<?, ?it/s][A
Validation:   0%|          | 0/40 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|          | 0/40 [00:00<?, ?it/s][A
Validation DataLoader 0:   2%|▎         | 1/40 [00:00<00:01, 27.70it/s][A
Validation DataLoader 0:   5%|▌         | 2/40 [00:00<00:01, 31.37it/s][A
Validation DataLoader 0:   8%|▊         | 3/40 [00:00<00:00, 39.20it/s][A
Validation DataLoader 0:  10%|█         | 4/40 [00:00<00:00, 43.74it/s][A
Validation DataLoader 0:  12%|█▎        | 5/40 [00:00<00:00, 45.61it/s][A
Validation DataLoader 0:  15%|█▌        | 6/40 [00:00<00:00, 49.07it/s][A
Validation DataLoader 0:  18%|█▊        | 7/40 [00:00<00:00, 50.56it/s][A
Validation DataLoader 0:  20%|██        | 8/40 [00:00<00:00, 52.08it/s][A
Validation DataLoader 0:  22%|██▎       | 9/40 [00:00<00:00, 52.39it/s][A
Validation DataLoader 0:  25%|██▌       | 10/40 [00:00<00:00, 52.41it/

Metric val_loss improved. New best score: 1.530


Epoch 1: 100%|██████████| 215/215 [00:08<00:00, 24.14it/s, v_num=1, train_acc=0.921, val_loss=1.530, val_acc=0.912]
Validation: |          | 0/? [00:00<?, ?it/s][A
Validation:   0%|          | 0/40 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|          | 0/40 [00:00<?, ?it/s][A
Validation DataLoader 0:   2%|▎         | 1/40 [00:00<00:00, 43.99it/s][A
Validation DataLoader 0:   5%|▌         | 2/40 [00:00<00:00, 51.36it/s][A
Validation DataLoader 0:   8%|▊         | 3/40 [00:00<00:00, 53.36it/s][A
Validation DataLoader 0:  10%|█         | 4/40 [00:00<00:00, 55.17it/s][A
Validation DataLoader 0:  12%|█▎        | 5/40 [00:00<00:00, 56.47it/s][A
Validation DataLoader 0:  15%|█▌        | 6/40 [00:00<00:00, 57.45it/s][A
Validation DataLoader 0:  18%|█▊        | 7/40 [00:00<00:00, 58.30it/s][A
Validation DataLoader 0:  20%|██        | 8/40 [00:00<00:00, 58.86it/s][A
Validation DataLoader 0:  22%|██▎       | 9/40 [00:00<00:00, 53.87it/s][A
Validation DataLoader 0:  25%|██▌      

Metric val_loss improved by 0.696 >= min_delta = 0.0. New best score: 0.834


Epoch 2: 100%|██████████| 215/215 [00:08<00:00, 24.13it/s, v_num=1, train_acc=0.940, val_loss=0.834, val_acc=0.935]
Validation: |          | 0/? [00:00<?, ?it/s][A
Validation:   0%|          | 0/40 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|          | 0/40 [00:00<?, ?it/s][A
Validation DataLoader 0:   2%|▎         | 1/40 [00:00<00:01, 36.53it/s][A
Validation DataLoader 0:   5%|▌         | 2/40 [00:00<00:00, 48.49it/s][A
Validation DataLoader 0:   8%|▊         | 3/40 [00:00<00:00, 53.97it/s][A
Validation DataLoader 0:  10%|█         | 4/40 [00:00<00:00, 55.45it/s][A
Validation DataLoader 0:  12%|█▎        | 5/40 [00:00<00:00, 57.48it/s][A
Validation DataLoader 0:  15%|█▌        | 6/40 [00:00<00:00, 58.49it/s][A
Validation DataLoader 0:  18%|█▊        | 7/40 [00:00<00:00, 56.93it/s][A
Validation DataLoader 0:  20%|██        | 8/40 [00:00<00:00, 57.20it/s][A
Validation DataLoader 0:  22%|██▎       | 9/40 [00:00<00:00, 53.23it/s][A
Validation DataLoader 0:  25%|██▌      

Metric val_loss improved by 0.282 >= min_delta = 0.0. New best score: 0.552


Epoch 5: 100%|██████████| 215/215 [00:09<00:00, 23.80it/s, v_num=1, train_acc=0.963, val_loss=0.552, val_acc=0.947]
Validation: |          | 0/? [00:00<?, ?it/s][A
Validation:   0%|          | 0/40 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|          | 0/40 [00:00<?, ?it/s][A
Validation DataLoader 0:   2%|▎         | 1/40 [00:00<00:00, 70.02it/s][A
Validation DataLoader 0:   5%|▌         | 2/40 [00:00<00:00, 73.40it/s][A
Validation DataLoader 0:   8%|▊         | 3/40 [00:00<00:00, 75.91it/s][A
Validation DataLoader 0:  10%|█         | 4/40 [00:00<00:00, 74.62it/s][A
Validation DataLoader 0:  12%|█▎        | 5/40 [00:00<00:00, 75.56it/s][A
Validation DataLoader 0:  15%|█▌        | 6/40 [00:00<00:00, 75.82it/s][A
Validation DataLoader 0:  18%|█▊        | 7/40 [00:00<00:00, 74.67it/s][A
Validation DataLoader 0:  20%|██        | 8/40 [00:00<00:00, 74.59it/s][A
Validation DataLoader 0:  22%|██▎       | 9/40 [00:00<00:00, 69.14it/s][A
Validation DataLoader 0:  25%|██▌      

Metric val_loss improved by 0.064 >= min_delta = 0.0. New best score: 0.488


Epoch 8: 100%|██████████| 215/215 [00:09<00:00, 23.61it/s, v_num=1, train_acc=0.981, val_loss=0.488, val_acc=0.956]
Validation: |          | 0/? [00:00<?, ?it/s][A
Validation:   0%|          | 0/40 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|          | 0/40 [00:00<?, ?it/s][A
Validation DataLoader 0:   2%|▎         | 1/40 [00:00<00:01, 27.25it/s][A
Validation DataLoader 0:   5%|▌         | 2/40 [00:00<00:01, 26.76it/s][A
Validation DataLoader 0:   8%|▊         | 3/40 [00:00<00:01, 32.94it/s][A
Validation DataLoader 0:  10%|█         | 4/40 [00:00<00:01, 34.38it/s][A
Validation DataLoader 0:  12%|█▎        | 5/40 [00:00<00:00, 38.39it/s][A
Validation DataLoader 0:  15%|█▌        | 6/40 [00:00<00:00, 41.23it/s][A
Validation DataLoader 0:  18%|█▊        | 7/40 [00:00<00:00, 43.49it/s][A
Validation DataLoader 0:  20%|██        | 8/40 [00:00<00:00, 45.49it/s][A
Validation DataLoader 0:  22%|██▎       | 9/40 [00:00<00:00, 46.52it/s][A
Validation DataLoader 0:  25%|██▌      

Metric val_loss improved by 0.035 >= min_delta = 0.0. New best score: 0.452


Epoch 10: 100%|██████████| 215/215 [00:08<00:00, 24.14it/s, v_num=1, train_acc=0.954, val_loss=0.452, val_acc=0.960]
Validation: |          | 0/? [00:00<?, ?it/s][A
Validation:   0%|          | 0/40 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|          | 0/40 [00:00<?, ?it/s][A
Validation DataLoader 0:   2%|▎         | 1/40 [00:00<00:01, 30.54it/s][A
Validation DataLoader 0:   5%|▌         | 2/40 [00:00<00:00, 39.08it/s][A
Validation DataLoader 0:   8%|▊         | 3/40 [00:00<00:00, 41.42it/s][A
Validation DataLoader 0:  10%|█         | 4/40 [00:00<00:00, 46.04it/s][A
Validation DataLoader 0:  12%|█▎        | 5/40 [00:00<00:00, 47.77it/s][A
Validation DataLoader 0:  15%|█▌        | 6/40 [00:00<00:00, 49.77it/s][A
Validation DataLoader 0:  18%|█▊        | 7/40 [00:00<00:00, 50.91it/s][A
Validation DataLoader 0:  20%|██        | 8/40 [00:00<00:00, 52.54it/s][A
Validation DataLoader 0:  22%|██▎       | 9/40 [00:00<00:00, 52.63it/s][A
Validation DataLoader 0:  25%|██▌     

Metric val_loss improved by 0.028 >= min_delta = 0.0. New best score: 0.424


Epoch 11: 100%|██████████| 215/215 [00:09<00:00, 23.84it/s, v_num=1, train_acc=0.977, val_loss=0.424, val_acc=0.963]
Validation: |          | 0/? [00:00<?, ?it/s][A
Validation:   0%|          | 0/40 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|          | 0/40 [00:00<?, ?it/s][A
Validation DataLoader 0:   2%|▎         | 1/40 [00:00<00:00, 39.98it/s][A
Validation DataLoader 0:   5%|▌         | 2/40 [00:00<00:00, 51.77it/s][A
Validation DataLoader 0:   8%|▊         | 3/40 [00:00<00:00, 55.79it/s][A
Validation DataLoader 0:  10%|█         | 4/40 [00:00<00:00, 52.46it/s][A
Validation DataLoader 0:  12%|█▎        | 5/40 [00:00<00:00, 54.98it/s][A
Validation DataLoader 0:  15%|█▌        | 6/40 [00:00<00:00, 55.87it/s][A
Validation DataLoader 0:  18%|█▊        | 7/40 [00:00<00:00, 57.47it/s][A
Validation DataLoader 0:  20%|██        | 8/40 [00:00<00:00, 57.78it/s][A
Validation DataLoader 0:  22%|██▎       | 9/40 [00:00<00:00, 55.61it/s][A
Validation DataLoader 0:  25%|██▌     

Metric val_loss improved by 0.071 >= min_delta = 0.0. New best score: 0.353


Epoch 14: 100%|██████████| 215/215 [00:09<00:00, 23.65it/s, v_num=1, train_acc=0.977, val_loss=0.353, val_acc=0.967]
Validation: |          | 0/? [00:00<?, ?it/s][A
Validation:   0%|          | 0/40 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|          | 0/40 [00:00<?, ?it/s][A
Validation DataLoader 0:   2%|▎         | 1/40 [00:00<00:01, 32.15it/s][A
Validation DataLoader 0:   5%|▌         | 2/40 [00:00<00:00, 41.06it/s][A
Validation DataLoader 0:   8%|▊         | 3/40 [00:00<00:00, 47.05it/s][A
Validation DataLoader 0:  10%|█         | 4/40 [00:00<00:00, 50.04it/s][A
Validation DataLoader 0:  12%|█▎        | 5/40 [00:00<00:00, 51.74it/s][A
Validation DataLoader 0:  15%|█▌        | 6/40 [00:00<00:00, 54.11it/s][A
Validation DataLoader 0:  18%|█▊        | 7/40 [00:00<00:00, 54.49it/s][A
Validation DataLoader 0:  20%|██        | 8/40 [00:00<00:00, 54.71it/s][A
Validation DataLoader 0:  22%|██▎       | 9/40 [00:00<00:00, 52.89it/s][A
Validation DataLoader 0:  25%|██▌     

Monitored metric val_loss did not improve in the last 3 records. Best score: 0.353. Signaling Trainer to stop.


Epoch 16: 100%|██████████| 215/215 [00:09<00:00, 22.22it/s, v_num=1, train_acc=0.981, val_loss=0.380, val_acc=0.966]


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing DataLoader 0: 100%|██████████| 79/79 [00:01<00:00, 66.24it/s]
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
        test_acc            0.9653000235557556
        test_loss           0.41163358092308044
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


[{'test_loss': 0.41163358092308044, 'test_acc': 0.9653000235557556}]

In [15]:
%reload_ext tensorboard
%tensorboard --logdir=../models/cnn-scratch/lightning_logs/

Reusing TensorBoard on port 6007 (pid 11622), started 0:00:28 ago. (Use '!kill 11622' to kill it.)