In [1]:
from torchvision.models import resnet50 
# may try more archs (resnet_50: 25.6M params, resnet_152:60.2M, but a tiny bit more accurate)
from torchinfo import summary

resnet_model = resnet50(weights = "DEFAULT")
summary(resnet_model, 
        input_size=(1, 3, 512, 512), # (batch, C = 3, H, W)
        col_names=["input_size", "output_size", "num_params", "trainable"],
        col_width=20,
        row_settings=["var_names"]
        )

Layer (type (var_name))                  Input Shape          Output Shape         Param #              Trainable
ResNet (ResNet)                          [1, 3, 512, 512]     [1, 1000]            --                   True
├─Conv2d (conv1)                         [1, 3, 512, 512]     [1, 64, 256, 256]    9,408                True
├─BatchNorm2d (bn1)                      [1, 64, 256, 256]    [1, 64, 256, 256]    128                  True
├─ReLU (relu)                            [1, 64, 256, 256]    [1, 64, 256, 256]    --                   --
├─MaxPool2d (maxpool)                    [1, 64, 256, 256]    [1, 64, 128, 128]    --                   --
├─Sequential (layer1)                    [1, 64, 128, 128]    [1, 256, 128, 128]   --                   True
│    └─Bottleneck (0)                    [1, 64, 128, 128]    [1, 256, 128, 128]   --                   True
│    │    └─Conv2d (conv1)               [1, 64, 128, 128]    [1, 64, 128, 128]    4,096                True
│    │    └─BatchN

In [None]:
resnet_model

In [3]:
from torch import nn, inference_mode

# These steps must be followed in this exact order
# -------------------------------------------------

# 1. get the pretraind weights of the original conv2d block (input)
old_conv = resnet_model.conv1

# 2. create the new conv2d input block, change the input to take 1 channel only (grayscale)
new_conv = nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)

# 3. average the rgb weights across channels to form one grayscale channel
with inference_mode():
    new_conv.weight[:] = old_conv.weight.mean(dim=1, keepdim=True)
    if old_conv.bias is not None:
        new_conv.bias[:] = old_conv.bias
    
# 4. replace the input block
resnet_model.conv1 = new_conv

# 5. freeze model
for param in resnet_model.parameters():
    param.requires_grad = False

# change the output to 2 classes only (healthy, pd)
# this only unfreezes this block
resnet_model.fc = nn.Linear(in_features=2048, out_features=2, bias=True)

In [4]:
summary(resnet_model, 
        input_size=(1, 1, 512, 512), # (batch, C = 3, H, W)
        col_names=["input_size", "output_size", "num_params", "trainable"],
        col_width=20,
        row_settings=["var_names"]
        )

Layer (type (var_name))                  Input Shape          Output Shape         Param #              Trainable
ResNet (ResNet)                          [1, 1, 512, 512]     [1, 2]               --                   Partial
├─Conv2d (conv1)                         [1, 1, 512, 512]     [1, 64, 256, 256]    (3,136)              False
├─BatchNorm2d (bn1)                      [1, 64, 256, 256]    [1, 64, 256, 256]    (128)                False
├─ReLU (relu)                            [1, 64, 256, 256]    [1, 64, 256, 256]    --                   --
├─MaxPool2d (maxpool)                    [1, 64, 256, 256]    [1, 64, 128, 128]    --                   --
├─Sequential (layer1)                    [1, 64, 128, 128]    [1, 256, 128, 128]   --                   False
│    └─Bottleneck (0)                    [1, 64, 128, 128]    [1, 256, 128, 128]   --                   False
│    │    └─Conv2d (conv1)               [1, 64, 128, 128]    [1, 64, 128, 128]    (4,096)              False
│    │    

In [1]:
from model_resnet50 import create_resnet
from torchinfo import summary

model = create_resnet()
summary(model, 
        input_size=(1, 1, 224, 224), # (batch, C = 1, H, W)
        col_names=["input_size", "output_size", "num_params", "trainable"],
        col_width=20,
        row_settings=["var_names"]
        )

Layer (type (var_name))                  Input Shape          Output Shape         Param #              Trainable
ResNet (ResNet)                          [1, 1, 224, 224]     [1, 2]               --                   Partial
├─Conv2d (conv1)                         [1, 1, 224, 224]     [1, 64, 112, 112]    (3,136)              False
├─BatchNorm2d (bn1)                      [1, 64, 112, 112]    [1, 64, 112, 112]    (128)                False
├─ReLU (relu)                            [1, 64, 112, 112]    [1, 64, 112, 112]    --                   --
├─MaxPool2d (maxpool)                    [1, 64, 112, 112]    [1, 64, 56, 56]      --                   --
├─Sequential (layer1)                    [1, 64, 56, 56]      [1, 256, 56, 56]     --                   False
│    └─Bottleneck (0)                    [1, 64, 56, 56]      [1, 256, 56, 56]     --                   False
│    │    └─Conv2d (conv1)               [1, 64, 56, 56]      [1, 64, 56, 56]      (4,096)              False
│    │    