In [96]:
import numpy as np
import sys
import torch
import os
from torchvision.models import resnet50, ResNet50_Weights
from transformers import ViTForImageClassification
from transformers import ViTModel
import torch.nn as nn
import timm
import pathlib

temp = pathlib.PosixPath

from ssl_library.src.pkg.embedder import Embedder
from ssl_library.src.pkg.wrappers import ViTWrapper, Wrapper
# from ssl_library.src.models.encoders.vision_transformer import vit_tiny

from local_python.local_utils import print_parameters, load_headless_model

In [97]:
seed = 19
all_checkpoint_paths = []

In [98]:
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)

# NOTE: ssl_library is not windows compatible in itself
# https://stackoverflow.com/questions/57286486/i-cant-load-my-model-because-i-cant-put-a-posixpath

pathlib.PosixPath = pathlib.WindowsPath

In [99]:
def print_checkpoint_keys(model_dir, n=0, depth=2):
    checkpoint = torch.load(model_dir, map_location=torch.device("cpu"))
    checkpoint_keys = list(checkpoint.keys())
    print(
        f"Key prefixes: {sorted(set('.'.join(x.split('.')[:depth]) for x in checkpoint_keys))}"
    )
    # print(f"Key prefixes: {set(x.split('.')[0] for x in checkpoint_keys)}")
    all_checkpoint_paths.append(model_dir)
    if n <= 0:
        print(f"{len(checkpoint_keys)} keys in total.")
        return
    print(f"{len(checkpoint_keys)} keys in total. First {n} keys: ")
    for key in checkpoint_keys[:n]:
        value = checkpoint[key]
        if 100 < len(str(value)):
            value = type(value)
        print(f"{key}: {value}")

In [100]:
model_dir = "../model_weights/resnet50/ResNet50-PDDD_raw.pth"
if os.path.exists(model_dir):
    print(f"File {model_dir} already exists")
else:
    model_url = "https://zenodo.org/records/7890438/files/ResNet50-Plant-model-80.pth"
    torch.hub.download_url_to_file(model_url, model_dir, progress=True)

raw_path = model_dir
model_dir = "../model_weights/resnet50/ResNet50-PDDD_headless.pth"
if os.path.exists(model_dir):
    print(f"File {model_dir} already exists")
else:
    model = Embedder.load_pretrained("resnet50_random")
    num_classes_weights = 120
    model = resnet50(weights=None)
    model.fc = nn.Linear(model.fc.in_features, num_classes_weights)
    print(f"model.fc.in_features: {model.fc.in_features}")
    checkpoint = torch.load(raw_path, map_location=torch.device("cpu"))
    model.load_state_dict(checkpoint, strict=True)
    print(f"model.fc.out_features: {model.fc.out_features}")
    # NOTE: The Wrapper from the ssl_library adds a prefic to the dictionary keys (replace_ckp_str="model.")
    model = torch.nn.Sequential(*list(model.children())[:-1])
    model = Wrapper(model=model)
    print_parameters(model)  # 23'508'032
    torch.save(model.state_dict(), model_dir)
    print(f"File {model_dir} saved")
print_checkpoint_keys(model_dir)

File ../model_weights/resnet50/ResNet50-PDDD_raw.pth already exists
File ../model_weights/resnet50/ResNet50-PDDD_headless.pth already exists
Key prefixes: ['model.0', 'model.1', 'model.4', 'model.5', 'model.6', 'model.7']
318 keys in total.


In [101]:
model_dir = "../model_weights/resnet50/ResNet50-Random_headless.pth"
if os.path.exists(model_dir):
    print(f"File {model_dir} already exists")
else:
    # NOTE: uses ResNet50_Weights.IMAGENET1K_V1 instead of default ResNet50_Weights.IMAGENET1K_V2
    model = Embedder.load_pretrained("resnet50_random")
    model.fc = nn.Sequential()
    print_parameters(model)  # 23'508'032
    torch.save(model.state_dict(), model_dir)
    print(f"File {model_dir} saved")
print_checkpoint_keys(model_dir)

File ../model_weights/resnet50/ResNet50-Random_headless.pth already exists
Key prefixes: ['model.0', 'model.1', 'model.4', 'model.5', 'model.6', 'model.7']
318 keys in total.


In [102]:
model_dir = "../model_weights/resnet50/ResNet50-ImageNet_1k_SL_V1_headless.pth"
if os.path.exists(model_dir):
    print(f"File {model_dir} already exists")
