In [2]:
import numpy as np
import sys
import torch
import os
from torchvision.models import resnet50, ResNet50_Weights
from transformers import ViTForImageClassification
from transformers import ViTModel
import torch.nn as nn
import timm
import pathlib

temp = pathlib.PosixPath

sys.path.append("ssl_library")
from src.pkg.embedder import Embedder
from src.pkg.wrappers import ViTWrapper, Wrapper

sys.path.append("local_python")
from local_utils import print_parameters

In [3]:
seed = 19

In [71]:
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)

# NOTE: ssl_library is not windows compatible in itself
# https://stackoverflow.com/questions/57286486/i-cant-load-my-model-because-i-cant-put-a-posixpath

pathlib.PosixPath = pathlib.WindowsPath

In [108]:
def print_checkpoint_keys(model_dir, n=5):
    checkpoint_keys = list(
        torch.load(model_dir, map_location=torch.device("cpu")).keys()
    )
    print(
        f"{len(checkpoint_keys)} keys in total. First {n} keys: {checkpoint_keys[:n]}"
    )

In [127]:
model_dir = "../model_weights/resnet50/ResNet50-PDDD_raw.pth"
if os.path.exists(model_dir):
    print(f"File {model_dir} already exists")
else:
    model_url = "https://zenodo.org/records/7890438/files/ResNet50-Plant-model-80.pth"
    torch.hub.download_url_to_file(model_url, model_dir, progress=True)

raw_path = model_dir
model_dir = "../model_weights/resnet50/ResNet50-PDDD_headless.pth"
if os.path.exists(model_dir):
    print(f"File {model_dir} already exists")
else:
    model = Embedder.load_pretrained("resnet50_random")
    num_classes_weights = 120

    model = resnet50(weights=None)

    model.fc = nn.Linear(model.fc.in_features, num_classes_weights)

    print(f"model.fc.in_features: {model.fc.in_features}")

    checkpoint = torch.load(raw_path, map_location=torch.device("cpu"))

    model.load_state_dict(checkpoint, strict=True)

    print(f"model.fc.out_features: {model.fc.out_features}")

    # NOTE: The Wrapper from the ssl_library adds a prefic to the dictionary keys (replace_ckp_str="model.")
    model = torch.nn.Sequential(*list(model.children())[:-1])
    model = Wrapper(model=model)

    print_parameters(model)  # 23'508'032

    torch.save(model.state_dict(), model_dir)

    print(f"File {model_dir} saved")
print_checkpoint_keys(model_dir)

File ../model_weights/resnet50/ResNet50-PDDD_raw.pth already exists
File ../model_weights/resnet50/ResNet50-PDDD_headless.pth already exists
318 keys in total. First 5 keys: ['model.0.weight', 'model.1.weight', 'model.1.bias', 'model.1.running_mean', 'model.1.running_var']


In [129]:
model_dir = "../model_weights/resnet50/ResNet50-Random_headless.pth"
if os.path.exists(model_dir):
    print(f"File {model_dir} already exists")
else:
    # NOTE: uses ResNet50_Weights.IMAGENET1K_V1 instead of default ResNet50_Weights.IMAGENET1K_V2
    model = Embedder.load_pretrained("resnet50_random")
    model.fc = nn.Sequential()
    print_parameters(model)  # 23'508'032
    torch.save(model.state_dict(), model_dir)
    print(f"File {model_dir} saved")
print_checkpoint_keys(model_dir)

File ../model_weights/resnet50/ResNet50-Random_headless.pth already exists
318 keys in total. First 5 keys: ['model.0.weight', 'model.1.weight', 'model.1.bias', 'model.1.running_mean', 'model.1.running_var']


In [131]:
model_dir = "../model_weights/resnet50/ResNet50-ImageNet_1k_SL_V1_headless.pth"
if os.path.exists(model_dir):
    print(f"File {model_dir} already exists")
