In [None]:
from torchvision.models import densenet201 
from torchinfo import summary

densenet_model = densenet201(weights = "DEFAULT")
summary(densenet_model, 
        input_size=(1, 3, 512, 512), # (batch, C = 3, H, W)
        col_names=["input_size", "output_size", "num_params", "trainable"],
        col_width=20,
        row_settings=["var_names"]
        )

In [None]:
densenet_model

In [5]:
from torch import nn, inference_mode

# These steps must be followed in this exact order
# -------------------------------------------------

# 1. get the pretraind weights of the original conv2d block (input)
old_conv = densenet_model.features[0]

# 2. create the new conv2d input block, change the input to take 1 channel only (grayscale)
new_conv = nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)

# 3. average the rgb weights across channels to form one grayscale channel
with inference_mode():
    new_conv.weight[:] = old_conv.weight.mean(dim=1, keepdim=True)
    if old_conv.bias is not None:
        new_conv.bias[:] = old_conv.bias
    
# 4. replace the input block
densenet_model.features[0] = new_conv

# 5. freeze model
for param in densenet_model.parameters():
    param.requires_grad = False

# change the output to 2 classes only (healthy, pd)
# this only unfreezes this block
densenet_model.classifier = nn.Linear(in_features=1920, out_features=2, bias=True)

In [None]:
summary(densenet_model, 
        input_size=(1, 1, 512, 512), # (batch, C = 1, H, W)
        col_names=["input_size", "output_size", "num_params", "trainable"],
        col_width=20,
        row_settings=["var_names"]
        )

In [None]:
from Archived_models.model_densenet201 import create_densenet
from torchinfo import summary

model = create_densenet()
summary(model, 
        input_size=(1, 1, 512, 512), # (batch, C = 1, H, W)
        col_names=["input_size", "output_size", "num_params", "trainable"],
        col_width=20,
        row_settings=["var_names"]
        )

In [None]:
from improved_DenseNet201 import create_improved_densenet
from torchinfo import summary

model = create_improved_densenet()
summary(model, 
        input_size=(1, 1, 512, 512), # (batch, C = 1, H, W)
        col_names=["input_size", "output_size", "num_params", "trainable"],
        col_width=20,
        row_settings=["var_names"]
        )