In [None]:
from torchvision.models import vit_b_16
from torchinfo import summary

vit_model = vit_b_16(weights = "DEFAULT")
summary(vit_model, 
        input_size=(1, 3, 224, 224), # (batch, C = 3, H, W)
        col_names=["input_size", "output_size", "num_params", "trainable"],
        col_width=20,
        row_settings=["var_names"]
        )

In [None]:
vit_model # get input and output functions from here

In [None]:
from torch import nn, inference_mode

# These steps must be followed in this exact order
# -------------------------------------------------

# 1. get the pretraind weights of the original conv2d block (input)
old_conv = vit_model.conv_proj

# 2. create the new conv2d input block, change the input to take 1 channel only (grayscale)
new_conv = nn.Conv2d(1, 768, kernel_size=(16, 16), stride=(16, 16))

# 3. average the rgb weights across channels to form one grayscale channel
with inference_mode():
    new_conv.weight[:] = old_conv.weight.mean(dim=1, keepdim=True)
    new_conv.bias[:] = old_conv.bias
    
# 4. replace the input block
vit_model.conv_proj = new_conv

# 5. freeze model
for param in vit_model.parameters():
    param.requires_grad = False

# change the output to 2 classes only (healthy, pd)
# this only unfreezes this block
vit_model.heads = nn.Sequential(
    nn.Linear(in_features=768, out_features=2)
)

In [None]:
summary(vit_model, 
        input_size=(1, 1, 224, 224), # (batch, C = 1 , H, W)
        col_names=["input_size", "output_size", "num_params", "trainable"],
        col_width=20,
        row_settings=["var_names"]
        )

In [19]:
from torchvision.models import ViT_B_16_Weights
ViT_B_16_Weights.DEFAULT.transforms()

ImageClassification(
    crop_size=[224]
    resize_size=[256]
    mean=[0.485, 0.456, 0.406]
    std=[0.229, 0.224, 0.225]
    interpolation=InterpolationMode.BILINEAR
)

In [None]:
from model_Vit_b16 import create_vit
from torchinfo import summary

model = create_vit()

summary(model, 
        input_size=(1, 1, 224, 224), # (batch, C = 1 , H, W)
        col_names=["input_size", "output_size", "num_params", "trainable"],
        col_width=20,
        row_settings=["var_names"]
        )