else:
    model = Embedder.load_pretrained("imagenet")
    model.fc = nn.Sequential()
    print_parameters(model)  # 23'508'032
    torch.save(model.state_dict(), model_dir)
    print(f"File {model_dir} saved")
print_checkpoint_keys(model_dir)

File ../model_weights/resnet50/ResNet50-ImageNet_1k_SL_V1_headless.pth already exists
Key prefixes: ['model.0', 'model.1', 'model.4', 'model.5', 'model.6', 'model.7']
318 keys in total.


In [103]:
model_dir = "../model_weights/resnet50/ResNet50-ImageNet_1k_SSL_SimCLR_raw.pth"
if os.path.exists(model_dir):
    print(f"File {model_dir} already exists")
else:
    model_url = "https://github.com/vm02-self-supervised-dermatology/self-supervised-models/tree/main/simclr_imagenet/resnet50_imagenet_bs2k_epochs600.pth.tar"
    torch.hub.download_url_to_file(model_url, model_dir, progress=True)

raw_path = model_dir
model_dir = "../model_weights/resnet50/ResNet50-ImageNet_1k_SSL_SimCLR_headless.pth"
if os.path.exists(model_dir):
    print(f"File {model_dir} already exists")
else:
    # NOTE: This model is not wrapped in ssl_library!
    # Therefore it get wrapped here to use common dictionary keys (replace_ckp_str="model.")
    model = Embedder.load_simclr_imagenet(raw_path)
    model = Wrapper(model=model)
    print_parameters(model)  # 23'508'032
    torch.save(model.state_dict(), model_dir)
    print(f"File {model_dir} saved")
print_checkpoint_keys(model_dir)

File ../model_weights/resnet50/ResNet50-ImageNet_1k_SSL_SimCLR_raw.pth already exists
File ../model_weights/resnet50/ResNet50-ImageNet_1k_SSL_SimCLR_headless.pth already exists
Key prefixes: ['model.0', 'model.1', 'model.4', 'model.5', 'model.6', 'model.7']
318 keys in total.


In [104]:
model_dir = "../model_weights/resnet50/ResNet50-Derma_SSL_SimCLR_raw.pth"
if os.path.exists(model_dir):
    print(f"File {model_dir} already exists")
else:
    model_url = "https://github.com/vm02-self-supervised-dermatology/self-supervised-models/raw/main/simclr/checkpoint-epoch100.pth"
    torch.hub.download_url_to_file(model_url, model_dir, progress=True)

raw_path = model_dir
model_dir = "../model_weights/resnet50/ResNet50-Derma_SSL_SimCLR_headless.pth"
if os.path.exists(model_dir):
    print(f"File {model_dir} already exists")
else:
    # NOTE: This model is not wrapped in ssl_library!
    # Therefore it get wrapped here to use common dictionary keys (replace_ckp_str="model.")
    model = Embedder.load_simclr(raw_path)
    model = Wrapper(model=model)
    print_parameters(model)  # 23'508'032
    torch.save(model.state_dict(), model_dir)
    print(f"File {model_dir} saved")
print_checkpoint_keys(model_dir)

File ../model_weights/resnet50/ResNet50-Derma_SSL_SimCLR_raw.pth already exists
File ../model_weights/resnet50/ResNet50-Derma_SSL_SimCLR_headless.pth already exists
Key prefixes: ['model.0', 'model.1', 'model.4', 'model.5', 'model.6', 'model.7']
318 keys in total.


In [105]:
model_dir = "../model_weights/vit_t16_v2/ViT_T16-ImageNet_AugReg_headless.pth"
if os.path.exists(model_dir):
    print(f"File {model_dir} already exists")
else:
    model = timm.create_model("vit_tiny_patch16_224", pretrained=True)
    model.head = nn.Sequential()
    print_parameters(model)  # 5524416
    model = ViTWrapper(model)
    torch.save(model.state_dict(), model_dir)
    print(f"File {model_dir} saved")
print_checkpoint_keys(model_dir)

File ../model_weights/vit_t16_v2/ViT_T16-ImageNet_AugReg_headless.pth already exists
Key prefixes: ['model.blocks', 'model.cls_token', 'model.norm', 'model.patch_embed', 'model.pos_embed']
150 keys in total.


In [106]:
model_dir = "../model_weights/vit_t16_v1/ViT_T16-ImageNet_1k_SL_WinKawaks_headless.pth"
if os.path.exists(model_dir):
    print(f"File {model_dir} already exists")
