# DINO

From: https://github.com/facebookresearch/dino

## Classifier


In [1]:
from transformers import AutoFeatureExtractor, ResNetModel, AutoImageProcessor
from PIL import Image
import requests
from torch import nn
import torch
import sys
import monai

url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
image = Image.open(requests.get(url, stream=True).raw)

feature_extractor = AutoFeatureExtractor.from_pretrained('Ramos-Ramos/dino-resnet-50')
model = ResNetModel.from_pretrained('Ramos-Ramos/dino-resnet-50')



In [2]:
model

ResNetModel(
  (embedder): ResNetEmbeddings(
    (embedder): ResNetConvLayer(
      (convolution): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (normalization): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (activation): ReLU()
    )
    (pooler): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  )
  (encoder): ResNetEncoder(
    (stages): ModuleList(
      (0): ResNetStage(
        (layers): Sequential(
          (0): ResNetBottleNeckLayer(
            (shortcut): ResNetShortCut(
              (convolution): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
              (normalization): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            )
            (layer): Sequential(
              (0): ResNetConvLayer(
                (convolution): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
                (normalizatio

In [3]:
resnet50 = torch.hub.load('facebookresearch/dino:main', 'dino_resnet50')

Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


In [4]:
resnet50

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 

In [5]:
# Print the values of the parameters for the first few layers
for name, param in resnet50.named_parameters():
    if any(layer_name in name for layer_name in ['conv1', 'bn1', 'layer1', 'layer2', 'layer3']):
        print(f"\nParameter: {name}")
        print(param.data)
        sys.exit()


Parameter: conv1.weight
tensor([[[[ 3.9314e-02, -2.3231e-02,  1.9019e-02,  ..., -2.8492e-02,
            9.4900e-03, -1.5267e-01],
          [-6.3298e-02, -1.9383e-01, -3.5079e-02,  ..., -1.1285e-01,
           -7.6729e-02, -1.9084e-01],
          [ 6.4572e-02, -9.0484e-02,  1.3011e-01,  ...,  1.9762e-01,
            2.1275e-01,  5.0856e-02],
          ...,
          [-1.0417e-01, -2.0973e-01,  2.7136e-01,  ...,  7.8400e-01,
            4.2137e-01,  5.4485e-02],
          [-1.8915e-02, -1.2731e-01,  2.9765e-01,  ...,  3.1650e-01,
            6.1463e-02, -1.8677e-01],
          [-9.0737e-02, -2.1856e-01,  1.6286e-01,  ..., -2.8385e-02,
           -1.9972e-01, -3.2438e-01]],

         [[ 6.0889e-02,  1.4994e-01, -6.9268e-02,  ..., -6.9937e-02,
            2.5998e-01,  5.2359e-01],
          [-6.4661e-02,  5.3575e-02, -1.7210e-01,  ..., -4.0584e-01,
            5.9906e-02,  3.2894e-01],
          [ 1.3143e-01,  3.8945e-01,  1.5902e-01,  ..., -6.7193e-01,
            1.2937e-01,  4.7857e-

SystemExit: 

  warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)


In [None]:
# Find the `avgpool` layer in the model
avgpool_layer = None
num_classes = 3
for layer in resnet50.children():
    if isinstance(layer, nn.AdaptiveAvgPool2d):
        avgpool_layer = layer
        break

if avgpool_layer is not None:
    in_features = resnet50.layer4[-1].conv3.in_channels  # Assuming the last convolutional layer is in layer4
    classification_head = nn.Linear(in_features, num_classes)
else:
    raise RuntimeError("No suitable layer found in the model for adding the classification head.")

in_features

512

In [None]:
# Adding a linear classifier on top of the ResNet50

num_classes = 3

linear_classifier = nn.Sequential(
    nn.Linear(2048, num_classes) # 2048 is hard-coded.
    )

resnet50.fc = linear_classifier
resnet50

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 

## SSL

According to the GitHub, this should work:

python -m torch.distributed.launch --nproc_per_node=8 main_dino.py --arch resnet50 --optimizer sgd --lr 0.03 --weight_decay 1e-4 --weight_decay_end 1e-4 --global_crops_scale 0.14 1 --local_crops_scale 0.05 0.14 --data_path /path/to/imagenet/train --output_dir /path/to/saving_dir

With our modifications (changing the image and output directory and removing the distributed training):

python main_dino.py --arch resnet50 --optimizer sgd --lr 0.03 --weight_decay 1e-4 --weight_decay_end 1e-4 --global_crops_scale 0.14 1 --local_crops_scale 0.05 0.14 --data_path /sddata/projects/Cervical_Cancer_Projects/data/full_dataset_duke_liger_itoju_5StLowQual --output_dir /sddata/projects/SSL/dino/finetuning/full_dataset_duke_liger_itoju_5StLowQual

Note, the lines 145-155 in main_dino.py:

    dataset = datasets.ImageFolder(args.data_path, transform=transform)
    sampler = torch.utils.data.DistributedSampler(dataset, shuffle=True)
    data_loader = torch.utils.data.DataLoader(
        dataset,
        sampler=sampler,
        batch_size=args.batch_size_per_gpu,
        num_workers=args.num_workers,
        pin_memory=True,
        drop_last=True,
    )
    print(f"Data loaded: there are {len(dataset)} images.")

will cause an issue. Namely, the dataset.ImageFolder will look for *subfolders*, so it won't find the images. We need to then put them all in a folder.

So, we put the data directory 'full_dataset_duke_liger_itoju_5StLowQual' in another folder called 'full_dataset'.

We then can use: 

python main_dino.py --arch resnet50 --optimizer sgd --lr 0.03 --weight_decay 1e-4 --weight_decay_end 1e-4 --global_crops_scale 0.14 1 --local_crops_scale 0.05 0.14 --data_path /sddata/projects/Cervical_Cancer_Projects/data/full_dataset --output_dir /sddata/projects/SSL/dino/finetuning/testing

python main_dino.py --arch resnet50 --optimizer sgd --lr 0.03 --weight_decay 1e-4 --weight_decay_end 1e-4 --global_crops_scale 0.14 1 --local_crops_scale 0.05 0.14 --data_path /sddata/projects/Cervical_Cancer_Projects/data/full_dataset --output_dir /sddata/projects/SSL/dino/finetuning/full_dataset_duke_liger_itoju_5StLowQual_w_full_datset_norms

python main_dino.py --arch densenet121 --optimizer sgd --lr 0.03 --weight_decay 1e-4 --weight_decay_end 1e-4 --global_crops_scale 0.14 1 --local_crops_scale 0.05 0.14 --data_path /sddata/projects/Cervical_Cancer_Projects/data/full_dataset --output_dir /sddata/projects/SSL/dino/finetuning/densenet_121_full_dataset_duke_liger_itoju_5StLowQual_w_full_dataset_norms

python main_dino.py --arch dino_resnet50 --optimizer sgd --lr 0.03 --weight_decay 1e-4 --weight_decay_end 1e-4 --global_crops_scale 0.14 1 --local_crops_scale 0.05 0.14 --data_path /sddata/data/retina_datasets/diabetic_retinopathy_detection/train --output_dir /sddata/projects/SSL/dino/finetuning/dia_ret_imagenet_norms_correct


Testing!
python main_dino.py --arch dino_resnet50 --optimizer sgd --lr 0.03 --weight_decay 1e-4 --weight_decay_end 1e-4 --global_crops_scale 0.14 1 --local_crops_scale 0.05 0.14 --data_path /sddata/data/retina_datasets/diabetic_retinopathy_detection/train --output_dir /sddata/projects/SSL/dino/finetuning/testing

Redoing cervix training:

Note, adding CUDA_VISIBLE_DEVICES=1 before the command line forces it on whichever cuda we want.

python main_dino.py --arch dino_resnet50 --optimizer sgd --lr 0.03 --weight_decay 1e-4 --weight_decay_end 1e-4 --global_crops_scale 0.14 1 --local_crops_scale 0.05 0.14 --data_path /sddata/projects/Cervical_Cancer_Projects/data/full_dataset --output_dir /sddata/projects/SSL/dino/finetuning/corrected_resnet50_imagenet_full_dataset_finetuned_imagenet_norms

python main_dino.py --arch dino_resnet50 --optimizer sgd --lr 0.03 --weight_decay 1e-4 --weight_decay_end 1e-4 --global_crops_scale 0.14 1 --local_crops_scale 0.05 0.14 --data_path /sddata/projects/Cervical_Cancer_Projects/data/full_dataset --output_dir /sddata/projects/SSL/dino/finetuning/corrected_resnet50_imagenet_full_dataset_finetuned_full_dataset_norms

Densenet121:

CUDA_VISIBLE_DEVICES=1 python main_dino.py --arch densenet121_pt --optimizer sgd --lr 0.03 --weight_decay 1e-4 --weight_decay_end 1e-4 --global_crops_scale 0.14 1 --local_crops_scale 0.05 0.14 --data_path /sddata/projects/Cervical_Cancer_Projects/data/full_dataset --output_dir /sddata/projects/SSL/dino/finetuning/Cervix/densenet121_full_dataset_duke_liger_itoju_5StLowQual_w_imagenet_norms


Attempting to continue from a checkpoint!:

CUDA_VISIBLE_DEVICES=1 python main_dino.py --arch densenet121_pt --optimizer sgd --lr 0.03 --weight_decay 1e-4 --weight_decay_end 1e-4 --global_crops_scale 0.14 1 --local_crops_scale 0.05 0.14 --data_path /sddata/projects/Cervical_Cancer_Projects/data/full_dataset --output_dir /sddata/projects/SSL/dino/finetuning/Cervix/densenet121_full_dataset_duke_liger_itoju_5StLowQual_w_imagenet_norms_cont

CUDA_VISIBLE_DEVICES=1 python main_dino.py --arch dino_resnet50 --optimizer sgd --lr 0.03 --weight_decay 1e-4 --weight_decay_end 1e-4 --global_crops_scale 0.14 1 --local_crops_scale 0.05 0.14 --data_path /sddata/projects/Cervical_Cancer_Projects/data/full_dataset --output_dir /sddata/projects/SSL/dino/finetuning/Cervix/corrected_resnet50_imagenet_full_dataset_finetuned_full_dataset_norms_cont


CUDA_VISIBLE_DEVICES=1 python main_dino.py --arch dino_resnet50 --optimizer sgd --lr 0.03 --weight_decay 1e-4 --weight_decay_end 1e-4 --global_crops_scale 0.14 1 --local_crops_scale 0.05 0.14 --data_path /sddata/projects/Cervical_Cancer_Projects/data/full_dataset --output_dir /sddata/projects/SSL/dino/finetuning/Cervix/corrected_resnet50_imagenet_full_dataset_finetuned_imagenet_norms_cont



CUDA_VISIBLE_DEVICES=1 python main_dino.py --arch densenet121_pt --optimizer sgd --lr 0.03 --weight_decay 1e-4 --weight_decay_end 1e-4 --global_crops_scale 0.14 1 --local_crops_scale 0.05 0.14 --data_path /sddata/projects/Cervical_Cancer_Projects/data/full_dataset --output_dir /sddata/projects/SSL/dino/finetuning/Cervix/densenet121_full_dataset_duke_liger_itoju_5StLowQual_w_imagenet_norms_monai_version

ViT

CUDA_VISIBLE_DEVICES=1 python main_dino.py --arch dino_vitb16 --optimizer sgd --lr 0.03 --weight_decay 1e-4 --weight_decay_end 1e-4 --global_crops_scale 0.14 1 --local_crops_scale 0.05 0.14 --data_path /sddata/data/retina_datasets/diabetic_retinopathy_detection/train --output_dir /sddata/projects/SSL/dino/finetuning/Diabetic_Retinopathy/vitb8_imagenet_and_dia_ret_pretext_dia_ret_full_dataset_norms




Rerunning things with Dr with train_val only normalizations
CUDA_VISIBLE_DEVICES=1 python main_dino.py --arch dino_resnet50 --optimizer sgd --lr 0.03 --weight_decay 1e-4 --weight_decay_end 1e-4 --global_crops_scale 0.14 1 --local_crops_scale 0.05 0.14 --data_path /sddata/data/retina_datasets/diabetic_retinopathy_detection/data_with_labels_train_val --output_dir /sddata/projects/SSL/dino/finetuning/Diabetic_Retinopathy_train_val_only/ResNet50_train_val_only_train_val_norms

python main_dino.py --arch dino_vitb16 --optimizer sgd --lr 0.03 --weight_decay 1e-4 --weight_decay_end 1e-4 --global_crops_scale 0.14 1 --local_crops_scale 0.05 0.14 --data_path /sddata/data/retina_datasets/diabetic_retinopathy_detection/data_with_labels_train_val --output_dir /sddata/projects/SSL/dino/finetuning/Diabetic_Retinopathy_train_val_only/ViTb16_train_val_only_train_val_norms


Cervix with only train/val norms and such

python main_dino.py --arch dino_resnet50 --optimizer sgd --lr 0.03 --weight_decay 1e-4 --weight_decay_end 1e-4 --global_crops_scale 0.14 1 --local_crops_scale 0.05 0.14 --data_path /sddata/projects/Cervical_Cancer_Projects/data/full_dataset_duke_liger_train_val --output_dir /sddata/projects/SSL/dino/finetuning/Cervix_train_val_only/ResNet50_train_val_only_train_val_norms

python main_dino.py --arch dino_vitb16 --optimizer sgd --lr 0.03 --weight_decay 1e-4 --weight_decay_end 1e-4 --global_crops_scale 0.14 1 --local_crops_scale 0.05 0.14 --data_path /sddata/projects/Cervical_Cancer_Projects/data/full_dataset_duke_liger_train_val --output_dir /sddata/projects/SSL/dino/finetuning/Cervix_train_val_only/ViTb16_train_val_only_train_val_norms

Cervix with all but testing and train/val norms

python main_dino.py --arch dino_resnet50 --optimizer sgd --lr 0.03 --weight_decay 1e-4 --weight_decay_end 1e-4 --global_crops_scale 0.14 1 --local_crops_scale 0.05 0.14 --data_path /sddata/projects/Cervical_Cancer_Projects/data/full_dataset --output_dir //sddata/projects/SSL/dino/finetuning/Cervix_all_but_testing/ReNet50_all_but_testing_all_but_testing_norms

python main_dino.py --arch dino_vitb16 --optimizer sgd --lr 0.03 --weight_decay 1e-4 --weight_decay_end 1e-4 --global_crops_scale 0.14 1 --local_crops_scale 0.05 0.14 --data_path /sddata/projects/Cervical_Cancer_Projects/data/full_dataset --output_dir //sddata/projects/SSL/dino/finetuning/Cervix_all_but_testing/ViTb16_all_but_testing_all_but_testing_norms

python main_dino.py --arch densenet121_pt --optimizer sgd --lr 0.03 --weight_decay 1e-4 --weight_decay_end 1e-4 --global_crops_scale 0.14 1 --local_crops_scale 0.05 0.14 --data_path /sddata/projects/Cervical_Cancer_Projects/data/full_dataset --output_dir /sddata/projects/SSL/dino/finetuning/Cervix_all_but_testing/DenseNet121_all_but_testing_all_but_testing_norms



# Cervix
python main_dino.py --arch dino_resnet50 --optimizer sgd --lr 0.03 --weight_decay 1e-4 --weight_decay_end 1e-4 --global_crops_scale 0.14 1 --local_crops_scale 0.05 0.14 --data_path /sddata/projects/SSL/csvs/datasets/all_cervix_images_up_to_03262024_no_seed_test_test2.csv --output_dir /sddata/projects/SSL/dino/finetuning/newest_runs_04062024/cervix/ResNet50

python main_dino.py --arch densenet121_pt --optimizer sgd --lr 0.03 --weight_decay 1e-4 --weight_decay_end 1e-4 --global_crops_scale 0.14 1 --local_crops_scale 0.05 0.14 --data_path /sddata/projects/SSL/csvs/datasets/all_cervix_images_up_to_03262024_no_seed_test_test2.csv --output_dir /sddata/projects/SSL/dino/finetuning/newest_runs_04062024/cervix/DenseNet121


# DR
python main_dino.py --arch dino_resnet50 --optimizer sgd --lr 0.03 --weight_decay 1e-4 --weight_decay_end 1e-4 --global_crops_scale 0.14 1 --local_crops_scale 0.05 0.14 --data_path /sddata/projects/SSL/csvs/datasets/all_dr_images_no_test.csv --output_dir /sddata/projects/SSL/dino/finetuning/newest_runs_04062024/DR/ResNet50

python main_dino.py --arch dino_vitb16 --optimizer sgd --lr 0.03 --weight_decay 1e-4 --weight_decay_end 1e-4 --global_crops_scale 0.14 1 --local_crops_scale 0.05 0.14 --data_path /sddata/projects/SSL/csvs/datasets/all_dr_images_no_test.csv --output_dir /sddata/projects/SSL/dino/finetuning/newest_runs_04062024/DR/Vitb16




## Loading in the fine-tuned model

In [None]:
# Getting the fine-tuned model
ft_model_path = '/sddata/projects/SSL/dino/finetuning/full_dataset_duke_liger_itoju_5StLowQual/checkpoint.pth'
checkpoint = torch.load(ft_model_path, map_location='cpu')

print(checkpoint)

{'student': OrderedDict([('module.backbone.conv1.weight', tensor([[[[-2.9915e-02,  2.5150e-02,  7.3270e-03,  ...,  1.0820e-03,
           -4.9218e-02, -2.5151e-03],
          [ 7.0777e-03,  8.8909e-03, -1.3589e-02,  ..., -1.2779e-01,
           -9.8306e-02, -6.6342e-02],
          [-4.6427e-03, -6.5781e-03, -9.8743e-03,  ..., -1.2719e-01,
           -1.3104e-01, -6.6032e-02],
          ...,
          [ 2.5559e-02,  3.1275e-02, -7.2820e-03,  ..., -4.5503e-02,
            1.3178e-02, -2.4226e-02],
          [ 2.6236e-02, -1.7655e-03, -2.0638e-02,  ..., -9.2696e-03,
            9.1214e-04,  2.1262e-02],
          [ 4.8517e-02,  6.1548e-02,  3.5557e-02,  ...,  2.4248e-02,
            7.9648e-02,  4.8634e-02]],

         [[ 2.2679e-02,  4.3330e-02, -6.3474e-03,  ..., -1.9881e-02,
           -3.4238e-02,  3.0735e-02],
          [-9.2467e-03,  4.1746e-02,  6.4840e-03,  ..., -5.1796e-02,
           -9.9418e-02, -9.5535e-03],
          [ 1.4053e-02,  1.1031e-02, -1.1679e-02,  ..., -8.6651e-02,


In [None]:
# Getting the student part

student_state_dict = checkpoint['student']  # Assuming save_dict contains the student's state_dict

In [None]:
# Getting the backbone only (as the student has a DINO head) and replacing the old weights with the fine-tuned backbone

import collections
import torch

# Define a ResNet-50 model without the custom head
backbone_model = torch.hub.load('facebookresearch/dino:main', 'dino_resnet50')
ft_model_path = '/sddata/projects/SSL/dino/finetuning/full_dataset_duke_liger_itoju_5StLowQual/checkpoint.pth'
checkpoint = torch.load(ft_model_path, map_location='cpu')

# Load the student's state_dict into the ResNet-50 model while matching keys
student_state_dict = checkpoint['student']  # Assuming save_dict contains the student's state_dict

# Initialize an empty OrderedDict to hold the backbone state_dict
backbone_finetuned_state_dict = collections.OrderedDict()

# Iterate through the items in student_state_dict and select keys starting with 'module.backbone'
for key, value in student_state_dict.items():
    if key.startswith('module.backbone.'):
        # Remove the 'module.backbone.' prefix to get the corresponding key in the backbone
        new_key = key[len('module.backbone.'):]
        backbone_finetuned_state_dict[new_key] = value

backbone_model.load_state_dict(backbone_finetuned_state_dict, strict=False)


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


<All keys matched successfully>

# Switching out models
## This is ResNet50 to DenseNet121

ResNet50
      (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
      )
    )
    (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
    (fc): Identity()
    (head): Identity()
  )
  (head): DINOHead(
    (mlp): Sequential(
      (0): Linear(in_features=2048, out_features=2048, bias=True)
      (1): GELU(approximate='none')
      (2): Linear(in_features=2048, out_features=2048, bias=True)
      (3): GELU(approximate='none')
      (4): Linear(in_features=2048, out_features=256, bias=True)
    )
    (last_layer): Linear(in_features=256, out_features=65536, bias=False)
  )
)

DenseNet121
          (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        )
      )
      (norm5): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (classifier): Linear(in_features=1024, out_features=1000, bias=True)
    (fc): Identity()
    (head): Identity()
  )
  (head): DINOHead(
    (mlp): Sequential(
      (0): Linear(in_features=1024, out_features=2048, bias=True)
      (1): GELU(approximate='none')
      (2): Linear(in_features=2048, out_features=2048, bias=True)
      (3): GELU(approximate='none')
      (4): Linear(in_features=2048, out_features=256, bias=True)
    )
    (last_layer): Linear(in_features=256, out_features=65536, bias=False)
  )
)

Note that the DenseNet121 has an extra linear layer. So, we need to include that.

Use this:

densenet = getattr(monai.networks.nets, architecture)
model = densenet(spatial_dims=2,
                    in_channels=3,
                    out_channels=output_channels,
                    dropout_prob=float(dropout_rate),
                    pretrained=True)

In [4]:
densenet = getattr(monai.networks.nets, 'densenet121')
model = densenet(spatial_dims=2,
                    in_channels=3,
                    out_channels=3,
                    dropout_prob=float(0.1),
                    pretrained=True)

In [5]:
model

DenseNet121(
  (features): Sequential(
    (conv0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (norm0): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu0): ReLU(inplace=True)
    (pool0): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (denseblock1): _DenseBlock(
      (denselayer1): _DenseLayer(
        (layers): Sequential(
          (norm1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu1): ReLU(inplace=True)
          (conv1): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu2): ReLU(inplace=True)
          (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (dropout): Dropout2d(p=0.1, inplace=False)
        )
      )
      (denselayer2): _DenseLayer(
  

In [12]:
model.class_layers = nn.Identity()
model

DenseNet121(
  (features): Sequential(
    (conv0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (norm0): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu0): ReLU(inplace=True)
    (pool0): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (denseblock1): _DenseBlock(
      (denselayer1): _DenseLayer(
        (layers): Sequential(
          (norm1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu1): ReLU(inplace=True)
          (conv1): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu2): ReLU(inplace=True)
          (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        )
      )
      (denselayer2): _DenseLayer(
        (layers): Sequential(
          (norm1): BatchN

In [13]:
for name, param in model.named_parameters():
        param_name = name
        param_data = param.data
        print(f"\nParameter: {param_name}")
        print(param_data)
        break


Parameter: features.conv0.weight
tensor([[[[ 7.8276e-02,  1.4949e-01,  1.6611e-01,  ...,  1.7676e-01,
            1.6588e-01,  1.4101e-01],
          [ 1.7546e-01,  2.4408e-01,  2.5000e-01,  ...,  2.7452e-01,
            2.5245e-01,  2.2199e-01],
          [ 1.2331e-01,  1.6441e-01,  1.4922e-01,  ...,  1.6301e-01,
            1.6191e-01,  1.4061e-01],
          ...,
          [-1.0461e-01, -1.2065e-01, -1.1969e-01,  ..., -1.1355e-01,
           -1.1181e-01, -1.1653e-01],
          [-1.4747e-01, -1.8658e-01, -1.8272e-01,  ..., -2.1694e-01,
           -2.0213e-01, -1.8302e-01],
          [-2.0729e-01, -2.7118e-01, -2.8157e-01,  ..., -2.8711e-01,
           -2.4883e-01, -2.2605e-01]],

         [[ 1.6418e-01,  2.4814e-01,  2.6538e-01,  ...,  2.7358e-01,
            2.5693e-01,  2.2483e-01],
          [ 2.4226e-01,  3.2158e-01,  3.2346e-01,  ...,  3.4569e-01,
            3.1805e-01,  2.8128e-01],
          [ 1.5825e-01,  2.0253e-01,  1.8364e-01,  ...,  1.9484e-01,
            1.8966e-01, 

# Reloading in a Student and Teacher

In [8]:
import torch
import collections

checkpoint_path = '/sddata/projects/SSL/dino/finetuning/Cervix/corrected_resnet50_imagenet_full_dataset_finetuned_full_dataset_norms/checkpoint.pth'
checkpoint = torch.load(checkpoint_path, map_location='cpu')

# Define ResNet-50 models without the custom head
backbone_student_model = torch.hub.load('facebookresearch/dino:main', 'dino_resnet50')
backbone_teacher_model = torch.hub.load('facebookresearch/dino:main', 'dino_resnet50')

# Load the student's state_dict into the ResNet-50 model while matching keys
student_state_dict = checkpoint['student']
teacher_state_dict = checkpoint['teacher']

# Initialize empty OrderedDicts to hold the backbone state_dicts
backbone_finetuned_state_dict_student = collections.OrderedDict()
backbone_finetuned_state_dict_teacher = collections.OrderedDict()

# Iterate through the items in student_state_dict and select keys starting with 'module.backbone'
for key, value in student_state_dict.items():
    if key.startswith('module.backbone.'):
        # Remove the 'module.backbone.' prefix to get the corresponding key in the backbone
        new_key = key[len('module.backbone.'):]
        backbone_finetuned_state_dict_student[new_key] = value

# Similarly, iterate through the items in teacher_state_dict
for key, value in teacher_state_dict.items():
    if key.startswith('module.backbone.'):
        new_key = key[len('module.backbone.'):]
        backbone_finetuned_state_dict_teacher[new_key] = value

# Load the state_dicts into the respective models
backbone_student_model.load_state_dict(backbone_finetuned_state_dict_student, strict=False)
backbone_teacher_model.load_state_dict(backbone_finetuned_state_dict_teacher, strict=False)

def load_student_teacher_optimizer(checkpoint, optimizier):

    checkpoint = torch.load(checkpoint_path, map_location='cpu')

    # Define ResNet-50 models without the custom head
    backbone_student_model = torch.hub.load('facebookresearch/dino:main', 'dino_resnet50')
    backbone_teacher_model = torch.hub.load('facebookresearch/dino:main', 'dino_resnet50')

    # Load the student's state_dict into the ResNet-50 model while matching keys
    student_state_dict = checkpoint['student']
    teacher_state_dict = checkpoint['teacher']

    # Initialize empty OrderedDicts to hold the backbone state_dicts
    backbone_finetuned_state_dict_student = collections.OrderedDict()
    backbone_finetuned_state_dict_teacher = collections.OrderedDict()

    # Iterate through the items in student_state_dict and select keys starting with 'module.backbone'
    for key, value in student_state_dict.items():
        if key.startswith('module.backbone.'):
            # Remove the 'module.backbone.' prefix to get the corresponding key in the backbone
            new_key = key[len('module.backbone.'):]
            backbone_finetuned_state_dict_student[new_key] = value

    # Similarly, iterate through the items in teacher_state_dict
    for key, value in teacher_state_dict.items():
        if key.startswith('module.backbone.'):
            new_key = key[len('module.backbone.'):]
            backbone_finetuned_state_dict_teacher[new_key] = value

    # Load the state_dicts into the respective models
    backbone_student_model.load_state_dict(backbone_finetuned_state_dict_student, strict=False)
    backbone_teacher_model.load_state_dict(backbone_finetuned_state_dict_teacher, strict=False)

    optimizier.load_state_dict(checkpoint['optimizer'])



Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main
Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


In [11]:
for name, param in backbone_student_model.named_parameters():
    param_name = name
    param_data = param.data
    print(f"\nPre-trained Parameter: {param_name}")
    print(param_data[0][0][0])
    break


Pre-trained Parameter: conv1.weight
tensor([ 0.0136, -0.0287,  0.0015, -0.0048, -0.0331, -0.0070, -0.1222])


Note, the lines:

# ============ optionally resume training ... ============
    to_restore = {"epoch": 57}
    utils.restart_from_checkpoint(
        os.path.join(args.output_dir, "checkpoint.pth"),
        run_variables=to_restore,
        student=student,
        teacher=teacher,
        optimizer=optimizer,
        fp16_scaler=fp16_scaler,
        dino_loss=dino_loss,
    )

Allow me to do this in the script

# Adding Dropout Layers to DINO ResNet50

In [6]:
import collections
import torch

def create_dino_resnet50_model():
    # Define a ResNet-50 model without the custom head
    backbone_model = torch.hub.load('facebookresearch/dino:main', 'dino_resnet50')
    ft_model_path = '/sddata/projects/SSL/dino/finetuning/Cervix/corrected_resnet50_imagenet_full_dataset_finetuned_imagenet_norms_cont/checkpoint.pth'
    checkpoint = torch.load(ft_model_path, map_location='cpu')

    # Load the student's state_dict into the ResNet-50 model while matching keys
    student_state_dict = checkpoint['student']  # Assuming save_dict contains the student's state_dict

    # Initialize an empty OrderedDict to hold the backbone state_dict
    backbone_finetuned_state_dict = collections.OrderedDict()

    # Iterate through the items in student_state_dict and select keys starting with 'module.backbone'
    for key, value in student_state_dict.items():
        if key.startswith('module.backbone.'):
            # Remove the 'module.backbone.' prefix to get the corresponding key in the backbone
            new_key = key[len('module.backbone.'):]
            backbone_finetuned_state_dict[new_key] = value

    backbone_model.load_state_dict(backbone_finetuned_state_dict, strict=False)

    model = backbone_model

    return model

In [7]:
model = create_dino_resnet50_model()
model

Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 

In [8]:
model0 = create_dino_resnet50_model()

import torch.nn as nn

def add_dropout_after_bottlenecksv0(model, dropout_rate=0.5):
    # Add dropout after each block in layer3 and layer4
    for block in [model.layer1, model.layer2, model.layer3, model.layer4]:
        for bottleneck_block in block:
            bottleneck_block.add_module("dropout", nn.Dropout2d(p=dropout_rate))

add_dropout_after_bottlenecksv0(model0, 0.1)
model0

Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 

In [None]:
ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (dropout): Dropout2d(p=0.1, inplace=False)
    )
    (1): Bottleneck(
      (conv1): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (dropout): Dropout2d(p=0.1, inplace=False)
    )
    (2): Bottleneck(
      (conv1): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (dropout): Dropout2d(p=0.1, inplace=False)
    )
  )

In [9]:
model1 = create_dino_resnet50_model()
import torch.nn as nn
from collections import OrderedDict

def add_dropout_after_bottlenecksv1(model, dropout_rate=0.5):
    # Add dropout after each block in layer3 and layer4
    for block in [model.layer1, model.layer2, model.layer3, model.layer4]:
        for bottleneck_block in block:
            # Find the Sequential block at the end of each bottleneck
            sequential_block = None
            for module in bottleneck_block.children():
                if isinstance(module, nn.Sequential):
                    sequential_block = module
                    break

            if sequential_block is not None:
                # Add a unique identifier to each dropout layer
                dropout_layer = nn.Dropout2d(p=dropout_rate)
                layer_name_with_dropout = f"dropout"

                # Add dropout layer to the existing Sequential block
                sequential_block.add_module(layer_name_with_dropout, dropout_layer)

add_dropout_after_bottlenecksv1(model1, 0.1)
model1

Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 

In [None]:
ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (dropout): Dropout2d(p=0.1, inplace=False)
      )
    )
    (1): Bottleneck(
      (conv1): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
    )
    (2): Bottleneck(
      (conv1): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
    )
  )

In [12]:
model2 = create_dino_resnet50_model()
import torch.nn as nn
from collections import OrderedDict

def add_dropout_after_bottlenecksv2(model, dropout_rate=0.5):
    # Add dropout after each block in layer3 and layer4
    for block in [model.layer1, model.layer2, model.layer3, model.layer4]:
        for i, bottleneck_block in enumerate(block):
            # Find the Sequential block at the end of each bottleneck
            sequential_block = None
            relu_block = None
            for module in bottleneck_block.children():
                if isinstance(module, nn.Sequential):
                    sequential_block = module
                elif isinstance(module, nn.ReLU):
                    relu_block = module

            if sequential_block is not None:
                # Add dropout layer after the Sequential block
                dropout_layer = nn.Dropout2d(p=dropout_rate)
                layer_name_with_dropout = f"dropout_{i}_seq"
                sequential_block.add_module(layer_name_with_dropout, dropout_layer)

            if relu_block is not None:
                if i == 0:
                    continue
                else:
                    # Add dropout layer after the ReLU block
                    dropout_layer = nn.Dropout2d(p=dropout_rate)
                    layer_name_with_dropout = f"dropout_{i}_relu"
                    bottleneck_block.add_module(layer_name_with_dropout, dropout_layer)

add_dropout_after_bottlenecksv2(model2, 0.1)
model2

Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 

In [None]:
ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (dropout_0_seq): Dropout2d(p=0.1, inplace=False)
      )
    )
    (1): Bottleneck(
      (conv1): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (dropout_1_relu): Dropout2d(p=0.1, inplace=False)
    )
    (2): Bottleneck(
      (conv1): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (dropout_2_relu): Dropout2d(p=0.1, inplace=False)
    )
  )

In [19]:
from transformers import AutoFeatureExtractor, ResNetModel, AutoImageProcessor
from PIL import Image
img = '/sddata/projects/Fundus_Segmentation/Organized_Datasets/Drishti_Organized/Uncropped/Images/drishtiGS_002.png'
feature_extractor = AutoImageProcessor.from_pretrained('facebook/vit-mae-huge') # Hardcoded here so we get ImageNet features
img = Image.open(img)
tensors = feature_extractor(img, return_tensors = 'pt')
input = tensors['pixel_values']

model0.train()
model1.train()
model2.train()

output0 = model0(input)
print(output0)
output1 = model1(input)
print(output1)
output2 = model2(input)
print(output2)

tensor([[0.0754, 0.1329, 0.0126,  ..., 0.1338, 0.0897, 0.0424]],
       grad_fn=<ReshapeAliasBackward0>)
tensor([[0.1056, 0.2022, 0.0115,  ..., 0.0690, 0.1108, 0.0659]],
       grad_fn=<ReshapeAliasBackward0>)
tensor([[0.0622, 0.1172, 0.0075,  ..., 0.1372, 0.0824, 0.0567]],
       grad_fn=<ReshapeAliasBackward0>)


In [20]:
model2.eval()
model2(input)

tensor([[0.1143, 0.0000, 0.2629,  ..., 0.0000, 0.0167, 0.1099]],
       grad_fn=<ReshapeAliasBackward0>)

# Figuring out Densenet121

In [1]:
import torchvision
import torch
import collections
densenet = torchvision.models.densenet121(pretrained=True, drop_rate = 0.1)
checkpoint = torch.load('/sddata/projects/SSL/dino/finetuning/Cervix/densenet121_full_dataset_duke_liger_itoju_5StLowQual_w_imagenet_norms_cont/checkpoint.pth', map_location='cpu')

# Load the student's state_dict into the ResNet-50 model while matching keys
student_state_dict = checkpoint['student']  # Assuming save_dict contains the student's state_dict

# Initialize an empty OrderedDict to hold the backbone state_dict
backbone_finetuned_state_dict = collections.OrderedDict()

# Iterate through the items in student_state_dict and select keys starting with 'module.backbone'
for key, value in student_state_dict.items():
    if key.startswith('module.backbone.'):
        # Remove the 'module.backbone.' prefix to get the corresponding key in the backbone
        new_key = key[len('module.backbone.'):]
        backbone_finetuned_state_dict[new_key] = value

densenet.load_state_dict(backbone_finetuned_state_dict, strict=False)



_IncompatibleKeys(missing_keys=['classifier.weight', 'classifier.bias'], unexpected_keys=[])

In [2]:
densenet

DenseNet(
  (features): Sequential(
    (conv0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (norm0): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu0): ReLU(inplace=True)
    (pool0): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (denseblock1): _DenseBlock(
      (denselayer1): _DenseLayer(
        (norm1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu1): ReLU(inplace=True)
        (conv1): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu2): ReLU(inplace=True)
        (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      )
      (denselayer2): _DenseLayer(
        (norm1): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu

In [4]:
import torch.nn as nn
linear_classifier = nn.Sequential(
    nn.Linear(2048, 3)
    )

densenet.classifier = linear_classifier
densenet

DenseNet(
  (features): Sequential(
    (conv0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (norm0): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu0): ReLU(inplace=True)
    (pool0): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (denseblock1): _DenseBlock(
      (denselayer1): _DenseLayer(
        (norm1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu1): ReLU(inplace=True)
        (conv1): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu2): ReLU(inplace=True)
        (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      )
      (denselayer2): _DenseLayer(
        (norm1): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu

In [5]:
from functools import reduce
import torch.nn as nn

def add_dropout2d_to_torchvision_densenet(model, dropout):

    # Assuming densenet is your DenseNet model
    all_denselayer_names = []

    # Specify the range for i and j based on the structure of your DenseNet
    num_dense_blocks = 100
    num_dense_layers = 100

    # Iterate through dense blocks and dense layers to find denselayer instances
    for i in range(1, num_dense_blocks + 1):
        for j in range(1, num_dense_layers + 1):
            feature_name = f'features.denseblock{i}.denselayer{j}'
            all_denselayer_names.append(feature_name)

    filtered_denselayer_names = []

    def get_attribute_by_string(model, attribute_string):
        return reduce(getattr, attribute_string.split('.'), model)

    for layer in all_denselayer_names:
        try: 
            really_exists = get_attribute_by_string(model, layer)
            filtered_denselayer_names.append(layer)
        except:
            continue

    for feature_name in filtered_denselayer_names:
        # Use getattr to dynamically access the layer based on the feature name
        denselayer = get_attribute_by_string(model, feature_name)

        # If the conv2_layer is not None, add dropout layer after it
        if denselayer is not None:
            denselayer.add_module("dropout", nn.Dropout2d(p=dropout))


In [6]:
add_dropout2d_to_torchvision_densenet(densenet, 0.1)

In [7]:
densenet

DenseNet(
  (features): Sequential(
    (conv0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (norm0): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu0): ReLU(inplace=True)
    (pool0): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (denseblock1): _DenseBlock(
      (denselayer1): _DenseLayer(
        (norm1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu1): ReLU(inplace=True)
        (conv1): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu2): ReLU(inplace=True)
        (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (dropout): Dropout2d(p=0.1, inplace=False)
      )
      (denselayer2): _DenseLayer(
        (norm1): BatchNorm2d(96, eps=1e-05, momentum=0.1, a

In [8]:
import os
def save_parameter_weights(model, output_path):
    for name, param in model.named_parameters():
        print(f'\nSwitched-out weights')
        param_name = name
        param_data = param.data
        print(f"Parameter: {param_name}")
        print(str(param_data[0][0][0]))
        # Serialize and write the dictionary to the file)
        dict_to_print = {"Param_Name": param_name, "str(param_data[0][0][0])": str(param_data[0][0][0])}
        # Serialize and write the dictionary to the file
        # save_training_args_path = os.path.join(output_path, 'training_records')
        # with open(save_training_args_path, "w") as json_file:
        #     json.dump(dict_to_print, json_file, indent=4)
        break

save_parameter_weights(densenet, '')


Switched-out weights
Parameter: features.conv0.weight
tensor([0.0626, 0.1331, 0.1496, 0.1519, 0.1549, 0.1408, 0.1192])


In [28]:
densenet_torchvision = torchvision.models.densenet121(pretrained=True)
save_parameter_weights(densenet_torchvision, '')


Switched-out weights
Parameter: features.conv0.weight
tensor([0.0783, 0.1495, 0.1661, 0.1663, 0.1768, 0.1659, 0.1410])




# Checking out some weights

In [31]:
import os
import json
from transformers import ViTFeatureExtractor, ViTMAEForPreTraining, ViTForImageClassification, Dinov2ForImageClassification, AutoImageProcessor


def save_parameter_weights(model, output_path):
    for name, param in model.named_parameters():
        print(f'\nSwitched-out weights')
        param_name = name
        param_data = param.data
        print(f"Parameter: {param_name}")
        print(str(param_data[0][0][0]))
        # Serialize and write the dictionary to the file)
        dict_to_print = {"Param_Name": param_name, "str(param_data[0][0][0])": str(param_data[0][0][0])}
        # Serialize and write the dictionary to the file
        save_training_args_path = os.path.join(output_path, 'training_records')
        with open(save_training_args_path, "w") as json_file:
            json.dump(dict_to_print, json_file, indent=4)
        break

model_clf = ViTForImageClassification.from_pretrained('google/vit-huge-patch14-224-in21k', num_labels=3)
for name, param in model_clf.named_parameters():
    print(f'\nOriginal classification encoder')
    print(f"Parameter: {name}")
    print(param.data)
    break



Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-huge-patch14-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.



Original classification encoder
Parameter: vit.embeddings.cls_token
tensor([[[ 0.0056,  0.0190, -0.0080,  ..., -0.0285,  0.0882,  0.0180]]])


In [33]:
# Define a ResNet-50 model without the custom head
backbone_model = torch.hub.load('facebookresearch/dino:main', 'dino_resnet50')
ft_model_path = '/sddata/projects/Cervical_Cancer_Projects/models/corrected_resnet50_imagenet_full_dataset_finetuned_full_dataset_norms_cont.pth'
checkpoint = torch.load(ft_model_path, map_location='cpu')

# Load the student's state_dict into the ResNet-50 model while matching keys
student_state_dict = checkpoint['student']  # Assuming save_dict contains the student's state_dict

# Initialize an empty OrderedDict to hold the backbone state_dict
backbone_finetuned_state_dict = collections.OrderedDict()

# Iterate through the items in student_state_dict and select keys starting with 'module.backbone'
for key, value in student_state_dict.items():
    if key.startswith('module.backbone.'):
        # Remove the 'module.backbone.' prefix to get the corresponding key in the backbone
        new_key = key[len('module.backbone.'):]
        backbone_finetuned_state_dict[new_key] = value

backbone_model.load_state_dict(backbone_finetuned_state_dict, strict=False)

for name, param in backbone_model.named_parameters():
    print(f'\nOriginal classification encoder')
    print(f"Parameter: {name}")
    print(param.data)
    break

Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main



Original classification encoder
Parameter: conv1.weight
tensor([[[[ 1.3494e-02, -2.8811e-02,  1.4298e-03,  ..., -3.3190e-02,
           -7.0193e-03, -1.2221e-01],
          [-5.9214e-02, -1.5025e-01, -3.5110e-02,  ..., -9.0357e-02,
           -6.5941e-02, -1.4781e-01],
          [ 3.3794e-02, -7.1403e-02,  8.8189e-02,  ...,  1.3720e-01,
            1.4697e-01,  2.9153e-02],
          ...,
          [-8.8219e-02, -1.5534e-01,  1.9226e-01,  ...,  5.5078e-01,
            2.9518e-01,  3.3559e-02],
          [-2.9023e-02, -9.9358e-02,  2.0682e-01,  ...,  2.1564e-01,
            3.7917e-02, -1.3978e-01],
          [-8.2113e-02, -1.6860e-01,  1.0474e-01,  ..., -3.4716e-02,
           -1.5281e-01, -2.4056e-01]],

         [[ 6.2532e-02,  1.1571e-01, -4.2451e-02,  ..., -3.6021e-02,
            1.9196e-01,  3.7871e-01],
          [-2.7000e-02,  4.2847e-02, -1.1685e-01,  ..., -2.7357e-01,
            4.9243e-02,  2.4108e-01],
          [ 1.1311e-01,  2.8563e-01,  1.2277e-01,  ..., -4.5671e-01,
 

In [34]:
# Define a ResNet-50 model without the custom head
backbone_model = torch.hub.load('facebookresearch/dino:main', 'dino_resnet50')
ft_model_path = '/sddata/projects/Cervical_Cancer_Projects/models/corrected_resnet50_imagenet_full_dataset_finetuned_imagenet_norms_cont.pth'
checkpoint = torch.load(ft_model_path, map_location='cpu')

# Load the student's state_dict into the ResNet-50 model while matching keys
student_state_dict = checkpoint['student']  # Assuming save_dict contains the student's state_dict

# Initialize an empty OrderedDict to hold the backbone state_dict
backbone_finetuned_state_dict = collections.OrderedDict()

# Iterate through the items in student_state_dict and select keys starting with 'module.backbone'
for key, value in student_state_dict.items():
    if key.startswith('module.backbone.'):
        # Remove the 'module.backbone.' prefix to get the corresponding key in the backbone
        new_key = key[len('module.backbone.'):]
        backbone_finetuned_state_dict[new_key] = value

backbone_model.load_state_dict(backbone_finetuned_state_dict, strict=False)

for name, param in backbone_model.named_parameters():
    print(f'\nOriginal classification encoder')
    print(f"Parameter: {name}")
    print(param.data)
    break

Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main



Original classification encoder
Parameter: conv1.weight
tensor([[[[ 2.1252e-02, -3.0409e-02, -1.5197e-03,  ..., -3.1616e-02,
           -2.8713e-03, -1.1525e-01],
          [-4.9247e-02, -1.4749e-01, -3.4378e-02,  ..., -8.9563e-02,
           -6.3063e-02, -1.4103e-01],
          [ 4.2537e-02, -7.0955e-02,  8.6280e-02,  ...,  1.3145e-01,
            1.4443e-01,  3.1683e-02],
          ...,
          [-8.1950e-02, -1.5899e-01,  1.8550e-01,  ...,  5.3970e-01,
            2.8625e-01,  2.9072e-02],
          [-1.9927e-02, -1.0014e-01,  2.0283e-01,  ...,  2.0756e-01,
            3.2204e-02, -1.4134e-01],
          [-7.1623e-02, -1.6682e-01,  1.0345e-01,  ..., -3.8307e-02,
           -1.5551e-01, -2.4142e-01]],

         [[ 6.0570e-02,  1.1602e-01, -4.4283e-02,  ..., -3.4401e-02,
            1.9209e-01,  3.8108e-01],
          [-3.4660e-02,  4.2324e-02, -1.2011e-01,  ..., -2.7794e-01,
            4.3516e-02,  2.3820e-01],
          [ 1.0485e-01,  2.8302e-01,  1.1765e-01,  ..., -4.6737e-01,
 

In [1]:
import monai
from torch import nn
monai_densenet = getattr(monai.networks.nets, 'densenet121')
student_monai = monai_densenet(spatial_dims=2,
            in_channels=3,
            out_channels=3,
            dropout_prob=float(0),
            pretrained=True)
teacher = monai_densenet(spatial_dims=2,
            in_channels=3,
            out_channels=3,
            dropout_prob=float(0),
            pretrained=True)
# sys.exit()
embed_dim = 1024 # Hardcoding it here
# student_monai.class_layers = nn.Flatten(start_dim=1, end_dim=-1)

'''
Add student.class_layers = Sequential(
    (relu): ReLU(inplace=True)
    (pool): AdaptiveAvgPool2d(output_size=1)
    (flatten): Flatten(start_dim=1, end_dim=-1)
Then add student.classifier = nn.Identity()

Or student.classifier = nn.Linear(1024, 3)
'''
teacher.class_layers = nn.Identity()

student_monai
# for name, param in student_monai.named_parameters():
#     print(f"Layer: {name}, Parameter Shape: {param.shape}")

DenseNet121(
  (features): Sequential(
    (conv0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (norm0): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu0): ReLU(inplace=True)
    (pool0): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (denseblock1): _DenseBlock(
      (denselayer1): _DenseLayer(
        (layers): Sequential(
          (norm1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu1): ReLU(inplace=True)
          (conv1): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu2): ReLU(inplace=True)
          (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        )
      )
      (denselayer2): _DenseLayer(
        (layers): Sequential(
          (norm1): BatchN

In [6]:
student_monai.class_layers.out = nn.Identity()
student_monai

DenseNet121(
  (features): Sequential(
    (conv0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (norm0): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu0): ReLU(inplace=True)
    (pool0): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (denseblock1): _DenseBlock(
      (denselayer1): _DenseLayer(
        (layers): Sequential(
          (norm1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu1): ReLU(inplace=True)
          (conv1): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu2): ReLU(inplace=True)
          (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        )
      )
      (denselayer2): _DenseLayer(
        (layers): Sequential(
          (norm1): BatchN

In [38]:
import torchvision
student_tv = torchvision.models.densenet121(pretrained=True)
teacher = torchvision.models.densenet121(pretrained=True)
embed_dim = 1024 # Hardcoding it here
student_tv.classifier = nn.Identity()  # Add the new 'classifier' head
teacher.classifier = nn.Identity()  # Add the new 'classifier' head
# for name, param in student.named_parameters():
#     print(f"Layer: {name}, Parameter Shape: {param.shape}")
student_tv



DenseNet(
  (features): Sequential(
    (conv0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (norm0): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu0): ReLU(inplace=True)
    (pool0): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (denseblock1): _DenseBlock(
      (denselayer1): _DenseLayer(
        (norm1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu1): ReLU(inplace=True)
        (conv1): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu2): ReLU(inplace=True)
        (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      )
      (denselayer2): _DenseLayer(
        (norm1): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu

In [51]:
img = '/sddata/projects/Fundus_Segmentation/Organized_Datasets/Drishti_Organized/Uncropped/Images/drishtiGS_002.png'
feature_extractor = AutoImageProcessor.from_pretrained('facebook/vit-mae-huge') # Hardcoded here so we get ImageNet features
img = Image.open(img)
tensors = feature_extractor(img, return_tensors = 'pt')
input = tensors['pixel_values']
print(student_monai(input).shape)
print(student_tv(input).shape)

torch.Size([1, 50176])
torch.Size([1, 1024])


# Looking at ViT

In [5]:
# vits16 = torch.hub.load('facebookresearch/dino:main', 'dino_vits16')
# vits8 = torch.hub.load('facebookresearch/dino:main', 'dino_vits8')
# vitb16 = torch.hub.load('facebookresearch/dino:main', 'dino_vitb16')
vitb8 = torch.hub.load('facebookresearch/dino:main', 'dino_vitb8').to('cuda:0')

Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


In [3]:
vitb8

VisionTransformer(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(3, 768, kernel_size=(8, 8), stride=(8, 8))
  )
  (pos_drop): Dropout(p=0.0, inplace=False)
  (blocks): ModuleList(
    (0-11): 12 x Block(
      (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=768, out_features=2304, bias=True)
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=768, out_features=768, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (drop_path): Identity()
      (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=768, out_features=3072, bias=True)
        (act): GELU(approximate='none')
        (fc2): Linear(in_features=3072, out_features=768, bias=True)
        (drop): Dropout(p=0.0, inplace=False)
      )
    )
  )
  (norm): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
  (head): Identity()
)

In [4]:
img = '/sddata/projects/Fundus_Segmentation/Organized_Datasets/Drishti_Organized/Uncropped/Images/drishtiGS_002.png'
feature_extractor = AutoImageProcessor.from_pretrained('facebook/vit-mae-huge') # Hardcoded here so we get ImageNet features
img = Image.open(img)
tensors = feature_extractor(img, return_tensors = 'pt')
input = tensors['pixel_values']
print(vitb8(input).shape)
print(vitb8(input).shape)

torch.Size([1, 768])
torch.Size([1, 768])


In [17]:
# Print the values of the parameters for the first few layers to confirm this is correct
for name, param in vitb8.named_parameters():
        param_name = name
        param_data = param.data
        print(f"\nPre-trained Parameter: {param_name}")
        print(param_data[0][0][0])
        break


Pre-trained Parameter: cls_token
tensor(-0.0104)
