In [23]:
import torch
from torch import nn
from torchinfo import summary

In [37]:
class VGG(nn.Module):
    def __init__(self):
        super().__init__()

        self.conv1 = nn.Sequential(
            nn.Conv2d(
                in_channels=1, 
                out_channels=16, 
                kernel_size=3, 
                stride=1,
                padding=2
            ),
            nn.ReLU(),
            nn.MaxPool2d(
                kernel_size=2
            )
        )

        self.conv2 = nn.Sequential(
            nn.Conv2d(
                in_channels=16, 
                out_channels=32, 
                kernel_size=3, 
                stride=1,
                padding=2
            ),
            nn.ReLU(),
            nn.MaxPool2d(
                kernel_size=2
            )
        )

        self.conv3 = nn.Sequential(
            nn.Conv2d(
                in_channels=32, 
                out_channels=64, 
                kernel_size=3, 
                stride=1,
                padding=2
            ),
            nn.ReLU(),
            nn.MaxPool2d(
                kernel_size=2
            )
        )

        self.conv4 = nn.Sequential(
            nn.Conv2d(
                in_channels=64, 
                out_channels=128, 
                kernel_size=3, 
                stride=1,
                padding=2
            ),
            nn.ReLU(),
            nn.MaxPool2d(
                kernel_size=2
            )
        )

        self.flatten = nn.Flatten()
        self.linear = nn.Linear(128 * 9 * 7, 5)
        self.softmax = nn.Softmax(dim=1)

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        x = self.flatten(x)
        x = self.linear(x)
        x = self.softmax(x)
        return x


In [38]:
vgg = VGG()

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
summary(vgg, (1, 1, 128, 87), device=device)

Using device: cuda


Layer (type:depth-idx)                   Output Shape              Param #
VGG                                      [1, 5]                    --
├─Sequential: 1-1                        [1, 16, 65, 44]           --
│    └─Conv2d: 2-1                       [1, 16, 130, 89]          160
│    └─ReLU: 2-2                         [1, 16, 130, 89]          --
│    └─MaxPool2d: 2-3                    [1, 16, 65, 44]           --
├─Sequential: 1-2                        [1, 32, 33, 23]           --
│    └─Conv2d: 2-4                       [1, 32, 67, 46]           4,640
│    └─ReLU: 2-5                         [1, 32, 67, 46]           --
│    └─MaxPool2d: 2-6                    [1, 32, 33, 23]           --
├─Sequential: 1-3                        [1, 64, 17, 12]           --
│    └─Conv2d: 2-7                       [1, 64, 35, 25]           18,496
│    └─ReLU: 2-8                         [1, 64, 35, 25]           --
│    └─MaxPool2d: 2-9                    [1, 64, 17, 12]           --
├─Seque