else:
    model = Embedder.load_pretrained("imagenet_vit_tiny")
    model.head = nn.Sequential()
    print_parameters(model)  # 5'561'472
    torch.save(model.state_dict(), model_dir)
    print(f"File {model_dir} saved")
print_checkpoint_keys(model_dir)
# TODO: teacher?

File ../model_weights/vit_t16_v1/ViT_T16-ImageNet_1k_SL_WinKawaks_headless.pth already exists
Key prefixes: ['model.embeddings', 'model.encoder', 'model.layernorm', 'model.pooler']
200 keys in total.


In [107]:
model_dir = "../model_weights/vit_t16_v2/ViT_T16-ImageNet_1k_SSL_Dino_raw.pth"
if os.path.exists(model_dir):
    print(f"File {model_dir} already exists")
else:
    model_url = "https://github.com/vm02-self-supervised-dermatology/self-supervised-models/raw/main/imagenet_dino/checkpoint-epoch100.pth"
    torch.hub.download_url_to_file(model_url, model_dir, progress=True)

raw_path = model_dir
model_dir = "../model_weights/vit_t16_v2/ViT_T16-ImageNet_1k_SSL_Dino_headless.pth"
if os.path.exists(model_dir):
    print(f"File {model_dir} already exists")
else:
    model = Embedder.load_dino(raw_path)
    model.head = nn.Sequential()
    print_parameters(model)  # 5'524'416
    torch.save(model.state_dict(), model_dir)
    print(f"File {model_dir} saved")
print_checkpoint_keys(model_dir)

model_dir = (
    "../model_weights/vit_t16_v2/ViT_T16-ImageNet_1k_SSL_Dino_teacher_headless.pth"
)
if os.path.exists(model_dir):
    print(f"File {model_dir} already exists")
else:
    model = Embedder.load_dino(raw_path, model_key="teacher")
    model.head = nn.Sequential()
    print_parameters(model)  # 5'524'416
    torch.save(model.state_dict(), model_dir)
    print(f"File {model_dir} saved")
print_checkpoint_keys(model_dir)

File ../model_weights/vit_t16_v2/ViT_T16-ImageNet_1k_SSL_Dino_raw.pth already exists
File ../model_weights/vit_t16_v2/ViT_T16-ImageNet_1k_SSL_Dino_headless.pth already exists
Key prefixes: ['model.blocks', 'model.cls_token', 'model.norm', 'model.patch_embed', 'model.pos_embed']
150 keys in total.
File ../model_weights/vit_t16_v2/ViT_T16-ImageNet_1k_SSL_Dino_teacher_headless.pth already exists
Key prefixes: ['model.blocks', 'model.cls_token', 'model.norm', 'model.patch_embed', 'model.pos_embed']
150 keys in total.


In [108]:
raw_path = "../model_weights/vit_t16_v2/model_best 1.pth"
model_dir = "../model_weights/vit_t16_v2/ViT_T16-Plant_SSL_Dino_headless.pth"
if os.path.exists(model_dir):
    print(f"File {model_dir} already exists")
else:
    model = Embedder.load_dino(raw_path)
    model.head = nn.Sequential()
    print_parameters(model)  # 5'524'416
    torch.save(model.state_dict(), model_dir)
    print(f"File {model_dir} saved")
print_checkpoint_keys(model_dir)

model_dir = "../model_weights/vit_t16_v2/ViT_T16-Plant_SSL_Dino_teacher_headless.pth"
if os.path.exists(model_dir):
    print(f"File {model_dir} already exists")
else:
    model = Embedder.load_dino(raw_path, model_key="teacher")
    model.head = nn.Sequential()
    print_parameters(model)  # 5'524'416
    torch.save(model.state_dict(), model_dir)
    print(f"File {model_dir} saved")
print_checkpoint_keys(model_dir)

File ../model_weights/vit_t16_v2/ViT_T16-Plant_SSL_Dino_headless.pth already exists
Key prefixes: ['model.blocks', 'model.cls_token', 'model.norm', 'model.patch_embed', 'model.pos_embed']
150 keys in total.
File ../model_weights/vit_t16_v2/ViT_T16-Plant_SSL_Dino_teacher_headless.pth already exists
Key prefixes: ['model.blocks', 'model.cls_token', 'model.norm', 'model.patch_embed', 'model.pos_embed']
150 keys in total.


In [109]:
model_dir = "../model_weights/vit_t16_v2/ViT_T16-Random_headless.pth"
if os.path.exists(model_dir):
    print(f"File {model_dir} already exists")
