In [None]:
from torchvision.models import inception_v3 
from torchinfo import summary

inception_model = inception_v3(weights = "DEFAULT")
summary(inception_model, 
        input_size=(1, 3, 512, 512), # (batch, C = 3, H, W)
        col_names=["input_size", "output_size", "num_params", "trainable"],
        col_width=20,
        row_settings=["var_names"]
        )

In [None]:
inception_model

In [9]:
old_conv = inception_model.Conv2d_1a_3x3.conv

old_conv

Conv2d(1, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)

In [17]:
from torch import nn, inference_mode

# These steps must be followed in this exact order
# -------------------------------------------------

# 1. get the pretraind weights of the original conv2d block (input)
old_conv = inception_model.Conv2d_1a_3x3.conv

# 2. create the new conv2d input block, change the input to take 1 channel only (grayscale)
new_conv = nn.Conv2d(1, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)

# 3. average the rgb weights across channels to form one grayscale channel
with inference_mode():
    new_conv.weight[:] = old_conv.weight.mean(dim=1, keepdim=True)
    if old_conv.bias is not None:
        new_conv.bias[:] = old_conv.bias
    
# 4. replace the input block
inception_model.Conv2d_1a_3x3.conv = new_conv

# Disable input transform (expects 3 channels)
inception_model._transform_input = lambda x: x

# 5. freeze model
for param in inception_model.parameters():
    param.requires_grad = False

# change the output to 2 classes only (healthy, pd)
# this only unfreezes this block
inception_model.fc = nn.Linear(in_features=2048, out_features=2, bias=True)

In [None]:
summary(inception_model, 
        input_size=(1, 1, 512, 512), # (batch, C = 1, H, W)
        col_names=["input_size", "output_size", "num_params", "trainable"],
        col_width=20,
        row_settings=["var_names"]
        )

In [None]:
from model_inceptionV3 import create_inception
from torchinfo import summary

model = create_inception()
summary(model, 
        input_size=(1, 1, 512, 512), # (batch, C = 1, H, W)
        col_names=["input_size", "output_size", "num_params", "trainable"],
        col_width=20,
        row_settings=["var_names"]
        )