In [1]:
import numpy as np
import sys
import torch
import os
from torchvision.models import resnet50, ResNet50_Weights
from transformers import ViTForImageClassification
from transformers import ViTModel
import torch.nn as nn
import timm
import pathlib

temp = pathlib.PosixPath

sys.path.append("ssl_library")
from src.pkg.embedder import Embedder
from src.pkg.wrappers import ViTWrapper, Wrapper

sys.path.append("local_python")
from local_utils import print_parameters

In [2]:
seed = 19

In [3]:
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)

# NOTE: ssl_library is not windows compatible in itself
# https://stackoverflow.com/questions/57286486/i-cant-load-my-model-because-i-cant-put-a-posixpath

pathlib.PosixPath = pathlib.WindowsPath

In [49]:
def print_checkpoint_keys(model_dir, n=0, depth=2):
    checkpoint = torch.load(model_dir, map_location=torch.device("cpu"))
    checkpoint_keys = list(checkpoint.keys())
    print(f"Key prefixes: {sorted(set('.'.join(x.split('.')[:depth]) for x in checkpoint_keys))}")
    # print(f"Key prefixes: {set(x.split('.')[0] for x in checkpoint_keys)}")
    if n <= 0:
        print(f"{len(checkpoint_keys)} keys in total.")
        return
    print(f"{len(checkpoint_keys)} keys in total. First {n} keys: ")
    for key in checkpoint_keys[:n]:
        value = checkpoint[key]
        if 100 < len(str(value)):
            value = type(value)
        print(f"{key}: {value}")
    # return checkpoint

In [54]:
model_dir = "../model_weights/resnet50/ResNet50-PDDD_raw.pth"
if os.path.exists(model_dir):
    print(f"File {model_dir} already exists")
else:
    model_url = "https://zenodo.org/records/7890438/files/ResNet50-Plant-model-80.pth"
    torch.hub.download_url_to_file(model_url, model_dir, progress=True)

raw_path = model_dir
model_dir = "../model_weights/resnet50/ResNet50-PDDD_headless.pth"
if os.path.exists(model_dir):
    print(f"File {model_dir} already exists")
else:
    model = Embedder.load_pretrained("resnet50_random")
    num_classes_weights = 120
    model = resnet50(weights=None)
    model.fc = nn.Linear(model.fc.in_features, num_classes_weights)
    print(f"model.fc.in_features: {model.fc.in_features}")
    checkpoint = torch.load(raw_path, map_location=torch.device("cpu"))
    model.load_state_dict(checkpoint, strict=True)
    print(f"model.fc.out_features: {model.fc.out_features}")
    # NOTE: The Wrapper from the ssl_library adds a prefic to the dictionary keys (replace_ckp_str="model.")
    model = torch.nn.Sequential(*list(model.children())[:-1])
    model = Wrapper(model=model)
    print_parameters(model)  # 23'508'032
    torch.save(model.state_dict(), model_dir)
    print(f"File {model_dir} saved")
print_checkpoint_keys(model_dir)

File ../model_weights/resnet50/ResNet50-PDDD_raw.pth already exists
File ../model_weights/resnet50/ResNet50-PDDD_headless.pth already exists
Key prefixes: ['model.0', 'model.1', 'model.4', 'model.5', 'model.6', 'model.7']
318 keys in total.


In [55]:
model_dir = "../model_weights/resnet50/ResNet50-Random_headless.pth"
if os.path.exists(model_dir):
    print(f"File {model_dir} already exists")
else:
    # NOTE: uses ResNet50_Weights.IMAGENET1K_V1 instead of default ResNet50_Weights.IMAGENET1K_V2
    model = Embedder.load_pretrained("resnet50_random")
    model.fc = nn.Sequential()
    print_parameters(model)  # 23'508'032
    torch.save(model.state_dict(), model_dir)
    print(f"File {model_dir} saved")
print_checkpoint_keys(model_dir)