else:
    model = Embedder.load_pretrained("vit_tiny_random")
    model.head = nn.Sequential()
    print_parameters(model)  # 5'524'416
    torch.save(model.state_dict(), model_dir)
    print(f"File {model_dir} saved")
print_checkpoint_keys(model_dir)

File ../model_weights/vit_t16_v2/ViT_T16-Random_headless.pth already exists
Key prefixes: ['model.blocks', 'model.cls_token', 'model.norm', 'model.patch_embed', 'model.pos_embed']
150 keys in total.


In [110]:
model_dir = "../model_weights/vit_t16_v2/ViT_T16-Derma_SSL_Dino_raw.pth"
if os.path.exists(model_dir):
    print(f"File {model_dir} already exists")
else:
    model_url = "https://github.com/vm02-self-supervised-dermatology/self-supervised-models/raw/main/dino/checkpoint-epoch100.pth"
    torch.hub.download_url_to_file(model_url, model_dir, progress=True)

raw_path = model_dir
model_dir = "../model_weights/vit_t16_v2/ViT_T16-Derma_SSL_Dino_headless.pth"
if os.path.exists(model_dir):
    print(f"File {model_dir} already exists")
else:
    model = Embedder.load_dino(raw_path)
    model.head = nn.Sequential()
    print_parameters(model)  # 5'524'416
    torch.save(model.state_dict(), model_dir)
    print(f"File {model_dir} saved")
print_checkpoint_keys(model_dir)

model_dir = "../model_weights/vit_t16_v2/ViT_T16-Derma_SSL_Dino_teacher_headless.pth"
if os.path.exists(model_dir):
    print(f"File {model_dir} already exists")
else:
    model = Embedder.load_dino(raw_path, model_key="teacher")
    model.head = nn.Sequential()
    print_parameters(model)  # 5'524'416
    torch.save(model.state_dict(), model_dir)
    print(f"File {model_dir} saved")
print_checkpoint_keys(model_dir)

File ../model_weights/vit_t16_v2/ViT_T16-Derma_SSL_Dino_raw.pth already exists
File ../model_weights/vit_t16_v2/ViT_T16-Derma_SSL_Dino_headless.pth already exists
Key prefixes: ['model.blocks', 'model.cls_token', 'model.norm', 'model.patch_embed', 'model.pos_embed']
150 keys in total.
File ../model_weights/vit_t16_v2/ViT_T16-Derma_SSL_Dino_teacher_headless.pth already exists
Key prefixes: ['model.blocks', 'model.cls_token', 'model.norm', 'model.patch_embed', 'model.pos_embed']
150 keys in total.


In [114]:
for checkpoint_path in all_checkpoint_paths:
    _ = load_headless_model(checkpoint_path, wrapped=True, ignore_key_prefix=False)
    _ = load_headless_model(checkpoint_path, wrapped=False, ignore_key_prefix=True)

Loading model with architecture 'resnet50' from ../model_weights/resnet50/ResNet50-PDDD_headless.pth
Loading model with architecture 'resnet50' from ../model_weights/resnet50/ResNet50-PDDD_headless.pth
Ignoring prefix 'model.'
Loading model with architecture 'resnet50' from ../model_weights/resnet50/ResNet50-Random_headless.pth
Loading model with architecture 'resnet50' from ../model_weights/resnet50/ResNet50-Random_headless.pth
Ignoring prefix 'model.'
Loading model with architecture 'resnet50' from ../model_weights/resnet50/ResNet50-ImageNet_1k_SL_V1_headless.pth
Loading model with architecture 'resnet50' from ../model_weights/resnet50/ResNet50-ImageNet_1k_SL_V1_headless.pth
Ignoring prefix 'model.'
Loading model with architecture 'resnet50' from ../model_weights/resnet50/ResNet50-ImageNet_1k_SSL_SimCLR_headless.pth
Loading model with architecture 'resnet50' from ../model_weights/resnet50/ResNet50-ImageNet_1k_SSL_SimCLR_headless.pth
Ignoring prefix 'model.'
Loading model with archite

