In [1]:
# Import necessary libraries
import torch
from torchinfo import summary
from torch import nn

In [2]:
# Inputs

Pytorch_file_path = r'model\MNIST_Digit_Detector.pt'

In [3]:
# TinyVGG Model Architecture

class TinyVGG(nn.Module):
    def __init__(self,
                in_features,
                out_features,
                hidden_units):
        super().__init__()
        self.conv_block_1 = nn.Sequential(
            nn.Conv2d(in_channels=in_features,
                                out_channels=hidden_units,
                                kernel_size=3,
                                padding=1,
                                stride=1),
            nn.ReLU(),
            nn.Conv2d(in_channels=hidden_units,
                    out_channels=hidden_units,
                    kernel_size=2,
                    padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2)
        )
        self.conv_block_2 = nn.Sequential(
            nn.Conv2d(in_channels=hidden_units,
                    out_channels=hidden_units,
                    kernel_size=3,
                    stride=1,
                    padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2)
        )
        self.classifier = nn.Sequential(
            nn.Flatten(),
            # We need to change this in_features below
            nn.Linear(in_features=7*7*hidden_units,  # This is a hardcoded value. 
                                        #The error in the dummy_x gives us the info for this
                    out_features=out_features)
        )

    def forward(self, X):
        X = self.conv_block_1(X)
        X = self.conv_block_2(X)
        X = self.classifier(X)
        return (X)

# Check if GPU is present
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Instantiate the model
model = TinyVGG(in_features=1,
                    out_features=10,
                    hidden_units=10).to(device)
    
# Load the Weights
model.load_state_dict(torch.load(f=Pytorch_file_path,map_location=torch.device(device)))

<All keys matched successfully>

In [4]:
summary(model)

Layer (type:depth-idx)                   Param #
TinyVGG                                  --
├─Sequential: 1-1                        --
│    └─Conv2d: 2-1                       100
│    └─ReLU: 2-2                         --
│    └─Conv2d: 2-3                       410
│    └─ReLU: 2-4                         --
│    └─MaxPool2d: 2-5                    --
├─Sequential: 1-2                        --
│    └─Conv2d: 2-6                       910
│    └─ReLU: 2-7                         --
│    └─MaxPool2d: 2-8                    --
├─Sequential: 1-3                        --
│    └─Flatten: 2-9                      --
│    └─Linear: 2-10                      4,910
Total params: 6,330
Trainable params: 6,330
Non-trainable params: 0