File ../model_weights/resnet50/ResNet50-Random_headless.pth already exists
Key prefixes: ['model.0', 'model.1', 'model.4', 'model.5', 'model.6', 'model.7']
318 keys in total.


In [56]:
model_dir = "../model_weights/resnet50/ResNet50-ImageNet_1k_SL_V1_headless.pth"
if os.path.exists(model_dir):
    print(f"File {model_dir} already exists")
else:
    model = Embedder.load_pretrained("imagenet")
    model.fc = nn.Sequential()
    print_parameters(model)  # 23'508'032
    torch.save(model.state_dict(), model_dir)
    print(f"File {model_dir} saved")
print_checkpoint_keys(model_dir)

File ../model_weights/resnet50/ResNet50-ImageNet_1k_SL_V1_headless.pth already exists
Key prefixes: ['model.0', 'model.1', 'model.4', 'model.5', 'model.6', 'model.7']
318 keys in total.


In [57]:
model_dir = "../model_weights/resnet50/ResNet50-ImageNet_1k_SSL_SimCLR_raw.pth"
if os.path.exists(model_dir):
    print(f"File {model_dir} already exists")
else:
    model_url = "https://github.com/vm02-self-supervised-dermatology/self-supervised-models/tree/main/simclr_imagenet/resnet50_imagenet_bs2k_epochs600.pth.tar"
    torch.hub.download_url_to_file(model_url, model_dir, progress=True)

raw_path = model_dir
model_dir = "../model_weights/resnet50/ResNet50-ImageNet_1k_SSL_SimCLR_headless.pth"
if os.path.exists(model_dir):
    print(f"File {model_dir} already exists")
else:
    # NOTE: This model is not wrapped in ssl_library!
    # Therefore it get wrapped here to use common dictionary keys (replace_ckp_str="model.")
    model = Embedder.load_simclr_imagenet(raw_path)
    model = Wrapper(model=model)
    print_parameters(model)  # 23'508'032
    torch.save(model.state_dict(), model_dir)
    print(f"File {model_dir} saved")
print_checkpoint_keys(model_dir)

File ../model_weights/resnet50/ResNet50-ImageNet_1k_SSL_SimCLR_raw.pth already exists
File ../model_weights/resnet50/ResNet50-ImageNet_1k_SSL_SimCLR_headless.pth already exists
Key prefixes: ['model.0', 'model.1', 'model.4', 'model.5', 'model.6', 'model.7']
318 keys in total.


In [58]:
model_dir = "../model_weights/resnet50/ResNet50-Derma_SSL_SimCLR_raw.pth"
if os.path.exists(model_dir):
    print(f"File {model_dir} already exists")
else:
    model_url = "https://github.com/vm02-self-supervised-dermatology/self-supervised-models/raw/main/simclr/checkpoint-epoch100.pth"
    torch.hub.download_url_to_file(model_url, model_dir, progress=True)

raw_path = model_dir
model_dir = "../model_weights/resnet50/ResNet50-Derma_SSL_SimCLR_headless.pth"
if os.path.exists(model_dir):
    print(f"File {model_dir} already exists")
else:
    # NOTE: This model is not wrapped in ssl_library!
    # Therefore it get wrapped here to use common dictionary keys (replace_ckp_str="model.")
    model = Embedder.load_simclr(raw_path)
    model = Wrapper(model=model)
    print_parameters(model)  # 23'508'032
    torch.save(model.state_dict(), model_dir)
    print(f"File {model_dir} saved")
print_checkpoint_keys(model_dir)

File ../model_weights/resnet50/ResNet50-Derma_SSL_SimCLR_raw.pth already exists
File ../model_weights/resnet50/ResNet50-Derma_SSL_SimCLR_headless.pth already exists
Key prefixes: ['model.0', 'model.1', 'model.4', 'model.5', 'model.6', 'model.7']
318 keys in total.


In [59]:
model = resnet50(weights=None)
model = torch.nn.Sequential(*list(model.children())[:-1])
model = Wrapper(model=model)

