In [None]:
from torchvision.models import mobilenet_v3_large 
from torchinfo import summary

mobilenet_model = mobilenet_v3_large(weights = "DEFAULT")
summary(mobilenet_model, 
        input_size=(1, 3, 224, 224), # (batch, C = 3, H, W)
        col_names=["input_size", "output_size", "num_params", "trainable"],
        col_width=20,
        row_settings=["var_names"]
        )

In [None]:
mobilenet_model

In [4]:
from torch import nn, inference_mode

# These steps must be followed in this exact order
# -------------------------------------------------

# 1. get the pretraind weights of the original conv2d block (input)
old_conv = mobilenet_model.features[0][0] # Conv2d inside Conv2dNormActivation

# 2. create the new conv2d input block, change the input to take 1 channel only (grayscale)
new_conv = nn.Conv2d(1, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)

# 3. average the rgb weights across channels to form one grayscale channel
with inference_mode():
    new_conv.weight[:] = old_conv.weight.mean(dim=1, keepdim=True)
    if old_conv.bias is not None:
        new_conv.bias[:] = old_conv.bias
    
# 4. replace the input block
mobilenet_model.features[0][0] = new_conv

# 5. freeze model
for param in mobilenet_model.parameters():
    param.requires_grad = False

# change the output to 2 classes only (healthy, pd)
# this only unfreezes this block
mobilenet_model.classifier[3] = nn.Linear(in_features=1280, out_features=1, bias=True)

In [None]:
summary(mobilenet_model, 
        input_size=(1, 1, 512, 512), # (batch, C = 1, H, W)
        col_names=["input_size", "output_size", "num_params", "trainable"],
        col_width=20,
        row_settings=["var_names"]
        )

In [None]:
from model_mobilenetV3 import create_mobilenetv3
from torchinfo import summary

model = create_mobilenetv3()
summary(model, 
        input_size=(1, 1, 512, 512), # (batch, C = 1, H, W)
        col_names=["input_size", "output_size", "num_params", "trainable"],
        col_width=20,
        row_settings=["var_names"]
        )