In [1]:
from helper_functions import train, predict
from dataset import train_cats_dogs, test_cats_dogs
import torch
from torch import nn, optim
import torch.nn.functional as F

In [27]:
class AlexNet(nn.Module):
    def __init__(self):
        super(AlexNet, self).__init__()
        self.conv_block = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=96, kernel_size=11, stride=4),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(in_channels=96, out_channels=256, kernel_size=5, padding=2),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(in_channels=256, out_channels=384, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(in_channels=384, out_channels=384, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(in_channels=384, out_channels=256, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=3, stride=2),
        )
        self.classification_head = nn.Sequential(
            nn.Linear(in_features=9216, out_features=4096),
            nn.ReLU(),
            nn.Linear(in_features=4096, out_features=4096),
            nn.ReLU(),
            nn.Linear(in_features=4096, out_features=2),
            nn.Softmax(dim=1)
        )

    def forward(self, x: torch.Tensor):
        x = self.conv_block(x)
        x = x.unsqueeze(dim=0) if x.dim() != 4 else x
        x = x.view(x.size(0), -1)
        return self.classification_head(x)

image, labels = next(iter(test_cats_dogs))
model = AlexNet()
print(f"Image Shape: {image.shape}")
print(f"Prediction: {predict(model, image)}")

Image Shape: torch.Size([3, 227, 227])
Prediction: tensor([[0.5008, 0.4992]])


In [29]:
from torchinfo import summary

summary(model, input_size=(3, 227, 227), col_names=["input_size", "output_size", "num_params", "trainable"], row_settings=["var_names"])

Layer (type (var_name))                  Input Shape               Output Shape              Param #                   Trainable
AlexNet (AlexNet)                        [3, 227, 227]             [1, 2]                    --                        True
├─Sequential (conv_block)                [3, 227, 227]             [256, 6, 6]               --                        True
│    └─Conv2d (0)                        [3, 227, 227]             [96, 55, 55]              34,944                    True
│    └─ReLU (1)                          [96, 55, 55]              [96, 55, 55]              --                        --
│    └─MaxPool2d (2)                     [96, 55, 55]              [96, 27, 27]              --                        --
│    └─Conv2d (3)                        [96, 27, 27]              [256, 27, 27]             614,656                   True
│    └─ReLU (4)                          [256, 27, 27]             [256, 27, 27]             --                        --
│    └─Ma