checkpoint = torch.load(model_dir, map_location=torch.device("cpu"))
# msg = 
model.load_state_dict(checkpoint, strict=True)
# msg.unexpected_keys

<All keys matched successfully>

In [68]:
model_dir = "../model_weights/vit_t16_v1/ViT_T16-ImageNet_1k_SL_WinKawaks_headless.pth"
if os.path.exists(model_dir):
    print(f"File {model_dir} already exists")
else:
    model = Embedder.load_pretrained("imagenet_vit_tiny")
    model.head = nn.Sequential()
    print_parameters(model)  # 5'561'472
    torch.save(model.state_dict(), model_dir)
    print(f"File {model_dir} saved")
print_checkpoint_keys(model_dir)
# TODO: teacher?

File ../model_weights/vit_t16_v1/ViT_T16-ImageNet_1k_SL_WinKawaks_headless.pth already exists
Key prefixes: ['model.embeddings', 'model.encoder', 'model.layernorm', 'model.pooler']
200 keys in total.


In [70]:
model_dir = "../model_weights/vit_t16_v2/ViT_T16-ImageNet_1k_SSL_Dino_raw.pth"
if os.path.exists(model_dir):
    print(f"File {model_dir} already exists")
else:
    model_url = "https://github.com/vm02-self-supervised-dermatology/self-supervised-models/raw/main/imagenet_dino/checkpoint-epoch100.pth"
    torch.hub.download_url_to_file(model_url, model_dir, progress=True)
print_checkpoint_keys(model_dir)

raw_path = model_dir
model_dir = "../model_weights/vit_t16_v2/ViT_T16-ImageNet_1k_SSL_Dino_headless.pth"
if os.path.exists(model_dir):
    print(f"File {model_dir} already exists")
else:
    model = Embedder.load_dino(raw_path)
    model.head = nn.Sequential()
    print_parameters(model)  # 5'524'416
    torch.save(model.state_dict(), model_dir)
    print(f"File {model_dir} saved")
print_checkpoint_keys(model_dir)

model_dir = "../model_weights/vit_t16_v2/ViT_T16-ImageNet_1k_SSL_Dino_teacher_headless.pth"
if os.path.exists(model_dir):
    print(f"File {model_dir} already exists")
else:
    model = Embedder.load_dino(raw_path, model_key="teacher")
    model.head = nn.Sequential()
    print_parameters(model)  # 5'524'416
    torch.save(model.state_dict(), model_dir)
    print(f"File {model_dir} saved")
print_checkpoint_keys(model_dir)

File ../model_weights/vit_t16_v2/ViT_T16-ImageNet_1k_SSL_Dino_raw.pth already exists