else:
    model = Embedder.load_pretrained("imagenet")
    model.fc = nn.Sequential()
    print_parameters(model)  # 23'508'032
    torch.save(model.state_dict(), model_dir)
    print(f"File {model_dir} saved")
print_checkpoint_keys(model_dir)

File ../model_weights/resnet50/ResNet50-ImageNet_1k_SL_V1_headless.pth already exists
318 keys in total. First 5 keys: ['model.0.weight', 'model.1.weight', 'model.1.bias', 'model.1.running_mean', 'model.1.running_var']


In [142]:
model_dir = "../model_weights/resnet50/ResNet50-ImageNet_1k_SSL_SimCLR_raw.pth"
if os.path.exists(model_dir):
    print(f"File {model_dir} already exists")
else:
    model_url = "https://github.com/vm02-self-supervised-dermatology/self-supervised-models/tree/main/simclr_imagenet/resnet50_imagenet_bs2k_epochs600.pth.tar"
    torch.hub.download_url_to_file(model_url, model_dir, progress=True)

raw_path = model_dir
model_dir = "../model_weights/resnet50/ResNet50-ImageNet_1k_SSL_SimCLR_headless.pth"
if os.path.exists(model_dir):
    print(f"File {model_dir} already exists")
else:
    # NOTE: This model is not wrapped in ssl_library!
    # Therefore it get wrapped here to use common dictionary keys (replace_ckp_str="model.")
    model = Embedder.load_simclr_imagenet(raw_path)
    model = Wrapper(model=model)
    print_parameters(model)  # 23'508'032
    torch.save(model.state_dict(), model_dir)
    print(f"File {model_dir} saved")
print_checkpoint_keys(model_dir)

File ../model_weights/resnet50/ResNet50-ImageNet_1k_SSL_SimCLR_raw.pth already exists