Some weights of ViTModel were not initialized from the model checkpoint at WinKawaks/vit-tiny-patch16-224 and are newly initialized: ['vit.pooler.dense.bias', 'vit.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Loading model with architecture 'vit_t16_v1' from ../model_weights/vit_t16_v1/ViT_T16-ImageNet_1k_SL_WinKawaks_headless.pth


Some weights of ViTModel were not initialized from the model checkpoint at WinKawaks/vit-tiny-patch16-224 and are newly initialized: ['vit.pooler.dense.bias', 'vit.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Ignoring prefix 'model.'
Loading model with architecture 'vit_t16_v2' from ../model_weights/vit_t16_v2/ViT_T16-ImageNet_1k_SSL_Dino_headless.pth
Loading model with architecture 'vit_t16_v2' from ../model_weights/vit_t16_v2/ViT_T16-ImageNet_1k_SSL_Dino_headless.pth
Ignoring prefix 'model.'
Loading model with architecture 'vit_t16_v2' from ../model_weights/vit_t16_v2/ViT_T16-ImageNet_1k_SSL_Dino_teacher_headless.pth
Loading model with architecture 'vit_t16_v2' from ../model_weights/vit_t16_v2/ViT_T16-ImageNet_1k_SSL_Dino_teacher_headless.pth
Ignoring prefix 'model.'
Loading model with architecture 'vit_t16_v2' from ../model_weights/vit_t16_v2/ViT_T16-Plant_SSL_Dino_headless.pth
Loading model with architecture 'vit_t16_v2' from ../model_weights/vit_t16_v2/ViT_T16-Plant_SSL_Dino_headless.pth
Ignoring prefix 'model.'
Loading model with architecture 'vit_t16_v2' from ../model_weights/vit_t16_v2/ViT_T16-Plant_SSL_Dino_teacher_headless.pth
Loading model with architecture 'vit_t16_v2' from ../m

In [246]:
model = Embedder.load_pretrained("imagenet_vit_tiny")
model.head = nn.Sequential()
print_parameters(model)
params1 = {}
for name, parameter in model.named_parameters():
    param_count = parameter.numel()
    params1[name] = param_count

model = Embedder.load_pretrained("vit_tiny_random")
model.head = nn.Sequential()
print_parameters(model)
params2 = {}
for name, parameter in model.named_parameters():
    param_count = parameter.numel()
    params2[name] = param_count

Some weights of ViTModel were not initialized from the model checkpoint at WinKawaks/vit-tiny-patch16-224 and are newly initialized: ['vit.pooler.dense.bias', 'vit.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Number of entries: 200
Total parameters: 5561472
Number of entries: 150
Total parameters: 5524416


In [263]:
keys1 = set(params1.keys())
sizedict1 = {}
for k, v in params1.items():
    sizedict1[v] = sizedict1.get(v, set()) | set([k])

keys2 = set(params2.keys())
sizedict2 = {}
for k, v in params2.items():
    sizedict2[v] = sizedict2.get(v, set()) | set([k])

print(f"len1: {len(keys1)}, len2: {len(keys2)}")

len1: 200, len2: 150


In [267]:
for key in sizedict2:
    len1 = 0
    len2 = len(sizedict2[key])
    if key in sizedict1:
        len1 = len(sizedict1[key])
    if len1 == len2:
        # print(f"{len1} == {len2}")
        keys1 = keys1 - sizedict1[key]
        keys2 = keys2 - sizedict2[key]
    else:
        print(f"Size({key}): {len1} != {len2}")

print(f"len1: {len(keys1)}, len2: {len(keys2)}")

Size(192): 113 != 76
Size(110592): 0 != 12
Size(576): 0 != 12
Size(36864): 49 != 12
len1: 162, len2: 112


In [280]:
keys1 = [x for x in list(params1.keys()) if x in params1.keys() if x in keys1]
keys2 = [x for x in list(params2.keys()) if x in params2.keys() if x in keys2]

i1 = 0
i2 = 0
while i1 < len(keys1) and i2 < len(keys2):
    if params1[keys1[i1]] < params2[keys2[i2]]:
        i2 += 1
    elif params1[keys1[i1]] > params2[keys2[i2]]:
        i1 += 1
    elif keys1[i1].split(".")[-1] == keys2[i2].split(".")[-1]:
        print(f"{keys1[i1]} matches {keys2[i2]}")
        del keys1[i1]
        del keys2[i2]
    else:
        # print(f"{keys1[i1]} does not match {keys2[i2]}")
        i1 += 1
        i2 += 1

print(f"len1: {len(keys1)}, len2: {len(keys2)}")

len1: 127, len2: 77


In [284]:
depth = 3
print(f"Key prefixes: {sorted(set('.'.join(x.split('.')[:depth]) for x in keys1))}")
print(f"Key prefixes: {sorted(set('.'.join(x.split('.')[:depth]) for x in keys2))}")

Key prefixes: ['model.encoder', 'model.layernorm', 'model.pooler']
Key prefixes: ['model.blocks', 'model.norm']


In [83]:
pathlib.PosixPath = temp  # revert back to Linux