[32m2024-08-29 04:06:02.688[0m | [1mINFO    [0m | [36msrc.pkg.embedder[0m:[36mrestart_from_checkpoint[0m:[36m659[0m - [1mFound checkpoint at ../model_weights/vit_t16_v2/ViT_T16-ImageNet_1k_SSL_Dino_raw.pth[0m


Key prefixes: ['arch', 'config', 'epoch', 'loss', 'optimizer', 'student', 'teacher']
7 keys in total.
File ../model_weights/vit_t16_v2/ViT_T16-ImageNet_1k_SSL_Dino_headless.pth already exists
Key prefixes: ['model.blocks', 'model.cls_token', 'model.norm', 'model.patch_embed', 'model.pos_embed']
150 keys in total.
ckp_path: ../model_weights/vit_t16_v2/ViT_T16-ImageNet_1k_SSL_Dino_raw.pth
kwargs keys: []
base_model: vit_tiny
Keys to load dummy model: ['out_dim', 'emb_dim', 'base_model', 'model_type', 'use_bn_in_head', 'norm_last_layer', 'student', 'teacher', 'eval']


[32m2024-08-29 04:06:03.016[0m | [1mINFO    [0m | [36msrc.pkg.embedder[0m:[36mrestart_from_checkpoint[0m:[36m659[0m - [1mFound checkpoint at ../model_weights/vit_t16_v2/ViT_T16-ImageNet_1k_SSL_Dino_raw.pth[0m
[32m2024-08-29 04:06:03.128[0m | [34m[1mDEBUG   [0m | [36msrc.pkg.embedder[0m:[36mrestart_from_checkpoint[0m:[36m680[0m - [34m[1m=> Found `backbone.` in teacher, trying to transform.[0m
[32m2024-08-29 04:06:03.139[0m | [34m[1mDEBUG   [0m | [36msrc.pkg.embedder[0m:[36mrestart_from_checkpoint[0m:[36m694[0m - [34m[1m=> loaded 'teacher' from checkpoint '../model_weights/vit_t16_v2/ViT_T16-ImageNet_1k_SSL_Dino_raw.pth' with msg _IncompatibleKeys(missing_keys=[], unexpected_keys=['head.mlp.0.weight', 'head.mlp.0.bias', 'head.mlp.2.weight', 'head.mlp.2.bias', 'head.mlp.4.weight', 'head.mlp.4.bias', 'head.last_layer.weight_g', 'head.last_layer.weight_v'])[0m


Given keys: ['teacher']
Assigned empty model to key: teacher
kwargs keys: ['teacher']
Loading key: teacher
Key size: 158
Key size: 158


[32m2024-08-29 04:06:03.255[0m | [1mINFO    [0m | [36msrc.pkg.embedder[0m:[36mrestart_from_checkpoint[0m:[36m659[0m - [1mFound checkpoint at ../model_weights/vit_t16_v2/ViT_T16-ImageNet_1k_SSL_Dino_raw.pth[0m
[32m2024-08-29 04:06:03.372[0m | [34m[1mDEBUG   [0m | [36msrc.pkg.embedder[0m:[36mrestart_from_checkpoint[0m:[36m694[0m - [34m[1m=> loaded 'teacher' from checkpoint '../model_weights/vit_t16_v2/ViT_T16-ImageNet_1k_SSL_Dino_raw.pth' with msg _IncompatibleKeys(missing_keys=['mlp.0.weight', 'mlp.0.bias', 'mlp.2.weight', 'mlp.2.bias', 'mlp.4.weight', 'mlp.4.bias', 'last_layer.weight_g', 'last_layer.weight_v'], unexpected_keys=['backbone.cls_token', 'backbone.pos_embed', 'backbone.patch_embed.proj.weight', 'backbone.patch_embed.proj.bias', 'backbone.blocks.0.norm1.weight', 'backbone.blocks.0.norm1.bias', 'backbone.blocks.0.attn.qkv.weight', 'backbone.blocks.0.attn.qkv.bias', 'backbone.blocks.0.attn.proj.weight', 'backbone.blocks.0.attn.proj.bias', 'backbone.blo

kwargs keys: ['teacher']
Loading key: teacher
Key size: 158
Total parameters: 5524416
Required parameters: 5524416 
File ../model_weights/vit_t16_v2/ViT_T16-ImageNet_1k_SSL_Dino_teacher_headless.pth saved
Key prefixes: ['model.blocks', 'model.cls_token', 'model.norm', 'model.patch_embed', 'model.pos_embed']
150 keys in total.


In [12]:
model_dir = "../model_weights/vit_t16_v2/ViT_T16-ImageNet_1k_SSL_Dino_raw.pth"
config = print_checkpoint_keys(model_dir, 7)
# config["model"].keys()
# "configs" in config["model"].keys()
# config["model"]["student"]
# model = Embedder.load_dino(raw_path)
# assert config["student"].keys() == config["teacher"].keys()
# print(config["config"]["model"]["student"])
# print(config["config"]["model"]["teacher"])
raw_path = model_dir
debug = True

# model = Embedder.load_dino(raw_path, model_key="student", debug=debug)
# print(model.model.blocks[0].mlp.fc1.weight[0][:5].data) #tensor([-0.0018, -0.0013, -0.0218, -0.0328,  0.0131])
# print(model.model.head[0]) #


# model = Embedder.load_dino(raw_path, model_key="teacher", debug=debug)
# print(model.model.blocks[0].mlp.fc1.weight[0][:5].data) # tensor([-0.0020, -0.0020, -0.0233, -0.0349,  0.0140])

# Embedder.restart_from_checkpoint(
#     ckp_path,
#     student=None,
#     replace_ckp_str="head.",
#     hide_logs=False,
# )

7 keys in total. First 7 keys: 
arch: DistributedDataParallel
epoch: 100
student: <class 'collections.OrderedDict'>
teacher: <class 'collections.OrderedDict'>
optimizer: <class 'dict'>
config: <class 'dict'>
loss: OrderedDict([('center', tensor([[-0.2820, -0.2820, -0.2820,  ..., -0.2820, -0.2820, -0.2820]]))])


In [60]:
# [x for x in dir(msg) if not x.startswith("_")]
# [x for x in dir(checkpoint) if not x.startswith("_")]
# [x for x in dir(model) if  x.startswith("_")]

In [21]:
# config["config"]
# [x for x in config["config"].keys() if "head" in x]
# [x for x in config["config"]["model"].keys() if "head" in x]

['use_bn_in_head']

In [15]:
# [x for x in config["student"].keys() if x.startswith("head")]

['head.mlp.0.weight',
 'head.mlp.0.bias',
 'head.mlp.2.weight',
 'head.mlp.2.bias',
 'head.mlp.4.weight',
 'head.mlp.4.bias',
 'head.last_layer.weight_g',
 'head.last_layer.weight_v']

In [16]:
# [x for x in config["teacher"].keys() if x.startswith("head")]

['head.mlp.0.weight',
 'head.mlp.0.bias',
 'head.mlp.2.weight',
 'head.mlp.2.bias',
 'head.mlp.4.weight',
 'head.mlp.4.bias',
 'head.last_layer.weight_g',
 'head.last_layer.weight_v']

In [72]:
model_dir = "../model_weights/vit_t16_v3/ViT_T16-ImageNet_AugReg_headless.pth"
if os.path.exists(model_dir):
    print(f"File {model_dir} already exists")
else:
    model = timm.create_model("vit_tiny_patch16_224", pretrained=True)
    model = ViTWrapper(model)
    model.head = nn.Sequential()
    print_parameters(model)  # 5'524'416 -> 5'717'416
    torch.save(model.state_dict(), model_dir)
    print(f"File {model_dir} saved")
print_checkpoint_keys(model_dir)

File ../model_weights/vit_t16_v3/ViT_T16-ImageNet_AugReg_headless.pth already exists
Key prefixes: ['model.blocks', 'model.cls_token', 'model.head', 'model.norm', 'model.patch_embed', 'model.pos_embed']
152 keys in total.


In [77]:
raw_path = "../model_weights/vit_t16_v2/model_best 1.pth"
print_checkpoint_keys(raw_path)

model_dir = "../model_weights/vit_t16_v2/ViT_T16-Plant_SSL_Dino_headless.pth"
if os.path.exists(model_dir):
    print(f"File {model_dir} already exists")
else:
    model = Embedder.load_dino(raw_path)
    model.head = nn.Sequential()
    print_parameters(model)  # 5'524'416
    torch.save(model.state_dict(), model_dir)
    print(f"File {model_dir} saved")
print_checkpoint_keys(model_dir)

model_dir = "../model_weights/vit_t16_v2/ViT_T16-Plant_SSL_Dino_teacher_headless.pth"
if os.path.exists(model_dir):
    print(f"File {model_dir} already exists")
else:
    model = Embedder.load_dino(raw_path, model_key="teacher")
    model.head = nn.Sequential()
    print_parameters(model)  # 5'524'416
    torch.save(model.state_dict(), model_dir)
    print(f"File {model_dir} saved")
print_checkpoint_keys(model_dir)


[32m2024-08-29 04:11:55.686[0m | [1mINFO    [0m | [36msrc.pkg.embedder[0m:[36mrestart_from_checkpoint[0m:[36m659[0m - [1mFound checkpoint at ../model_weights/vit_t16_v2/model_best 1.pth[0m


Key prefixes: ['arch', 'config', 'epoch', 'loss', 'optimizer', 'student', 'teacher']
7 keys in total.
File ../model_weights/vit_t16_v2/ViT_T16-Plant_SSL_Dino_headless.pth already exists
Key prefixes: ['model.blocks', 'model.cls_token', 'model.norm', 'model.patch_embed', 'model.pos_embed']
150 keys in total.
ckp_path: ../model_weights/vit_t16_v2/model_best 1.pth


[32m2024-08-29 04:11:56.021[0m | [1mINFO    [0m | [36msrc.pkg.embedder[0m:[36mrestart_from_checkpoint[0m:[36m659[0m - [1mFound checkpoint at ../model_weights/vit_t16_v2/model_best 1.pth[0m


kwargs keys: []
base_model: vit_tiny
Keys to load dummy model: ['out_dim', 'emb_dim', 'base_model', 'model_type', 'use_bn_in_head', 'norm_last_layer', 'student', 'teacher', 'eval']
Given keys: ['teacher']
Assigned empty model to key: teacher


[32m2024-08-29 04:11:56.127[0m | [34m[1mDEBUG   [0m | [36msrc.pkg.embedder[0m:[36mrestart_from_checkpoint[0m:[36m680[0m - [34m[1m=> Found `backbone.` in teacher, trying to transform.[0m
[32m2024-08-29 04:11:56.137[0m | [34m[1mDEBUG   [0m | [36msrc.pkg.embedder[0m:[36mrestart_from_checkpoint[0m:[36m694[0m - [34m[1m=> loaded 'teacher' from checkpoint '../model_weights/vit_t16_v2/model_best 1.pth' with msg _IncompatibleKeys(missing_keys=[], unexpected_keys=['head.mlp.0.weight', 'head.mlp.0.bias', 'head.mlp.2.weight', 'head.mlp.2.bias', 'head.mlp.4.weight', 'head.mlp.4.bias', 'head.last_layer.weight_g', 'head.last_layer.weight_v'])[0m
[32m2024-08-29 04:11:56.229[0m | [1mINFO    [0m | [36msrc.pkg.embedder[0m:[36mrestart_from_checkpoint[0m:[36m659[0m - [1mFound checkpoint at ../model_weights/vit_t16_v2/model_best 1.pth[0m


kwargs keys: ['teacher']
Loading key: teacher
Key size: 158
Key size: 158
kwargs keys: ['teacher']
Loading key: teacher
Key size: 158


[32m2024-08-29 04:11:56.324[0m | [34m[1mDEBUG   [0m | [36msrc.pkg.embedder[0m:[36mrestart_from_checkpoint[0m:[36m694[0m - [34m[1m=> loaded 'teacher' from checkpoint '../model_weights/vit_t16_v2/model_best 1.pth' with msg _IncompatibleKeys(missing_keys=['mlp.0.weight', 'mlp.0.bias', 'mlp.2.weight', 'mlp.2.bias', 'mlp.4.weight', 'mlp.4.bias', 'last_layer.weight_g', 'last_layer.weight_v'], unexpected_keys=['backbone.cls_token', 'backbone.pos_embed', 'backbone.patch_embed.proj.weight', 'backbone.patch_embed.proj.bias', 'backbone.blocks.0.norm1.weight', 'backbone.blocks.0.norm1.bias', 'backbone.blocks.0.attn.qkv.weight', 'backbone.blocks.0.attn.qkv.bias', 'backbone.blocks.0.attn.proj.weight', 'backbone.blocks.0.attn.proj.bias', 'backbone.blocks.0.norm2.weight', 'backbone.blocks.0.norm2.bias', 'backbone.blocks.0.mlp.fc1.weight', 'backbone.blocks.0.mlp.fc1.bias', 'backbone.blocks.0.mlp.fc2.weight', 'backbone.blocks.0.mlp.fc2.bias', 'backbone.blocks.1.norm1.weight', 'backbone.bloc

Total parameters: 5524416
Required parameters: 5524416 
File ../model_weights/vit_t16_v2/ViT_T16-Plant_SSL_Dino_teacher_headless.pth saved
Key prefixes: ['model.blocks', 'model.cls_token', 'model.norm', 'model.patch_embed', 'model.pos_embed']
150 keys in total.


In [79]:
model_dir = "../model_weights/vit_t16_v2/ViT_T16-Random_headless.pth"
if os.path.exists(model_dir):
    print(f"File {model_dir} already exists")
else:
    model = Embedder.load_pretrained("vit_tiny_random")
    model.head = nn.Sequential()
    print_parameters(model)  # 5'524'416
    torch.save(model.state_dict(), model_dir)
    print(f"File {model_dir} saved")
print_checkpoint_keys(model_dir)

File ../model_weights/vit_t16_v2/ViT_T16-Random_headless.pth already exists
Key prefixes: ['model.blocks', 'model.cls_token', 'model.norm', 'model.patch_embed', 'model.pos_embed']
150 keys in total.


In [82]:
model_dir = "../model_weights/vit_t16_v2/ViT_T16-Derma_SSL_Dino_raw.pth"
if os.path.exists(model_dir):
    print(f"File {model_dir} already exists")
else:
    model_url = "https://github.com/vm02-self-supervised-dermatology/self-supervised-models/raw/main/dino/checkpoint-epoch100.pth"
    torch.hub.download_url_to_file(model_url, model_dir, progress=True)
print_checkpoint_keys(model_dir)

raw_path = model_dir
model_dir = "../model_weights/vit_t16_v2/ViT_T16-Derma_SSL_Dino_headless.pth"
if os.path.exists(model_dir):
    print(f"File {model_dir} already exists")
else:
    model = Embedder.load_dino(raw_path)
    model.head = nn.Sequential()
    print_parameters(model)  # 5'524'416
    torch.save(model.state_dict(), model_dir)
    print(f"File {model_dir} saved")
print_checkpoint_keys(model_dir)

model_dir = "../model_weights/vit_t16_v2/ViT_T16-Derma_SSL_Dino_teacher_headless.pth"
if os.path.exists(model_dir):
    print(f"File {model_dir} already exists")
else:
    model = Embedder.load_dino(raw_path, model_key="teacher")
    model.head = nn.Sequential()
    print_parameters(model)  # 5'524'416
    torch.save(model.state_dict(), model_dir)
    print(f"File {model_dir} saved")
print_checkpoint_keys(model_dir)

[32m2024-08-29 04:14:57.518[0m | [1mINFO    [0m | [36msrc.pkg.embedder[0m:[36mrestart_from_checkpoint[0m:[36m659[0m - [1mFound checkpoint at ../model_weights/vit_t16_v2/ViT_T16-Derma_SSL_Dino_raw.pth[0m


File ../model_weights/vit_t16_v2/ViT_T16-Derma_SSL_Dino_raw.pth already exists
Key prefixes: ['arch', 'config', 'epoch', 'loss', 'optimizer', 'student', 'teacher']
7 keys in total.
File ../model_weights/vit_t16_v2/ViT_T16-Derma_SSL_Dino_headless.pth already exists
Key prefixes: ['model.blocks', 'model.cls_token', 'model.norm', 'model.patch_embed', 'model.pos_embed']
150 keys in total.
ckp_path: ../model_weights/vit_t16_v2/ViT_T16-Derma_SSL_Dino_raw.pth


[32m2024-08-29 04:14:57.713[0m | [1mINFO    [0m | [36msrc.pkg.embedder[0m:[36mrestart_from_checkpoint[0m:[36m659[0m - [1mFound checkpoint at ../model_weights/vit_t16_v2/ViT_T16-Derma_SSL_Dino_raw.pth[0m


kwargs keys: []
base_model: vit_tiny
Keys to load dummy model: ['out_dim', 'emb_dim', 'base_model', 'model_type', 'use_bn_in_head', 'norm_last_layer', 'student', 'teacher', 'eval']
Given keys: ['teacher']
Assigned empty model to key: teacher


[32m2024-08-29 04:14:57.961[0m | [34m[1mDEBUG   [0m | [36msrc.pkg.embedder[0m:[36mrestart_from_checkpoint[0m:[36m680[0m - [34m[1m=> Found `backbone.` in teacher, trying to transform.[0m
[32m2024-08-29 04:14:57.975[0m | [34m[1mDEBUG   [0m | [36msrc.pkg.embedder[0m:[36mrestart_from_checkpoint[0m:[36m694[0m - [34m[1m=> loaded 'teacher' from checkpoint '../model_weights/vit_t16_v2/ViT_T16-Derma_SSL_Dino_raw.pth' with msg _IncompatibleKeys(missing_keys=[], unexpected_keys=['head.mlp.0.weight', 'head.mlp.0.bias', 'head.mlp.2.weight', 'head.mlp.2.bias', 'head.mlp.4.weight', 'head.mlp.4.bias', 'head.last_layer.weight_g', 'head.last_layer.weight_v'])[0m
[32m2024-08-29 04:14:58.079[0m | [1mINFO    [0m | [36msrc.pkg.embedder[0m:[36mrestart_from_checkpoint[0m:[36m659[0m - [1mFound checkpoint at ../model_weights/vit_t16_v2/ViT_T16-Derma_SSL_Dino_raw.pth[0m


kwargs keys: ['teacher']
Loading key: teacher
Key size: 158
Key size: 158
kwargs keys: ['teacher']
Loading key: teacher
Key size: 158


[32m2024-08-29 04:14:58.170[0m | [34m[1mDEBUG   [0m | [36msrc.pkg.embedder[0m:[36mrestart_from_checkpoint[0m:[36m694[0m - [34m[1m=> loaded 'teacher' from checkpoint '../model_weights/vit_t16_v2/ViT_T16-Derma_SSL_Dino_raw.pth' with msg _IncompatibleKeys(missing_keys=['mlp.0.weight', 'mlp.0.bias', 'mlp.2.weight', 'mlp.2.bias', 'mlp.4.weight', 'mlp.4.bias', 'last_layer.weight_g', 'last_layer.weight_v'], unexpected_keys=['backbone.cls_token', 'backbone.pos_embed', 'backbone.patch_embed.proj.weight', 'backbone.patch_embed.proj.bias', 'backbone.blocks.0.norm1.weight', 'backbone.blocks.0.norm1.bias', 'backbone.blocks.0.attn.qkv.weight', 'backbone.blocks.0.attn.qkv.bias', 'backbone.blocks.0.attn.proj.weight', 'backbone.blocks.0.attn.proj.bias', 'backbone.blocks.0.norm2.weight', 'backbone.blocks.0.norm2.bias', 'backbone.blocks.0.mlp.fc1.weight', 'backbone.blocks.0.mlp.fc1.bias', 'backbone.blocks.0.mlp.fc2.weight', 'backbone.blocks.0.mlp.fc2.bias', 'backbone.blocks.1.norm1.weight', 

Total parameters: 5524416
Required parameters: 5524416 
File ../model_weights/vit_t16_v2/ViT_T16-Derma_SSL_Dino_teacher_headless.pth saved
Key prefixes: ['model.blocks', 'model.cls_token', 'model.norm', 'model.patch_embed', 'model.pos_embed']
150 keys in total.


In [83]:
pathlib.PosixPath = temp  # revert back to Linux