[32m2024-08-18 19:28:00.077[0m | [34m[1mDEBUG   [0m | [36msrc.pkg.embedder[0m:[36mrestart_from_checkpoint[0m:[36m670[0m - [34m[1m=> Found `convnet.` in state_dict, trying to transform.[0m
[32m2024-08-18 19:28:00.101[0m | [34m[1mDEBUG   [0m | [36msrc.pkg.embedder[0m:[36mrestart_from_checkpoint[0m:[36m683[0m - [34m[1m=> loaded 'state_dict' from checkpoint '../model_weights/resnet50/ResNet50-ImageNet_1k_SSL_SimCLR_raw.pth' with msg _IncompatibleKeys(missing_keys=[], unexpected_keys=['projection.fc1.weight', 'projection.bn1.weight', 'projection.bn1.bias', 'projection.bn1.running_mean', 'projection.bn1.running_var', 'projection.bn1.num_batches_tracked', 'projection.fc2.weight', 'projection.bn2.weight', 'projection.bn2.bias', 'projection.bn2.running_mean', 'projection.bn2.running_var', 'projection.bn2.num_batches_tracked'])[0m


Total parameters: 23508032
Required parameters: 23508032 
File ../model_weights/resnet50/ResNet50-ImageNet_1k_SSL_SimCLR_headless.pth saved
318 keys in total. First 5 keys: ['model.0.weight', 'model.1.weight', 'model.1.bias', 'model.1.running_mean', 'model.1.running_var']


In [144]:
model_dir = "../model_weights/resnet50/ResNet50-Derma_SSL_SimCLR_raw.pth"
if os.path.exists(model_dir):
    print(f"File {model_dir} already exists")
else:
    model_url = "https://github.com/vm02-self-supervised-dermatology/self-supervised-models/raw/main/simclr/checkpoint-epoch100.pth"
    torch.hub.download_url_to_file(model_url, model_dir, progress=True)

raw_path = model_dir
model_dir = "../model_weights/resnet50/ResNet50-Derma_SSL_SimCLR_headless.pth"
if os.path.exists(model_dir):
    print(f"File {model_dir} already exists")
else:
    # NOTE: This model is not wrapped in ssl_library!
    # Therefore it get wrapped here to use common dictionary keys (replace_ckp_str="model.")
    model = Embedder.load_simclr(raw_path)
    model = Wrapper(model=model)
    print_parameters(model)  # 23'508'032
    torch.save(model.state_dict(), model_dir)
    print(f"File {model_dir} saved")
print_checkpoint_keys(model_dir)

File ../model_weights/resnet50/ResNet50-Derma_SSL_SimCLR_raw.pth already exists


[32m2024-08-18 19:29:17.347[0m | [34m[1mDEBUG   [0m | [36msrc.pkg.embedder[0m:[36mrestart_from_checkpoint[0m:[36m670[0m - [34m[1m=> Found `model.` in state_dict, trying to transform.[0m
[32m2024-08-18 19:29:17.369[0m | [34m[1mDEBUG   [0m | [36msrc.pkg.embedder[0m:[36mrestart_from_checkpoint[0m:[36m683[0m - [34m[1m=> loaded 'state_dict' from checkpoint '../model_weights/resnet50/ResNet50-Derma_SSL_SimCLR_raw.pth' with msg _IncompatibleKeys(missing_keys=[], unexpected_keys=['dense1.weight', 'dense1.bias', 'dense2.weight', 'dense2.bias'])[0m


Total parameters: 23508032
Required parameters: 23508032 
File ../model_weights/resnet50/ResNet50-Derma_SSL_SimCLR_headless.pth saved
318 keys in total. First 5 keys: ['model.0.weight', 'model.1.weight', 'model.1.bias', 'model.1.running_mean', 'model.1.running_var']


In [145]:
model = resnet50(weights=None)
model = torch.nn.Sequential(*list(model.children())[:-1])
model = Wrapper(model=model)

checkpoint = torch.load(model_dir, map_location=torch.device("cpu"))
model.load_state_dict(checkpoint, strict=True)

<All keys matched successfully>

In [135]:
model_dir = "../model_weights/vit_t16_v1/ViT_T16-ImageNet_1k_SL_WinKawaks_headless.pth"
if os.path.exists(model_dir):
    print(f"File {model_dir} already exists")
else:
    model = Embedder.load_pretrained("imagenet_vit_tiny")
    model.head = nn.Sequential()
    print_parameters(model)  # 5'561'472
    torch.save(model.state_dict(), model_dir)
    print(f"File {model_dir} saved")
print_checkpoint_keys(model_dir)

File ../model_weights/vit_t16_v1/ViT_T16-ImageNet_1k_SL_WinKawaks_headless.pth already exists
200 keys in total. First 5 keys: ['model.embeddings.cls_token', 'model.embeddings.position_embeddings', 'model.embeddings.patch_embeddings.projection.weight', 'model.embeddings.patch_embeddings.projection.bias', 'model.encoder.layer.0.attention.attention.query.weight']


In [114]:
model_dir = "../model_weights/vit_t16_v2/ViT_T16-ImageNet_1k_SSL_Dino_raw.pth"
if os.path.exists(model_dir):
    print(f"File {model_dir} already exists")
else:
    model_url = "https://github.com/vm02-self-supervised-dermatology/self-supervised-models/raw/main/imagenet_dino/checkpoint-epoch100.pth"
    torch.hub.download_url_to_file(model_url, model_dir, progress=True)

raw_path = model_dir
model_dir = "../model_weights/vit_t16_v2/ViT_T16-ImageNet_1k_SSL_Dino_headless.pth"
if os.path.exists(model_dir):
    print(f"File {model_dir} already exists")
else:
    model = Embedder.load_dino(raw_path)
    model.head = nn.Sequential()
    print_parameters(model)  # 5'524'416
    torch.save(model.state_dict(), model_dir)
    print(f"File {model_dir} saved")
print_checkpoint_keys(model_dir)

File ../model_weights/vit_t16_v2/ViT_T16-ImageNet_1k_SSL_Dino_raw.pth already exists
File ../model_weights/vit_t16_v2/ViT_T16-ImageNet_1k_SSL_Dino_headless.pth already exists
150 keys in total. First 5 keys: ['model.cls_token', 'model.pos_embed', 'model.patch_embed.proj.weight', 'model.patch_embed.proj.bias', 'model.blocks.0.norm1.weight']


In [5]:
model_dir = "../model_weights/vit_t16_v3/ViT_T16-ImageNet_AugReg_headless.pth"
if os.path.exists(model_dir):
    print(f"File {model_dir} already exists")
else:
    model = timm.create_model("vit_tiny_patch16_224", pretrained=True)
    model = ViTWrapper(model)
    model.head = nn.Sequential()
    print_parameters(model)  # 5'524'416 -> 5'717'416
    torch.save(model.state_dict(), model_dir)
    print(f"File {model_dir} saved")

Total parameters: 5717416
Required parameters: 5717416 
File ../model_weights/vit_t16_v3/ViT_T16-ImageNet_AugReg_headless.pth saved


In [115]:
raw_path = "../model_weights/vit_t16_v2/model_best 1.pth"
model_dir = "../model_weights/vit_t16_v2/ViT_T16-Plant_SSL_Dino_headless.pth"
if os.path.exists(model_dir):
    print(f"File {model_dir} already exists")
else:
    model = Embedder.load_dino(raw_path)
    model.head = nn.Sequential()
    print_parameters(model)  # 5'524'416
    torch.save(model.state_dict(), model_dir)
    print(f"File {model_dir} saved")
print_checkpoint_keys(model_dir)

File ../model_weights/vit_t16_v2/ViT_T16-Plant_SSL_Dino_headless.pth already exists
150 keys in total. First 5 keys: ['model.cls_token', 'model.pos_embed', 'model.patch_embed.proj.weight', 'model.patch_embed.proj.bias', 'model.blocks.0.norm1.weight']


In [116]:
model_dir = "../model_weights/vit_t16_v2/ViT_T16-Random_headless.pth"
if os.path.exists(model_dir):
    print(f"File {model_dir} already exists")
else:
    model = Embedder.load_pretrained("vit_tiny_random")
    model.head = nn.Sequential()
    print_parameters(model)  # 5'524'416
    torch.save(model.state_dict(), model_dir)
    print(f"File {model_dir} saved")
print_checkpoint_keys(model_dir)

File ../model_weights/vit_t16_v2/ViT_T16-Random_headless.pth already exists
150 keys in total. First 5 keys: ['model.cls_token', 'model.pos_embed', 'model.patch_embed.proj.weight', 'model.patch_embed.proj.bias', 'model.blocks.0.norm1.weight']


In [117]:
model_dir = "../model_weights/vit_t16_v2/ViT_T16-Derma_SSL_Dino_raw.pth"
if os.path.exists(model_dir):
    print(f"File {model_dir} already exists")
else:
    model_url = "https://github.com/vm02-self-supervised-dermatology/self-supervised-models/raw/main/dino/checkpoint-epoch100.pth"
    torch.hub.download_url_to_file(model_url, model_dir, progress=True)


raw_path = model_dir
model_dir = "../model_weights/vit_t16_v2/ViT_T16-Derma_SSL_Dino_headless.pth"
if os.path.exists(model_dir):
    print(f"File {model_dir} already exists")
else:
    model = Embedder.load_dino(raw_path)
    model.head = nn.Sequential()
    print_parameters(model)  # 5'524'416
    torch.save(model.state_dict(), model_dir)
    print(f"File {model_dir} saved")
print_checkpoint_keys(model_dir)

File ../model_weights/vit_t16_v2/ViT_T16-Derma_SSL_Dino_raw.pth already exists
File ../model_weights/vit_t16_v2/ViT_T16-Derma_SSL_Dino_headless.pth already exists
150 keys in total. First 5 keys: ['model.cls_token', 'model.pos_embed', 'model.patch_embed.proj.weight', 'model.patch_embed.proj.bias', 'model.blocks.0.norm1.weight']


In [17]:
pathlib.PosixPath = temp  # revert back to Linux