In [9]:
import gin
import torch
from torch import nn
import torch.optim as optim
from torchinfo import summary

In [16]:
# Get cpu or gpu device for training.
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")


Using cpu device


In [79]:
#@gin.configurable
class CNN(nn.Module):
    def __init__(
        self, num_classes: int, kernel_size: int, filter1: int, filter2: int
    ) -> None:
        super().__init__()

        self.convolutions = nn.Sequential(
            nn.Conv2d(1, filter1, kernel_size=kernel_size, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2),
            
            nn.Conv2d(filter1, filter2, kernel_size=kernel_size, stride=1, padding=0),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2),
            
            nn.Conv2d(filter2, 32, kernel_size=kernel_size, stride=1, padding=0),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2),

            #nn.Conv2d(filter2, 32, kernel_size=kernel_size, stride=1, padding=0),
            #nn.ReLU(),
            #nn.MaxPool2d(kernel_size=2),
        )

        self.dense = nn.Sequential(
            nn.Flatten(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 32),
            nn.ReLU(),
            nn.Linear(32, num_classes),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.convolutions(x)
        logits = self.dense(x)
        return logits


In [None]:
gin.parse_config_file("model.gin")

In [82]:
# 15 = aantal classes, 3 = kernel_size, 32 = grootte filter 1, 32 = grootte  filter2

model = CNN(10, 3, 32,32).to(device)
#print(model)
summary(model, input_size=(64, 1, 28, 28))

Layer (type:depth-idx)                   Output Shape              Param #
CNN                                      [64, 10]                  --
├─Sequential: 1-1                        [64, 32, 2, 2]            --
│    └─Conv2d: 2-1                       [64, 32, 28, 28]          320
│    └─ReLU: 2-2                         [64, 32, 28, 28]          --
│    └─MaxPool2d: 2-3                    [64, 32, 14, 14]          --
│    └─Conv2d: 2-4                       [64, 32, 12, 12]          9,248
│    └─ReLU: 2-5                         [64, 32, 12, 12]          --
│    └─MaxPool2d: 2-6                    [64, 32, 6, 6]            --
│    └─Conv2d: 2-7                       [64, 32, 4, 4]            9,248
│    └─ReLU: 2-8                         [64, 32, 4, 4]            --
│    └─MaxPool2d: 2-9                    [64, 32, 2, 2]            --
├─Sequential: 1-2                        [64, 10]                  --
│    └─Flatten: 2-10                     [64, 128]                 --
│    └─L