In [3]:
from torchvision.models import vgg19_bn # may try more vgg archs
from torchinfo import summary

vgg_model = vgg19_bn(weights = "DEFAULT")
summary(vgg_model, 
        input_size=(1, 3, 512, 512), # (batch, C = 3, H, W)
        col_names=["input_size", "output_size", "num_params", "trainable"],
        col_width=20,
        row_settings=["var_names"]
        )

Layer (type (var_name))                  Input Shape          Output Shape         Param #              Trainable
VGG (VGG)                                [1, 3, 512, 512]     [1, 1000]            --                   True
├─Sequential (features)                  [1, 3, 512, 512]     [1, 512, 16, 16]     --                   True
│    └─Conv2d (0)                        [1, 3, 512, 512]     [1, 64, 512, 512]    1,792                True
│    └─BatchNorm2d (1)                   [1, 64, 512, 512]    [1, 64, 512, 512]    128                  True
│    └─ReLU (2)                          [1, 64, 512, 512]    [1, 64, 512, 512]    --                   --
│    └─Conv2d (3)                        [1, 64, 512, 512]    [1, 64, 512, 512]    36,928               True
│    └─BatchNorm2d (4)                   [1, 64, 512, 512]    [1, 64, 512, 512]    128                  True
│    └─ReLU (5)                          [1, 64, 512, 512]    [1, 64, 512, 512]    --                   --
│    └─MaxPool2d (

In [None]:
vgg_model

In [7]:
from torch import nn, inference_mode

# These steps must be followed in this exact order
# -------------------------------------------------

# 1. get the pretraind weights of the original conv2d block (input)
old_conv = vgg_model.features[0]

# 2. create the new conv2d input block, change the input to take 1 channel only (grayscale)
new_conv = nn.Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))

# 3. average the rgb weights across channels to form one grayscale channel
with inference_mode():
    new_conv.weight = nn.Parameter(old_conv.weight.mean(dim=1, keepdim=True))
    if old_conv.bias is not None:
        new_conv.bias = old_conv.bias
    
# 4. replace the input block
vgg_model.features[0] = new_conv

# 5. freeze model
for param in vgg_model.parameters():
    param.requires_grad = False

# change the output to 2 classes only (healthy, pd)
# this only unfreezes this block
vgg_model.classifier[-1] = nn.Linear(in_features=4096, out_features=2, bias=True)

In [None]:
summary(vgg_model, 
        input_size=(1, 1, 224, 224), # (batch, C = 1, H, W)
        col_names=["input_size", "output_size", "num_params", "trainable"],
        col_width=20,
        row_settings=["var_names"]
        )

Layer (type (var_name))                  Input Shape          Output Shape         Param #              Trainable
VGG (VGG)                                [1, 1, 224, 224]     [1, 2]               --                   Partial
├─Sequential (features)                  [1, 1, 224, 224]     [1, 512, 7, 7]       --                   False
│    └─Conv2d (0)                        [1, 1, 224, 224]     [1, 64, 224, 224]    (640)                False
│    └─BatchNorm2d (1)                   [1, 64, 224, 224]    [1, 64, 224, 224]    (128)                False
│    └─ReLU (2)                          [1, 64, 224, 224]    [1, 64, 224, 224]    --                   --
│    └─Conv2d (3)                        [1, 64, 224, 224]    [1, 64, 224, 224]    (36,928)             False
│    └─BatchNorm2d (4)                   [1, 64, 224, 224]    [1, 64, 224, 224]    (128)                False
│    └─ReLU (5)                          [1, 64, 224, 224]    [1, 64, 224, 224]    --                   --
│    └─Max

In [4]:
from model_vgg19_bn import create_vgg
from torchinfo import summary

model = create_vgg()
summary(model, 
        input_size=(1, 1, 224, 224), # (batch, C = 1, H, W)
        col_names=["input_size", "output_size", "num_params", "trainable"],
        col_width=20,
        row_settings=["var_names"]
        )

Layer (type (var_name))                  Input Shape          Output Shape         Param #              Trainable
VGG (VGG)                                [1, 1, 224, 224]     [1, 2]               --                   Partial
├─Sequential (features)                  [1, 1, 224, 224]     [1, 512, 7, 7]       --                   False
│    └─Conv2d (0)                        [1, 1, 224, 224]     [1, 64, 224, 224]    (640)                False
│    └─BatchNorm2d (1)                   [1, 64, 224, 224]    [1, 64, 224, 224]    (128)                False
│    └─ReLU (2)                          [1, 64, 224, 224]    [1, 64, 224, 224]    --                   --
│    └─Conv2d (3)                        [1, 64, 224, 224]    [1, 64, 224, 224]    (36,928)             False
│    └─BatchNorm2d (4)                   [1, 64, 224, 224]    [1, 64, 224, 224]    (128)                False
│    └─ReLU (5)                          [1, 64, 224, 224]    [1, 64, 224, 224]    --                   --
│    └─Max